forked from open-webui/open-webui
		
	Endpoint role-checking was redundantly applied but FastAPI provides a nice abstraction mechanic...so I applied it. There should be no logical changes in this code; only simpler, cleaner ways for doing the same thing.
This commit is contained in:
		
							parent
							
								
									46d0eff218
								
							
						
					
					
						commit
						08e8e922fd
					
				
					 11 changed files with 127 additions and 251 deletions
				
			
		|  | @ -1,4 +1,4 @@ | ||||||
| from fastapi import FastAPI, Request, Response, HTTPException, Depends | from fastapi import FastAPI, Request, Response, HTTPException, Depends, status | ||||||
| from fastapi.middleware.cors import CORSMiddleware | from fastapi.middleware.cors import CORSMiddleware | ||||||
| from fastapi.responses import StreamingResponse | from fastapi.responses import StreamingResponse | ||||||
| from fastapi.concurrency import run_in_threadpool | from fastapi.concurrency import run_in_threadpool | ||||||
|  | @ -10,7 +10,7 @@ from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
| from apps.web.models.users import Users | from apps.web.models.users import Users | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| from utils.utils import decode_token, get_current_user | from utils.utils import decode_token, get_current_user, get_admin_user | ||||||
| from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | ||||||
| 
 | 
 | ||||||
| app = FastAPI() | app = FastAPI() | ||||||
|  | @ -31,11 +31,8 @@ REQUEST_POOL = [] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.get("/url") | @app.get("/url") | ||||||
| async def get_ollama_api_url(user=Depends(get_current_user)): | async def get_ollama_api_url(user=Depends(get_admin_user)): | ||||||
|     if user and user.role == "admin": |     return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} | ||||||
|         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} |  | ||||||
|     else: |  | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class UrlUpdateForm(BaseModel): | class UrlUpdateForm(BaseModel): | ||||||
|  | @ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel): | ||||||
| 
 | 
 | ||||||
| @app.post("/url/update") | @app.post("/url/update") | ||||||
| async def update_ollama_api_url( | async def update_ollama_api_url( | ||||||
|     form_data: UrlUpdateForm, user=Depends(get_current_user) |     form_data: UrlUpdateForm, user=Depends(get_admin_user) | ||||||
| ): | ): | ||||||
|     if user and user.role == "admin": |     app.state.OLLAMA_API_BASE_URL = form_data.url | ||||||
|         app.state.OLLAMA_API_BASE_URL = form_data.url |     return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} | ||||||
|         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} |  | ||||||
|     else: |  | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.get("/cancel/{request_id}") | @app.get("/cancel/{request_id}") | ||||||
|  | @ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
|         if path in ["pull", "delete", "push", "copy", "create"]: |         if path in ["pull", "delete", "push", "copy", "create"]: | ||||||
|             if user.role != "admin": |             if user.role != "admin": | ||||||
|                 raise HTTPException( |                 raise HTTPException( | ||||||
|                     status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED |                     status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED | ||||||
|                 ) |                 ) | ||||||
|     else: |     else: | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |         raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
| 
 | 
 | ||||||
|     headers.pop("host", None) |     headers.pop("host", None) | ||||||
|     headers.pop("authorization", None) |     headers.pop("authorization", None) | ||||||
|  |  | ||||||
|  | @ -9,7 +9,7 @@ from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
| from apps.web.models.users import Users | from apps.web.models.users import Users | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| from utils.utils import decode_token, get_current_user | from utils.utils import decode_token, get_current_user, get_verified_user, get_admin_user | ||||||
| from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR | from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR | ||||||
| 
 | 
 | ||||||
| import hashlib | import hashlib | ||||||
|  | @ -37,45 +37,32 @@ class KeyUpdateForm(BaseModel): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.get("/url") | @app.get("/url") | ||||||
| async def get_openai_url(user=Depends(get_current_user)): | async def get_openai_url(user=Depends(get_admin_user)): | ||||||
|     if user and user.role == "admin": |     return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} | ||||||
|         return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} |  | ||||||
|     else: |  | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/url/update") | @app.post("/url/update") | ||||||
| async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): | async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): | ||||||
|     if user and user.role == "admin": |     app.state.OPENAI_API_BASE_URL = form_data.url | ||||||
|         app.state.OPENAI_API_BASE_URL = form_data.url |     return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} | ||||||
|         return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} | 
 | ||||||
