open-webui/backend/apps/litellm/main.py
2024-04-21 00:52:27 -05:00

143 lines
3.6 KiB
Python

from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
import logging
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
from utils.utils import get_http_authorization_cred, get_current_user
from config import SRC_LOG_LEVELS, ENV
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
import asyncio
import subprocess
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
async def run_background_process(command):
process = await asyncio.create_subprocess_exec(
*command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
return process
async def start_litellm_background():
# Command to run in the background
command = "litellm --config ./data/litellm/config.yaml"
await run_background_process(command)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(start_litellm_background())
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
request.state.user = None
try:
user = get_current_user(get_http_authorization_cred(auth_header))
log.debug(f"user: {user}")
request.state.user = user
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request)
return response
@app.get("/")
async def get_status():
return {"status": True}
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
if "/models" in request.url.path:
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
data = json.loads(body.decode("utf-8"))
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Modified Flag
data["modified"] = True
return JSONResponse(content=data)
return response
app.add_middleware(ModifyModelsResponseMiddleware)
# from litellm.proxy.proxy_server import ProxyConfig, initialize
# from litellm.proxy.proxy_server import app
# proxy_config = ProxyConfig()
# async def config():
# router, model_list, general_settings = await proxy_config.load_config(
# router=None, config_file_path="./data/litellm/config.yaml"
# )
# await initialize(config="./data/litellm/config.yaml", telemetry=False)
# async def startup():
# await config()
# @app.on_event("startup")
# async def on_startup():
# await startup()