Merge pull request #311 from anuraagdjain/refac/auth-middleware

feat(auth): add auth middleware
This commit is contained in:
Timothy Jaeryang Baek 2024-01-01 14:06:26 -05:00 committed by GitHub
commit c5386d05ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 165 additions and 322 deletions

3
backend/.gitignore vendored
View file

@ -4,4 +4,5 @@ _old
uploads uploads
.ipynb_checkpoints .ipynb_checkpoints
*.db *.db
_test _test
Pipfile

View file

@ -8,7 +8,7 @@ import json
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import extract_token_from_auth_header from utils.utils import decode_token
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = Flask(__name__) app = Flask(__name__)
@ -34,8 +34,12 @@ def proxy(path):
# Basic RBAC support # Basic RBAC support
if WEBUI_AUTH: if WEBUI_AUTH:
if "Authorization" in headers: if "Authorization" in headers:
token = extract_token_from_auth_header(headers["Authorization"]) _, credentials = headers["Authorization"].split()
user = Users.get_user_by_token(token) token_data = decode_token(credentials)
if token_data is None or "email" not in token_data:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
user = Users.get_user_by_email(token_data["email"])
if user: if user:
# Only user and admin roles can access # Only user and admin roles can access
if user.role in ["user", "admin"]: if user.role in ["user", "admin"]:

View file

@ -1,6 +1,6 @@
from fastapi import FastAPI, Request, Depends, HTTPException from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH from config import WEBUI_VERSION, WEBUI_AUTH
@ -16,13 +16,11 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])

View file

@ -3,8 +3,6 @@ from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.web.internal.db import DB
@ -85,14 +83,6 @@ class UsersTable:
except: except:
return None return None
def get_user_by_token(self, token: str) -> Optional[UserModel]:
data = decode_token(token)
if data != None and "email" in data:
return self.get_user_by_email(data["email"])
else:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ return [
UserModel(**model_to_dict(user)) UserModel(**model_to_dict(user))

View file

@ -19,11 +19,7 @@ from apps.web.models.auths import (
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import ( from utils.utils import get_password_hash, get_current_user, create_token
get_password_hash,
bearer_scheme,
create_token,
)
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -36,22 +32,14 @@ router = APIRouter()
@router.get("/", response_model=UserResponse) @router.get("/", response_model=UserResponse)
async def get_session_user(cred=Depends(bearer_scheme)): async def get_session_user(user=Depends(get_current_user)):
token = cred.credentials return {
user = Users.get_user_by_token(token) "id": user.id,
if user: "email": user.email,
return { "name": user.name,
"id": user.id, "role": user.role,
"email": user.email, "profile_image_url": user.profile_image_url,
"name": user.name, }
"role": user.role,
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -60,10 +48,9 @@ async def get_session_user(cred=Depends(bearer_scheme)):
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password(form_data: UpdatePasswordForm, cred=Depends(bearer_scheme)): async def update_password(
token = cred.credentials form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
session_user = Users.get_user_by_token(token) ):
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)

View file

@ -1,8 +1,7 @@
from fastapi import Response from fastapi import Depends, Request, HTTPException, status
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from utils.utils import get_current_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
@ -30,17 +29,10 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_user_chats(
token = cred.credentials user=Depends(get_current_user), skip: int = 0, limit: int = 50
user = Users.get_user_by_token(token) ):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
if user:
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -49,20 +41,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(cred=Depends(bearer_scheme)): async def get_all_user_chats(user=Depends(get_current_user)):
token = cred.credentials return [
user = Users.get_user_by_token(token) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id)
if user: ]
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id)
]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -71,18 +54,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
token = cred.credentials chat = Chats.insert_new_chat(user.id, form_data)
user = Users.get_user_by_token(token) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
if user:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -91,24 +65,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def get_chat_by_id(id: str, user=Depends(get_current_user)):
token = cred.credentials chat = Chats.get_chat_by_id_and_user_id(id, user.id)
user = Users.get_user_by_token(token)
if user: if chat:
chat = Chats.get_chat_by_id_and_user_id(id, user.id) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
@ -118,26 +82,19 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)): async def update_chat_by_id(
token = cred.credentials id: str, form_data: ChatForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
if user: chat = Chats.update_chat_by_id(id, updated_chat)
chat = Chats.get_chat_by_id_and_user_id(id, user.id) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
@ -147,18 +104,9 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
token = cred.credentials result = Chats.delete_chat_by_id_and_user_id(id, user.id)
user = Users.get_user_by_token(token) return result
if user:
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -167,15 +115,6 @@ async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats(cred=Depends(bearer_scheme)): async def delete_all_user_chats(user=Depends(get_current_user)):
token = cred.credentials result = Chats.delete_chats_by_user_id(user.id)
user = Users.get_user_by_token(token) return result
if user:
result = Chats.delete_chats_by_user_id(user.id)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View file

@ -1,4 +1,3 @@
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
@ -6,8 +5,6 @@ from typing import List, Union, Optional
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.web.models.users import Users
from apps.web.models.modelfiles import ( from apps.web.models.modelfiles import (
Modelfiles, Modelfiles,
ModelfileForm, ModelfileForm,
@ -16,9 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse, ModelfileResponse,
) )
from utils.utils import ( from utils.utils import get_current_user
bearer_scheme,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -29,17 +24,8 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
token = cred.credentials return Modelfiles.get_modelfiles(skip, limit)
user = Users.get_user_by_token(token)
if user:
return Modelfiles.get_modelfiles(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -48,36 +34,28 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)): async def create_new_modelfile(
token = cred.credentials form_data: ModelfileForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user: modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
# Admin Only
if user.role == "admin":
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile: if modelfile:
return ModelfileResponse( return ModelfileResponse(
**{ **{
**modelfile.model_dump(), **modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile), "modelfile": json.loads(modelfile.modelfile),
} }
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.DEFAULT(),
) )
@ -87,31 +65,20 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name( async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user: if modelfile:
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) return ModelfileResponse(
**{
if modelfile: **modelfile.model_dump(),
return ModelfileResponse( "modelfile": json.loads(modelfile.modelfile),
**{ }
**modelfile.model_dump(), )
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.NOT_FOUND,
) )
@ -122,44 +89,34 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name( async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme) form_data: ModelfileUpdateForm, user=Depends(get_current_user)
): ):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token) raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
**json.loads(modelfile.modelfile),
**form_data.modelfile,
}
if user: modelfile = Modelfiles.update_modelfile_by_tag_name(
if user.role == "admin": form_data.tag_name, updated_modelfile
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) )
if modelfile:
updated_modelfile = {
**json.loads(modelfile.modelfile),
**form_data.modelfile,
}
modelfile = Modelfiles.update_modelfile_by_tag_name( return ModelfileResponse(
form_data.tag_name, updated_modelfile **{
) **modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
return ModelfileResponse( }
**{ )
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
@ -170,22 +127,13 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name( async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme) form_data: ModelfileTagNameForm, user=Depends(get_current_user)
): ):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View file