|     else: |  | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.get("/key") | @app.get("/key") | ||||||
| async def get_openai_key(user=Depends(get_current_user)): | async def get_openai_key(user=Depends(get_admin_user)): | ||||||
|     if user and user.role == "admin": |     return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} | ||||||
|         return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} |  | ||||||
|     else: |  | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/key/update") | @app.post("/key/update") | ||||||
| async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): | async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)): | ||||||
|     if user and user.role == "admin": |     app.state.OPENAI_API_KEY = form_data.key | ||||||
|         app.state.OPENAI_API_KEY = form_data.key |     return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} | ||||||
|         return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} |  | ||||||
|     else: |  | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/audio/speech") | @app.post("/audio/speech") | ||||||
| async def speech(request: Request, user=Depends(get_current_user)): | async def speech(request: Request, user=Depends(get_verified_user)): | ||||||
|     target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" |     target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" | ||||||
| 
 | 
 | ||||||
|     if user.role not in ["user", "admin"]: |  | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
|     if app.state.OPENAI_API_KEY == "": |     if app.state.OPENAI_API_KEY == "": | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) | ||||||
| 
 | 
 | ||||||
|  | @ -133,12 +120,10 @@ async def speech(request: Request, user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||||
| async def proxy(path: str, request: Request, user=Depends(get_current_user)): | async def proxy(path: str, request: Request, user=Depends(get_verified_user)): | ||||||
|     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" |     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" | ||||||
|     print(target_url, app.state.OPENAI_API_KEY) |     print(target_url, app.state.OPENAI_API_KEY) | ||||||
| 
 | 
 | ||||||
|     if user.role not in ["user", "admin"]: |  | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
|     if app.state.OPENAI_API_KEY == "": |     if app.state.OPENAI_API_KEY == "": | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -39,7 +39,7 @@ import uuid | ||||||
| import time | import time | ||||||
| 
 | 
 | ||||||
| from utils.misc import calculate_sha256, calculate_sha256_string | from utils.misc import calculate_sha256, calculate_sha256_string | ||||||
| from utils.utils import get_current_user | from utils.utils import get_current_user, get_admin_user | ||||||
| from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
|  | @ -354,38 +354,26 @@ def store_doc( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.get("/reset/db") | @app.get("/reset/db") | ||||||
| def reset_vector_db(user=Depends(get_current_user)): | def reset_vector_db(user=Depends(get_admin_user)): | ||||||
|     if user.role == "admin": |     CHROMA_CLIENT.reset() | ||||||
|         CHROMA_CLIENT.reset() |  | ||||||
|     else: |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.get("/reset") | @app.get("/reset") | ||||||
| def reset(user=Depends(get_current_user)) -> bool: | def reset(user=Depends(get_admin_user)) -> bool: | ||||||
|     if user.role == "admin": |     folder = f"{UPLOAD_DIR}" | ||||||
|         folder = f"{UPLOAD_DIR}" |     for filename in os.listdir(folder): | ||||||
|         for filename in os.listdir(folder): |         file_path = os.path.join(folder, filename) | ||||||
|             file_path = os.path.join(folder, filename) |  | ||||||
|             try: |  | ||||||
|                 if os.path.isfile(file_path) or os.path.islink(file_path): |  | ||||||
|                     os.unlink(file_path) |  | ||||||
|                 elif os.path.isdir(file_path): |  | ||||||
|                     shutil.rmtree(file_path) |  | ||||||
|             except Exception as e: |  | ||||||
|                 print("Failed to delete %s. Reason: %s" % (file_path, e)) |  | ||||||
| 
 |  | ||||||
|         try: |         try: | ||||||
|             CHROMA_CLIENT.reset() |             if os.path.isfile(file_path) or os.path.islink(file_path): | ||||||
|  |                 os.unlink(file_path) | ||||||
|  |             elif os.path.isdir(file_path): | ||||||
|  |                 shutil.rmtree(file_path) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(e) |             print("Failed to delete %s. Reason: %s" % (file_path, e)) | ||||||
| 
 | 
 | ||||||
|         return True |     try: | ||||||
|     else: |         CHROMA_CLIENT.reset() | ||||||
|         raise HTTPException( |     except Exception as e: | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |         print(e) | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | 
 | ||||||
|         ) |     return True | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status | ||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
| from typing import List, Union | from typing import List, Union | ||||||
| 
 | 
 | ||||||
| from fastapi import APIRouter | from fastapi import APIRouter, status | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| import time | import time | ||||||
| import uuid | import uuid | ||||||
|  | @ -19,7 +19,7 @@ from apps.web.models.auths import ( | ||||||
| ) | ) | ||||||
| from apps.web.models.users import Users | from apps.web.models.users import Users | ||||||
| 
 | 
 | ||||||
| from utils.utils import get_password_hash, get_current_user, create_token | from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token | ||||||
| from utils.misc import get_gravatar_url, validate_email_format | from utils.misc import get_gravatar_url, validate_email_format | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
|  | @ -116,10 +116,10 @@ async def signin(form_data: SigninForm): | ||||||
| @router.post("/signup", response_model=SigninResponse) | @router.post("/signup", response_model=SigninResponse) | ||||||
| async def signup(request: Request, form_data: SignupForm): | async def signup(request: Request, form_data: SignupForm): | ||||||
|     if not request.app.state.ENABLE_SIGNUP: |     if not request.app.state.ENABLE_SIGNUP: | ||||||
|         raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |         raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
| 
 | 
 | ||||||
|     if not validate_email_format(form_data.email.lower()): |     if not validate_email_format(form_data.email.lower()): | ||||||
|         raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) |         raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) | ||||||
| 
 | 
 | ||||||
