Merge Updates & Dockerfile improvements

This commit is contained in:
lainedfles 2024-04-02 03:25:20 -06:00 committed by GitHub
parent fdef2abdfb
commit 9763d885be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
155 changed files with 14509 additions and 4803 deletions

View file

@ -1,13 +1,16 @@
from peewee import *
from config import DATA_DIR
from config import SRC_LOG_LEVELS, DATA_DIR
import os
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
print("File renamed successfully.")
log.info("File renamed successfully.")
else:
pass

View file

@ -19,6 +19,7 @@ from config import (
DEFAULT_USER_ROLE,
ENABLE_SIGNUP,
USER_PERMISSIONS,
WEBHOOK_URL,
)
app = FastAPI()
@ -32,6 +33,7 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.add_middleware(

View file

@ -2,6 +2,7 @@ from pydantic import BaseModel
from typing import List, Union, Optional
import time
import uuid
import logging
from peewee import *
from apps.web.models.users import UserModel, Users
@ -9,6 +10,11 @@ from utils.utils import verify_password
from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# DB MODEL
####################
@ -86,7 +92,7 @@ class AuthsTable:
def insert_new_auth(
self, email: str, password: str, name: str, role: str = "pending"
) -> Optional[UserModel]:
print("insert_new_auth")
log.info("insert_new_auth")
id = str(uuid.uuid4())
@ -103,7 +109,7 @@ class AuthsTable:
return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
print("authenticate_user", email)
log.info(f"authenticate_user: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
if auth:

View file

@ -95,20 +95,6 @@ class ChatTable:
except:
return None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
query = Chat.update(
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
timestamp=int(time.time()),
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:

View file

@ -3,6 +3,7 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from utils.utils import decode_token
from utils.misc import get_gravatar_url
@ -11,6 +12,11 @@ from apps.web.internal.db import DB
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Documents DB Schema
####################
@ -118,7 +124,7 @@ class DocumentsTable:
doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
print(e)
log.exception(e)
return None
def update_doc_content_by_name(
@ -138,7 +144,7 @@ class DocumentsTable:
doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
print(e)
log.exception(e)
return None
def delete_doc_by_name(self, name: str) -> bool:

View file

@ -64,8 +64,8 @@ class ModelfilesTable:
self.db.create_tables([Modelfile])
def insert_new_modelfile(
self, user_id: str,
form_data: ModelfileForm) -> Optional[ModelfileModel]:
self, user_id: str, form_data: ModelfileForm
) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile:
modelfile = ModelfileModel(
**{
@ -73,7 +73,8 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()),
})
}
)
try:
result = Modelfile.create(**modelfile.model_dump())
@ -87,29 +88,28 @@ class ModelfilesTable:
else:
return None
def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
return [
ModelfileResponse(
**{
**model_to_dict(modelfile),
"modelfile":
json.loads(modelfile.modelfile),
}) for modelfile in Modelfile.select()
"modelfile": json.loads(modelfile.modelfile),
}
)
for modelfile in Modelfile.select()
# .limit(limit).offset(skip)
]
def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
try:
query = Modelfile.update(
modelfile=json.dumps(modelfile),

View file

@ -52,8 +52,9 @@ class PromptsTable:
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(self, user_id: str,
form_data: PromptForm) -> Optional[PromptModel]:
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
prompt = PromptModel(
**{
"user_id": user_id,
@ -61,7 +62,8 @@ class PromptsTable:
"title": form_data.title,
"content": form_data.content,
"timestamp": int(time.time()),
})
}
)
try:
result = Prompt.create(**prompt.model_dump())
@ -81,13 +83,14 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
def update_prompt_by_command(
self, command: str,
form_data: PromptForm) -> Optional[PromptModel]:
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,

View file

@ -6,9 +6,15 @@ from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
import logging
from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tag DB Schema
####################
@ -173,7 +179,7 @@ class TagTable:
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
print(res)
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
@ -185,7 +191,7 @@ class TagTable:
return True
except Exception as e:
print("delete_tag", e)
log.error(f"delete_tag: {e}")
return False
def delete_tag_by_tag_name_and_chat_id_and_user_id(
@ -198,7 +204,7 @@ class TagTable:
& (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
print(res)
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
@ -210,7 +216,7 @@ class TagTable:
return True
except Exception as e:
print("delete_tag", e)
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:

View file

@ -27,7 +27,8 @@ from utils.utils import (
create_token,
)
from utils.misc import parse_duration, validate_email_format
from constants import ERROR_MESSAGES
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
router = APIRouter()
@ -155,6 +156,17 @@ async def signup(request: Request, form_data: SignupForm):
)
# response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL:
post_webhook(
request.app.state.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
return {
"token": token,
"token_type": "Bearer",

View file

@ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter
from pydantic import BaseModel
import json
import logging
from apps.web.models.users import Users
from apps.web.models.chats import (
@ -27,6 +28,11 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
@ -78,7 +84,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e:
print(e)
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
@ -95,7 +101,7 @@ async def get_all_tags(user=Depends(get_current_user)):
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
print(e)
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)

View file

@ -10,7 +10,12 @@ import uuid
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.utils import (
get_password_hash,
get_current_user,
get_admin_user,
create_token,
)
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -43,7 +48,6 @@ async def set_global_default_models(
return request.app.state.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions(
request: Request,

View file

@ -24,9 +24,9 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0,
limit: int = 50,
user=Depends(get_current_user)):
async def get_modelfiles(
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
):
return Modelfiles.get_modelfiles(skip, limit)
@ -36,17 +36,16 @@ async def get_modelfiles(skip: int = 0,
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_admin_user)):
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile":
json.loads(modelfile.modelfile),
})
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -60,17 +59,18 @@ async def create_new_modelfile(form_data: ModelfileForm,
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile":
json.loads(modelfile.modelfile),
})
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -84,8 +84,9 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_admin_user)):
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
@ -94,14 +95,15 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile)
form_data.tag_name, updated_modelfile
)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile":
json.loads(modelfile.modelfile),
})
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -115,7 +117,8 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_admin_user)):
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View file

@ -7,6 +7,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
@ -14,6 +15,11 @@ from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
@ -83,7 +89,7 @@ async def update_user_by_id(
if form_data.password:
hashed = get_password_hash(form_data.password)
print(hashed)
log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower())

View file

@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES
router = APIRouter()
class UploadBlobForm(BaseModel):
filename: str
from urllib.parse import urlparse
def parse_huggingface_url(hf_url):
try:
# Parse the URL
parsed_url = urlparse(hf_url)
# Get the path and split it into components
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
except ValueError:
return None
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
done = False
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
else:
current_size = 0
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
current_size += len(data)
file.write(data)
done = current_size == total_size
progress = round((current_size / total_size) * 100, 2)
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
file.seek(0)
hashed = calculate_sha256(file)
file.seek(0)
url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
response = requests.post(url, data=file)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file_name,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise "Ollama: Could not create blob, Please try again."
@router.get("/download")
async def download(
url: str,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url)
if file_name:
file_path = f"{UPLOAD_DIR}/{file_name}"
return StreamingResponse(
download_file_stream(url, file_path, file_name),
media_type="text/event-stream",
)
else:
return None
@router.post("/upload")
def upload(file: UploadFile = File(...)):
file_path = f"{UPLOAD_DIR}/{file.filename}"
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
"progress": progress,
"total": total_size,
"completed": total,
}
yield f"data: {json.dumps(res)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e:
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
@router.get("/gravatar")
async def get_gravatar(
email: str,