from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware import logging from fastapi import FastAPI, Request, Depends, status, Response from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import StreamingResponse import json from utils.utils import get_http_authorization_cred, get_current_user from config import SRC_LOG_LEVELS, ENV log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["LITELLM"]) from config import ( MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, ) import asyncio import subprocess app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) async def run_background_process(command): process = await asyncio.create_subprocess_exec( *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE ) return process async def start_litellm_background(): # Command to run in the background command = "litellm --config ./data/litellm/config.yaml" await run_background_process(command) @app.on_event("startup") async def startup_event(): asyncio.create_task(start_litellm_background()) app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST @app.middleware("http") async def auth_middleware(request: Request, call_next): auth_header = request.headers.get("Authorization", "") request.state.user = None try: user = get_current_user(get_http_authorization_cred(auth_header)) log.debug(f"user: {user}") request.state.user = user except Exception as e: return JSONResponse(status_code=400, content={"detail": str(e)}) response = await call_next(request) return response @app.get("/") async def get_status(): return {"status": True} class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: response = await call_next(request) user = request.state.user if "/models" in request.url.path: if isinstance(response, StreamingResponse): # Read the content of the streaming response body = b"" async for chunk in response.body_iterator: body += chunk data = json.loads(body.decode("utf-8")) if app.state.MODEL_FILTER_ENABLED: if user and user.role == "user": data["data"] = list( filter( lambda model: model["id"] in app.state.MODEL_FILTER_LIST, data["data"], ) ) # Modified Flag data["modified"] = True return JSONResponse(content=data) return response app.add_middleware(ModifyModelsResponseMiddleware) # from litellm.proxy.proxy_server import ProxyConfig, initialize # from litellm.proxy.proxy_server import app # proxy_config = ProxyConfig() # async def config(): # router, model_list, general_settings = await proxy_config.load_config( # router=None, config_file_path="./data/litellm/config.yaml" # ) # await initialize(config="./data/litellm/config.yaml", telemetry=False) # async def startup(): # await config() # @app.on_event("startup") # async def on_startup(): # await startup()