|     if Users.get_user_by_email(form_data.email.lower()): |     if Users.get_user_by_email(form_data.email.lower()): | ||||||
|         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) |         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) | ||||||
|  | @ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.get("/signup/enabled", response_model=bool) | @router.get("/signup/enabled", response_model=bool) | ||||||
| async def get_sign_up_status(request: Request, user=Depends(get_current_user)): | async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): | ||||||
|     if user.role == "admin": |     return request.app.state.ENABLE_SIGNUP | ||||||
|         return request.app.state.ENABLE_SIGNUP |  | ||||||
|     else: |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.get("/signup/enabled/toggle", response_model=bool) | @router.get("/signup/enabled/toggle", response_model=bool) | ||||||
| async def toggle_sign_up(request: Request, user=Depends(get_current_user)): | async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): | ||||||
|     if user.role == "admin": |     request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP | ||||||
|         request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP |     return request.app.state.ENABLE_SIGNUP | ||||||
|         return request.app.state.ENABLE_SIGNUP |  | ||||||
|     else: |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| from fastapi import Depends, Request, HTTPException, status | from fastapi import Depends, Request, HTTPException, status | ||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
| from typing import List, Union, Optional | from typing import List, Union, Optional | ||||||
| from utils.utils import get_current_user | from utils.utils import get_current_user, get_admin_user | ||||||
| from fastapi import APIRouter | from fastapi import APIRouter | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| import json | import json | ||||||
|  | @ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.get("/all/db", response_model=List[ChatResponse]) | @router.get("/all/db", response_model=List[ChatResponse]) | ||||||
| async def get_all_user_chats_in_db(user=Depends(get_current_user)): | async def get_all_user_chats_in_db(user=Depends(get_admin_user)): | ||||||
|     if user.role == "admin": |     return [ | ||||||
|         return [ |         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||||
|             ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) |         for chat in Chats.get_all_chats() | ||||||
|             for chat in Chats.get_all_chats() |     ] | ||||||
|         ] |  | ||||||
|     else: |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ############################ | ############################ | ||||||
|  |  | ||||||
|  | @ -10,7 +10,7 @@ import uuid | ||||||
| 
 | 
 | ||||||
| from apps.web.models.users import Users | from apps.web.models.users import Users | ||||||
| 
 | 
 | ||||||
| from utils.utils import get_password_hash, get_current_user, create_token | from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token | ||||||
| from utils.misc import get_gravatar_url, validate_email_format | from utils.misc import get_gravatar_url, validate_email_format | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
|  | @ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel): | ||||||
| 
 | 
 | ||||||
| @router.post("/default/models", response_model=str) | @router.post("/default/models", response_model=str) | ||||||
| async def set_global_default_models( | async def set_global_default_models( | ||||||
|     request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) |     request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) | ||||||
| ): | ): | ||||||
|     if user.role == "admin": |     request.app.state.DEFAULT_MODELS = form_data.models | ||||||
|         request.app.state.DEFAULT_MODELS = form_data.models |     return request.app.state.DEFAULT_MODELS | ||||||
|         return request.app.state.DEFAULT_MODELS | 
 | ||||||
|     else: |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.post("/default/suggestions", response_model=List[PromptSuggestion]) | @router.post("/default/suggestions", response_model=List[PromptSuggestion]) | ||||||
| async def set_global_default_suggestions( | async def set_global_default_suggestions( | ||||||
|     request: Request, |     request: Request, | ||||||
|     form_data: SetDefaultSuggestionsForm, |     form_data: SetDefaultSuggestionsForm, | ||||||
|     user=Depends(get_current_user), |     user=Depends(get_admin_user), | ||||||
| ): | ): | ||||||
|     if user.role == "admin": |     data = form_data.model_dump() | ||||||
|         data = form_data.model_dump() |     request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] | ||||||
|         request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] |     return request.app.state.DEFAULT_PROMPT_SUGGESTIONS | ||||||
|         return request.app.state.DEFAULT_PROMPT_SUGGESTIONS |  | ||||||
|     else: |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ from apps.web.models.documents import ( | ||||||
|     DocumentResponse, |     DocumentResponse, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from utils.utils import get_current_user | from utils.utils import get_current_user, get_admin_user | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
| router = APIRouter() | router = APIRouter() | ||||||
|  | @ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.post("/create", response_model=Optional[DocumentResponse]) | @router.post("/create", response_model=Optional[DocumentResponse]) | ||||||
| async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)): | async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     doc = Documents.get_doc_by_name(form_data.name) |     doc = Documents.get_doc_by_name(form_data.name) | ||||||
|     if doc == None: |     if doc == None: | ||||||
|         doc = Documents.insert_new_doc(user.id, form_data) |         doc = Documents.insert_new_doc(user.id, form_data) | ||||||
|  | @ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u | ||||||
| 
 | 
 | ||||||
| @router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) | @router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) | ||||||
| async def update_doc_by_name( | async def update_doc_by_name( | ||||||
|     name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user) |     name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user) | ||||||
| ): | ): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     doc = Documents.update_doc_by_name(name, form_data) |     doc = Documents.update_doc_by_name(name, form_data) | ||||||
|     if doc: |     if doc: | ||||||
|         return DocumentResponse( |         return DocumentResponse( | ||||||
|  | @ -161,12 +149,6 @@ async def update_doc_by_name( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.delete("/name/{name}/delete", response_model=bool) | @router.delete("/name/{name}/delete", response_model=bool) | ||||||
| async def delete_doc_by_name(name: str, user=Depends(get_current_user)): | async def delete_doc_by_name(name: str, user=Depends(get_admin_user)): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     result = Documents.delete_doc_by_name(name) |     result = Documents.delete_doc_by_name(name) | ||||||
|     return result |     return result | ||||||
|  |  | ||||||
|  | @ -13,7 +13,7 @@ from apps.web.models.modelfiles import ( | ||||||
|     ModelfileResponse, |     ModelfileResponse, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from utils.utils import get_current_user | from utils.utils import get_current_user, get_admin_user | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
| router = APIRouter() | router = APIRouter() | ||||||
|  | @ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0, | ||||||
| 
 | 
 | ||||||
| @router.post("/create", response_model=Optional[ModelfileResponse]) | @router.post("/create", response_model=Optional[ModelfileResponse]) | ||||||
| async def create_new_modelfile(form_data: ModelfileForm, | async def create_new_modelfile(form_data: ModelfileForm, | ||||||
|                                user=Depends(get_current_user)): |                                user=Depends(get_admin_user)): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) |     modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) | ||||||
| 
 | 
 | ||||||
|     if modelfile: |     if modelfile: | ||||||
|  | @ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, | ||||||
| 
 | 
 | ||||||
| @router.post("/update", response_model=Optional[ModelfileResponse]) | @router.post("/update", response_model=Optional[ModelfileResponse]) | ||||||
| async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, | async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, | ||||||
|                                        user=Depends(get_current_user)): |                                        user=Depends(get_admin_user)): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
|     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) |     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) | ||||||
|     if modelfile: |     if modelfile: | ||||||
|         updated_modelfile = { |         updated_modelfile = { | ||||||
|  | @ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, | ||||||
| 
 | 
 | ||||||
| @router.delete("/delete", response_model=bool) | @router.delete("/delete", response_model=bool) | ||||||
| async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm, | async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm, | ||||||
|                                        user=Depends(get_current_user)): |                                        user=Depends(get_admin_user)): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) |     result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) | ||||||
|     return result |     return result | ||||||
|  |  | ||||||
|  | @ -8,7 +8,7 @@ import json | ||||||
| 
 | 
 | ||||||
