From b5bd07a06a7f62e548c5dd18d23ba3c3eb1f19c2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 23 Feb 2024 22:44:56 -0800 Subject: [PATCH] feat: secure litellm api --- backend/main.py | 21 +++++++++++++++++++-- backend/utils/utils.py | 11 +++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/backend/main.py b/backend/main.py index b9370a18..94145a97 100644 --- a/backend/main.py +++ b/backend/main.py @@ -4,9 +4,10 @@ import markdown import time -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, Depends from fastapi.staticfiles import StaticFiles from fastapi import HTTPException +from fastapi.responses import JSONResponse from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException @@ -19,10 +20,11 @@ from apps.openai.main import app as openai_app from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app - from apps.web.main import app as webui_app + from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR +from utils.utils import get_http_authorization_cred, get_current_user class SPAStaticFiles(StaticFiles): @@ -59,6 +61,21 @@ async def check_url(request: Request, call_next): return response +@litellm_app.middleware("http") +async def auth_middleware(request: Request, call_next): + auth_header = request.headers.get("Authorization", "") + + if ENV != "dev": + try: + user = get_current_user(get_http_authorization_cred(auth_header)) + print(user) + except Exception as e: + return JSONResponse(status_code=400, content={"detail": str(e)}) + + response = await call_next(request) + return response + + app.mount("/api/v1", webui_app) app.mount("/litellm/api", litellm_app) diff --git a/backend/utils/utils.py b/backend/utils/utils.py index c6d01814..7ae7e694 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -58,6 +58,17 @@ def extract_token_from_auth_header(auth_header: str): return auth_header[len("Bearer ") :] +def get_http_authorization_cred(auth_header: str): + try: + scheme, credentials = auth_header.split(" ") + return { + "scheme": scheme, + "credentials": credentials, + } + except: + raise ValueError(ERROR_MESSAGES.INVALID_TOKEN) + + def get_current_user( auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), ):