Revert "Merge Updates & Dockerfile improvements" (#3)

This reverts commit 9763d885be.
This commit is contained in:
Jannik S 2024-04-02 11:28:04 +02:00 committed by GitHub
parent 9763d885be
commit 099b1d066b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
155 changed files with 4795 additions and 14501 deletions

View file

@ -1,16 +1,13 @@
from peewee import *
from config import SRC_LOG_LEVELS, DATA_DIR
from config import DATA_DIR
import os
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("File renamed successfully.")
print("File renamed successfully.")
else:
pass

View file

@ -19,7 +19,6 @@ from config import (
DEFAULT_USER_ROLE,
ENABLE_SIGNUP,
USER_PERMISSIONS,
WEBHOOK_URL,
)
app = FastAPI()
@ -33,7 +32,6 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.add_middleware(

View file

@ -2,7 +2,6 @@ from pydantic import BaseModel
from typing import List, Union, Optional
import time
import uuid
import logging
from peewee import *
from apps.web.models.users import UserModel, Users
@ -10,11 +9,6 @@ from utils.utils import verify_password
from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# DB MODEL
####################
@ -92,7 +86,7 @@ class AuthsTable:
def insert_new_auth(
self, email: str, password: str, name: str, role: str = "pending"
) -> Optional[UserModel]:
log.info("insert_new_auth")
print("insert_new_auth")
id = str(uuid.uuid4())
@ -109,7 +103,7 @@ class AuthsTable:
return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
print("authenticate_user", email)
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
if auth:

View file

@ -95,6 +95,20 @@ class ChatTable:
except:
return None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
query = Chat.update(
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
timestamp=int(time.time()),
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:

View file

@ -3,7 +3,6 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from utils.utils import decode_token
from utils.misc import get_gravatar_url
@ -12,11 +11,6 @@ from apps.web.internal.db import DB
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Documents DB Schema
####################
@ -124,7 +118,7 @@ class DocumentsTable:
doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
log.exception(e)
print(e)
return None
def update_doc_content_by_name(
@ -144,7 +138,7 @@ class DocumentsTable:
doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
log.exception(e)
print(e)
return None
def delete_doc_by_name(self, name: str) -> bool:

View file

@ -64,8 +64,8 @@ class ModelfilesTable:
self.db.create_tables([Modelfile])
def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm
) -> Optional[ModelfileModel]:
self, user_id: str,
form_data: ModelfileForm) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile:
modelfile = ModelfileModel(
**{
@ -73,8 +73,7 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()),
}
)
})
try:
result = Modelfile.create(**modelfile.model_dump())
@ -88,28 +87,29 @@ class ModelfilesTable:
else:
return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [
ModelfileResponse(
**{
**model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile),
}
)
for modelfile in Modelfile.select()
"modelfile":
json.loads(modelfile.modelfile),
}) for modelfile in Modelfile.select()
# .limit(limit).offset(skip)
]
def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
try:
query = Modelfile.update(
modelfile=json.dumps(modelfile),

View file

@ -52,9 +52,8 @@ class PromptsTable:
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
def insert_new_prompt(self, user_id: str,
form_data: PromptForm) -> Optional[PromptModel]:
prompt = PromptModel(
**{
"user_id": user_id,
@ -62,8 +61,7 @@ class PromptsTable:
"title": form_data.title,
"content": form_data.content,
"timestamp": int(time.time()),
}
)
})
try:
result = Prompt.create(**prompt.model_dump())
@ -83,14 +81,13 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
self, command: str,
form_data: PromptForm) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,

View file

@ -6,15 +6,9 @@ from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
import logging
from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tag DB Schema
####################
@ -179,7 +173,7 @@ class TagTable:
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
print(res)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
@ -191,7 +185,7 @@ class TagTable:
return True
except Exception as e:
log.error(f"delete_tag: {e}")
print("delete_tag", e)
return False
def delete_tag_by_tag_name_and_chat_id_and_user_id(
@ -204,7 +198,7 @@ class TagTable:
& (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
print(res)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
@ -216,7 +210,7 @@ class TagTable:
return True
except Exception as e:
log.error(f"delete_tag: {e}")
print("delete_tag", e)
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:

View file

@ -27,8 +27,7 @@ from utils.utils import (
create_token,
)
from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from constants import ERROR_MESSAGES
router = APIRouter()
@ -156,17 +155,6 @@ async def signup(request: Request, form_data: SignupForm):
)
# response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL:
post_webhook(
request.app.state.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
return {
"token": token,
"token_type": "Bearer",

View file

@ -5,7 +5,6 @@ from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter
from pydantic import BaseModel
import json
import logging
from apps.web.models.users import Users
from apps.web.models.chats import (
@ -28,11 +27,6 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
@ -84,7 +78,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
@ -101,7 +95,7 @@ async def get_all_tags(user=Depends(get_current_user)):
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)

View file

@ -10,12 +10,7 @@ import uuid
from apps.web.models.users import Users
from utils.utils import (
get_password_hash,
get_current_user,
get_admin_user,
create_token,
)
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -48,6 +43,7 @@ async def set_global_default_models(
return request.app.state.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions(
request: Request,

View file

@ -24,9 +24,9 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
):
async def get_modelfiles(skip: int = 0,
limit: int = 50,
user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit)
@ -36,16 +36,17 @@ async def get_modelfiles(
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -59,18 +60,17 @@ async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -84,9 +84,8 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_admin_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
@ -95,15 +94,14 @@ async def update_modelfile_by_tag_name(
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
form_data.tag_name, updated_modelfile)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -117,8 +115,7 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_admin_user)):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View file

@ -7,7 +7,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
@ -15,11 +14,6 @@ from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
@ -89,7 +83,7 @@ async def update_user_by_id(
if form_data.password:
hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}")
print(hashed)
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower())

View file

@ -21,6 +21,155 @@ from constants import ERROR_MESSAGES
router = APIRouter()
class UploadBlobForm(BaseModel):
filename: str
from urllib.parse import urlparse
def parse_huggingface_url(hf_url):
try:
# Parse the URL
parsed_url = urlparse(hf_url)
# Get the path and split it into components
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
except ValueError:
return None
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
done = False
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
else:
current_size = 0
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
current_size += len(data)
file.write(data)
done = current_size == total_size
progress = round((current_size / total_size) * 100, 2)
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
file.seek(0)
hashed = calculate_sha256(file)
file.seek(0)
url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
response = requests.post(url, data=file)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file_name,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise "Ollama: Could not create blob, Please try again."
@router.get("/download")
async def download(
url: str,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url)
if file_name:
file_path = f"{UPLOAD_DIR}/{file_name}"
return StreamingResponse(
download_file_stream(url, file_path, file_name),
media_type="text/event-stream",
)
else:
return None
@router.post("/upload")
def upload(file: UploadFile = File(...)):
file_path = f"{UPLOAD_DIR}/{file.filename}"
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
"progress": progress,
"total": total_size,
"completed": total,
}
yield f"data: {json.dumps(res)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e:
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
@router.get("/gravatar")
async def get_gravatar(
email: str,