| from apps.web.models.prompts import Prompts, PromptForm, PromptModel | from apps.web.models.prompts import Prompts, PromptForm, PromptModel | ||||||
| 
 | 
 | ||||||
| from utils.utils import get_current_user | from utils.utils import get_current_user, get_admin_user | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
| router = APIRouter() | router = APIRouter() | ||||||
|  | @ -29,29 +29,21 @@ async def get_prompts(user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.post("/create", response_model=Optional[PromptModel]) | @router.post("/create", response_model=Optional[PromptModel]) | ||||||
| async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)): | async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     prompt = Prompts.get_prompt_by_command(form_data.command) |     prompt = Prompts.get_prompt_by_command(form_data.command) | ||||||
|     if prompt == None: |     if prompt == None: | ||||||
|         prompt = Prompts.insert_new_prompt(user.id, form_data) |         prompt = Prompts.insert_new_prompt(user.id, form_data) | ||||||
| 
 | 
 | ||||||
|         if prompt: |         if prompt: | ||||||
|             return prompt |             return prompt | ||||||
|         else: |  | ||||||
|             raise HTTPException( |  | ||||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|                 detail=ERROR_MESSAGES.DEFAULT(), |  | ||||||
|             ) |  | ||||||
|     else: |  | ||||||
|         raise HTTPException( |         raise HTTPException( | ||||||
|             status_code=status.HTTP_400_BAD_REQUEST, |             status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|             detail=ERROR_MESSAGES.COMMAND_TAKEN, |             detail=ERROR_MESSAGES.DEFAULT(), | ||||||
|         ) |         ) | ||||||
|  |     raise HTTPException( | ||||||
|  |         status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|  |         detail=ERROR_MESSAGES.COMMAND_TAKEN, | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ############################ | ############################ | ||||||
|  | @ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
| @router.post("/command/{command}/update", response_model=Optional[PromptModel]) | @router.post("/command/{command}/update", response_model=Optional[PromptModel]) | ||||||
| async def update_prompt_by_command( | async def update_prompt_by_command( | ||||||
|     command: str, form_data: PromptForm, user=Depends(get_current_user) |     command: str, form_data: PromptForm, user=Depends(get_admin_user) | ||||||
| ): | ): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) |     prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) | ||||||
|     if prompt: |     if prompt: | ||||||
|         return prompt |         return prompt | ||||||
|  | @ -103,12 +89,6 @@ async def update_prompt_by_command( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.delete("/command/{command}/delete", response_model=bool) | @router.delete("/command/{command}/delete", response_model=bool) | ||||||
| async def delete_prompt_by_command(command: str, user=Depends(get_current_user)): | async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     result = Prompts.delete_prompt_by_command(f"/{command}") |     result = Prompts.delete_prompt_by_command(f"/{command}") | ||||||
|     return result |     return result | ||||||
|  |  | ||||||
|  | @ -11,7 +11,7 @@ import uuid | ||||||
| from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users | from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users | ||||||
| from apps.web.models.auths import Auths | from apps.web.models.auths import Auths | ||||||
| 
 | 
 | ||||||
| from utils.utils import get_current_user, get_password_hash | from utils.utils import get_current_user, get_password_hash, get_admin_user | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
| router = APIRouter() | router = APIRouter() | ||||||
|  | @ -22,12 +22,7 @@ router = APIRouter() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.get("/", response_model=List[UserModel]) | @router.get("/", response_model=List[UserModel]) | ||||||
| async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)): | async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
|     return Users.get_users(skip, limit) |     return Users.get_users(skip, limit) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -38,21 +33,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use | ||||||
| 
 | 
 | ||||||
| @router.post("/update/role", response_model=Optional[UserModel]) | @router.post("/update/role", response_model=Optional[UserModel]) | ||||||
| async def update_user_role( | async def update_user_role( | ||||||
|     form_data: UserRoleUpdateForm, user=Depends(get_current_user) |     form_data: UserRoleUpdateForm, user=Depends(get_admin_user) | ||||||
| ): | ): | ||||||
|     if user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     if user.id != form_data.id: |     if user.id != form_data.id: | ||||||
|         return Users.update_user_role_by_id(form_data.id, form_data.role) |         return Users.update_user_role_by_id(form_data.id, form_data.role) | ||||||
|     else: | 
 | ||||||
