From 08e8e922fd71a0650dfd2c05d1dcf883f1be0357 Mon Sep 17 00:00:00 2001 From: Tim Farrell Date: Thu, 8 Feb 2024 18:05:01 -0600 Subject: [PATCH] Endpoint role-checking was redundantly applied but FastAPI provides a nice abstraction mechanic...so I applied it. There should be no logical changes in this code; only simpler, cleaner ways for doing the same thing. --- backend/apps/ollama/main.py | 24 +++----- backend/apps/openai/main.py | 43 +++++--------- backend/apps/rag/main.py | 48 ++++++--------- backend/apps/web/routers/auths.py | 30 +++------- backend/apps/web/routers/chats.py | 18 ++---- backend/apps/web/routers/configs.py | 29 +++------ backend/apps/web/routers/documents.py | 26 ++------- backend/apps/web/routers/modelfiles.py | 25 ++------ backend/apps/web/routers/prompts.py | 38 +++--------- backend/apps/web/routers/users.py | 81 +++++++++----------------- backend/utils/utils.py | 16 +++++ 11 files changed, 127 insertions(+), 251 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index a94dc37f..5a1d0891 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, Request, Response, HTTPException, Depends +from fastapi import FastAPI, Request, Response, HTTPException, Depends, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool @@ -10,7 +10,7 @@ from pydantic import BaseModel from apps.web.models.users import Users from constants import ERROR_MESSAGES -from utils.utils import decode_token, get_current_user +from utils.utils import decode_token, get_current_user, get_admin_user from config import OLLAMA_API_BASE_URL, WEBUI_AUTH app = FastAPI() @@ -31,11 +31,8 @@ REQUEST_POOL = [] @app.get("/url") -async def get_ollama_api_url(user=Depends(get_current_user)): - if user and user.role == "admin": - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +async def get_ollama_api_url(user=Depends(get_admin_user)): + return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} class UrlUpdateForm(BaseModel): @@ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel): @app.post("/url/update") async def update_ollama_api_url( - form_data: UrlUpdateForm, user=Depends(get_current_user) + form_data: UrlUpdateForm, user=Depends(get_admin_user) ): - if user and user.role == "admin": - app.state.OLLAMA_API_BASE_URL = form_data.url - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + app.state.OLLAMA_API_BASE_URL = form_data.url + return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} @app.get("/cancel/{request_id}") @@ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): if path in ["pull", "delete", "push", "copy", "create"]: if user.role != "admin": raise HTTPException( - status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) headers.pop("host", None) headers.pop("authorization", None) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index bed9181a..2e2d377f 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -9,7 +9,7 @@ from pydantic import BaseModel from apps.web.models.users import Users from constants import ERROR_MESSAGES -from utils.utils import decode_token, get_current_user +from utils.utils import decode_token, get_current_user, get_verified_user, get_admin_user from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR import hashlib @@ -37,45 +37,32 @@ class KeyUpdateForm(BaseModel): @app.get("/url") -async def get_openai_url(user=Depends(get_current_user)): - if user and user.role == "admin": - return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +async def get_openai_url(user=Depends(get_admin_user)): + return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} @app.post("/url/update") -async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): - if user and user.role == "admin": - app.state.OPENAI_API_BASE_URL = form_data.url - return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): + app.state.OPENAI_API_BASE_URL = form_data.url + return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} + @app.get("/key") -async def get_openai_key(user=Depends(get_current_user)): - if user and user.role == "admin": - return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +async def get_openai_key(user=Depends(get_admin_user)): + return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} @app.post("/key/update") -async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): - if user and user.role == "admin": - app.state.OPENAI_API_KEY = form_data.key - return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)): + app.state.OPENAI_API_KEY = form_data.key + return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} @app.post("/audio/speech") -async def speech(request: Request, user=Depends(get_current_user)): +async def speech(request: Request, user=Depends(get_verified_user)): target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" - if user.role not in ["user", "admin"]: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) if app.state.OPENAI_API_KEY == "": raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) @@ -133,12 +120,10 @@ async def speech(request: Request, user=Depends(get_current_user)): @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_current_user)): +async def proxy(path: str, request: Request, user=Depends(get_verified_user)): target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" print(target_url, app.state.OPENAI_API_KEY) - if user.role not in ["user", "admin"]: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) if app.state.OPENAI_API_KEY == "": raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 95535274..07a30ade 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -39,7 +39,7 @@ import uuid import time from utils.misc import calculate_sha256, calculate_sha256_string -from utils.utils import get_current_user +from utils.utils import get_current_user, get_admin_user from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from constants import ERROR_MESSAGES @@ -354,38 +354,26 @@ def store_doc( @app.get("/reset/db") -def reset_vector_db(user=Depends(get_current_user)): - if user.role == "admin": - CHROMA_CLIENT.reset() - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) +def reset_vector_db(user=Depends(get_admin_user)): + CHROMA_CLIENT.reset() @app.get("/reset") -def reset(user=Depends(get_current_user)) -> bool: - if user.role == "admin": - folder = f"{UPLOAD_DIR}" - for filename in os.listdir(folder): - file_path = os.path.join(folder, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print("Failed to delete %s. Reason: %s" % (file_path, e)) - +def reset(user=Depends(get_admin_user)) -> bool: + folder = f"{UPLOAD_DIR}" + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) try: - CHROMA_CLIENT.reset() + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) except Exception as e: - print(e) + print("Failed to delete %s. Reason: %s" % (file_path, e)) - return True - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) + try: + CHROMA_CLIENT.reset() + except Exception as e: + print(e) + + return True diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index d06539f8..58da3512 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status from datetime import datetime, timedelta from typing import List, Union -from fastapi import APIRouter +from fastapi import APIRouter, status from pydantic import BaseModel import time import uuid @@ -19,7 +19,7 @@ from apps.web.models.auths import ( ) from apps.web.models.users import Users -from utils.utils import get_password_hash, get_current_user, create_token +from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token from utils.misc import get_gravatar_url, validate_email_format from constants import ERROR_MESSAGES @@ -116,10 +116,10 @@ async def signin(form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) async def signup(request: Request, form_data: SignupForm): if not request.app.state.ENABLE_SIGNUP: - raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) if not validate_email_format(form_data.email.lower()): - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) if Users.get_user_by_email(form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) @@ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm): @router.get("/signup/enabled", response_model=bool) -async def get_sign_up_status(request: Request, user=Depends(get_current_user)): - if user.role == "admin": - return request.app.state.ENABLE_SIGNUP - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) +async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): + return request.app.state.ENABLE_SIGNUP @router.get("/signup/enabled/toggle", response_model=bool) -async def toggle_sign_up(request: Request, user=Depends(get_current_user)): - if user.role == "admin": - request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP - return request.app.state.ENABLE_SIGNUP - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) +async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): + request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP + return request.app.state.ENABLE_SIGNUP diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index aa725409..1150234a 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -1,7 +1,7 @@ from fastapi import Depends, Request, HTTPException, status from datetime import datetime, timedelta from typing import List, Union, Optional -from utils.utils import get_current_user +from utils.utils import get_current_user, get_admin_user from fastapi import APIRouter from pydantic import BaseModel import json @@ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)): @router.get("/all/db", response_model=List[ChatResponse]) -async def get_all_user_chats_in_db(user=Depends(get_current_user)): - if user.role == "admin": - return [ - ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_all_chats() - ] - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) +async def get_all_user_chats_in_db(user=Depends(get_admin_user)): + return [ + ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + for chat in Chats.get_all_chats() + ] ############################ diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py index 376686e0..b293a398 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/web/routers/configs.py @@ -10,7 +10,7 @@ import uuid from apps.web.models.users import Users -from utils.utils import get_password_hash, get_current_user, create_token +from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token from utils.misc import get_gravatar_url, validate_email_format from constants import ERROR_MESSAGES @@ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel): @router.post("/default/models", response_model=str) async def set_global_default_models( - request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) + request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) ): - if user.role == "admin": - request.app.state.DEFAULT_MODELS = form_data.models - return request.app.state.DEFAULT_MODELS - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) + request.app.state.DEFAULT_MODELS = form_data.models + return request.app.state.DEFAULT_MODELS + @router.post("/default/suggestions", response_model=List[PromptSuggestion]) async def set_global_default_suggestions( request: Request, form_data: SetDefaultSuggestionsForm, - user=Depends(get_current_user), + user=Depends(get_admin_user), ): - if user.role == "admin": - data = form_data.model_dump() - request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] - return request.app.state.DEFAULT_PROMPT_SUGGESTIONS - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) + data = form_data.model_dump() + request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] + return request.app.state.DEFAULT_PROMPT_SUGGESTIONS diff --git a/backend/apps/web/routers/documents.py b/backend/apps/web/routers/documents.py index 3b6434d1..5bc473fa 100644 --- a/backend/apps/web/routers/documents.py +++ b/backend/apps/web/routers/documents.py @@ -14,7 +14,7 @@ from apps.web.models.documents import ( DocumentResponse, ) -from utils.utils import get_current_user +from utils.utils import get_current_user, get_admin_user from constants import ERROR_MESSAGES router = APIRouter() @@ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)): @router.post("/create", response_model=Optional[DocumentResponse]) -async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - +async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): doc = Documents.get_doc_by_name(form_data.name) if doc == None: doc = Documents.insert_new_doc(user.id, form_data) @@ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u @router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) async def update_doc_by_name( - name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user) + name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user) ): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - doc = Documents.update_doc_by_name(name, form_data) if doc: return DocumentResponse( @@ -161,12 +149,6 @@ async def update_doc_by_name( @router.delete("/name/{name}/delete", response_model=bool) -async def delete_doc_by_name(name: str, user=Depends(get_current_user)): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - +async def delete_doc_by_name(name: str, user=Depends(get_admin_user)): result = Documents.delete_doc_by_name(name) return result diff --git a/backend/apps/web/routers/modelfiles.py b/backend/apps/web/routers/modelfiles.py index 0af9ca0f..0c5c1216 100644 --- a/backend/apps/web/routers/modelfiles.py +++ b/backend/apps/web/routers/modelfiles.py @@ -13,7 +13,7 @@ from apps.web.models.modelfiles import ( ModelfileResponse, ) -from utils.utils import get_current_user +from utils.utils import get_current_user, get_admin_user from constants import ERROR_MESSAGES router = APIRouter() @@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0, @router.post("/create", response_model=Optional[ModelfileResponse]) async def create_new_modelfile(form_data: ModelfileForm, - user=Depends(get_current_user)): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - + user=Depends(get_admin_user)): modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) if modelfile: @@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, @router.post("/update", response_model=Optional[ModelfileResponse]) async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, - user=Depends(get_current_user)): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) + user=Depends(get_admin_user)): modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) if modelfile: updated_modelfile = { @@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, @router.delete("/delete", response_model=bool) async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm, - user=Depends(get_current_user)): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - + user=Depends(get_admin_user)): result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) return result diff --git a/backend/apps/web/routers/prompts.py b/backend/apps/web/routers/prompts.py index 23825dbb..db761967 100644 --- a/backend/apps/web/routers/prompts.py +++ b/backend/apps/web/routers/prompts.py @@ -8,7 +8,7 @@ import json from apps.web.models.prompts import Prompts, PromptForm, PromptModel -from utils.utils import get_current_user +from utils.utils import get_current_user, get_admin_user from constants import ERROR_MESSAGES router = APIRouter() @@ -29,29 +29,21 @@ async def get_prompts(user=Depends(get_current_user)): @router.post("/create", response_model=Optional[PromptModel]) -async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - +async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): prompt = Prompts.get_prompt_by_command(form_data.command) if prompt == None: prompt = Prompts.insert_new_prompt(user.id, form_data) if prompt: return prompt - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.COMMAND_TAKEN, + detail=ERROR_MESSAGES.DEFAULT(), ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.COMMAND_TAKEN, + ) ############################ @@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)): @router.post("/command/{command}/update", response_model=Optional[PromptModel]) async def update_prompt_by_command( - command: str, form_data: PromptForm, user=Depends(get_current_user) + command: str, form_data: PromptForm, user=Depends(get_admin_user) ): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) if prompt: return prompt @@ -103,12 +89,6 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command(command: str, user=Depends(get_current_user)): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - +async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): result = Prompts.delete_prompt_by_command(f"/{command}") return result diff --git a/backend/apps/web/routers/users.py b/backend/apps/web/routers/users.py index 32d1e67f..ce4cac77 100644 --- a/backend/apps/web/routers/users.py +++ b/backend/apps/web/routers/users.py @@ -11,7 +11,7 @@ import uuid from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users from apps.web.models.auths import Auths -from utils.utils import get_current_user, get_password_hash +from utils.utils import get_current_user, get_password_hash, get_admin_user from constants import ERROR_MESSAGES router = APIRouter() @@ -22,12 +22,7 @@ router = APIRouter() @router.get("/", response_model=List[UserModel]) -async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) +async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)): return Users.get_users(skip, limit) @@ -38,21 +33,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use @router.post("/update/role", response_model=Optional[UserModel]) async def update_user_role( - form_data: UserRoleUpdateForm, user=Depends(get_current_user) + form_data: UserRoleUpdateForm, user=Depends(get_admin_user) ): - if user.role != "admin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - if user.id != form_data.id: return Users.update_user_role_by_id(form_data.id, form_data.role) - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACTION_PROHIBITED, - ) + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) ############################ @@ -62,14 +51,8 @@ async def update_user_role( @router.post("/{user_id}/update", response_model=Optional[UserModel]) async def update_user_by_id( - user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user) + user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user) ): - if session_user.role != "admin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - user = Users.get_user_by_id(user_id) if user: @@ -98,18 +81,17 @@ async def update_user_by_id( if updated_user: return updated_user - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(), - ) - else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.USER_NOT_FOUND, + detail=ERROR_MESSAGES.DEFAULT(), ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + ############################ # DeleteUserById @@ -117,25 +99,20 @@ async def update_user_by_id( @router.delete("/{user_id}", response_model=bool) -async def delete_user_by_id(user_id: str, user=Depends(get_current_user)): - if user.role == "admin": - if user.id != user_id: - result = Auths.delete_auth_by_id(user_id) +async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): + if user.id != user_id: + result = Auths.delete_auth_by_id(user_id) + + if result: + return True - if result: - return True - else: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DELETE_USER_ERROR, - ) - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACTION_PROHIBITED, - ) - else: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DELETE_USER_ERROR, ) + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + diff --git a/backend/utils/utils.py b/backend/utils/utils.py index 2795a613..97b4afb2 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -73,3 +73,19 @@ def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_s status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) + + +def get_verified_user(user: Users = Depends(get_current_user)): + if user.role not in {"user", "admin"}: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +def get_admin_user(user: Users = Depends(get_current_user)): + if user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + )