forked from open-webui/open-webui
feat: secure litellm api
This commit is contained in:
parent
af388dfe62
commit
b5bd07a06a
2 changed files with 30 additions and 2 deletions
|
@ -4,9 +4,10 @@ import markdown
|
|||
import time
|
||||
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, Depends
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.wsgi import WSGIMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
@ -19,10 +20,11 @@ from apps.openai.main import app as openai_app
|
|||
from apps.audio.main import app as audio_app
|
||||
from apps.images.main import app as images_app
|
||||
from apps.rag.main import app as rag_app
|
||||
|
||||
from apps.web.main import app as webui_app
|
||||
|
||||
|
||||
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
|
||||
from utils.utils import get_http_authorization_cred, get_current_user
|
||||
|
||||
|
||||
class SPAStaticFiles(StaticFiles):
|
||||
|
@ -59,6 +61,21 @@ async def check_url(request: Request, call_next):
|
|||
return response
|
||||
|
||||
|
||||
@litellm_app.middleware("http")
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
||||
if ENV != "dev":
|
||||
try:
|
||||
user = get_current_user(get_http_authorization_cred(auth_header))
|
||||
print(user)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"detail": str(e)})
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
app.mount("/api/v1", webui_app)
|
||||
app.mount("/litellm/api", litellm_app)
|
||||
|
||||
|
|
|
@ -58,6 +58,17 @@ def extract_token_from_auth_header(auth_header: str):
|
|||
return auth_header[len("Bearer ") :]
|
||||
|
||||
|
||||
def get_http_authorization_cred(auth_header: str):
|
||||
try:
|
||||
scheme, credentials = auth_header.split(" ")
|
||||
return {
|
||||
"scheme": scheme,
|
||||
"credentials": credentials,
|
||||
}
|
||||
except:
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
|
||||
|
||||
|
||||
def get_current_user(
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
|
|
Loading…
Reference in a new issue