|         raise HTTPException( |     raise HTTPException( | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |         status_code=status.HTTP_403_FORBIDDEN, | ||||||
|             detail=ERROR_MESSAGES.ACTION_PROHIBITED, |         detail=ERROR_MESSAGES.ACTION_PROHIBITED, | ||||||
|         ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ############################ | ############################ | ||||||
|  | @ -62,14 +51,8 @@ async def update_user_role( | ||||||
| 
 | 
 | ||||||
| @router.post("/{user_id}/update", response_model=Optional[UserModel]) | @router.post("/{user_id}/update", response_model=Optional[UserModel]) | ||||||
| async def update_user_by_id( | async def update_user_by_id( | ||||||
|     user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user) |     user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user) | ||||||
| ): | ): | ||||||
|     if session_user.role != "admin": |  | ||||||
|         raise HTTPException( |  | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     user = Users.get_user_by_id(user_id) |     user = Users.get_user_by_id(user_id) | ||||||
| 
 | 
 | ||||||
|     if user: |     if user: | ||||||
|  | @ -98,18 +81,17 @@ async def update_user_by_id( | ||||||
| 
 | 
 | ||||||
|         if updated_user: |         if updated_user: | ||||||
|             return updated_user |             return updated_user | ||||||
|         else: |  | ||||||
|             raise HTTPException( |  | ||||||
|                 status_code=status.HTTP_400_BAD_REQUEST, |  | ||||||
|                 detail=ERROR_MESSAGES.DEFAULT(), |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|     else: |  | ||||||
|         raise HTTPException( |         raise HTTPException( | ||||||
|             status_code=status.HTTP_400_BAD_REQUEST, |             status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|             detail=ERROR_MESSAGES.USER_NOT_FOUND, |             detail=ERROR_MESSAGES.DEFAULT(), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     raise HTTPException( | ||||||
|  |         status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|  |         detail=ERROR_MESSAGES.USER_NOT_FOUND, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| ############################ | ############################ | ||||||
| # DeleteUserById | # DeleteUserById | ||||||
|  | @ -117,25 +99,20 @@ async def update_user_by_id( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.delete("/{user_id}", response_model=bool) | @router.delete("/{user_id}", response_model=bool) | ||||||
| async def delete_user_by_id(user_id: str, user=Depends(get_current_user)): | async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): | ||||||
|     if user.role == "admin": |     if user.id != user_id: | ||||||
|         if user.id != user_id: |         result = Auths.delete_auth_by_id(user_id) | ||||||
|             result = Auths.delete_auth_by_id(user_id) | 
 | ||||||
|  |         if result: | ||||||
|  |             return True | ||||||
| 
 | 
 | ||||||
|             if result: |  | ||||||
|                 return True |  | ||||||
|             else: |  | ||||||
|                 raise HTTPException( |  | ||||||
|                     status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |  | ||||||
|                     detail=ERROR_MESSAGES.DELETE_USER_ERROR, |  | ||||||
|                 ) |  | ||||||
|         else: |  | ||||||
|             raise HTTPException( |  | ||||||
|                 status_code=status.HTTP_403_FORBIDDEN, |  | ||||||
|                 detail=ERROR_MESSAGES.ACTION_PROHIBITED, |  | ||||||
|             ) |  | ||||||
|     else: |  | ||||||
|         raise HTTPException( |         raise HTTPException( | ||||||
|             status_code=status.HTTP_403_FORBIDDEN, |             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |             detail=ERROR_MESSAGES.DELETE_USER_ERROR, | ||||||
|         ) |         ) | ||||||
|  | 
 | ||||||
|  |     raise HTTPException( | ||||||
|  |         status_code=status.HTTP_403_FORBIDDEN, | ||||||
|  |         detail=ERROR_MESSAGES.ACTION_PROHIBITED, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  | @ -73,3 +73,19 @@ def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_s | ||||||
|             status_code=status.HTTP_401_UNAUTHORIZED, |             status_code=status.HTTP_401_UNAUTHORIZED, | ||||||
|             detail=ERROR_MESSAGES.UNAUTHORIZED, |             detail=ERROR_MESSAGES.UNAUTHORIZED, | ||||||
|         ) |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_verified_user(user: Users = Depends(get_current_user)): | ||||||
|  |     if user.role not in {"user", "admin"}: | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=status.HTTP_401_UNAUTHORIZED, | ||||||
|  |             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_admin_user(user: Users = Depends(get_current_user)): | ||||||
|  |     if user.role != "admin": | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=status.HTTP_401_UNAUTHORIZED, | ||||||
|  |             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||||
|  |         ) | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Tim Farrell
						Tim Farrell