Merge pull request #682 from explorigin/simplify-endpoint-code

Simplify endpoint role checking
This commit is contained in:
Timothy Jaeryang Baek 2024-02-09 14:26:56 -08:00 committed by GitHub
commit 9f3346a6ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 127 additions and 251 deletions

View file

@ -1,4 +1,4 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
@ -10,7 +10,7 @@ from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = FastAPI()
@ -31,11 +31,8 @@ REQUEST_POOL = []
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin":
async def get_ollama_api_url(user=Depends(get_admin_user)):
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel):
@ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/url/update")
async def update_ollama_api_url(
form_data: UrlUpdateForm, user=Depends(get_current_user)
form_data: UrlUpdateForm, user=Depends(get_admin_user)
):
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/cancel/{request_id}")
@ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("host", None)
headers.pop("authorization", None)

View file

@ -9,7 +9,7 @@ from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from utils.utils import decode_token, get_current_user, get_verified_user, get_admin_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
import hashlib
@ -37,45 +37,32 @@ class KeyUpdateForm(BaseModel):
@app.get("/url")
async def get_openai_url(user=Depends(get_current_user)):
if user and user.role == "admin":
async def get_openai_url(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
if user and user.role == "admin":
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/key")
async def get_openai_key(user=Depends(get_current_user)):
if user and user.role == "admin":
async def get_openai_key(user=Depends(get_admin_user)):
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
if user and user.role == "admin":
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_current_user)):
async def speech(request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
@ -133,12 +120,10 @@ async def speech(request: Request, user=Depends(get_current_user)):
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY)
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

View file

@ -39,7 +39,7 @@ import uuid
import time
from utils.misc import calculate_sha256, calculate_sha256_string
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES
@ -354,19 +354,12 @@ def store_doc(
@app.get("/reset/db")
def reset_vector_db(user=Depends(get_current_user)):
if user.role == "admin":
def reset_vector_db(user=Depends(get_admin_user)):
CHROMA_CLIENT.reset()
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@app.get("/reset")
def reset(user=Depends(get_current_user)) -> bool:
if user.role == "admin":
def reset(user=Depends(get_admin_user)) -> bool:
folder = f"{UPLOAD_DIR}"
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
@ -384,8 +377,3 @@ def reset(user=Depends(get_current_user)) -> bool:
print(e)
return True
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)

View file

@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import APIRouter
from fastapi import APIRouter, status
from pydantic import BaseModel
import time
import uuid
@ -19,7 +19,7 @@ from apps.web.models.auths import (
)
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -116,10 +116,10 @@ async def signin(form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm):
if not request.app.state.ENABLE_SIGNUP:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if not validate_email_format(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm):
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_current_user)):
if user.role == "admin":
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_current_user)):
if user.role == "admin":
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)

View file

@ -1,7 +1,7 @@
from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter
from pydantic import BaseModel
import json
@ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)):
@router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_current_user)):
if user.role == "admin":
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats()
]
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################

View file

@ -10,7 +10,7 @@ import uuid
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel):
@router.post("/default/models", response_model=str)
async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
):
if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions(
request: Request,
form_data: SetDefaultSuggestionsForm,
user=Depends(get_current_user),
user=Depends(get_admin_user),
):
if user.role == "admin":
data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)

View file

@ -14,7 +14,7 @@ from apps.web.models.documents import (
DocumentResponse,
)
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
doc = Documents.get_doc_by_name(form_data.name)
if doc == None:
doc = Documents.insert_new_doc(user.id, form_data)
@ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user)
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
doc = Documents.update_doc_by_name(name, form_data)
if doc:
return DocumentResponse(
@ -161,12 +149,6 @@ async def update_doc_by_name(
@router.delete("/name/{name}/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result

View file

@ -13,7 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse,
)
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0,
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View file

@ -8,7 +8,7 @@ import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -29,25 +29,17 @@ async def get_prompts(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.COMMAND_TAKEN,
@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_current_user)
command: str, form_data: PromptForm, user=Depends(get_admin_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
return prompt
@ -103,12 +89,6 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
result = Prompts.delete_prompt_by_command(f"/{command}")
return result

View file

@ -11,7 +11,7 @@ import uuid
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash
from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -22,12 +22,7 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
return Users.get_users(skip, limit)
@ -38,17 +33,11 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_current_user)
form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
@ -62,14 +51,8 @@ async def update_user_role(
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user)
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
):
if session_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user = Users.get_user_by_id(user_id)
if user:
@ -98,13 +81,12 @@ async def update_user_by_id(
if updated_user:
return updated_user
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
@ -117,25 +99,20 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
if user.role == "admin":
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
if result:
return True
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)

View file

@ -73,3 +73,19 @@ def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_s
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
def get_verified_user(user: Users = Depends(get_current_user)):
if user.role not in {"user", "admin"}:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
def get_admin_user(user: Users = Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)