main #3

Merged
tdpeuter merged 116 commits from open-webui/open-webui:main into main 2024-03-26 12:22:23 +01:00
2 changed files with 65 additions and 2 deletions
Showing only changes of commit 93c90dc186 - Show all commits

View file

@ -1,11 +1,23 @@
from litellm.proxy.proxy_server import ProxyConfig, initialize from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
from utils.utils import get_http_authorization_cred, get_current_user from utils.utils import get_http_authorization_cred, get_current_user
from config import ENV from config import ENV
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
proxy_config = ProxyConfig() proxy_config = ProxyConfig()
@ -26,16 +38,67 @@ async def on_startup():
await startup() await startup()
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http") @app.middleware("http")
async def auth_middleware(request: Request, call_next): async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
request.state.user = None
if ENV != "dev": if ENV != "dev":
try: try:
user = get_current_user(get_http_authorization_cred(auth_header)) user = get_current_user(get_http_authorization_cred(auth_header))
print(user) print(user)
request.state.user = user
except Exception as e: except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)}) return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request) response = await call_next(request)
return response return response
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
# Check if the request is for the `/models` route
if "/models" in request.url.path:
# Ensure the response is a StreamingResponse
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
# Modify the content as needed
data = json.loads(body.decode("utf-8"))
print(data)
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Example modification: Add a new key-value pair
data["modified"] = True
# Return a new JSON response with the modified content
return JSONResponse(content=data)
return response
# Add the middleware to the app
app.add_middleware(ModifyModelsResponseMiddleware)

View file

@ -298,7 +298,7 @@ USER_PERMISSIONS_CHAT_DELETION = (
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False) MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]