feat: litellm model filter support

This commit is contained in:
Timothy J. Baek 2024-03-20 19:28:33 -07:00
parent 8cb7127fbc
commit 93c90dc186
2 changed files with 65 additions and 2 deletions

View file

@ -1,11 +1,23 @@
from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
from utils.utils import get_http_authorization_cred, get_current_user
from config import ENV
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
proxy_config = ProxyConfig()
@ -26,16 +38,67 @@ async def on_startup():
await startup()
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
request.state.user = None
if ENV != "dev":
try:
user = get_current_user(get_http_authorization_cred(auth_header))
print(user)
request.state.user = user
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request)
return response
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
# Check if the request is for the `/models` route
if "/models" in request.url.path:
# Ensure the response is a StreamingResponse
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
# Modify the content as needed
data = json.loads(body.decode("utf-8"))
print(data)
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Example modification: Add a new key-value pair
data["modified"] = True
# Return a new JSON response with the modified content
return JSONResponse(content=data)
return response
# Add the middleware to the app
app.add_middleware(ModifyModelsResponseMiddleware)

View file

@ -298,7 +298,7 @@ USER_PERMISSIONS_CHAT_DELETION = (
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]