Endpoint role-checking was redundantly applied but FastAPI provides a nice abstraction mechanic...so I applied it. There should be no logical changes in this code; only simpler, cleaner ways for doing the same thing.

This commit is contained in:
Tim Farrell 2024-02-08 18:05:01 -06:00
parent 46d0eff218
commit 08e8e922fd
11 changed files with 127 additions and 251 deletions

View file

@ -1,4 +1,4 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
@ -10,7 +10,7 @@ from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = FastAPI()
@ -31,11 +31,8 @@ REQUEST_POOL = []
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
async def get_ollama_api_url(user=Depends(get_admin_user)):
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
class UrlUpdateForm(BaseModel):
@ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/url/update")
async def update_ollama_api_url(
form_data: UrlUpdateForm, user=Depends(get_current_user)
form_data: UrlUpdateForm, user=Depends(get_admin_user)
):
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
@app.get("/cancel/{request_id}")
@ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("host", None)
headers.pop("authorization", None)