Merge branch 'dev' into feat/trusted-email-header

This commit is contained in:
Jun Siang Cheah 2024-03-31 22:08:26 +01:00
commit 562e40a7bd
58 changed files with 2915 additions and 2152 deletions

View file

@ -22,7 +22,13 @@ from utils.utils import (
)
from utils.misc import calculate_sha256
from config import SRC_LOG_LEVELS, CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR
from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
UPLOAD_DIR,
WHISPER_MODEL,
WHISPER_MODEL_DIR,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])

View file

@ -325,7 +325,7 @@ def save_url_image(url):
return image_id
except Exception as e:
print(f"Error saving image: {e}")
log.exception(f"Error saving image: {e}")
return None
@ -397,7 +397,7 @@ def generate_image(
user.id,
app.state.COMFYUI_BASE_URL,
)
print(res)
log.debug(f"res: {res}")
images = []
@ -409,7 +409,7 @@ def generate_image(
with open(file_body_path, "w") as f:
json.dump(data.model_dump(exclude_none=True), f)
print(images)
log.debug(f"images: {images}")
return images
else:
if form_data.model:

View file

@ -4,6 +4,12 @@ import json
import urllib.request
import urllib.parse
import random
import logging
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
from pydantic import BaseModel
@ -121,7 +127,7 @@ COMFYUI_DEFAULT_PROMPT = """
def queue_prompt(prompt, client_id, base_url):
print("queue_prompt")
log.info("queue_prompt")
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(f"{base_url}/prompt", data=data)
@ -129,7 +135,7 @@ def queue_prompt(prompt, client_id, base_url):
def get_image(filename, subfolder, folder_type, base_url):
print("get_image")
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
@ -137,14 +143,14 @@ def get_image(filename, subfolder, folder_type, base_url):
def get_image_url(filename, subfolder, folder_type, base_url):
print("get_image")
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
return f"{base_url}/view?{url_values}"
def get_history(prompt_id, base_url):
print("get_history")
log.info("get_history")
with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
return json.loads(response.read())
@ -212,15 +218,15 @@ def comfyui_generate_image(
try:
ws = websocket.WebSocket()
ws.connect(f"ws://{host}/ws?clientId={client_id}")
print("WebSocket connection established.")
log.info("WebSocket connection established.")
except Exception as e:
print(f"Failed to connect to WebSocket server: {e}")
log.exception(f"Failed to connect to WebSocket server: {e}")
return None
try:
images = get_images(ws, comfyui_prompt, client_id, base_url)
except Exception as e:
print(f"Error while receiving images: {e}")
log.exception(f"Error while receiving images: {e}")
images = None
ws.close()

View file

@ -33,7 +33,13 @@ from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR
from config import (
SRC_LOG_LEVELS,
OLLAMA_BASE_URLS,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
UPLOAD_DIR,
)
from utils.misc import calculate_sha256
log = logging.getLogger(__name__)
@ -266,7 +272,7 @@ async def pull_model(
if request_id in REQUEST_POOL:
yield chunk
else:
print("User: canceled request")
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
@ -664,7 +670,7 @@ async def generate_completion(
else:
raise HTTPException(
status_code=400,
detail="error_detail",
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
@ -770,7 +776,11 @@ async def generate_chat_completion(
r = None
log.debug("form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(form_data.model_dump_json(exclude_none=True).encode()))
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
def get_request():
nonlocal form_data

View file

@ -333,7 +333,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
print(f"deleting existing collection {collection_name}")
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
@ -346,7 +346,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
)
return True
except Exception as e:
print(e)
log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
@ -575,7 +575,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
),
)
except Exception as e:
print(e)
log.exception(e)
pass
except Exception as e:

View file

@ -11,6 +11,7 @@ from utils.utils import verify_password
from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View file

@ -13,6 +13,7 @@ from apps.web.internal.db import DB
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View file

@ -64,8 +64,8 @@ class ModelfilesTable:
self.db.create_tables([Modelfile])
def insert_new_modelfile(
self, user_id: str,
form_data: ModelfileForm) -> Optional[ModelfileModel]:
self, user_id: str, form_data: ModelfileForm
) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile:
modelfile = ModelfileModel(
**{
@ -73,7 +73,8 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()),
})
}
)
try:
result = Modelfile.create(**modelfile.model_dump())
@ -87,29 +88,28 @@ class ModelfilesTable:
else:
return None
def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
return [
ModelfileResponse(
**{
**model_to_dict(modelfile),
"modelfile":
json.loads(modelfile.modelfile),
}) for modelfile in Modelfile.select()
"modelfile": json.loads(modelfile.modelfile),
}
)
for modelfile in Modelfile.select()
# .limit(limit).offset(skip)
]
def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
try:
query = Modelfile.update(
modelfile=json.dumps(modelfile),

View file

@ -52,8 +52,9 @@ class PromptsTable:
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(self, user_id: str,
form_data: PromptForm) -> Optional[PromptModel]:
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
prompt = PromptModel(
**{
"user_id": user_id,
@ -61,7 +62,8 @@ class PromptsTable:
"title": form_data.title,
"content": form_data.content,
"timestamp": int(time.time()),
})
}
)
try:
result = Prompt.create(**prompt.model_dump())
@ -81,13 +83,14 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
def update_prompt_by_command(
self, command: str,
form_data: PromptForm) -> Optional[PromptModel]:
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,

View file

@ -11,6 +11,7 @@ import logging
from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View file

@ -29,6 +29,7 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

View file

@ -10,7 +10,12 @@ import uuid
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.utils import (
get_password_hash,
get_current_user,
get_admin_user,
create_token,
)
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -43,7 +48,6 @@ async def set_global_default_models(
return request.app.state.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions(
request: Request,

View file

@ -24,9 +24,9 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0,
limit: int = 50,
user=Depends(get_current_user)):
async def get_modelfiles(
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
):
return Modelfiles.get_modelfiles(skip, limit)
@ -36,17 +36,16 @@ async def get_modelfiles(skip: int = 0,
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_admin_user)):
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile":
json.loads(modelfile.modelfile),
})
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -60,17 +59,18 @@ async def create_new_modelfile(form_data: ModelfileForm,
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile":
json.loads(modelfile.modelfile),
})
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -84,8 +84,9 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_admin_user)):
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
@ -94,14 +95,15 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile)
form_data.tag_name, updated_modelfile
)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile":
json.loads(modelfile.modelfile),
})
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -115,7 +117,8 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_admin_user)):
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View file

@ -16,6 +16,7 @@ from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])