Endpoint role-checking was redundantly applied but FastAPI provides a nice abstraction mechanic...so I applied it. There should be no logical changes in this code; only simpler, cleaner ways for doing the same thing.

This commit is contained in:
Tim Farrell 2024-02-08 18:05:01 -06:00
parent 46d0eff218
commit 08e8e922fd
11 changed files with 127 additions and 251 deletions

View file

@ -13,7 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse,
)
from utils.utils import get_current_user
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0,
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user=Depends(get_admin_user)):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result