feat(auth): add auth middleware

- refactored chat routes to use request.user instead of doing authentication in every route
This commit is contained in:
Anuraag Jain 2023-12-28 22:15:54 +02:00
parent 8370465796
commit a01b112f7f
5 changed files with 63 additions and 89 deletions

1
backend/.gitignore vendored
View file

@ -5,3 +5,4 @@ uploads
.ipynb_checkpoints
*.db
_test
Pipfile

View file

@ -1,8 +1,9 @@
from fastapi import FastAPI, Request, Depends, HTTPException
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH
from apps.web.middlewares.auth import BearerTokenAuthBackend, on_auth_error
app = FastAPI()
@ -18,11 +19,12 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(), on_error=on_auth_error)
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])

View file

@ -0,0 +1,27 @@
from apps.web.models.users import Users
from fastapi import Request, status
from starlette.authentication import (
AuthCredentials, AuthenticationBackend, AuthenticationError,
)
from starlette.requests import HTTPConnection
from utils.utils import verify_token
from starlette.responses import JSONResponse
from constants import ERROR_MESSAGES
class BearerTokenAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
if "Authorization" not in conn.headers:
return
data = verify_token(conn)
if data != None and 'email' in data:
user = Users.get_user_by_email(data['email'])
if user is None:
raise AuthenticationError('Invalid credentials')
return AuthCredentials([user.role]), user
else:
raise AuthenticationError('Invalid credentials')
def on_auth_error(request: Request, exc: Exception):
print('Authentication failed: ', exc)
return JSONResponse({"detail": ERROR_MESSAGES.INVALID_TOKEN}, status_code=status.HTTP_401_UNAUTHORIZED)

View file

@ -1,5 +1,5 @@
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
@ -30,17 +30,8 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
return Chats.get_chat_lists_by_user_id(request.user.id, skip, limit)
############################
@ -49,20 +40,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def get_all_user_chats(request:Request,):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id)
for chat in Chats.get_all_chats_by_user_id(request.user.id)
]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
@ -71,18 +53,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
chat = Chats.insert_new_chat(user.id, form_data)
async def create_new_chat(form_data: ChatForm,request:Request):
chat = Chats.insert_new_chat(request.user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
@ -91,25 +64,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
@router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def get_chat_by_id(id: str, request:Request):
chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND)
############################
@ -118,12 +80,8 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
@ -134,11 +92,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
@ -147,15 +100,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
async def delete_chat_by_id(id: str, request: Request):
result = Chats.delete_chat_by_id_and_user_id(id, request.user.id)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View file

@ -55,13 +55,13 @@ def extract_token_from_auth_header(auth_header: str):
def verify_token(request):
try:
bearer = request.headers["authorization"]
if bearer:
token = bearer[len("Bearer ") :]
decoded = jwt.decode(
authorization = request.headers["authorization"]
if authorization:
_, token = authorization.split()
decoded_token = jwt.decode(
token, JWT_SECRET_KEY, options={"verify_signature": False}
)
return decoded
return decoded_token
else:
return None
except Exception as e: