Merge Updates & Dockerfile improvements

This commit is contained in:
lainedfles 2024-04-02 03:25:20 -06:00 committed by GitHub
parent fdef2abdfb
commit 9763d885be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
155 changed files with 14509 additions and 4803 deletions

View file

@ -1,10 +1,27 @@
import logging
from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
from utils.utils import get_http_authorization_cred, get_current_user
from config import ENV
from config import SRC_LOG_LEVELS, ENV
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
proxy_config = ProxyConfig()
@ -26,16 +43,58 @@ async def on_startup():
await startup()
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
request.state.user = None
if ENV != "dev":
try:
user = get_current_user(get_http_authorization_cred(auth_header))
print(user)
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
try:
user = get_current_user(get_http_authorization_cred(auth_header))
log.debug(f"user: {user}")
request.state.user = user
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request)
return response
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
if "/models" in request.url.path:
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
data = json.loads(body.decode("utf-8"))
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Modified Flag
data["modified"] = True
return JSONResponse(content=data)
return response
app.add_middleware(ModifyModelsResponseMiddleware)