@ -12,11 +12,7 @@ from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths from apps.web.models.auths import Auths
from utils.utils import ( from utils.utils import get_current_user
get_password_hash,
bearer_scheme,
create_token,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -27,23 +23,13 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel]) @router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
return Users.get_users(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
return Users.get_users(skip, limit)
############################ ############################
@ -52,28 +38,21 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)): async def update_user_role(
token = cred.credentials form_data: UserRoleUpdateForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user: if user.id != form_data.id:
if user.role == "admin": return Users.update_user_role_by_id(form_data.id, form_data.role)
if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
) )
@ -83,34 +62,25 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)): async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
token = cred.credentials if user.role == "admin":
user = Users.get_user_by_token(token) if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
if user: if result:
if user.role == "admin": return True
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
if result:
return True
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.DELETE_USER_ERROR,
) )
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
) )
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )

View file

@ -18,3 +18,5 @@ bcrypt
PyJWT PyJWT
pyjwt[crypto] pyjwt[crypto]
black

View file

@ -1,7 +1,9 @@
from fastapi.security import HTTPBasicCredentials, HTTPBearer from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
import requests import requests
@ -53,16 +55,18 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def verify_token(request): def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
try: data = decode_token(auth_token.credentials)
bearer = request.headers["authorization"] if data != None and "email" in data:
if bearer: user = Users.get_user_by_email(data["email"])
token = bearer[len("Bearer ") :] if user is None:
decoded = jwt.decode( raise HTTPException(
token, JWT_SECRET_KEY, options={"verify_signature": False} status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
return decoded return user
else: else:
return None raise HTTPException(
except Exception as e: status_code=status.HTTP_401_UNAUTHORIZED,
return None detail=ERROR_MESSAGES.UNAUTHORIZED,
)