refac: use dependencies to verify token

- feat: added new util to get the current user when needed. Middleware was adding authentication logic to all the routes. let's revisit if we can move the non-auth endpoints to a separate route.
- refac: update the routes to use new helpers for verification and retrieving user
- chore: added black for local formatting of py code
This commit is contained in:
Anuraag Jain 2023-12-30 12:53:33 +02:00
parent a01b112f7f
commit bdd153d8f5
10 changed files with 167 additions and 251 deletions

View file

@ -8,7 +8,7 @@ import json
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import extract_token_from_auth_header
from utils.utils import decode_token
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = Flask(__name__)
@ -34,8 +34,12 @@ def proxy(path):
# Basic RBAC support
if WEBUI_AUTH:
if "Authorization" in headers:
token = extract_token_from_auth_header(headers["Authorization"])
user = Users.get_user_by_token(token)
_, credentials = headers["Authorization"].split()
token_data = decode_token(credentials)
if token_data is None or "email" not in token_data:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
user = Users.get_user_by_email(token_data["email"])
if user:
# Only user and admin roles can access
if user.role in ["user", "admin"]:

View file

@ -1,9 +1,10 @@
from fastapi import FastAPI
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH
from apps.web.middlewares.auth import BearerTokenAuthBackend, on_auth_error
from utils.utils import verify_auth_token
app = FastAPI()
@ -17,14 +18,26 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(), on_error=on_auth_error)
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(
users.router,
prefix="/users",
tags=["users"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(
chats.router,
prefix="/chats",
tags=["chats"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(
modelfiles.router,
prefix="/modelfiles",
tags=["modelfiles"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(utils.router, prefix="/utils", tags=["utils"])

View file

@ -1,27 +0,0 @@
from apps.web.models.users import Users
from fastapi import Request, status
from starlette.authentication import (
AuthCredentials, AuthenticationBackend, AuthenticationError,
)
from starlette.requests import HTTPConnection
from utils.utils import verify_token
from starlette.responses import JSONResponse
from constants import ERROR_MESSAGES
class BearerTokenAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
if "Authorization" not in conn.headers:
return
data = verify_token(conn)
if data != None and 'email' in data:
user = Users.get_user_by_email(data['email'])
if user is None:
raise AuthenticationError('Invalid credentials')
return AuthCredentials([user.role]), user
else:
raise AuthenticationError('Invalid credentials')
def on_auth_error(request: Request, exc: Exception):
print('Authentication failed: ', exc)
return JSONResponse({"detail": ERROR_MESSAGES.INVALID_TOKEN}, status_code=status.HTTP_401_UNAUTHORIZED)

View file

@ -3,8 +3,6 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
@ -83,14 +81,6 @@ class UsersTable:
except:
return None
def get_user_by_token(self, token: str) -> Optional[UserModel]:
data = decode_token(token)
if data != None and "email" in data:
return self.get_user_by_email(data["email"])
else:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))

View file

@ -20,7 +20,7 @@ from apps.web.models.users import Users
from utils.utils import (
get_password_hash,
bearer_scheme,
get_current_user,
create_token,
)
from utils.misc import get_gravatar_url
@ -35,22 +35,14 @@ router = APIRouter()
@router.get("/", response_model=UserResponse)
async def get_session_user(cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
async def get_session_user(user=Depends(get_current_user)):
return {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
}
############################

View file

@ -1,8 +1,7 @@
from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from utils.utils import get_current_user
from fastapi import APIRouter
from pydantic import BaseModel
import json
@ -30,8 +29,10 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
return Chats.get_chat_lists_by_user_id(request.user.id, skip, limit)
async def get_user_chats(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
############################
@ -40,11 +41,11 @@ async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
@router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(request:Request,):
async def get_all_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(request.user.id)
]
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id)
]
############################
@ -53,8 +54,8 @@ async def get_all_user_chats(request:Request,):
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm,request:Request):
chat = Chats.insert_new_chat(request.user.id, form_data)
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@ -64,14 +65,15 @@ async def create_new_chat(form_data: ChatForm,request:Request):
@router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, request:Request):
chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
@ -80,18 +82,20 @@ async def get_chat_by_id(id: str, request:Request):
@router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_current_user)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
@ -100,6 +104,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, request: Request):
result = Chats.delete_chat_by_id_and_user_id(id, request.user.id)
async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result

View file

@ -1,4 +1,3 @@
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
@ -16,9 +15,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse,
)
from utils.utils import (
bearer_scheme,
)
from utils.utils import bearer_scheme, get_current_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -30,16 +27,7 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return Modelfiles.get_modelfiles(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
return Modelfiles.get_modelfiles(skip, limit)
############################
@ -48,36 +36,28 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
async def create_new_modelfile(
form_data: ModelfileForm, user=Depends(get_current_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user:
# Admin Only
if user.role == "admin":
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
detail=ERROR_MESSAGES.DEFAULT(),
)
@ -87,31 +67,20 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
):
token = cred.credentials
user = Users.get_user_by_token(token)
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if user:
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@ -122,44 +91,34 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme)
form_data: ModelfileUpdateForm, user=Depends(get_current_user)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
**json.loads(modelfile.modelfile),
**form_data.modelfile,
}
if user:
if user.role == "admin":
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
**json.loads(modelfile.modelfile),
**form_data.modelfile,
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@ -170,22 +129,13 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View file

@ -10,11 +10,7 @@ import uuid
from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
from utils.utils import (
get_password_hash,
bearer_scheme,
create_token,
)
from utils.utils import get_current_user
from constants import ERROR_MESSAGES
router = APIRouter()
@ -25,23 +21,13 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
return Users.get_users(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return Users.get_users(skip, limit)
############################
@ -50,26 +36,19 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_current_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user:
if user.role == "admin":
if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)

View file

@ -18,3 +18,5 @@ bcrypt
PyJWT
pyjwt[crypto]
black

View file

@ -1,7 +1,9 @@
from fastapi.security import HTTPBasicCredentials, HTTPBearer
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users
from pydantic import BaseModel
from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext
from datetime import datetime, timedelta
import requests
@ -53,16 +55,23 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :]
def verify_token(request):
try:
authorization = request.headers["authorization"]
if authorization:
_, token = authorization.split()
decoded_token = jwt.decode(
token, JWT_SECRET_KEY, options={"verify_signature": False}
def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
data = decode_token(auth_token.credentials)
if data != None and "email" in data:
user = Users.get_user_by_email(data["email"])
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
return decoded_token
else:
return None
except Exception as e:
return None
return
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
data = decode_token(auth_token.credentials)
return Users.get_user_by_email(data["email"])