forked from open-webui/open-webui
		
	Endpoint role-checking was redundantly applied but FastAPI provides a nice abstraction mechanic...so I applied it. There should be no logical changes in this code; only simpler, cleaner ways for doing the same thing.
This commit is contained in:
		
							parent
							
								
									46d0eff218
								
							
						
					
					
						commit
						08e8e922fd
					
				
					 11 changed files with 127 additions and 251 deletions
				
			
		|  | @ -1,4 +1,4 @@ | |||
| from fastapi import FastAPI, Request, Response, HTTPException, Depends | ||||
| from fastapi import FastAPI, Request, Response, HTTPException, Depends, status | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from fastapi.responses import StreamingResponse | ||||
| from fastapi.concurrency import run_in_threadpool | ||||
|  | @ -10,7 +10,7 @@ from pydantic import BaseModel | |||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| from constants import ERROR_MESSAGES | ||||
| from utils.utils import decode_token, get_current_user | ||||
| from utils.utils import decode_token, get_current_user, get_admin_user | ||||
| from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | ||||
| 
 | ||||
| app = FastAPI() | ||||
|  | @ -31,11 +31,8 @@ REQUEST_POOL = [] | |||
| 
 | ||||
| 
 | ||||
| @app.get("/url") | ||||
| async def get_ollama_api_url(user=Depends(get_current_user)): | ||||
|     if user and user.role == "admin": | ||||
| async def get_ollama_api_url(user=Depends(get_admin_user)): | ||||
|     return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| class UrlUpdateForm(BaseModel): | ||||
|  | @ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel): | |||
| 
 | ||||
| @app.post("/url/update") | ||||
| async def update_ollama_api_url( | ||||
|     form_data: UrlUpdateForm, user=Depends(get_current_user) | ||||
|     form_data: UrlUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     if user and user.role == "admin": | ||||
|     app.state.OLLAMA_API_BASE_URL = form_data.url | ||||
|     return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/cancel/{request_id}") | ||||
|  | @ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | |||
|         if path in ["pull", "delete", "push", "copy", "create"]: | ||||
|             if user.role != "admin": | ||||
|                 raise HTTPException( | ||||
|                     status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED | ||||
|                     status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED | ||||
|                 ) | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
|         raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
|     headers.pop("host", None) | ||||
|     headers.pop("authorization", None) | ||||
|  |  | |||
|  | @ -9,7 +9,7 @@ from pydantic import BaseModel | |||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| from constants import ERROR_MESSAGES | ||||
| from utils.utils import decode_token, get_current_user | ||||
| from utils.utils import decode_token, get_current_user, get_verified_user, get_admin_user | ||||
| from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR | ||||
| 
 | ||||
| import hashlib | ||||
|  | @ -37,45 +37,32 @@ class KeyUpdateForm(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| @app.get("/url") | ||||
| async def get_openai_url(user=Depends(get_current_user)): | ||||
|     if user and user.role == "admin": | ||||
| async def get_openai_url(user=Depends(get_admin_user)): | ||||
|     return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/url/update") | ||||
| async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): | ||||
|     if user and user.role == "admin": | ||||
| async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): | ||||
|     app.state.OPENAI_API_BASE_URL = form_data.url | ||||
|     return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/key") | ||||
| async def get_openai_key(user=Depends(get_current_user)): | ||||
|     if user and user.role == "admin": | ||||
| async def get_openai_key(user=Depends(get_admin_user)): | ||||
|     return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/key/update") | ||||
| async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): | ||||
|     if user and user.role == "admin": | ||||
| async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)): | ||||
|     app.state.OPENAI_API_KEY = form_data.key | ||||
|     return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/audio/speech") | ||||
| async def speech(request: Request, user=Depends(get_current_user)): | ||||
| async def speech(request: Request, user=Depends(get_verified_user)): | ||||
|     target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" | ||||
| 
 | ||||
|     if user.role not in ["user", "admin"]: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
|     if app.state.OPENAI_API_KEY == "": | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) | ||||
| 
 | ||||
|  | @ -133,12 +120,10 @@ async def speech(request: Request, user=Depends(get_current_user)): | |||
| 
 | ||||
| 
 | ||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||
| async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||
| async def proxy(path: str, request: Request, user=Depends(get_verified_user)): | ||||
|     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" | ||||
|     print(target_url, app.state.OPENAI_API_KEY) | ||||
| 
 | ||||
|     if user.role not in ["user", "admin"]: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
|     if app.state.OPENAI_API_KEY == "": | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) | ||||
| 
 | ||||
|  |  | |||
|  | @ -39,7 +39,7 @@ import uuid | |||
| import time | ||||
| 
 | ||||
| from utils.misc import calculate_sha256, calculate_sha256_string | ||||
| from utils.utils import get_current_user | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
|  | @ -354,19 +354,12 @@ def store_doc( | |||
| 
 | ||||
| 
 | ||||
| @app.get("/reset/db") | ||||
| def reset_vector_db(user=Depends(get_current_user)): | ||||
|     if user.role == "admin": | ||||
| def reset_vector_db(user=Depends(get_admin_user)): | ||||
|     CHROMA_CLIENT.reset() | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/reset") | ||||
| def reset(user=Depends(get_current_user)) -> bool: | ||||
|     if user.role == "admin": | ||||
| def reset(user=Depends(get_admin_user)) -> bool: | ||||
|     folder = f"{UPLOAD_DIR}" | ||||
|     for filename in os.listdir(folder): | ||||
|         file_path = os.path.join(folder, filename) | ||||
|  | @ -384,8 +377,3 @@ def reset(user=Depends(get_current_user)) -> bool: | |||
|         print(e) | ||||
| 
 | ||||
|     return True | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status | |||
| from datetime import datetime, timedelta | ||||
| from typing import List, Union | ||||
| 
 | ||||
| from fastapi import APIRouter | ||||
| from fastapi import APIRouter, status | ||||
| from pydantic import BaseModel | ||||
| import time | ||||
| import uuid | ||||
|  | @ -19,7 +19,7 @@ from apps.web.models.auths import ( | |||
| ) | ||||
| from apps.web.models.users import Users | ||||
| 
 | ||||
| from utils.utils import get_password_hash, get_current_user, create_token | ||||
| from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token | ||||
| from utils.misc import get_gravatar_url, validate_email_format | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
|  | @ -116,10 +116,10 @@ async def signin(form_data: SigninForm): | |||
| @router.post("/signup", response_model=SigninResponse) | ||||
| async def signup(request: Request, form_data: SignupForm): | ||||
|     if not request.app.state.ENABLE_SIGNUP: | ||||
|         raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
|         raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
|     if not validate_email_format(form_data.email.lower()): | ||||
|         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) | ||||
|         raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) | ||||
| 
 | ||||
|     if Users.get_user_by_email(form_data.email.lower()): | ||||
|         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) | ||||
|  | @ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm): | |||
| 
 | ||||
| 
 | ||||
| @router.get("/signup/enabled", response_model=bool) | ||||
| async def get_sign_up_status(request: Request, user=Depends(get_current_user)): | ||||
|     if user.role == "admin": | ||||
| async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): | ||||
|     return request.app.state.ENABLE_SIGNUP | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/signup/enabled/toggle", response_model=bool) | ||||
| async def toggle_sign_up(request: Request, user=Depends(get_current_user)): | ||||
|     if user.role == "admin": | ||||
| async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): | ||||
|     request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP | ||||
|     return request.app.state.ENABLE_SIGNUP | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| from fastapi import Depends, Request, HTTPException, status | ||||
| from datetime import datetime, timedelta | ||||
| from typing import List, Union, Optional | ||||
| from utils.utils import get_current_user | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from fastapi import APIRouter | ||||
| from pydantic import BaseModel | ||||
| import json | ||||
|  | @ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)): | |||
| 
 | ||||
| 
 | ||||
| @router.get("/all/db", response_model=List[ChatResponse]) | ||||
| async def get_all_user_chats_in_db(user=Depends(get_current_user)): | ||||
|     if user.role == "admin": | ||||
| async def get_all_user_chats_in_db(user=Depends(get_admin_user)): | ||||
|     return [ | ||||
|         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||
|         for chat in Chats.get_all_chats() | ||||
|     ] | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
|  |  | |||
|  | @ -10,7 +10,7 @@ import uuid | |||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| 
 | ||||
| from utils.utils import get_password_hash, get_current_user, create_token | ||||
| from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token | ||||
| from utils.misc import get_gravatar_url, validate_email_format | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
|  | @ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel): | |||
| 
 | ||||
| @router.post("/default/models", response_model=str) | ||||
| async def set_global_default_models( | ||||
|     request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) | ||||
|     request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     if user.role == "admin": | ||||
|     request.app.state.DEFAULT_MODELS = form_data.models | ||||
|     return request.app.state.DEFAULT_MODELS | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/default/suggestions", response_model=List[PromptSuggestion]) | ||||
| async def set_global_default_suggestions( | ||||
|     request: Request, | ||||
|     form_data: SetDefaultSuggestionsForm, | ||||
|     user=Depends(get_current_user), | ||||
|     user=Depends(get_admin_user), | ||||
| ): | ||||
|     if user.role == "admin": | ||||
|     data = form_data.model_dump() | ||||
|     request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] | ||||
|     return request.app.state.DEFAULT_PROMPT_SUGGESTIONS | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ from apps.web.models.documents import ( | |||
|     DocumentResponse, | ||||
| ) | ||||
| 
 | ||||
| from utils.utils import get_current_user | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
|  | @ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)): | |||
| 
 | ||||
| 
 | ||||
| @router.post("/create", response_model=Optional[DocumentResponse]) | ||||
| async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): | ||||
|     doc = Documents.get_doc_by_name(form_data.name) | ||||
|     if doc == None: | ||||
|         doc = Documents.insert_new_doc(user.id, form_data) | ||||
|  | @ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u | |||
| 
 | ||||
| @router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) | ||||
| async def update_doc_by_name( | ||||
|     name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user) | ||||
|     name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|     doc = Documents.update_doc_by_name(name, form_data) | ||||
|     if doc: | ||||
|         return DocumentResponse( | ||||
|  | @ -161,12 +149,6 @@ async def update_doc_by_name( | |||
| 
 | ||||
| 
 | ||||
| @router.delete("/name/{name}/delete", response_model=bool) | ||||
| async def delete_doc_by_name(name: str, user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| async def delete_doc_by_name(name: str, user=Depends(get_admin_user)): | ||||
|     result = Documents.delete_doc_by_name(name) | ||||
|     return result | ||||
|  |  | |||
|  | @ -13,7 +13,7 @@ from apps.web.models.modelfiles import ( | |||
|     ModelfileResponse, | ||||
| ) | ||||
| 
 | ||||
| from utils.utils import get_current_user | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
|  | @ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0, | |||
| 
 | ||||
| @router.post("/create", response_model=Optional[ModelfileResponse]) | ||||
| async def create_new_modelfile(form_data: ModelfileForm, | ||||
|                                user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|                                user=Depends(get_admin_user)): | ||||
|     modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) | ||||
| 
 | ||||
|     if modelfile: | ||||
|  | @ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, | |||
| 
 | ||||
| @router.post("/update", response_model=Optional[ModelfileResponse]) | ||||
| async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, | ||||
|                                        user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
|                                        user=Depends(get_admin_user)): | ||||
|     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) | ||||
|     if modelfile: | ||||
|         updated_modelfile = { | ||||
|  | @ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, | |||
| 
 | ||||
| @router.delete("/delete", response_model=bool) | ||||
| async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm, | ||||
|                                        user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|                                        user=Depends(get_admin_user)): | ||||
|     result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) | ||||
|     return result | ||||
|  |  | |||
|  | @ -8,7 +8,7 @@ import json | |||
| 
 | ||||
| from apps.web.models.prompts import Prompts, PromptForm, PromptModel | ||||
| 
 | ||||
| from utils.utils import get_current_user | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
|  | @ -29,25 +29,17 @@ async def get_prompts(user=Depends(get_current_user)): | |||
| 
 | ||||
| 
 | ||||
| @router.post("/create", response_model=Optional[PromptModel]) | ||||
| async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): | ||||
|     prompt = Prompts.get_prompt_by_command(form_data.command) | ||||
|     if prompt == None: | ||||
|         prompt = Prompts.insert_new_prompt(user.id, form_data) | ||||
| 
 | ||||
|         if prompt: | ||||
|             return prompt | ||||
|         else: | ||||
|         raise HTTPException( | ||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(), | ||||
|         ) | ||||
|     else: | ||||
|     raise HTTPException( | ||||
|         status_code=status.HTTP_400_BAD_REQUEST, | ||||
|         detail=ERROR_MESSAGES.COMMAND_TAKEN, | ||||
|  | @ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)): | |||
| 
 | ||||
