refac: use dependencies to verify token

- feat: added new util to get the current user when needed. Middleware was adding authentication logic to all the routes. let's revisit if we can move the non-auth endpoints to a separate route.
- refac: update the routes to use new helpers for verification and retrieving user
- chore: added black for local formatting of py code
This commit is contained in:
Anuraag Jain 2023-12-30 12:53:33 +02:00
parent a01b112f7f
commit bdd153d8f5
10 changed files with 167 additions and 251 deletions

View file

@ -8,7 +8,7 @@ import json
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import extract_token_from_auth_header from utils.utils import decode_token
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = Flask(__name__) app = Flask(__name__)
@ -34,8 +34,12 @@ def proxy(path):
# Basic RBAC support # Basic RBAC support
if WEBUI_AUTH: if WEBUI_AUTH:
if "Authorization" in headers: if "Authorization" in headers:
token = extract_token_from_auth_header(headers["Authorization"]) _, credentials = headers["Authorization"].split()
user = Users.get_user_by_token(token) token_data = decode_token(credentials)
if token_data is None or "email" not in token_data:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
user = Users.get_user_by_email(token_data["email"])
if user: if user:
# Only user and admin roles can access # Only user and admin roles can access
if user.role in ["user", "admin"]: if user.role in ["user", "admin"]:

View file

