feat(auth): add auth middleware

- refactored chat routes to use request.user instead of doing authentication in every route
This commit is contained in:
Anuraag Jain 2023-12-28 22:15:54 +02:00
parent 8370465796
commit a01b112f7f
5 changed files with 63 additions and 89 deletions

1
backend/.gitignore vendored
View file

@ -5,3 +5,4 @@ uploads
.ipynb_checkpoints .ipynb_checkpoints
*.db *.db
_test _test
Pipfile

View file

@ -1,8 +1,9 @@
from fastapi import FastAPI, Request, Depends, HTTPException from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH from config import WEBUI_VERSION, WEBUI_AUTH
from apps.web.middlewares.auth import BearerTokenAuthBackend, on_auth_error
app = FastAPI() app = FastAPI()
@ -18,11 +19,12 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(), on_error=on_auth_error)
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])

View file

@ -0,0 +1,27 @@
from apps.web.models.users import Users
from fastapi import Request, status
from starlette.authentication import (
AuthCredentials, AuthenticationBackend, AuthenticationError,
)
from starlette.requests import HTTPConnection
from utils.utils import verify_token
from starlette.responses import JSONResponse
from constants import ERROR_MESSAGES
class BearerTokenAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
if "Authorization" not in conn.headers:
return
data = verify_token(conn)
if data != None and 'email' in data:
user = Users.get_user_by_email(data['email'])
if user is None:
raise AuthenticationError('Invalid credentials')
return AuthCredentials([user.role]), user
else:
raise AuthenticationError('Invalid credentials')
def on_auth_error(request: Request, exc: Exception):
print('Authentication failed: ', exc)
return JSONResponse({"detail": ERROR_MESSAGES.INVALID_TOKEN}, status_code=status.HTTP_401_UNAUTHORIZED)

View file

@ -1,5 +1,5 @@
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
@ -30,17 +30,8 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
token = cred.credentials return Chats.get_chat_lists_by_user_id(request.user.id, skip, limit)
user = Users.get_user_by_token(token)
if user:
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -49,20 +40,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(cred=Depends(bearer_scheme)): async def get_all_user_chats(request:Request,):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id) for chat in Chats.get_all_chats_by_user_id(request.user.id)
] ]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -71,18 +53,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): async def create_new_chat(form_data: ChatForm,request:Request):
token = cred.credentials chat = Chats.insert_new_chat(request.user.id, form_data)
user = Users.get_user_by_token(token)
if user:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -91,25 +64,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def get_chat_by_id(id: str, request:Request):
token = cred.credentials chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
user = Users.get_user_by_token(token)
if user:
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND)
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -118,12 +80,8 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)): async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
token = cred.credentials chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
user = Users.get_user_by_token(token)
if user:
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
@ -134,11 +92,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -147,15 +100,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def delete_chat_by_id(id: str, request: Request):
token = cred.credentials result = Chats.delete_chat_by_id_and_user_id(id, request.user.id)
user = Users.get_user_by_token(token)
if user:
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View file

@ -55,13 +55,13 @@ def extract_token_from_auth_header(auth_header: str):
def verify_token(request): def verify_token(request):
try: try:
bearer = request.headers["authorization"] authorization = request.headers["authorization"]
if bearer: if authorization:
token = bearer[len("Bearer ") :] _, token = authorization.split()
decoded = jwt.decode( decoded_token = jwt.decode(
token, JWT_SECRET_KEY, options={"verify_signature": False} token, JWT_SECRET_KEY, options={"verify_signature": False}
) )
return decoded return decoded_token
else: else:
return None return None
except Exception as e: except Exception as e: