Revert "Merge Updates & Dockerfile improvements" (#3)

This reverts commit 9763d885be.
This commit is contained in:
Jannik S 2024-04-02 11:28:04 +02:00 committed by GitHub
parent 9763d885be
commit 099b1d066b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
155 changed files with 4795 additions and 14501 deletions

View file

@ -1,5 +1,4 @@
import os
import logging
from fastapi import (
FastAPI,
Request,
@ -22,24 +21,11 @@ from utils.utils import (
)
from utils.misc import calculate_sha256
from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
UPLOAD_DIR,
WHISPER_MODEL,
WHISPER_MODEL_DIR,
DEVICE_TYPE,
)
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, DEVICE_TYPE
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
whisper_device_type = DEVICE_TYPE
if whisper_device_type != "cuda":
if DEVICE_TYPE != "cuda":
whisper_device_type = "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")
app = FastAPI()
app.add_middleware(
@ -56,7 +42,7 @@ def transcribe(
file: UploadFile = File(...),
user=Depends(get_current_user),
):
log.info(f"file.content_type: {file.content_type}")
print(file.content_type)
if file.content_type not in ["audio/mpeg", "audio/wav"]:
raise HTTPException(
@ -80,7 +66,7 @@ def transcribe(
)
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
@ -90,7 +76,7 @@ def transcribe(
return {"text": transcript.strip()}
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View file

@ -18,8 +18,6 @@ from utils.utils import (
get_current_user,
get_admin_user,
)
from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
@ -27,14 +25,10 @@ from pathlib import Path
import uuid
import base64
import json
import logging
from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@ -55,8 +49,6 @@ app.state.MODEL = ""
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = "512x512"
app.state.IMAGE_STEPS = 50
@ -79,48 +71,32 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
class EngineUrlUpdateForm(BaseModel):
AUTOMATIC1111_BASE_URL: Optional[str] = None
COMFYUI_BASE_URL: Optional[str] = None
class UrlUpdateForm(BaseModel):
url: str
@app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)):
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
}
async def get_automatic1111_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
@app.post("/url/update")
async def update_engine_url(
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
async def update_automatic1111_url(
form_data: UrlUpdateForm, user=Depends(get_admin_user)
):
if form_data.AUTOMATIC1111_BASE_URL == None:
if form_data.url == "":
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
url = form_data.url.strip("/")
try:
r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
app.state.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
"status": True,
}
@ -210,18 +186,6 @@ def get_models(user=Depends(get_current_user)):
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif app.state.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
info = r.json()
return list(
map(
lambda model: {"id": model, "name": model},
info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
)
)
else:
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
@ -243,8 +207,6 @@ async def get_default_model(user=Depends(get_admin_user)):
try:
if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
elif app.state.ENGINE == "comfyui":
return {"model": app.state.MODEL if app.state.MODEL else ""}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
@ -259,12 +221,10 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str):
if app.state.ENGINE == "openai":
app.state.MODEL = model
return app.state.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
@ -308,24 +268,7 @@ def save_b64_image(b64_str):
return image_id
except Exception as e:
log.error(f"Error saving image: {e}")
return None
def save_url_image(url):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try:
r = requests.get(url)
r.raise_for_status()
with open(file_path, "wb") as image_file:
image_file.write(r.content)
return image_id
except Exception as e:
log.exception(f"Error saving image: {e}")
print(f"Error saving image: {e}")
return None
@ -335,8 +278,6 @@ def generate_image(
user=Depends(get_current_user),
):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
r = None
try:
if app.state.ENGINE == "openai":
@ -352,7 +293,6 @@ def generate_image(
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
"response_format": "b64_json",
}
r = requests.post(
url=f"https://api.openai.com/v1/images/generations",
json=data,
@ -360,6 +300,7 @@ def generate_image(
)
r.raise_for_status()
res = r.json()
images = []
@ -374,47 +315,12 @@ def generate_image(
return images
elif app.state.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
app.state.MODEL,
data,
user.id,
app.state.COMFYUI_BASE_URL,
)
log.debug(f"res: {res}")
images = []
for image in res["data"]:
image_id = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
json.dump(data.model_dump(exclude_none=True), f)
log.debug(f"images: {images}")
return images
else:
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
@ -435,7 +341,7 @@ def generate_image(
res = r.json()
log.debug(f"res: {res}")
print(res)
images = []
@ -450,10 +356,7 @@ def generate_image(
return images
except Exception as e:
error = e
if r != None:
data = r.json()
if "error" in data:
error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
print(e)
if r:
print(r.json())
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))

View file

@ -1,234 +0,0 @@
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import random
import logging
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
from pydantic import BaseModel
from typing import Optional
COMFYUI_DEFAULT_PROMPT = """
{
"3": {
"inputs": {
"seed": 0,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "model.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "Negative Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
"""
def queue_prompt(prompt, client_id, base_url):
log.info("queue_prompt")
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(f"{base_url}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type, base_url):
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
return response.read()
def get_image_url(filename, subfolder, folder_type, base_url):
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
return f"{base_url}/view?{url_values}"
def get_history(prompt_id, base_url):
log.info("get_history")
with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
return json.loads(response.read())
def get_images(ws, prompt, client_id, base_url):
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
output_images = []
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = get_history(prompt_id, base_url)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
for image in node_output["images"]:
url = get_image_url(
image["filename"], image["subfolder"], image["type"], base_url
)
output_images.append({"url": url})
return {"data": output_images}
class ImageGenerationPayload(BaseModel):
prompt: str
negative_prompt: Optional[str] = ""
steps: Optional[int] = None
seed: Optional[int] = None
width: int
height: int
n: int = 1
def comfyui_generate_image(
model: str, payload: ImageGenerationPayload, client_id, base_url
):
host = base_url.replace("http://", "").replace("https://", "")
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
comfyui_prompt["5"]["inputs"]["width"] = payload.width
comfyui_prompt["5"]["inputs"]["height"] = payload.height
# set the text prompt for our positive CLIPTextEncode
comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
if payload.steps:
comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
comfyui_prompt["3"]["inputs"]["seed"] = (
payload.seed if payload.seed else random.randint(0, 18446744073709551614)
)
try:
ws = websocket.WebSocket()
ws.connect(f"ws://{host}/ws?clientId={client_id}")
log.info("WebSocket connection established.")
except Exception as e:
log.exception(f"Failed to connect to WebSocket server: {e}")
return None
try:
images = get_images(ws, comfyui_prompt, client_id, base_url)
except Exception as e:
log.exception(f"Error while receiving images: {e}")
images = None
ws.close()
return images

View file

@ -1,27 +1,10 @@
import logging
from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi import FastAPI, Request, Depends, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
from utils.utils import get_http_authorization_cred, get_current_user
from config import SRC_LOG_LEVELS, ENV
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
from config import ENV
proxy_config = ProxyConfig()
@ -43,58 +26,16 @@ async def on_startup():
await startup()
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
request.state.user = None
try:
user = get_current_user(get_http_authorization_cred(auth_header))
log.debug(f"user: {user}")
request.state.user = user
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
if ENV != "dev":
try:
user = get_current_user(get_http_authorization_cred(auth_header))
print(user)
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request)
return response
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
if "/models" in request.url.path:
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
data = json.loads(body.decode("utf-8"))
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Modified Flag
data["modified"] = True
return JSONResponse(content=data)
return response
app.add_middleware(ModifyModelsResponseMiddleware)

View file

@ -1,49 +1,24 @@
from fastapi import (
FastAPI,
Request,
Response,
HTTPException,
Depends,
status,
UploadFile,
File,
BackgroundTasks,
)
from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, ConfigDict
import os
import copy
import random
import requests
import json
import uuid
import aiohttp
import asyncio
import logging
from urllib.parse import urlparse
from typing import Optional, List, Union
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
from typing import Optional, List, Union
from config import (
SRC_LOG_LEVELS,
OLLAMA_BASE_URLS,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
UPLOAD_DIR,
)
from utils.misc import calculate_sha256
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
app = FastAPI()
app.add_middleware(
@ -94,7 +69,7 @@ class UrlUpdateForm(BaseModel):
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OLLAMA_BASE_URLS = form_data.urls
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
print(app.state.OLLAMA_BASE_URLS)
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
@ -115,7 +90,7 @@ async def fetch_url(url):
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
print(f"Connection error: {e}")
return None
@ -123,14 +98,13 @@ def merge_models_lists(model_lists):
merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
digest = model["digest"]
if digest not in merged_models:
model["urls"] = [idx]
merged_models[digest] = model
else:
merged_models[digest]["urls"].append(idx)
for model in model_list:
digest = model["digest"]
if digest not in merged_models:
model["urls"] = [idx]
merged_models[digest] = model
else:
merged_models[digest]["urls"].append(idx)
return list(merged_models.values())
@ -139,16 +113,16 @@ def merge_models_lists(model_lists):
async def get_all_models():
log.info("get_all_models()")
print("get_all_models")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
models = {
"models": merge_models_lists(
map(lambda response: response["models"] if response else None, responses)
map(lambda response: response["models"], responses)
)
}
app.state.MODELS = {model["model"]: model for model in models["models"]}
return models
@ -180,7 +154,7 @@ async def get_ollama_tags(
return r.json()
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -207,17 +181,11 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
if len(responses) > 0:
lowest_version = min(
responses, key=lambda x: tuple(map(int, x["version"].split(".")))
)
lowest_version = min(
responses, key=lambda x: tuple(map(int, x["version"].split(".")))
)
return {"version": lowest_version["version"]}
else:
raise HTTPException(
status_code=500,
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
return {"version": lowest_version["version"]}
else:
url = app.state.OLLAMA_BASE_URLS[url_idx]
try:
@ -226,7 +194,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
return r.json()
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -252,33 +220,18 @@ async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
):
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
r = None
def get_request():
nonlocal url
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
@ -299,9 +252,8 @@ async def pull_model(
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -340,7 +292,7 @@ async def push_model(
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}")
print(url)
r = None
@ -372,7 +324,7 @@ async def push_model(
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -400,9 +352,9 @@ class CreateModelForm(BaseModel):
async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
):
log.debug(f"form_data: {form_data}")
print(form_data)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
r = None
@ -424,7 +376,7 @@ async def create_model(
r.raise_for_status()
log.debug(f"r: {r}")
print(r)
return StreamingResponse(
stream_content(),
@ -437,7 +389,7 @@ async def create_model(
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -475,7 +427,7 @@ async def copy_model(
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
try:
r = requests.request(
@ -485,11 +437,11 @@ async def copy_model(
)
r.raise_for_status()
log.debug(f"r.text: {r.text}")
print(r.text)
return True
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -522,7 +474,7 @@ async def delete_model(
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
try:
r = requests.request(
@ -532,11 +484,11 @@ async def delete_model(
)
r.raise_for_status()
log.debug(f"r.text: {r.text}")
print(r.text)
return True
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -562,7 +514,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
try:
r = requests.request(
@ -574,7 +526,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
return r.json()
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -614,7 +566,7 @@ async def generate_embeddings(
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
try:
r = requests.request(
@ -626,7 +578,7 @@ async def generate_embeddings(
return r.json()
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -670,11 +622,11 @@ async def generate_completion(
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
detail="error_detail",
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
r = None
@ -695,7 +647,7 @@ async def generate_completion(
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
print("User: canceled request")
break
finally:
if hasattr(r, "close"):
@ -750,7 +702,7 @@ class GenerateChatCompletionForm(BaseModel):
format: Optional[str] = None
options: Optional[dict] = None
template: Optional[str] = None
stream: Optional[bool] = None
stream: Optional[bool] = True
keep_alive: Optional[Union[int, str]] = None
@ -772,15 +724,11 @@ async def generate_chat_completion(
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
r = None
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
print(form_data.model_dump_json(exclude_none=True).encode())
def get_request():
nonlocal form_data
@ -799,7 +747,7 @@ async def generate_chat_completion(
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
print("User: canceled request")
break
finally:
if hasattr(r, "close"):
@ -822,7 +770,7 @@ async def generate_chat_completion(
headers=dict(r.headers),
)
except Exception as e:
log.exception(e)
print(e)
raise e
try:
@ -876,7 +824,7 @@ async def generate_openai_chat_completion(
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(url)
r = None
@ -899,7 +847,7 @@ async def generate_openai_chat_completion(
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
print("User: canceled request")
break
finally:
if hasattr(r, "close"):
@ -942,220 +890,6 @@ async def generate_openai_chat_completion(
)
class UrlForm(BaseModel):
url: str
class UploadBlobForm(BaseModel):
filename: str
def parse_huggingface_url(hf_url):
try:
# Parse the URL
parsed_url = urlparse(hf_url)
# Get the path and split it into components
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
except ValueError:
return None
async def download_file_stream(
ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
):
done = False
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
else:
current_size = 0
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(file_url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
current_size += len(data)
file.write(data)
done = current_size == total_size
progress = round((current_size / total_size) * 100, 2)
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
file.seek(0)
hashed = calculate_sha256(file)
file.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=file)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file_name,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise "Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download")
@app.post("/models/download/{url_idx}")
async def download_model(
form_data: UrlForm,
url_idx: Optional[int] = None,
):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
if not any(form_data.url.startswith(host) for host in allowed_hosts):
raise HTTPException(
status_code=400,
detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
)
if url_idx == None:
url_idx = 0
url = app.state.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url)
if file_name:
file_path = f"{UPLOAD_DIR}/{file_name}"
return StreamingResponse(
download_file_stream(url, form_data.url, file_path, file_name),
)
else:
return None
@app.post("/models/upload")
@app.post("/models/upload/{url_idx}")
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None:
url_idx = 0
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}"
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
nonlocal ollama_url
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
"progress": progress,
"total": total_size,
"completed": total,
}
yield f"data: {json.dumps(res)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e:
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
url = app.state.OLLAMA_BASE_URLS[0]
@ -1206,7 +940,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
print("User: canceled request")
break
finally:
if hasattr(r, "close"):

View file

@ -6,7 +6,6 @@ import requests
import aiohttp
import asyncio
import json
import logging
from pydantic import BaseModel
@ -20,7 +19,6 @@ from utils.utils import (
get_admin_user,
)
from config import (
SRC_LOG_LEVELS,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
@ -33,9 +31,6 @@ from typing import List, Optional
import hashlib
from pathlib import Path
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI()
app.add_middleware(
CORSMiddleware,
@ -116,7 +111,6 @@ async def speech(request: Request, user=Depends(get_verified_user)):
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
@ -139,7 +133,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -149,9 +143,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
raise HTTPException(status_code=r.status_code, detail=error_detail)
except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
@ -165,7 +157,7 @@ async def fetch_url(url, key):
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
print(f"Connection error: {e}")
return None
@ -173,21 +165,20 @@ def merge_models_lists(model_lists):
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{**model, "urlIdx": idx}
for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"]
]
)
merged_list.extend(
[
{**model, "urlIdx": idx}
for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"]
]
)
return merged_list
async def get_all_models():
log.info("get_all_models()")
print("get_all_models")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
models = {"data": []}
@ -196,24 +187,15 @@ async def get_all_models():
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
responses = list(
filter(lambda x: x is not None and "error" not in x, responses)
)
models = {
"data": merge_models_lists(
list(
map(
lambda response: (
response["data"]
if response and "data" in response
else None
),
responses,
)
)
list(map(lambda response: response["data"], responses))
)
}
log.info(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
@ -236,9 +218,6 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return models
else:
url = app.state.OPENAI_API_BASE_URLS[url_idx]
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
@ -251,7 +230,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return response_data
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -285,7 +264,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if body.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body:
body["max_tokens"] = 4000
log.debug("Modified body_dict:", body)
print("Modified body_dict:", body)
# Fix for ChatGPT calls failing because the num_ctx key is in body
if "num_ctx" in body:
@ -297,7 +276,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
# Convert the modified body back to JSON
body = json.dumps(body)
except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e)
print("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx]
@ -311,8 +290,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(
method=request.method,
@ -335,7 +312,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
response_data = r.json()
return response_data
except Exception as e:
log.exception(e)
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
@ -345,6 +322,4 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
raise HTTPException(status_code=r.status_code, detail=error_detail)

View file

@ -8,7 +8,7 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
import os, shutil, logging
import os, shutil
from pathlib import Path
from typing import List
@ -21,7 +21,6 @@ from langchain_community.document_loaders import (
TextLoader,
PyPDFLoader,
CSVLoader,
BSHTMLLoader,
Docx2txtLoader,
UnstructuredEPubLoader,
UnstructuredWordDocumentLoader,
@ -55,7 +54,6 @@ from utils.misc import (
)
from utils.utils import get_current_user, get_admin_user
from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR,
DOCS_DIR,
RAG_EMBEDDING_MODEL,
@ -68,9 +66,6 @@ from config import (
from constants import ERROR_MESSAGES
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
#
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
@ -116,6 +111,39 @@ class StoreWebForm(CollectionNameForm):
url: str
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
print(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
@app.get("/")
async def get_status():
return {
@ -245,7 +273,7 @@ def query_doc_handler(
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@ -289,69 +317,13 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
"filename": form_data.url,
}
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.split_documents(data)
if len(docs) > 0:
return store_docs_in_vector_db(docs, collection_name, overwrite), None
else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
return store_docs_in_vector_db(docs, collection_name, overwrite)
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
@ -409,8 +381,6 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml":
loader = UnstructuredXMLLoader(file_path)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
elif file_content_type == "application/epub+zip":
@ -429,9 +399,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
loader = TextLoader(file_path)
else:
loader = TextLoader(file_path, autodetect_encoding=True)
loader = TextLoader(file_path)
known_type = False
return loader, known_type
@ -445,7 +415,7 @@ def store_doc(
):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
log.info(f"file.content_type: {file.content_type}")
print(file.content_type)
try:
filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}"
@ -461,24 +431,22 @@ def store_doc(
loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
try:
result = store_data_in_vector_db(data, collection_name)
if result:
return {
"status": True,
"collection_name": collection_name,
"filename": filename,
"known_type": known_type,
}
except Exception as e:
if result:
return {
"status": True,
"collection_name": collection_name,
"filename": filename,
"known_type": known_type,
}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
detail=ERROR_MESSAGES.DEFAULT(),
)
except Exception as e:
log.exception(e)
print(e)
if "No pandoc was found" in str(e):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -491,37 +459,6 @@ def store_doc(
)
class TextRAGForm(BaseModel):
name: str
content: str
collection_name: Optional[str] = None
@app.post("/text")
def store_text(
form_data: TextRAGForm,
user=Depends(get_current_user),
):
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256_string(form_data.content)
result = store_text_in_vector_db(
form_data.content,
metadata={"name": form_data.name, "created_by": user.id},
collection_name=collection_name,
)
if result:
return {"status": True, "collection_name": collection_name}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
for path in Path(DOCS_DIR).rglob("./**/*"):
@ -540,45 +477,41 @@ def scan_docs_dir(user=Depends(get_admin_user)):
)
data = loader.load()
try:
result = store_data_in_vector_db(data, collection_name)
result = store_data_in_vector_db(data, collection_name)
if result:
sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if result:
sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None:
doc = Documents.insert_new_doc(
user.id,
DocumentForm(
**{
"name": sanitized_filename,
"title": filename,
"collection_name": collection_name,
"filename": filename,
"content": (
json.dumps(
{
"tags": list(
map(
lambda name: {"name": name},
tags,
)
if doc == None:
doc = Documents.insert_new_doc(
user.id,
DocumentForm(
**{
"name": sanitized_filename,
"title": filename,
"collection_name": collection_name,
"filename": filename,
"content": (
json.dumps(
{
"tags": list(
map(
lambda name: {"name": name},
tags,
)
}
)
if len(tags)
else "{}"
),
}
),
)
except Exception as e:
log.exception(e)
pass
)
}
)
if len(tags)
else "{}"
),
}
),
)
except Exception as e:
log.exception(e)
print(e)
return True
@ -599,11 +532,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
print("Failed to delete %s. Reason: %s" % (file_path, e))
try:
CHROMA_CLIENT.reset()
except Exception as e:
log.exception(e)
print(e)
return True

View file

@ -1,11 +1,7 @@
import re
import logging
from typing import List
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
from config import CHROMA_CLIENT
def query_doc(collection_name: str, query: str, k: int, embedding_function):
@ -95,13 +91,14 @@ def query_collection(
def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context)
template = template.replace("[query]", query)
template = re.sub(r"\[context\]", context, template)
template = re.sub(r"\[query\]", query, template)
return template
def rag_messages(docs, messages, template, k, embedding_function):
log.debug(f"docs: {docs}")
print(docs)
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
@ -141,8 +138,6 @@ def rag_messages(docs, messages, template, k, embedding_function):
k=k,
embedding_function=embedding_function,
)
elif doc["type"] == "text":
context = doc["content"]
else:
context = query_doc(
collection_name=doc["collection_name"],
@ -151,13 +146,11 @@ def rag_messages(docs, messages, template, k, embedding_function):
embedding_function=embedding_function,
)
except Exception as e:
log.exception(e)
print(e)
context = None
relevant_contexts.append(context)
log.debug(f"relevant_contexts: {relevant_contexts}")
context_string = ""
for context in relevant_contexts:
if context:

View file

@ -1,16 +1,13 @@
from peewee import *
from config import SRC_LOG_LEVELS, DATA_DIR
from config import DATA_DIR
import os
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("File renamed successfully.")
print("File renamed successfully.")
else:
pass

View file

@ -19,7 +19,6 @@ from config import (
DEFAULT_USER_ROLE,
ENABLE_SIGNUP,
USER_PERMISSIONS,
WEBHOOK_URL,
)
app = FastAPI()
@ -33,7 +32,6 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.add_middleware(

View file

@ -2,7 +2,6 @@ from pydantic import BaseModel
from typing import List, Union, Optional
import time
import uuid
import logging
from peewee import *
from apps.web.models.users import UserModel, Users
@ -10,11 +9,6 @@ from utils.utils import verify_password
from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# DB MODEL
####################
@ -92,7 +86,7 @@ class AuthsTable:
def insert_new_auth(
self, email: str, password: str, name: str, role: str = "pending"
) -> Optional[UserModel]:
log.info("insert_new_auth")
print("insert_new_auth")
id = str(uuid.uuid4())
@ -109,7 +103,7 @@ class AuthsTable:
return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
print("authenticate_user", email)
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
if auth:

View file

@ -95,6 +95,20 @@ class ChatTable:
except:
return None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
query = Chat.update(
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
timestamp=int(time.time()),
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:

View file

@ -3,7 +3,6 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from utils.utils import decode_token
from utils.misc import get_gravatar_url
@ -12,11 +11,6 @@ from apps.web.internal.db import DB
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Documents DB Schema
####################
@ -124,7 +118,7 @@ class DocumentsTable:
doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
log.exception(e)
print(e)
return None
def update_doc_content_by_name(
@ -144,7 +138,7 @@ class DocumentsTable:
doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
log.exception(e)
print(e)
return None
def delete_doc_by_name(self, name: str) -> bool:

View file

@ -64,8 +64,8 @@ class ModelfilesTable:
self.db.create_tables([Modelfile])
def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm
) -> Optional[ModelfileModel]:
self, user_id: str,
form_data: ModelfileForm) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile:
modelfile = ModelfileModel(
**{
@ -73,8 +73,7 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()),
}
)
})
try:
result = Modelfile.create(**modelfile.model_dump())
@ -88,28 +87,29 @@ class ModelfilesTable:
else:
return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [
ModelfileResponse(
**{
**model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile),
}
)
for modelfile in Modelfile.select()
"modelfile":
json.loads(modelfile.modelfile),
}) for modelfile in Modelfile.select()
# .limit(limit).offset(skip)
]
def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
try:
query = Modelfile.update(
modelfile=json.dumps(modelfile),

View file

@ -52,9 +52,8 @@ class PromptsTable:
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
def insert_new_prompt(self, user_id: str,
form_data: PromptForm) -> Optional[PromptModel]:
prompt = PromptModel(
**{
"user_id": user_id,
@ -62,8 +61,7 @@ class PromptsTable:
"title": form_data.title,
"content": form_data.content,
"timestamp": int(time.time()),
}
)
})
try:
result = Prompt.create(**prompt.model_dump())
@ -83,14 +81,13 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
self, command: str,
form_data: PromptForm) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,

View file

@ -6,15 +6,9 @@ from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
import logging
from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tag DB Schema
####################
@ -179,7 +173,7 @@ class TagTable:
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
print(res)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
@ -191,7 +185,7 @@ class TagTable:
return True
except Exception as e:
log.error(f"delete_tag: {e}")
print("delete_tag", e)
return False
def delete_tag_by_tag_name_and_chat_id_and_user_id(
@ -204,7 +198,7 @@ class TagTable:
& (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
print(res)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
@ -216,7 +210,7 @@ class TagTable:
return True
except Exception as e:
log.error(f"delete_tag: {e}")
print("delete_tag", e)
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:

View file

@ -27,8 +27,7 @@ from utils.utils import (
create_token,
)
from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from constants import ERROR_MESSAGES
router = APIRouter()
@ -156,17 +155,6 @@ async def signup(request: Request, form_data: SignupForm):
)
# response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL:
post_webhook(
request.app.state.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
return {
"token": token,
"token_type": "Bearer",

View file

@ -5,7 +5,6 @@ from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter
from pydantic import BaseModel
import json
import logging
from apps.web.models.users import Users
from apps.web.models.chats import (
@ -28,11 +27,6 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
@ -84,7 +78,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
@ -101,7 +95,7 @@ async def get_all_tags(user=Depends(get_current_user)):
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)

View file

@ -10,12 +10,7 @@ import uuid
from apps.web.models.users import Users
from utils.utils import (
get_password_hash,
get_current_user,
get_admin_user,
create_token,
)
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
@ -48,6 +43,7 @@ async def set_global_default_models(
return request.app.state.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
async def set_global_default_suggestions(
request: Request,

View file

@ -24,9 +24,9 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
):
async def get_modelfiles(skip: int = 0,
limit: int = 50,
user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit)
@ -36,16 +36,17 @@ async def get_modelfiles(
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -59,18 +60,17 @@ async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -84,9 +84,8 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_admin_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
@ -95,15 +94,14 @@ async def update_modelfile_by_tag_name(
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
form_data.tag_name, updated_modelfile)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -117,8 +115,7 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_admin_user)):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result

View file

@ -7,7 +7,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
@ -15,11 +14,6 @@ from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
@ -89,7 +83,7 @@ async def update_user_by_id(
if form_data.password:
hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}")
print(hashed)
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower())

View file

@ -21,6 +21,155 @@ from constants import ERROR_MESSAGES
router = APIRouter()
class UploadBlobForm(BaseModel):
filename: str
from urllib.parse import urlparse
def parse_huggingface_url(hf_url):
try:
# Parse the URL
parsed_url = urlparse(hf_url)
# Get the path and split it into components
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
except ValueError:
return None
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
done = False
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
else:
current_size = 0
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
current_size += len(data)
file.write(data)
done = current_size == total_size
progress = round((current_size / total_size) * 100, 2)
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
file.seek(0)
hashed = calculate_sha256(file)
file.seek(0)
url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
response = requests.post(url, data=file)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file_name,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise "Ollama: Could not create blob, Please try again."
@router.get("/download")
async def download(
url: str,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url)
if file_name:
file_path = f"{UPLOAD_DIR}/{file_name}"
return StreamingResponse(
download_file_stream(url, file_path, file_name),
media_type="text/event-stream",
)
else:
return None
@router.post("/upload")
def upload(file: UploadFile = File(...)):
file_path = f"{UPLOAD_DIR}/{file.filename}"
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
"progress": progress,
"total": total_size,
"completed": total,
}
yield f"data: {json.dumps(res)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e:
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
@router.get("/gravatar")
async def get_gravatar(
email: str,