feat: multi-user support w/ RBAC

This commit is contained in:
Timothy J. Baek 2023-11-18 16:47:12 -08:00
parent 31e38df0a5
commit 921eef03b3
21 changed files with 1815 additions and 66 deletions

View file

@ -1,4 +1,4 @@
from flask import Flask, request, Response
from flask import Flask, request, Response, jsonify
from flask_cors import CORS
@ -6,7 +6,10 @@ import requests
import json
from config import OLLAMA_API_BASE_URL
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils import extract_token_from_auth_header
from config import OLLAMA_API_BASE_URL, OLLAMA_WEBUI_AUTH
app = Flask(__name__)
CORS(
@ -28,6 +31,21 @@ def proxy(path):
data = request.get_data()
headers = dict(request.headers)
if OLLAMA_WEBUI_AUTH:
if "Authorization" in headers:
token = extract_token_from_auth_header(headers["Authorization"])
user = Users.get_user_by_token(token)
if user:
print(user)
pass
else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
else:
pass
# Make a request to the target server
target_response = requests.request(
method=request.method,

25
backend/apps/web/main.py Normal file
View file

@ -0,0 +1,25 @@
from fastapi import FastAPI, Request, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths
from config import OLLAMA_WEBUI_VERSION, OLLAMA_WEBUI_AUTH
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auths.router, prefix="/auths", tags=["auths"])
@app.get("/")
async def get_status():
return {"status": True, "version": OLLAMA_WEBUI_VERSION, "auth": OLLAMA_WEBUI_AUTH}

View file

@ -0,0 +1,102 @@
from pydantic import BaseModel
from typing import List, Union, Optional
import time
import uuid
from apps.web.models.users import UserModel, Users
from utils import (
verify_password,
get_password_hash,
bearer_scheme,
create_token,
)
import config
DB = config.DB
####################
# DB MODEL
####################
class AuthModel(BaseModel):
id: str
email: str
password: str
active: bool = True
####################
# Forms
####################
class Token(BaseModel):
token: str
token_type: str
class UserResponse(BaseModel):
id: str
email: str
name: str
role: str
class SigninResponse(Token, UserResponse):
pass
class SigninForm(BaseModel):
email: str
password: str
class SignupForm(BaseModel):
name: str
email: str
password: str
class AuthsTable:
def __init__(self, db):
self.db = db
self.table = db.auths
def insert_new_auth(
self, email: str, password: str, name: str, role: str = "user"
) -> Optional[UserModel]:
print("insert_new_auth")
id = str(uuid.uuid4())
auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True}
)
result = self.table.insert_one(auth.model_dump())
user = Users.insert_new_user(id, name, email, role)
print(result, user)
if result and user:
return user
else:
return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
print("authenticate_user")
auth = self.table.find_one({"email": email, "active": True})
if auth:
if verify_password(password, auth["password"]):
user = self.db.users.find_one({"id": auth["id"]})
return UserModel(**user)
else:
return None
else:
return None
Auths = AuthsTable(DB)

View file

@ -0,0 +1,76 @@
from pydantic import BaseModel
from typing import List, Union, Optional
from pymongo import ReturnDocument
import time
from utils import decode_token
from config import DB
####################
# User DB Schema
####################
class UserModel(BaseModel):
id: str
name: str
email: str
role: str = "user"
created_at: int # timestamp in epoch
####################
# Forms
####################
class UsersTable:
def __init__(self, db):
self.db = db
self.table = db.users
def insert_new_user(
self, id: str, name: str, email: str, role: str = "user"
) -> Optional[UserModel]:
user = UserModel(
**{
"id": id,
"name": name,
"email": email,
"role": role,
"created_at": int(time.time()),
}
)
result = self.table.insert_one(user.model_dump())
if result:
return user
else:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
user = self.table.find_one({"email": email}, {"_id": False})
if user:
return UserModel(**user)
else:
return None
def get_user_by_token(self, token: str) -> Optional[UserModel]:
data = decode_token(token)
if data != None and "email" in data:
return self.get_user_by_email(data["email"])
else:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> Optional[UserModel]:
return [
UserModel(**user)
for user in list(self.table.find({}, {"_id": False}))
.skip(skip)
.limit(limit)
]
Users = UsersTable(DB)

View file

@ -0,0 +1,107 @@
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
from constants import ERROR_MESSAGES
from utils import (
get_password_hash,
bearer_scheme,
create_token,
)
from apps.web.models.auths import (
SigninForm,
SignupForm,
UserResponse,
SigninResponse,
Auths,
)
from apps.web.models.users import Users
import config
router = APIRouter()
DB = config.DB
############################
# GetSessionUser
############################
@router.get("/", response_model=UserResponse)
async def get_session_user(cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
)
############################
# SignIn
############################
@router.post("/signin", response_model=SigninResponse)
async def signin(form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user:
token = create_token(data={"email": user.email})
return {
"token": token,
"token_type": "Bearer",
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
############################
# SignUp
############################
@router.post("/signup", response_model=SigninResponse)
async def signup(form_data: SignupForm):
if not Users.get_user_by_email(form_data.email.lower()):
try:
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(form_data.email, hashed, form_data.name)
if user:
token = create_token(data={"email": user.email})
# response.set_cookie(key='token', value=token, httponly=True)
return {
"token": token,
"token_type": "Bearer",
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
else:
raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())