refac: remove the verify_token and use get-current user for auth+user

This commit is contained in:
Anuraag Jain 2024-01-01 10:55:50 +02:00
parent 2d323b31e1
commit 77323d9b25
5 changed files with 12 additions and 41 deletions

View file

@ -3,7 +3,6 @@ from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH from config import WEBUI_VERSION, WEBUI_AUTH
from utils.utils import verify_auth_token
app = FastAPI() app = FastAPI()
@ -19,24 +18,9 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router( app.include_router(users.router, prefix="/users", tags=["users"])
users.router, app.include_router(chats.router, prefix="/chats", tags=["chats"])
prefix="/users", app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
tags=["users"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(
chats.router,
prefix="/chats",
tags=["chats"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(
modelfiles.router,
prefix="/modelfiles",
tags=["modelfiles"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])

View file

@ -19,12 +19,7 @@ from apps.web.models.auths import (
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import ( from utils.utils import get_password_hash, get_current_user, create_token
get_password_hash,
get_current_user,
create_token,
verify_auth_token,
)
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -36,7 +31,7 @@ router = APIRouter()
############################ ############################
@router.get("/", response_model=UserResponse, dependencies=[Depends(verify_auth_token)]) @router.get("/", response_model=UserResponse)
async def get_session_user(user=Depends(get_current_user)): async def get_session_user(user=Depends(get_current_user)):
return { return {
"id": user.id, "id": user.id,
@ -52,9 +47,7 @@ async def get_session_user(user=Depends(get_current_user)):
############################ ############################
@router.post( @router.post("/update/password", response_model=bool)
"/update/password", response_model=bool, dependencies=[Depends(verify_auth_token)]
)
async def update_password( async def update_password(
form_data: UpdatePasswordForm, session_user=Depends(get_current_user) form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
): ):

View file

@ -108,6 +108,7 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
result = Chats.delete_chat_by_id_and_user_id(id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
############################ ############################
# DeleteAllChats # DeleteAllChats
############################ ############################

View file

@ -5,8 +5,6 @@ from typing import List, Union, Optional
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.web.models.users import Users
from apps.web.models.modelfiles import ( from apps.web.models.modelfiles import (
Modelfiles, Modelfiles,
ModelfileForm, ModelfileForm,
@ -15,7 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse, ModelfileResponse,
) )
from utils.utils import bearer_scheme, get_current_user from utils.utils import get_current_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -26,7 +24,7 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit) return Modelfiles.get_modelfiles(skip, limit)
@ -67,7 +65,7 @@ async def create_new_modelfile(
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm): async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:

View file

@ -55,7 +55,7 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
data = decode_token(auth_token.credentials) data = decode_token(auth_token.credentials)
if data != None and "email" in data: if data != None and "email" in data:
user = Users.get_user_by_email(data["email"]) user = Users.get_user_by_email(data["email"])
@ -64,14 +64,9 @@ def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBea
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
return return user
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
data = decode_token(auth_token.credentials)
return Users.get_user_by_email(data["email"])