| @router.post("/command/{command}/update", response_model=Optional[PromptModel]) | ||||
| async def update_prompt_by_command( | ||||
|     command: str, form_data: PromptForm, user=Depends(get_current_user) | ||||
|     command: str, form_data: PromptForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|     prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) | ||||
|     if prompt: | ||||
|         return prompt | ||||
|  | @ -103,12 +89,6 @@ async def update_prompt_by_command( | |||
| 
 | ||||
| 
 | ||||
| @router.delete("/command/{command}/delete", response_model=bool) | ||||
| async def delete_prompt_by_command(command: str, user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): | ||||
|     result = Prompts.delete_prompt_by_command(f"/{command}") | ||||
|     return result | ||||
|  |  | |||
|  | @ -11,7 +11,7 @@ import uuid | |||
| from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users | ||||
| from apps.web.models.auths import Auths | ||||
| 
 | ||||
| from utils.utils import get_current_user, get_password_hash | ||||
| from utils.utils import get_current_user, get_password_hash, get_admin_user | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
|  | @ -22,12 +22,7 @@ router = APIRouter() | |||
| 
 | ||||
| 
 | ||||
| @router.get("/", response_model=List[UserModel]) | ||||
| async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)): | ||||
|     return Users.get_users(skip, limit) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -38,17 +33,11 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use | |||
| 
 | ||||
| @router.post("/update/role", response_model=Optional[UserModel]) | ||||
| async def update_user_role( | ||||
|     form_data: UserRoleUpdateForm, user=Depends(get_current_user) | ||||
|     form_data: UserRoleUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|     if user.id != form_data.id: | ||||
|         return Users.update_user_role_by_id(form_data.id, form_data.role) | ||||
|     else: | ||||
| 
 | ||||
|     raise HTTPException( | ||||
|         status_code=status.HTTP_403_FORBIDDEN, | ||||
|         detail=ERROR_MESSAGES.ACTION_PROHIBITED, | ||||
|  | @ -62,14 +51,8 @@ async def update_user_role( | |||
| 
 | ||||
| @router.post("/{user_id}/update", response_model=Optional[UserModel]) | ||||
| async def update_user_by_id( | ||||
|     user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user) | ||||
|     user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user) | ||||
| ): | ||||
|     if session_user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|     user = Users.get_user_by_id(user_id) | ||||
| 
 | ||||
|     if user: | ||||
|  | @ -98,13 +81,12 @@ async def update_user_by_id( | |||
| 
 | ||||
|         if updated_user: | ||||
|             return updated_user | ||||
|         else: | ||||
| 
 | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(), | ||||
|         ) | ||||
| 
 | ||||
|     else: | ||||
|     raise HTTPException( | ||||
|         status_code=status.HTTP_400_BAD_REQUEST, | ||||
|         detail=ERROR_MESSAGES.USER_NOT_FOUND, | ||||
|  | @ -117,25 +99,20 @@ async def update_user_by_id( | |||
| 
 | ||||
| 
 | ||||
| @router.delete("/{user_id}", response_model=bool) | ||||
| async def delete_user_by_id(user_id: str, user=Depends(get_current_user)): | ||||
|     if user.role == "admin": | ||||
| async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): | ||||
|     if user.id != user_id: | ||||
|         result = Auths.delete_auth_by_id(user_id) | ||||
| 
 | ||||
|         if result: | ||||
|             return True | ||||
|             else: | ||||
| 
 | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||
|             detail=ERROR_MESSAGES.DELETE_USER_ERROR, | ||||
|         ) | ||||
|         else: | ||||
| 
 | ||||
|     raise HTTPException( | ||||
|         status_code=status.HTTP_403_FORBIDDEN, | ||||
|         detail=ERROR_MESSAGES.ACTION_PROHIBITED, | ||||
|     ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -73,3 +73,19 @@ def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_s | |||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.UNAUTHORIZED, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def get_verified_user(user: Users = Depends(get_current_user)): | ||||
|     if user.role not in {"user", "admin"}: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def get_admin_user(user: Users = Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Tim Farrell
						Tim Farrell