From 91743310255c7aac0db37ccb4790909582d7f6a9 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 25 Dec 2023 21:44:28 -0800 Subject: [PATCH] feat: db migration to sqlite --- backend/.gitignore | 5 +- backend/apps/ollama/main.py | 2 +- backend/apps/web/internal/db.py | 4 ++ backend/apps/web/main.py | 3 +- backend/apps/web/models/auths.py | 36 +++++++--- backend/apps/web/models/chats.py | 108 ++++++++++++++++++++++++++++++ backend/apps/web/models/users.py | 68 ++++++++++++------- backend/apps/web/routers/auths.py | 4 +- backend/apps/web/routers/chats.py | 100 +++++++++++++++++++++++++++ backend/config.py | 20 +----- backend/constants.py | 6 ++ backend/requirements.txt | 4 +- src/routes/auth/+page.svelte | 2 +- 13 files changed, 302 insertions(+), 60 deletions(-) create mode 100644 backend/apps/web/internal/db.py create mode 100644 backend/apps/web/models/chats.py create mode 100644 backend/apps/web/routers/chats.py diff --git a/backend/.gitignore b/backend/.gitignore index bbb8ba18..11f9256f 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -1,4 +1,7 @@ __pycache__ .env _old -uploads \ No newline at end of file +uploads +.ipynb_checkpoints +*.db +_test \ No newline at end of file diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index a2a09fc7..c7961ea7 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -25,7 +25,7 @@ TARGET_SERVER_URL = OLLAMA_API_BASE_URL def proxy(path): # Combine the base URL of the target server with the requested path target_url = f"{TARGET_SERVER_URL}/{path}" - print(path) + print(target_url) # Get data from the original request data = request.get_data() diff --git a/backend/apps/web/internal/db.py b/backend/apps/web/internal/db.py new file mode 100644 index 00000000..d2f7db95 --- /dev/null +++ b/backend/apps/web/internal/db.py @@ -0,0 +1,4 @@ +from peewee import * + +DB = SqliteDatabase("./ollama.db") +DB.connect() diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 854f1626..52238aae 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -1,7 +1,7 @@ from fastapi import FastAPI, Request, Depends, HTTPException from fastapi.middleware.cors import CORSMiddleware -from apps.web.routers import auths, users, utils +from apps.web.routers import auths, users, chats, utils from config import WEBUI_VERSION, WEBUI_AUTH app = FastAPI() @@ -20,6 +20,7 @@ app.add_middleware( app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) +app.include_router(chats.router, prefix="/chats", tags=["chats"]) @app.get("/") diff --git a/backend/apps/web/models/auths.py b/backend/apps/web/models/auths.py index 41c82efd..aef80e2f 100644 --- a/backend/apps/web/models/auths.py +++ b/backend/apps/web/models/auths.py @@ -2,6 +2,7 @@ from pydantic import BaseModel from typing import List, Union, Optional import time import uuid +from peewee import * from apps.web.models.users import UserModel, Users @@ -12,15 +13,23 @@ from utils.utils import ( create_token, ) -import config - -DB = config.DB +from apps.web.internal.db import DB #################### # DB MODEL #################### +class Auth(Model): + id = CharField(unique=True) + email = CharField() + password = CharField() + active = BooleanField() + + class Meta: + database = DB + + class AuthModel(BaseModel): id: str email: str @@ -64,7 +73,7 @@ class SignupForm(BaseModel): class AuthsTable: def __init__(self, db): self.db = db - self.table = db.auths + self.db.create_tables([Auth]) def insert_new_auth( self, email: str, password: str, name: str, role: str = "pending" @@ -76,7 +85,9 @@ class AuthsTable: auth = AuthModel( **{"id": id, "email": email, "password": password, "active": True} ) - result = self.table.insert_one(auth.model_dump()) + result = Auth.create(**auth.model_dump()) + print(result) + user = Users.insert_new_user(id, name, email, role) print(result, user) @@ -86,14 +97,19 @@ class AuthsTable: return None def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: - print("authenticate_user") + print("authenticate_user", email) - auth = self.table.find_one({"email": email, "active": True}) + auth = Auth.get(Auth.email == email, Auth.active == True) + print(auth.email) if auth: - if verify_password(password, auth["password"]): - user = self.db.users.find_one({"id": auth["id"]}) - return UserModel(**user) + print(password, str(auth.password)) + print(verify_password(password, str(auth.password))) + if verify_password(password, auth.password): + user = Users.get_user_by_id(auth.id) + + print(user) + return user else: return None else: diff --git a/backend/apps/web/models/chats.py b/backend/apps/web/models/chats.py new file mode 100644 index 00000000..1105452f --- /dev/null +++ b/backend/apps/web/models/chats.py @@ -0,0 +1,108 @@ +from pydantic import BaseModel +from typing import List, Union, Optional +from peewee import * +from playhouse.shortcuts import model_to_dict + + +import json +import uuid +import time + +from apps.web.internal.db import DB + + +#################### +# Chat DB Schema +#################### + + +class Chat(Model): + id = CharField(unique=True) + user_id: CharField() + title = CharField() + chat = TextField() # Save Chat JSON as Text + timestamp = DateField() + + class Meta: + database = DB + + +class ChatModel(BaseModel): + id: str + user_id: str + title: str + chat: dict + timestamp: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ChatForm(BaseModel): + chat: dict + + +class ChatUpdateForm(ChatForm): + id: str + + +class ChatTitleIdResponse(BaseModel): + id: str + title: str + + +class ChatTable: + def __init__(self, db): + self.db = db + db.create_tables([Chat]) + + def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": form_data.chat["title"], + "chat": json.dump(form_data.chat), + "timestamp": int(time.time()), + } + ) + + result = Chat.create(**chat.model_dump()) + return chat if result else None + + def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: + try: + query = Chat.update(chat=json.dump(chat)).where(Chat.id == id) + query.execute() + + chat = Chat.get(Chat.id == id) + return ChatModel(**model_to_dict(chat)) + except: + return None + + def get_chat_lists_by_user_id( + self, user_id: str, skip: int = 0, limit: int = 50 + ) -> List[ChatModel]: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select(Chat.user_id == user_id).limit(limit).offset(skip) + ] + + def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: + try: + chat = Chat.get(Chat.id == id, Chat.user_id == user_id) + return ChatModel(**model_to_dict(chat)) + except: + return None + + def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select().limit(limit).offset(skip) + ] + + +Chats = ChatTable(DB) diff --git a/backend/apps/web/models/users.py b/backend/apps/web/models/users.py index 4dc3fc7a..88414999 100644 --- a/backend/apps/web/models/users.py +++ b/backend/apps/web/models/users.py @@ -1,25 +1,41 @@ from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict from typing import List, Union, Optional -from pymongo import ReturnDocument import time from utils.utils import decode_token from utils.misc import get_gravatar_url -from config import DB +from apps.web.internal.db import DB #################### # User DB Schema #################### +class User(Model): + id = CharField(unique=True) + name = CharField() + email = CharField() + role = CharField() + profile_image_url = CharField() + timestamp = DateField() + + class Meta: + database = DB + + class UserModel(BaseModel): + class Config: + orm_mode = True + id: str name: str email: str role: str = "pending" profile_image_url: str = "/user.png" - created_at: int # timestamp in epoch + timestamp: int # timestamp in epoch #################### @@ -35,7 +51,7 @@ class UserRoleUpdateForm(BaseModel): class UsersTable: def __init__(self, db): self.db = db - self.table = db.users + self.db.create_tables([User]) def insert_new_user( self, id: str, name: str, email: str, role: str = "pending" @@ -47,22 +63,27 @@ class UsersTable: "email": email, "role": role, "profile_image_url": get_gravatar_url(email), - "created_at": int(time.time()), + "timestamp": int(time.time()), } ) - result = self.table.insert_one(user.model_dump()) - + result = User.create(**user.model_dump()) if result: return user else: return None - def get_user_by_email(self, email: str) -> Optional[UserModel]: - user = self.table.find_one({"email": email}, {"_id": False}) + def get_user_by_id(self, id: str) -> Optional[UserModel]: + try: + user = User.get(User.id == id) + return UserModel(**model_to_dict(user)) + except: + return None - if user: - return UserModel(**user) - else: + def get_user_by_email(self, email: str) -> Optional[UserModel]: + try: + user = User.get(User.email == email) + return UserModel(**model_to_dict(user)) + except: return None def get_user_by_token(self, token: str) -> Optional[UserModel]: @@ -75,23 +96,22 @@ class UsersTable: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: return [ - UserModel(**user) - for user in list( - self.table.find({}, {"_id": False}).skip(skip).limit(limit) - ) + UserModel(**model_to_dict(user)) + for user in User.select().limit(limit).offset(skip) ] def get_num_users(self) -> Optional[int]: - return self.table.count_documents({}) - - def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: - user = self.table.find_one_and_update( - {"id": id}, {"$set": updated}, return_document=ReturnDocument.AFTER - ) - return UserModel(**user) + return User.select().count() def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: - return self.update_user_by_id(id, {"role": role}) + try: + query = User.update(role=role).where(User.id == id) + query.execute() + + user = User.get(User.id == id) + return UserModel(**model_to_dict(user)) + except: + return None Users = UsersTable(DB) diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index 023e1914..0fb34d47 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -104,8 +104,8 @@ async def signup(form_data: SignupForm): "profile_image_url": user.profile_image_url, } else: - raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT()) except Exception as err: raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) else: - raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT()) + raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py new file mode 100644 index 00000000..a29f2da5 --- /dev/null +++ b/backend/apps/web/routers/chats.py @@ -0,0 +1,100 @@ +from fastapi import Response +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel + +from apps.web.models.users import Users +from apps.web.models.chats import ( + ChatModel, + ChatForm, + ChatUpdateForm, + ChatTitleIdResponse, + Chats, +) + +from utils.utils import ( + bearer_scheme, +) +from constants import ERROR_MESSAGES + +router = APIRouter() + +############################ +# GetChats +############################ + + +@router.get("/", response_model=List[ChatTitleIdResponse]) +async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): + token = cred.credentials + user = Users.get_user_by_token(token) + + if user: + return Chats.get_chat_titles_and_ids_by_user_id(user.id, skip, limit) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + + +############################ +# CreateNewChat +############################ + + +@router.post("/new", response_model=Optional[ChatModel]) +async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): + token = cred.credentials + user = Users.get_user_by_token(token) + + if user: + return Chats.insert_new_chat(user.id, form_data) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + + +############################ +# GetChatById +############################ + + +@router.get("/{id}", response_model=Optional[ChatModel]) +async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): + token = cred.credentials + user = Users.get_user_by_token(token) + + if user: + return Chats.get_chat_by_id_and_user_id(id, user.id) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) + + +############################ +# UpdateChatById +############################ + + +@router.post("/{id}", response_model=Optional[ChatModel]) +async def update_chat_by_id( + id: str, form_data: ChatUpdateForm, cred=Depends(bearer_scheme) +): + token = cred.credentials + user = Users.get_user_by_token(token) + + if user: + return Chats.update_chat_by_id_and_user_id(id, user.id, form_data.chat) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + ) diff --git a/backend/config.py b/backend/config.py index cf9eae02..e0014bd8 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,9 +1,10 @@ from dotenv import load_dotenv, find_dotenv -from pymongo import MongoClient + from constants import ERROR_MESSAGES from secrets import token_bytes from base64 import b64encode + import os load_dotenv(find_dotenv("../.env")) @@ -36,25 +37,8 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.40") # WEBUI_AUTH #################################### - WEBUI_AUTH = True if os.environ.get("WEBUI_AUTH", "FALSE") == "TRUE" else False - -#################################### -# WEBUI_DB (Deprecated, Should be removed) -#################################### - - -WEBUI_DB_URL = os.environ.get("WEBUI_DB_URL", "mongodb://root:root@localhost:27017/") - -if WEBUI_AUTH and WEBUI_DB_URL == "": - raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) - - -DB_CLIENT = MongoClient(f"{WEBUI_DB_URL}?authSource=admin") -DB = DB_CLIENT["ollama-webui"] - - #################################### # WEBUI_JWT_SECRET_KEY #################################### diff --git a/backend/constants.py b/backend/constants.py index b383957b..8301ea0b 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -11,6 +11,12 @@ class ERROR_MESSAGES(str, Enum): DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}" ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." + + EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew." + USERNAME_TAKEN = ( + "Uh-oh! This username is already registered. Please choose another username." + ) + INVALID_TOKEN = ( "Your session has expired or the token is invalid. Please sign in again." ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 1568f7bb..2644d559 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -13,8 +13,8 @@ uuid requests aiohttp -pymongo +peewee bcrypt PyJWT -pyjwt[crypto] \ No newline at end of file +pyjwt[crypto] diff --git a/src/routes/auth/+page.svelte b/src/routes/auth/+page.svelte index a3d33f2f..9ec0b16f 100644 --- a/src/routes/auth/+page.svelte +++ b/src/routes/auth/+page.svelte @@ -66,7 +66,7 @@ if (res) { console.log(res); - toast.success(`Account creation successful."`); + toast.success(`Account creation successful.`); localStorage.token = res.token; await user.set(res); goto('/');