@ -1,9 +1,10 @@
from fastapi import FastAPI from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH from config import WEBUI_VERSION, WEBUI_AUTH
from apps.web.middlewares.auth import BearerTokenAuthBackend, on_auth_error from utils.utils import verify_auth_token
app = FastAPI() app = FastAPI()
@ -17,14 +18,26 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(), on_error=on_auth_error) app.include_router(
users.router,
app.include_router(users.router, prefix="/users", tags=["users"]) prefix="/users",
app.include_router(chats.router, prefix="/chats", tags=["chats"]) tags=["users"],
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) dependencies=[Depends(verify_auth_token)],
)
app.include_router(
chats.router,
prefix="/chats",
tags=["chats"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(
modelfiles.router,
prefix="/modelfiles",
tags=["modelfiles"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])

View file

@ -1,27 +0,0 @@
from apps.web.models.users import Users
from fastapi import Request, status
from starlette.authentication import (
AuthCredentials, AuthenticationBackend, AuthenticationError,
)
from starlette.requests import HTTPConnection
from utils.utils import verify_token
from starlette.responses import JSONResponse
from constants import ERROR_MESSAGES
class BearerTokenAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
if "Authorization" not in conn.headers:
return
data = verify_token(conn)
if data != None and 'email' in data:
user = Users.get_user_by_email(data['email'])
if user is None:
raise AuthenticationError('Invalid credentials')
return AuthCredentials([user.role]), user
else:
raise AuthenticationError('Invalid credentials')
def on_auth_error(request: Request, exc: Exception):
print('Authentication failed: ', exc)
return JSONResponse({"detail": ERROR_MESSAGES.INVALID_TOKEN}, status_code=status.HTTP_401_UNAUTHORIZED)

View file

@ -3,8 +3,6 @@ from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.web.internal.db import DB
@ -83,14 +81,6 @@ class UsersTable:
except: except:
return None return None
def get_user_by_token(self, token: str) -> Optional[UserModel]:
data = decode_token(token)
if data != None and "email" in data:
return self.get_user_by_email(data["email"])
else:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ return [
UserModel(**model_to_dict(user)) UserModel(**model_to_dict(user))

View file

@ -20,7 +20,7 @@ from apps.web.models.users import Users
from utils.utils import ( from utils.utils import (
get_password_hash, get_password_hash,
bearer_scheme, get_current_user,
create_token, create_token,
) )
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
@ -35,10 +35,7 @@ router = APIRouter()
@router.get("/", response_model=UserResponse) @router.get("/", response_model=UserResponse)
async def get_session_user(cred=Depends(bearer_scheme)): async def get_session_user(user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return { return {
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
@ -46,11 +43,6 @@ async def get_session_user(cred=Depends(bearer_scheme)):
"role": user.role, "role": user.role,
"profile_image_url": user.profile_image_url, "profile_image_url": user.profile_image_url,
} }
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################

View file

@ -1,8 +1,7 @@
from fastapi import Depends, Request, HTTPException, status from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from utils.utils import get_current_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
@ -30,8 +29,10 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(request:Request, skip: int = 0, limit: int = 50): async def get_user_chats(
return Chats.get_chat_lists_by_user_id(request.user.id, skip, limit) user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
############################ ############################
@ -40,10 +41,10 @@ async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(request:Request,): async def get_all_user_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(request.user.id) for chat in Chats.get_all_chats_by_user_id(user.id)
] ]
@ -53,8 +54,8 @@ async def get_all_user_chats(request:Request,):
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm,request:Request): async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat = Chats.insert_new_chat(request.user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@ -64,14 +65,15 @@ async def create_new_chat(form_data: ChatForm,request:Request):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, request:Request): async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, request.user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, raise HTTPException(
detail=ERROR_MESSAGES.NOT_FOUND) status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################ ############################
@ -80,8 +82,10 @@ async def get_chat_by_id(id: str, request:Request):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, request:Request): async def update_chat_by_id(
chat = Chats.get_chat_by_id_and_user_id(id, request.user.id) id: str, form_data: ChatForm, user=Depends(get_current_user)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
@ -100,6 +104,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, request: Request): async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
result = Chats.delete_chat_by_id_and_user_id(id, request.user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result

View file

@ -1,4 +1,3 @@
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
@ -16,9 +15,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse, ModelfileResponse,
) )
from utils.utils import ( from utils.utils import bearer_scheme, get_current_user
bearer_scheme,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -30,16 +27,7 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return Modelfiles.get_modelfiles(skip, limit) return Modelfiles.get_modelfiles(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -48,13 +36,15 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)): async def create_new_modelfile(
token = cred.credentials form_data: ModelfileForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user:
# Admin Only
if user.role == "admin":
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile: if modelfile:
@ -69,16 +59,6 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(), detail=ERROR_MESSAGES.DEFAULT(),
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -87,13 +67,7 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name( async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm):
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
@ -108,11 +82,6 @@ async def get_modelfile_by_tag_name(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -122,13 +91,13 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name( async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme) form_data: ModelfileUpdateForm, user=Depends(get_current_user)
): ):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token) raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
if user: detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
if user.role == "admin": )
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
updated_modelfile = { updated_modelfile = {
@ -151,16 +120,6 @@ async def update_modelfile_by_tag_name(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -170,22 +129,13 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name( async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme) form_data: ModelfileTagNameForm, user=Depends(get_current_user)
): ):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else:
raise HTTPException( result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
status_code=status.HTTP_401_UNAUTHORIZED, return result
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View file

@ -10,11 +10,7 @@ import uuid
from apps.web.models.users import UserModel, UserRoleUpdateForm, Users from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
from utils.utils import ( from utils.utils import get_current_user
get_password_hash,
bearer_scheme,
create_token,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -25,23 +21,13 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel]) @router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
return Users.get_users(skip, limit)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else: return Users.get_users(skip, limit)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -50,12 +36,15 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)): async def update_user_role(
token = cred.credentials form_data: UserRoleUpdateForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user:
if user.role == "admin":
if user.id != form_data.id: if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
else: else:
@ -63,13 +52,3 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View file

@ -18,3 +18,5 @@ bcrypt
PyJWT PyJWT
pyjwt[crypto] pyjwt[crypto]
black

View file

@ -1,7 +1,9 @@
from fastapi.security import HTTPBasicCredentials, HTTPBearer from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
import requests import requests
@ -53,16 +55,23 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def verify_token(request): def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
try: data = decode_token(auth_token.credentials)
authorization = request.headers["authorization"] if data != None and "email" in data:
if authorization: user = Users.get_user_by_email(data["email"])
_, token = authorization.split() if user is None:
decoded_token = jwt.decode( raise HTTPException(
token, JWT_SECRET_KEY, options={"verify_signature": False} status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
return decoded_token return
else: else:
return None raise HTTPException(
except Exception as e: status_code=status.HTTP_401_UNAUTHORIZED,
return None detail=ERROR_MESSAGES.UNAUTHORIZED,
)
def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
data = decode_token(auth_token.credentials)
return Users.get_user_by_email(data["email"])