Endpoint role-checking was redundantly applied but FastAPI provides a nice abstraction mechanic...so I applied it. There should be no logical changes in this code; only simpler, cleaner ways for doing the same thing.

This commit is contained in:
Tim Farrell 2024-02-08 18:05:01 -06:00
parent 46d0eff218
commit 08e8e922fd
11 changed files with 127 additions and 251 deletions

View file

@ -1,4 +1,4 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
@ -10,7 +10,7 @@ from pydantic import BaseModel
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = FastAPI() app = FastAPI()
@ -31,11 +31,8 @@ REQUEST_POOL = []
@app.get("/url") @app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)): async def get_ollama_api_url(user=Depends(get_admin_user)):
if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
@ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/url/update") @app.post("/url/update")
async def update_ollama_api_url( async def update_ollama_api_url(
form_data: UrlUpdateForm, user=Depends(get_current_user) form_data: UrlUpdateForm, user=Depends(get_admin_user)
): ):
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/cancel/{request_id}") @app.get("/cancel/{request_id}")
@ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
if path in ["pull", "delete", "push", "copy", "create"]: if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("host", None) headers.pop("host", None)
headers.pop("authorization", None) headers.pop("authorization", None)

View file

@ -9,7 +9,7 @@ from pydantic import BaseModel
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user from utils.utils import decode_token, get_current_user, get_verified_user, get_admin_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
import hashlib import hashlib
@ -37,45 +37,32 @@ class KeyUpdateForm(BaseModel):
@app.get("/url") @app.get("/url")
async def get_openai_url(user=Depends(get_current_user)): async def get_openai_url(user=Depends(get_admin_user)):
if user and user.role == "admin":
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/url/update") @app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
if user and user.role == "admin":
app.state.OPENAI_API_BASE_URL = form_data.url app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/key") @app.get("/key")
async def get_openai_key(user=Depends(get_current_user)): async def get_openai_key(user=Depends(get_admin_user)):
if user and user.role == "admin":
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/key/update") @app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
if user and user.role == "admin":
app.state.OPENAI_API_KEY = form_data.key app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/audio/speech") @app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_current_user)): async def speech(request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "": if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
@ -133,12 +120,10 @@ async def speech(request: Request, user=Depends(get_current_user)):
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)): async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY) print(target_url, app.state.OPENAI_API_KEY)
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "": if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

View file

@ -39,7 +39,7 @@ import uuid
import time import time
from utils.misc import calculate_sha256, calculate_sha256_string from utils.misc import calculate_sha256, calculate_sha256_string
from utils.utils import get_current_user from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -354,19 +354,12 @@ def store_doc(
@app.get("/reset/db") @app.get("/reset/db")
def reset_vector_db(user=Depends(get_current_user)): def reset_vector_db(user=Depends(get_admin_user)):
if user.role == "admin":
CHROMA_CLIENT.reset() CHROMA_CLIENT.reset()
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@app.get("/reset") @app.get("/reset")
def reset(user=Depends(get_current_user)) -> bool: def reset(user=Depends(get_admin_user)) -> bool:
if user.role == "admin":
folder = f"{UPLOAD_DIR}" folder = f"{UPLOAD_DIR}"
for filename in os.listdir(folder): for filename in os.listdir(folder):
file_path = os.path.join(folder, filename) file_path = os.path.join(folder, filename)
@ -384,8 +377,3 @@ def reset(user=Depends(get_current_user)) -> bool:
print(e) print(e)
return True return True
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)

View file

@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union from typing import List, Union
from fastapi import APIRouter from fastapi import APIRouter, status
from pydantic import BaseModel from pydantic import BaseModel
import time import time
import uuid import uuid
@ -19,7 +19,7 @@ from apps.web.models.auths import (
) )
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -116,10 +116,10 @@ async def signin(form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm): async def signup(request: Request, form_data: SignupForm):
if not request.app.state.ENABLE_SIGNUP: if not request.app.state.ENABLE_SIGNUP:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if not validate_email_format(form_data.email.lower()): if not validate_email_format(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
if Users.get_user_by_email(form_data.email.lower()): if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm):
@router.get("/signup/enabled", response_model=bool) @router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_current_user)): async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
if user.role == "admin":
return request.app.state.ENABLE_SIGNUP return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@router.get("/signup/enabled/toggle", response_model=bool) @router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_current_user)): async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
if user.role == "admin":
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)

View file

