Endpoint role-checking was redundantly applied but FastAPI provides a nice abstraction mechanic...so I applied it. There should be no logical changes in this code; only simpler, cleaner ways for doing the same thing.

This commit is contained in:
Tim Farrell 2024-02-08 18:05:01 -06:00
parent 46d0eff218
commit 08e8e922fd
11 changed files with 127 additions and 251 deletions

View file

@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import APIRouter
from fastapi import APIRouter, status
from pydantic import BaseModel
import time
import uuid
@ -19,7 +19,7 @@ from apps.web.models.auths import (
)
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -116,10 +116,10 @@ async def signin(form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm):
if not request.app.state.ENABLE_SIGNUP:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if not validate_email_format(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm):
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_current_user)):
if user.role == "admin":
return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.ENABLE_SIGNUP
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_current_user)):
if user.role == "admin":
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP

View file

@ -1,7 +1,7 @@
from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter
from pydantic import BaseModel
import json
@ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)):
@router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_current_user)):
if user.role == "admin":
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats()
]
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats()
]
############################

View file

@ -10,7 +10,7 @@ import uuid
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel):
@router.post("/default/models", response_model=str)
async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
):
if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions(
request: Request,
form_data: SetDefaultSuggestionsForm,
user=Depends(get_current_user),
user=Depends(get_admin_user),
):
if user.role == "admin":
data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS

View file

@ -14,7 +14,7 @@ from apps.web.models.documents import (
DocumentResponse,
)
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
doc = Documents.get_doc_by_name(form_data.name)
if doc == None:
doc = Documents.insert_new_doc(user.id, form_data)
@ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user)
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
doc = Documents.update_doc_by_name(name, form_data)
if doc:
return DocumentResponse(
@ -161,12 +149,6 @@ async def update_doc_by_name(
@router.delete("/name/{name}/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result

View file

@ -13,7 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse,
)
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0,
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View file

@ -8,7 +8,7 @@ import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -29,29 +29,21 @@ async def get_prompts(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.COMMAND_TAKEN,
detail=ERROR_MESSAGES.DEFAULT(),
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.COMMAND_TAKEN,
)
############################
@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_current_user)
command: str, form_data: PromptForm, user=Depends(get_admin_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
return prompt
@ -103,12 +89,6 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
result = Prompts.delete_prompt_by_command(f"/{command}")
return result

View file

@ -11,7 +11,7 @@ import uuid
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash
from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -22,12 +22,7 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
return Users.get_users(skip, limit)
@ -38,21 +33,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_current_user)
form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
############################
@ -62,14 +51,8 @@ async def update_user_role(
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user)
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
):
if session_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user = Users.get_user_by_id(user_id)
if user:
@ -98,18 +81,17 @@ async def update_user_by_id(
if updated_user:
return updated_user
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
detail=ERROR_MESSAGES.DEFAULT(),
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# DeleteUserById
@ -117,25 +99,20 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
if user.role == "admin":
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
if result:
return True
if result:
return True
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)