@ -1,7 +1,7 @@
from fastapi import Depends, Request, HTTPException, status from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from utils.utils import get_current_user from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
@ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)):
@router.get("/all/db", response_model=List[ChatResponse]) @router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_current_user)): async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
if user.role == "admin":
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats() for chat in Chats.get_all_chats()
] ]
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################ ############################

View file

@ -10,7 +10,7 @@ import uuid
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel):
@router.post("/default/models", response_model=str) @router.post("/default/models", response_model=str)
async def set_global_default_models( async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
): ):
if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS return request.app.state.DEFAULT_MODELS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@router.post("/default/suggestions", response_model=List[PromptSuggestion]) @router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions( async def set_global_default_suggestions(
request: Request, request: Request,
form_data: SetDefaultSuggestionsForm, form_data: SetDefaultSuggestionsForm,
user=Depends(get_current_user), user=Depends(get_admin_user),
): ):
if user.role == "admin":
data = form_data.model_dump() data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)

View file

@ -14,7 +14,7 @@ from apps.web.models.documents import (
DocumentResponse, DocumentResponse,
) )
from utils.utils import get_current_user from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[DocumentResponse]) @router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)): async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
doc = Documents.get_doc_by_name(form_data.name) doc = Documents.get_doc_by_name(form_data.name)
if doc == None: if doc == None:
doc = Documents.insert_new_doc(user.id, form_data) doc = Documents.insert_new_doc(user.id, form_data)
@ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) @router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name( async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user) name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
): ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
doc = Documents.update_doc_by_name(name, form_data) doc = Documents.update_doc_by_name(name, form_data)
if doc: if doc:
return DocumentResponse( return DocumentResponse(
@ -161,12 +149,6 @@ async def update_doc_by_name(
@router.delete("/name/{name}/delete", response_model=bool) @router.delete("/name/{name}/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_current_user)): async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Documents.delete_doc_by_name(name) result = Documents.delete_doc_by_name(name)
return result return result

View file

@ -13,7 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse, ModelfileResponse,
) )
from utils.utils import get_current_user from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0,
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_current_user)): user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile: if modelfile:
@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_current_user)): user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
updated_modelfile = { updated_modelfile = {
@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm, async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)): user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result return result

View file

@ -8,7 +8,7 @@ import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -29,25 +29,17 @@ async def get_prompts(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[PromptModel]) @router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)): async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.get_prompt_by_command(form_data.command) prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None: if prompt == None:
prompt = Prompts.insert_new_prompt(user.id, form_data) prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt: if prompt:
return prompt return prompt
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(), detail=ERROR_MESSAGES.DEFAULT(),
) )
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.COMMAND_TAKEN, detail=ERROR_MESSAGES.COMMAND_TAKEN,
@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel]) @router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command( async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_current_user) command: str, form_data: PromptForm, user=Depends(get_admin_user)
): ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt: if prompt:
return prompt return prompt
@ -103,12 +89,6 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool) @router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)): async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Prompts.delete_prompt_by_command(f"/{command}") result = Prompts.delete_prompt_by_command(f"/{command}")
return result return result

View file

@ -11,7 +11,7 @@ import uuid
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -22,12 +22,7 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel]) @router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)): async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return Users.get_users(skip, limit) return Users.get_users(skip, limit)
@ -38,17 +33,11 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role( async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_current_user) form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
): ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user.id != form_data.id: if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
@ -62,14 +51,8 @@ async def update_user_role(
@router.post("/{user_id}/update", response_model=Optional[UserModel]) @router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id( async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user) user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
): ):
if session_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
if user: if user:
@ -98,13 +81,12 @@ async def update_user_by_id(
if updated_user: if updated_user:
return updated_user return updated_user
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(), detail=ERROR_MESSAGES.DEFAULT(),
) )
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND, detail=ERROR_MESSAGES.USER_NOT_FOUND,
@ -117,25 +99,20 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_current_user)): async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
if user.role == "admin":
if user.id != user_id: if user.id != user_id:
result = Auths.delete_auth_by_id(user_id) result = Auths.delete_auth_by_id(user_id)
if result: if result:
return True return True
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DELETE_USER_ERROR, detail=ERROR_MESSAGES.DELETE_USER_ERROR,
) )
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)

View file

@ -73,3 +73,19 @@ def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_s
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
def get_verified_user(user: Users = Depends(get_current_user)):
if user.role not in {"user", "admin"}:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
def get_admin_user(user: Users = Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)