forked from open-webui/open-webui
Merge pull request 'main' (!3) from open-webui/open-webui:main into main
Some checks failed
Some checks failed
Reviewed-on: #3
This commit is contained in:
commit
921a1cf978
74 changed files with 7314 additions and 5536 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import logging
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Request,
|
||||
|
|
@ -21,7 +22,10 @@ from utils.utils import (
|
|||
)
|
||||
from utils.misc import calculate_sha256
|
||||
|
||||
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR
|
||||
from config import SRC_LOG_LEVELS, CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
|
|
@ -38,7 +42,7 @@ def transcribe(
|
|||
file: UploadFile = File(...),
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
print(file.content_type)
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if file.content_type not in ["audio/mpeg", "audio/wav"]:
|
||||
raise HTTPException(
|
||||
|
|
@ -62,7 +66,7 @@ def transcribe(
|
|||
)
|
||||
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
print(
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
|
|
@ -72,7 +76,7 @@ def transcribe(
|
|||
return {"text": transcript.strip()}
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from utils.utils import (
|
|||
get_current_user,
|
||||
get_admin_user,
|
||||
)
|
||||
|
||||
from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
|
||||
from utils.misc import calculate_sha256
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -25,10 +27,14 @@ from pathlib import Path
|
|||
import uuid
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
|
||||
from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||
|
||||
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -49,6 +55,8 @@ app.state.MODEL = ""
|
|||
|
||||
|
||||
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
|
||||
|
||||
app.state.IMAGE_SIZE = "512x512"
|
||||
app.state.IMAGE_STEPS = 50
|
||||
|
|
@ -71,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
|
|||
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
|
||||
|
||||
|
||||
class UrlUpdateForm(BaseModel):
|
||||
url: str
|
||||
class EngineUrlUpdateForm(BaseModel):
|
||||
AUTOMATIC1111_BASE_URL: Optional[str] = None
|
||||
COMFYUI_BASE_URL: Optional[str] = None
|
||||
|
||||
|
||||
@app.get("/url")
|
||||
async def get_automatic1111_url(user=Depends(get_admin_user)):
|
||||
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
|
||||
async def get_engine_url(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
|
||||
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/url/update")
|
||||
async def update_automatic1111_url(
|
||||
form_data: UrlUpdateForm, user=Depends(get_admin_user)
|
||||
async def update_engine_url(
|
||||
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
if form_data.url == "":
|
||||
if form_data.AUTOMATIC1111_BASE_URL == None:
|
||||
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
else:
|
||||
url = form_data.url.strip("/")
|
||||
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
|
||||
try:
|
||||
r = requests.head(url)
|
||||
app.state.AUTOMATIC1111_BASE_URL = url
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||
|
||||
if form_data.COMFYUI_BASE_URL == None:
|
||||
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
else:
|
||||
url = form_data.COMFYUI_BASE_URL.strip("/")
|
||||
|
||||
try:
|
||||
r = requests.head(url)
|
||||
app.state.COMFYUI_BASE_URL = url
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||
|
||||
return {
|
||||
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
|
||||
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
|
||||
"status": True,
|
||||
}
|
||||
|
||||
|
|
@ -186,6 +210,18 @@ def get_models(user=Depends(get_current_user)):
|
|||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
]
|
||||
elif app.state.ENGINE == "comfyui":
|
||||
|
||||
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
|
||||
info = r.json()
|
||||
|
||||
return list(
|
||||
map(
|
||||
lambda model: {"id": model, "name": model},
|
||||
info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
r = requests.get(
|
||||
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
|
||||
|
|
@ -207,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)):
|
|||
try:
|
||||
if app.state.ENGINE == "openai":
|
||||
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
|
||||
elif app.state.ENGINE == "comfyui":
|
||||
return {"model": app.state.MODEL if app.state.MODEL else ""}
|
||||
else:
|
||||
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
|
||||
options = r.json()
|
||||
|
|
@ -221,10 +259,12 @@ class UpdateModelForm(BaseModel):
|
|||
|
||||
|
||||
def set_model_handler(model: str):
|
||||
|
||||
if app.state.ENGINE == "openai":
|
||||
app.state.MODEL = model
|
||||
return app.state.MODEL
|
||||
if app.state.ENGINE == "comfyui":
|
||||
app.state.MODEL = model
|
||||
return app.state.MODEL
|
||||
else:
|
||||
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
|
||||
options = r.json()
|
||||
|
|
@ -266,6 +306,23 @@ def save_b64_image(b64_str):
|
|||
with open(file_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
|
||||
return image_id
|
||||
except Exception as e:
|
||||
log.error(f"Error saving image: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_url_image(url):
|
||||
image_id = str(uuid.uuid4())
|
||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
|
||||
|
||||
try:
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
|
||||
with open(file_path, "wb") as image_file:
|
||||
image_file.write(r.content)
|
||||
|
||||
return image_id
|
||||
except Exception as e:
|
||||
print(f"Error saving image: {e}")
|
||||
|
|
@ -278,6 +335,8 @@ def generate_image(
|
|||
user=Depends(get_current_user),
|
||||
):
|
||||
|
||||
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
|
||||
|
||||
r = None
|
||||
try:
|
||||
if app.state.ENGINE == "openai":
|
||||
|
|
@ -315,12 +374,47 @@ def generate_image(
|
|||
|
||||
return images
|
||||
|
||||
elif app.state.ENGINE == "comfyui":
|
||||
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"n": form_data.n,
|
||||
}
|
||||
|
||||
if app.state.IMAGE_STEPS != None:
|
||||
data["steps"] = app.state.IMAGE_STEPS
|
||||
|
||||
if form_data.negative_prompt != None:
|
||||
data["negative_prompt"] = form_data.negative_prompt
|
||||
|
||||
data = ImageGenerationPayload(**data)
|
||||
|
||||
res = comfyui_generate_image(
|
||||
app.state.MODEL,
|
||||
data,
|
||||
user.id,
|
||||
app.state.COMFYUI_BASE_URL,
|
||||
)
|
||||
print(res)
|
||||
|
||||
images = []
|
||||
|
||||
for image in res["data"]:
|
||||
image_id = save_url_image(image["url"])
|
||||
images.append({"url": f"/cache/image/generations/{image_id}.png"})
|
||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(data.model_dump(exclude_none=True), f)
|
||||
|
||||
print(images)
|
||||
return images
|
||||
else:
|
||||
if form_data.model:
|
||||
set_model_handler(form_data.model)
|
||||
|
||||
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
|
||||
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
"batch_size": form_data.n,
|
||||
|
|
@ -341,7 +435,7 @@ def generate_image(
|
|||
|
||||
res = r.json()
|
||||
|
||||
print(res)
|
||||
log.debug(f"res: {res}")
|
||||
|
||||
images = []
|
||||
|
||||
|
|
|
|||
228
backend/apps/images/utils/comfyui.py
Normal file
228
backend/apps/images/utils/comfyui.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||
import uuid
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import random
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from typing import Optional
|
||||
|
||||
COMFYUI_DEFAULT_PROMPT = """
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 0,
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "model.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "Prompt",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "Negative Prompt",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def queue_prompt(prompt, client_id, base_url):
|
||||
print("queue_prompt")
|
||||
p = {"prompt": prompt, "client_id": client_id}
|
||||
data = json.dumps(p).encode("utf-8")
|
||||
req = urllib.request.Request(f"{base_url}/prompt", data=data)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
|
||||
def get_image(filename, subfolder, folder_type, base_url):
|
||||
print("get_image")
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
|
||||
return response.read()
|
||||
|
||||
|
||||
def get_image_url(filename, subfolder, folder_type, base_url):
|
||||
print("get_image")
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
return f"{base_url}/view?{url_values}"
|
||||
|
||||
|
||||
def get_history(prompt_id, base_url):
|
||||
print("get_history")
|
||||
with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
|
||||
def get_images(ws, prompt, client_id, base_url):
|
||||
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
|
||||
output_images = []
|
||||
while True:
|
||||
out = ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message["type"] == "executing":
|
||||
data = message["data"]
|
||||
if data["node"] is None and data["prompt_id"] == prompt_id:
|
||||
break # Execution is done
|
||||
else:
|
||||
continue # previews are binary data
|
||||
|
||||
history = get_history(prompt_id, base_url)[prompt_id]
|
||||
for o in history["outputs"]:
|
||||
for node_id in history["outputs"]:
|
||||
node_output = history["outputs"][node_id]
|
||||
if "images" in node_output:
|
||||
for image in node_output["images"]:
|
||||
url = get_image_url(
|
||||
image["filename"], image["subfolder"], image["type"], base_url
|
||||
)
|
||||
output_images.append({"url": url})
|
||||
return {"data": output_images}
|
||||
|
||||
|
||||
class ImageGenerationPayload(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = ""
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
width: int
|
||||
height: int
|
||||
n: int = 1
|
||||
|
||||
|
||||
def comfyui_generate_image(
|
||||
model: str, payload: ImageGenerationPayload, client_id, base_url
|
||||
):
|
||||
host = base_url.replace("http://", "").replace("https://", "")
|
||||
|
||||
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
|
||||
|
||||
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
|
||||
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
|
||||
comfyui_prompt["5"]["inputs"]["width"] = payload.width
|
||||
comfyui_prompt["5"]["inputs"]["height"] = payload.height
|
||||
|
||||
# set the text prompt for our positive CLIPTextEncode
|
||||
comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
|
||||
comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
|
||||
|
||||
if payload.steps:
|
||||
comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
|
||||
|
||||
comfyui_prompt["3"]["inputs"]["seed"] = (
|
||||
payload.seed if payload.seed else random.randint(0, 18446744073709551614)
|
||||
)
|
||||
|
||||
try:
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect(f"ws://{host}/ws?clientId={client_id}")
|
||||
print("WebSocket connection established.")
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to WebSocket server: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
images = get_images(ws, comfyui_prompt, client_id, base_url)
|
||||
except Exception as e:
|
||||
print(f"Error while receiving images: {e}")
|
||||
images = None
|
||||
|
||||
ws.close()
|
||||
|
||||
return images
|
||||
|
|
@ -1,10 +1,27 @@
|
|||
import logging
|
||||
|
||||
from litellm.proxy.proxy_server import ProxyConfig, initialize
|
||||
from litellm.proxy.proxy_server import app
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status
|
||||
from fastapi import FastAPI, Request, Depends, status, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
from utils.utils import get_http_authorization_cred, get_current_user
|
||||
from config import ENV
|
||||
from config import SRC_LOG_LEVELS, ENV
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
|
||||
|
||||
|
||||
from config import (
|
||||
MODEL_FILTER_ENABLED,
|
||||
MODEL_FILTER_LIST,
|
||||
)
|
||||
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
|
||||
|
|
@ -26,16 +43,58 @@ async def on_startup():
|
|||
await startup()
|
||||
|
||||
|
||||
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
request.state.user = None
|
||||
|
||||
if ENV != "dev":
|
||||
try:
|
||||
user = get_current_user(get_http_authorization_cred(auth_header))
|
||||
print(user)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"detail": str(e)})
|
||||
try:
|
||||
user = get_current_user(get_http_authorization_cred(auth_header))
|
||||
log.debug(f"user: {user}")
|
||||
request.state.user = user
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"detail": str(e)})
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
|
||||
response = await call_next(request)
|
||||
user = request.state.user
|
||||
|
||||
if "/models" in request.url.path:
|
||||
if isinstance(response, StreamingResponse):
|
||||
# Read the content of the streaming response
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
|
||||
data = json.loads(body.decode("utf-8"))
|
||||
|
||||
if app.state.MODEL_FILTER_ENABLED:
|
||||
if user and user.role == "user":
|
||||
data["data"] = list(
|
||||
filter(
|
||||
lambda model: model["id"]
|
||||
in app.state.MODEL_FILTER_LIST,
|
||||
data["data"],
|
||||
)
|
||||
)
|
||||
|
||||
# Modified Flag
|
||||
data["modified"] = True
|
||||
return JSONResponse(content=data)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(ModifyModelsResponseMiddleware)
|
||||
|
|
|
|||
|
|
@ -1,24 +1,43 @@
|
|||
from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
Request,
|
||||
Response,
|
||||
HTTPException,
|
||||
Depends,
|
||||
status,
|
||||
UploadFile,
|
||||
File,
|
||||
BackgroundTasks,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
import os
|
||||
import copy
|
||||
import random
|
||||
import requests
|
||||
import json
|
||||
import uuid
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
from apps.web.models.users import Users
|
||||
from constants import ERROR_MESSAGES
|
||||
from utils.utils import decode_token, get_current_user, get_admin_user
|
||||
from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
|
||||
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR
|
||||
from utils.misc import calculate_sha256
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
|
|
@ -69,7 +88,7 @@ class UrlUpdateForm(BaseModel):
|
|||
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
|
||||
app.state.OLLAMA_BASE_URLS = form_data.urls
|
||||
|
||||
print(app.state.OLLAMA_BASE_URLS)
|
||||
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
|
||||
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
|
||||
|
||||
|
||||
|
|
@ -90,7 +109,7 @@ async def fetch_url(url):
|
|||
return await response.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.error(f"Connection error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -114,7 +133,7 @@ def merge_models_lists(model_lists):
|
|||
|
||||
|
||||
async def get_all_models():
|
||||
print("get_all_models")
|
||||
log.info("get_all_models()")
|
||||
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
|
|
@ -155,7 +174,7 @@ async def get_ollama_tags(
|
|||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -201,7 +220,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -227,18 +246,33 @@ async def pull_model(
|
|||
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
|
||||
):
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal url
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
yield chunk
|
||||
try:
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
print("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
|
|
@ -259,8 +293,9 @@ async def pull_model(
|
|||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -299,7 +334,7 @@ async def push_model(
|
|||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.debug(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
|
|
@ -331,7 +366,7 @@ async def push_model(
|
|||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -359,9 +394,9 @@ class CreateModelForm(BaseModel):
|
|||
async def create_model(
|
||||
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
|
||||
):
|
||||
print(form_data)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
|
|
@ -383,7 +418,7 @@ async def create_model(
|
|||
|
||||
r.raise_for_status()
|
||||
|
||||
print(r)
|
||||
log.debug(f"r: {r}")
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
|
|
@ -396,7 +431,7 @@ async def create_model(
|
|||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -434,7 +469,7 @@ async def copy_model(
|
|||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
|
|
@ -444,11 +479,11 @@ async def copy_model(
|
|||
)
|
||||
r.raise_for_status()
|
||||
|
||||
print(r.text)
|
||||
log.debug(f"r.text: {r.text}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -481,7 +516,7 @@ async def delete_model(
|
|||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
|
|
@ -491,11 +526,11 @@ async def delete_model(
|
|||
)
|
||||
r.raise_for_status()
|
||||
|
||||
print(r.text)
|
||||
log.debug(f"r.text: {r.text}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -521,7 +556,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
|
|||
|
||||
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
|
|
@ -533,7 +568,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
|
|||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -573,7 +608,7 @@ async def generate_embeddings(
|
|||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
|
|
@ -585,7 +620,7 @@ async def generate_embeddings(
|
|||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -633,7 +668,7 @@ async def generate_completion(
|
|||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
|
|
@ -654,7 +689,7 @@ async def generate_completion(
|
|||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
print("User: canceled request")
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
|
|
@ -709,7 +744,7 @@ class GenerateChatCompletionForm(BaseModel):
|
|||
format: Optional[str] = None
|
||||
options: Optional[dict] = None
|
||||
template: Optional[str] = None
|
||||
stream: Optional[bool] = True
|
||||
stream: Optional[bool] = None
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
|
|
@ -731,11 +766,11 @@ async def generate_chat_completion(
|
|||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
print(form_data.model_dump_json(exclude_none=True).encode())
|
||||
log.debug("form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(form_data.model_dump_json(exclude_none=True).encode()))
|
||||
|
||||
def get_request():
|
||||
nonlocal form_data
|
||||
|
|
@ -754,7 +789,7 @@ async def generate_chat_completion(
|
|||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
print("User: canceled request")
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
|
|
@ -777,7 +812,7 @@ async def generate_chat_completion(
|
|||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
raise e
|
||||
|
||||
try:
|
||||
|
|
@ -831,7 +866,7 @@ async def generate_openai_chat_completion(
|
|||
)
|
||||
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
print(url)
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
|
|
@ -854,7 +889,7 @@ async def generate_openai_chat_completion(
|
|||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
print("User: canceled request")
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
|
|
@ -897,6 +932,211 @@ async def generate_openai_chat_completion(
|
|||
)
|
||||
|
||||
|
||||
class UrlForm(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class UploadBlobForm(BaseModel):
|
||||
filename: str
|
||||
|
||||
|
||||
def parse_huggingface_url(hf_url):
|
||||
try:
|
||||
# Parse the URL
|
||||
parsed_url = urlparse(hf_url)
|
||||
|
||||
# Get the path and split it into components
|
||||
path_components = parsed_url.path.split("/")
|
||||
|
||||
# Extract the desired output
|
||||
user_repo = "/".join(path_components[1:3])
|
||||
model_file = path_components[-1]
|
||||
|
||||
return model_file
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
async def download_file_stream(
|
||||
ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
|
||||
):
|
||||
done = False
|
||||
|
||||
if os.path.exists(file_path):
|
||||
current_size = os.path.getsize(file_path)
|
||||
else:
|
||||
current_size = 0
|
||||
|
||||
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(file_url, headers=headers) as response:
|
||||
total_size = int(response.headers.get("content-length", 0)) + current_size
|
||||
|
||||
with open(file_path, "ab+") as file:
|
||||
async for data in response.content.iter_chunked(chunk_size):
|
||||
current_size += len(data)
|
||||
file.write(data)
|
||||
|
||||
done = current_size == total_size
|
||||
progress = round((current_size / total_size) * 100, 2)
|
||||
|
||||
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
|
||||
|
||||
if done:
|
||||
file.seek(0)
|
||||
hashed = calculate_sha256(file)
|
||||
file.seek(0)
|
||||
|
||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
||||
response = requests.post(url, data=file)
|
||||
|
||||
if response.ok:
|
||||
res = {
|
||||
"done": done,
|
||||
"blob": f"sha256:{hashed}",
|
||||
"name": file_name,
|
||||
}
|
||||
os.remove(file_path)
|
||||
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
else:
|
||||
raise "Ollama: Could not create blob, Please try again."
|
||||
|
||||
|
||||
# def number_generator():
|
||||
# for i in range(1, 101):
|
||||
# yield f"data: {i}\n"
|
||||
|
||||
|
||||
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
||||
@app.post("/models/download")
|
||||
@app.post("/models/download/{url_idx}")
|
||||
async def download_model(
|
||||
form_data: UrlForm,
|
||||
url_idx: Optional[int] = None,
|
||||
):
|
||||
|
||||
if url_idx == None:
|
||||
url_idx = 0
|
||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
file_name = parse_huggingface_url(form_data.url)
|
||||
|
||||
if file_name:
|
||||
file_path = f"{UPLOAD_DIR}/{file_name}"
|
||||
return StreamingResponse(
|
||||
download_file_stream(url, form_data.url, file_path, file_name),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@app.post("/models/upload")
|
||||
@app.post("/models/upload/{url_idx}")
|
||||
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
|
||||
if url_idx == None:
|
||||
url_idx = 0
|
||||
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
||||
|
||||
# Save file in chunks
|
||||
with open(file_path, "wb+") as f:
|
||||
for chunk in file.file:
|
||||
f.write(chunk)
|
||||
|
||||
def file_process_stream():
|
||||
nonlocal ollama_url
|
||||
total_size = os.path.getsize(file_path)
|
||||
chunk_size = 1024 * 1024
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
total = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
done = True
|
||||
continue
|
||||
|
||||
total += len(chunk)
|
||||
progress = round((total / total_size) * 100, 2)
|
||||
|
||||
res = {
|
||||
"progress": progress,
|
||||
"total": total_size,
|
||||
"completed": total,
|
||||
}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
|
||||
if done:
|
||||
f.seek(0)
|
||||
hashed = calculate_sha256(f)
|
||||
f.seek(0)
|
||||
|
||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
||||
response = requests.post(url, data=f)
|
||||
|
||||
if response.ok:
|
||||
res = {
|
||||
"done": done,
|
||||
"blob": f"sha256:{hashed}",
|
||||
"name": file.filename,
|
||||
}
|
||||
os.remove(file_path)
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
else:
|
||||
raise Exception(
|
||||
"Ollama: Could not create blob, Please try again."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
res = {"error": str(e)}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
|
||||
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
|
||||
# if url_idx == None:
|
||||
# url_idx = 0
|
||||
# url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
# file_location = os.path.join(UPLOAD_DIR, file.filename)
|
||||
# total_size = file.size
|
||||
|
||||
# async def file_upload_generator(file):
|
||||
# print(file)
|
||||
# try:
|
||||
# async with aiofiles.open(file_location, "wb") as f:
|
||||
# completed_size = 0
|
||||
# while True:
|
||||
# chunk = await file.read(1024*1024)
|
||||
# if not chunk:
|
||||
# break
|
||||
# await f.write(chunk)
|
||||
# completed_size += len(chunk)
|
||||
# progress = (completed_size / total_size) * 100
|
||||
|
||||
# print(progress)
|
||||
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
|
||||
# finally:
|
||||
# await file.close()
|
||||
# print("done")
|
||||
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
|
||||
|
||||
# return StreamingResponse(
|
||||
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
|
||||
# )
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
|
||||
url = app.state.OLLAMA_BASE_URLS[0]
|
||||
|
|
@ -947,7 +1187,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
|
|||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
print("User: canceled request")
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import requests
|
|||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -19,6 +20,7 @@ from utils.utils import (
|
|||
get_admin_user,
|
||||
)
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
CACHE_DIR,
|
||||
|
|
@ -31,6 +33,9 @@ from typing import List, Optional
|
|||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
|
@ -134,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||
return FileResponse(file_path)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -160,7 +165,7 @@ async def fetch_url(url, key):
|
|||
return await response.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
log.error(f"Connection error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -182,7 +187,7 @@ def merge_models_lists(model_lists):
|
|||
|
||||
|
||||
async def get_all_models():
|
||||
print("get_all_models")
|
||||
log.info("get_all_models()")
|
||||
|
||||
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
|
||||
models = {"data": []}
|
||||
|
|
@ -208,7 +213,7 @@ async def get_all_models():
|
|||
)
|
||||
}
|
||||
|
||||
print(models)
|
||||
log.info(f"models: {models}")
|
||||
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
||||
|
||||
return models
|
||||
|
|
@ -246,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
|
|||
|
||||
return response_data
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
@ -280,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
if body.get("model") == "gpt-4-vision-preview":
|
||||
if "max_tokens" not in body:
|
||||
body["max_tokens"] = 4000
|
||||
print("Modified body_dict:", body)
|
||||
log.debug("Modified body_dict:", body)
|
||||
|
||||
# Fix for ChatGPT calls failing because the num_ctx key is in body
|
||||
if "num_ctx" in body:
|
||||
|
|
@ -292,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
# Convert the modified body back to JSON
|
||||
body = json.dumps(body)
|
||||
except json.JSONDecodeError as e:
|
||||
print("Error loading request body into a dictionary:", e)
|
||||
log.error("Error loading request body into a dictionary:", e)
|
||||
|
||||
url = app.state.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.OPENAI_API_KEYS[idx]
|
||||
|
|
@ -330,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||
response_data = r.json()
|
||||
return response_data
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from fastapi import (
|
|||
Form,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import os, shutil
|
||||
import os, shutil, logging
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
|
@ -54,6 +54,7 @@ from utils.misc import (
|
|||
)
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
|
|
@ -66,6 +67,9 @@ from config import (
|
|||
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
#
|
||||
# if RAG_EMBEDDING_MODEL:
|
||||
# sentence_transformer_ef = SentenceTransformer(
|
||||
|
|
@ -110,40 +114,6 @@ class CollectionNameForm(BaseModel):
|
|||
class StoreWebForm(CollectionNameForm):
|
||||
url: str
|
||||
|
||||
|
||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
|
||||
)
|
||||
docs = text_splitter.split_documents(data)
|
||||
|
||||
texts = [doc.page_content for doc in docs]
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
||||
try:
|
||||
if overwrite:
|
||||
for collection in CHROMA_CLIENT.list_collections():
|
||||
if collection_name == collection.name:
|
||||
print(f"deleting existing collection {collection_name}")
|
||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
|
||||
collection.add(
|
||||
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if e.__class__.__name__ == "UniqueConstraintError":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {
|
||||
|
|
@ -274,7 +244,7 @@ def query_doc_handler(
|
|||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
|
|
@ -318,13 +288,63 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|||
"filename": form_data.url,
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=app.state.CHUNK_SIZE,
|
||||
chunk_overlap=app.state.CHUNK_OVERLAP,
|
||||
add_start_index=True,
|
||||
)
|
||||
docs = text_splitter.split_documents(data)
|
||||
return store_docs_in_vector_db(docs, collection_name, overwrite)
|
||||
|
||||
|
||||
def store_text_in_vector_db(
|
||||
text, metadata, collection_name, overwrite: bool = False
|
||||
) -> bool:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=app.state.CHUNK_SIZE,
|
||||
chunk_overlap=app.state.CHUNK_OVERLAP,
|
||||
add_start_index=True,
|
||||
)
|
||||
docs = text_splitter.create_documents([text], metadatas=[metadata])
|
||||
return store_docs_in_vector_db(docs, collection_name, overwrite)
|
||||
|
||||
|
||||
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
|
||||
texts = [doc.page_content for doc in docs]
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
||||
try:
|
||||
if overwrite:
|
||||
for collection in CHROMA_CLIENT.list_collections():
|
||||
if collection_name == collection.name:
|
||||
print(f"deleting existing collection {collection_name}")
|
||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
|
||||
collection.add(
|
||||
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if e.__class__.__name__ == "UniqueConstraintError":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_loader(filename: str, file_content_type: str, file_path: str):
|
||||
file_ext = filename.split(".")[-1].lower()
|
||||
known_type = True
|
||||
|
|
@ -416,7 +436,7 @@ def store_doc(
|
|||
):
|
||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||
|
||||
print(file.content_type)
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
filename = file.filename
|
||||
file_path = f"{UPLOAD_DIR}/{filename}"
|
||||
|
|
@ -447,7 +467,7 @@ def store_doc(
|
|||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
if "No pandoc was found" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -460,6 +480,37 @@ def store_doc(
|
|||
)
|
||||
|
||||
|
||||
class TextRAGForm(BaseModel):
|
||||
name: str
|
||||
content: str
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/text")
|
||||
def store_text(
|
||||
form_data: TextRAGForm,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == None:
|
||||
collection_name = calculate_sha256_string(form_data.content)
|
||||
|
||||
result = store_text_in_vector_db(
|
||||
form_data.content,
|
||||
metadata={"name": form_data.name, "created_by": user.id},
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
if result:
|
||||
return {"status": True, "collection_name": collection_name}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/scan")
|
||||
def scan_docs_dir(user=Depends(get_admin_user)):
|
||||
for path in Path(DOCS_DIR).rglob("./**/*"):
|
||||
|
|
@ -512,7 +563,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -533,11 +584,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
|
|||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
print("Failed to delete %s. Reason: %s" % (file_path, e))
|
||||
log.error("Failed to delete %s. Reason: %s" % (file_path, e))
|
||||
|
||||
try:
|
||||
CHROMA_CLIENT.reset()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import re
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from config import CHROMA_CLIENT
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def query_doc(collection_name: str, query: str, k: int, embedding_function):
|
||||
|
|
@ -97,7 +101,7 @@ def rag_template(template: str, context: str, query: str):
|
|||
|
||||
|
||||
def rag_messages(docs, messages, template, k, embedding_function):
|
||||
print(docs)
|
||||
log.debug(f"docs: {docs}")
|
||||
|
||||
last_user_message_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
|
|
@ -137,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
elif doc["type"] == "text":
|
||||
context = doc["content"]
|
||||
else:
|
||||
context = query_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
|
|
@ -145,7 +151,7 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|||
embedding_function=embedding_function,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
context = None
|
||||
|
||||
relevant_contexts.append(context)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
from peewee import *
|
||||
from config import DATA_DIR
|
||||
from config import SRC_LOG_LEVELS, DATA_DIR
|
||||
import os
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||
|
||||
# Check if the file exists
|
||||
if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
||||
# Rename the file
|
||||
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
|
||||
print("File renamed successfully.")
|
||||
log.info("File renamed successfully.")
|
||||
else:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from config import (
|
|||
DEFAULT_USER_ROLE,
|
||||
ENABLE_SIGNUP,
|
||||
USER_PERMISSIONS,
|
||||
WEBHOOK_URL,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
|
@ -32,6 +33,7 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS
|
|||
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
app.state.USER_PERMISSIONS = USER_PERMISSIONS
|
||||
app.state.WEBHOOK_URL = WEBHOOK_URL
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from pydantic import BaseModel
|
|||
from typing import List, Union, Optional
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from peewee import *
|
||||
|
||||
from apps.web.models.users import UserModel, Users
|
||||
|
|
@ -9,6 +10,10 @@ from utils.utils import verify_password
|
|||
|
||||
from apps.web.internal.db import DB
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# DB MODEL
|
||||
####################
|
||||
|
|
@ -86,7 +91,7 @@ class AuthsTable:
|
|||
def insert_new_auth(
|
||||
self, email: str, password: str, name: str, role: str = "pending"
|
||||
) -> Optional[UserModel]:
|
||||
print("insert_new_auth")
|
||||
log.info("insert_new_auth")
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
|
|
@ -103,7 +108,7 @@ class AuthsTable:
|
|||
return None
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
||||
print("authenticate_user", email)
|
||||
log.info(f"authenticate_user: {email}")
|
||||
try:
|
||||
auth = Auth.get(Auth.email == email, Auth.active == True)
|
||||
if auth:
|
||||
|
|
|
|||
|
|
@ -95,20 +95,6 @@ class ChatTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||
try:
|
||||
query = Chat.update(
|
||||
chat=json.dumps(chat),
|
||||
title=chat["title"] if "title" in chat else "New Chat",
|
||||
timestamp=int(time.time()),
|
||||
).where(Chat.id == id)
|
||||
query.execute()
|
||||
|
||||
chat = Chat.get(Chat.id == id)
|
||||
return ChatModel(**model_to_dict(chat))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_chat_lists_by_user_id(
|
||||
self, user_id: str, skip: int = 0, limit: int = 50
|
||||
) -> List[ChatModel]:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from peewee import *
|
|||
from playhouse.shortcuts import model_to_dict
|
||||
from typing import List, Union, Optional
|
||||
import time
|
||||
import logging
|
||||
|
||||
from utils.utils import decode_token
|
||||
from utils.misc import get_gravatar_url
|
||||
|
|
@ -11,6 +12,10 @@ from apps.web.internal.db import DB
|
|||
|
||||
import json
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Documents DB Schema
|
||||
####################
|
||||
|
|
@ -118,7 +123,7 @@ class DocumentsTable:
|
|||
doc = Document.get(Document.name == form_data.name)
|
||||
return DocumentModel(**model_to_dict(doc))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def update_doc_content_by_name(
|
||||
|
|
@ -138,7 +143,7 @@ class DocumentsTable:
|
|||
doc = Document.get(Document.name == name)
|
||||
return DocumentModel(**model_to_dict(doc))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def delete_doc_by_name(self, name: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -6,9 +6,14 @@ from playhouse.shortcuts import model_to_dict
|
|||
import json
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
|
||||
from apps.web.internal.db import DB
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Tag DB Schema
|
||||
####################
|
||||
|
|
@ -173,7 +178,7 @@ class TagTable:
|
|||
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
|
||||
)
|
||||
res = query.execute() # Remove the rows, return number of rows removed.
|
||||
print(res)
|
||||
log.debug(f"res: {res}")
|
||||
|
||||
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
|
||||
if tag_count == 0:
|
||||
|
|
@ -185,7 +190,7 @@ class TagTable:
|
|||
|
||||
return True
|
||||
except Exception as e:
|
||||
print("delete_tag", e)
|
||||
log.error(f"delete_tag: {e}")
|
||||
return False
|
||||
|
||||
def delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||
|
|
@ -198,7 +203,7 @@ class TagTable:
|
|||
& (ChatIdTag.user_id == user_id)
|
||||
)
|
||||
res = query.execute() # Remove the rows, return number of rows removed.
|
||||
print(res)
|
||||
log.debug(f"res: {res}")
|
||||
|
||||
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
|
||||
if tag_count == 0:
|
||||
|
|
@ -210,7 +215,7 @@ class TagTable:
|
|||
|
||||
return True
|
||||
except Exception as e:
|
||||
print("delete_tag", e)
|
||||
log.error(f"delete_tag: {e}")
|
||||
return False
|
||||
|
||||
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ from utils.utils import (
|
|||
create_token,
|
||||
)
|
||||
from utils.misc import parse_duration, validate_email_format
|
||||
from constants import ERROR_MESSAGES
|
||||
from utils.webhook import post_webhook
|
||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -155,6 +156,17 @@ async def signup(request: Request, form_data: SignupForm):
|
|||
)
|
||||
# response.set_cookie(key='token', value=token, httponly=True)
|
||||
|
||||
if request.app.state.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
request.app.state.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
"action": "signup",
|
||||
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
"user": user.model_dump_json(exclude_none=True),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user
|
|||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
import logging
|
||||
|
||||
from apps.web.models.users import Users
|
||||
from apps.web.models.chats import (
|
||||
|
|
@ -27,6 +28,10 @@ from apps.web.models.tags import (
|
|||
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
|
|
@ -78,7 +83,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
|
|||
chat = Chats.insert_new_chat(user.id, form_data)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -95,7 +100,7 @@ async def get_all_tags(user=Depends(get_current_user)):
|
|||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from fastapi import APIRouter
|
|||
from pydantic import BaseModel
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
|
||||
from apps.web.models.auths import Auths
|
||||
|
|
@ -14,6 +15,10 @@ from apps.web.models.auths import Auths
|
|||
from utils.utils import get_current_user, get_password_hash, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
|
|
@ -83,7 +88,7 @@ async def update_user_by_id(
|
|||
|
||||
if form_data.password:
|
||||
hashed = get_password_hash(form_data.password)
|
||||
print(hashed)
|
||||
log.debug(f"hashed: {hashed}")
|
||||
Auths.update_user_password_by_id(user_id, hashed)
|
||||
|
||||
Auths.update_email_by_id(user_id, form_data.email.lower())
|
||||
|
|
|
|||
|
|
@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
class UploadBlobForm(BaseModel):
|
||||
filename: str
|
||||
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def parse_huggingface_url(hf_url):
|
||||
try:
|
||||
# Parse the URL
|
||||
parsed_url = urlparse(hf_url)
|
||||
|
||||
# Get the path and split it into components
|
||||
path_components = parsed_url.path.split("/")
|
||||
|
||||
# Extract the desired output
|
||||
user_repo = "/".join(path_components[1:3])
|
||||
model_file = path_components[-1]
|
||||
|
||||
return model_file
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
|
||||
done = False
|
||||
|
||||
if os.path.exists(file_path):
|
||||
current_size = os.path.getsize(file_path)
|
||||
else:
|
||||
current_size = 0
|
||||
|
||||
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
total_size = int(response.headers.get("content-length", 0)) + current_size
|
||||
|
||||
with open(file_path, "ab+") as file:
|
||||
async for data in response.content.iter_chunked(chunk_size):
|
||||
current_size += len(data)
|
||||
file.write(data)
|
||||
|
||||
done = current_size == total_size
|
||||
progress = round((current_size / total_size) * 100, 2)
|
||||
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
|
||||
|
||||
if done:
|
||||
file.seek(0)
|
||||
hashed = calculate_sha256(file)
|
||||
file.seek(0)
|
||||
|
||||
url = f"{OLLAMA_BASE_URLS[0]}/api/blobs/sha256:{hashed}"
|
||||
response = requests.post(url, data=file)
|
||||
|
||||
if response.ok:
|
||||
res = {
|
||||
"done": done,
|
||||
"blob": f"sha256:{hashed}",
|
||||
"name": file_name,
|
||||
}
|
||||
os.remove(file_path)
|
||||
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
else:
|
||||
raise "Ollama: Could not create blob, Please try again."
|
||||
|
||||
|
||||
@router.get("/download")
|
||||
async def download(
|
||||
url: str,
|
||||
):
|
||||
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
||||
file_name = parse_huggingface_url(url)
|
||||
|
||||
if file_name:
|
||||
file_path = f"{UPLOAD_DIR}/{file_name}"
|
||||
|
||||
return StreamingResponse(
|
||||
download_file_stream(url, file_path, file_name),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
def upload(file: UploadFile = File(...)):
|
||||
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
||||
|
||||
# Save file in chunks
|
||||
with open(file_path, "wb+") as f:
|
||||
for chunk in file.file:
|
||||
f.write(chunk)
|
||||
|
||||
def file_process_stream():
|
||||
total_size = os.path.getsize(file_path)
|
||||
chunk_size = 1024 * 1024
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
total = 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
done = True
|
||||
continue
|
||||
|
||||
total += len(chunk)
|
||||
progress = round((total / total_size) * 100, 2)
|
||||
|
||||
res = {
|
||||
"progress": progress,
|
||||
"total": total_size,
|
||||
"completed": total,
|
||||
}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
|
||||
if done:
|
||||
f.seek(0)
|
||||
hashed = calculate_sha256(f)
|
||||
f.seek(0)
|
||||
|
||||
url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
|
||||
response = requests.post(url, data=f)
|
||||
|
||||
if response.ok:
|
||||
res = {
|
||||
"done": done,
|
||||
"blob": f"sha256:{hashed}",
|
||||
"name": file.filename,
|
||||
}
|
||||
os.remove(file_path)
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
else:
|
||||
raise Exception(
|
||||
"Ollama: Could not create blob, Please try again."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
res = {"error": str(e)}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
|
||||
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/gravatar")
|
||||
async def get_gravatar(
|
||||
email: str,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
import logging
|
||||
import chromadb
|
||||
from chromadb import Settings
|
||||
from base64 import b64encode
|
||||
|
|
@ -21,7 +23,7 @@ try:
|
|||
|
||||
load_dotenv(find_dotenv("../.env"))
|
||||
except ImportError:
|
||||
print("dotenv not installed, skipping...")
|
||||
log.warning("dotenv not installed, skipping...")
|
||||
|
||||
WEBUI_NAME = "Aura"
|
||||
shutil.copyfile("../build/favicon.png", "./static/favicon.png")
|
||||
|
|
@ -100,6 +102,34 @@ for version in soup.find_all("h2"):
|
|||
CHANGELOG = changelog_json
|
||||
|
||||
|
||||
####################################
|
||||
# LOGGING
|
||||
####################################
|
||||
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
||||
|
||||
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
||||
if GLOBAL_LOG_LEVEL in log_levels:
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
||||
else:
|
||||
GLOBAL_LOG_LEVEL = "INFO"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
||||
|
||||
log_sources = ["AUDIO", "CONFIG", "DB", "IMAGES", "LITELLM", "MAIN", "MODELS", "OLLAMA", "OPENAI", "RAG"]
|
||||
|
||||
SRC_LOG_LEVELS = {}
|
||||
|
||||
for source in log_sources:
|
||||
log_env_var = source + "_LOG_LEVEL"
|
||||
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
||||
if SRC_LOG_LEVELS[source] not in log_levels:
|
||||
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
||||
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
||||
|
||||
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||
|
||||
|
||||
####################################
|
||||
# CUSTOM_NAME
|
||||
####################################
|
||||
|
|
@ -125,7 +155,7 @@ if CUSTOM_NAME:
|
|||
|
||||
WEBUI_NAME = data["name"]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception(e)
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -194,9 +224,9 @@ def create_config_file(file_path):
|
|||
LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml"
|
||||
|
||||
if not os.path.exists(LITELLM_CONFIG_PATH):
|
||||
print("Config file doesn't exist. Creating...")
|
||||
log.info("Config file doesn't exist. Creating...")
|
||||
create_config_file(LITELLM_CONFIG_PATH)
|
||||
print("Config file created successfully.")
|
||||
log.info("Config file created successfully.")
|
||||
|
||||
|
||||
####################################
|
||||
|
|
@ -290,13 +320,19 @@ DEFAULT_PROMPT_SUGGESTIONS = (
|
|||
|
||||
|
||||
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
|
||||
USER_PERMISSIONS = {"chat": {"deletion": True}}
|
||||
|
||||
USER_PERMISSIONS_CHAT_DELETION = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
|
||||
|
||||
|
||||
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
|
||||
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
|
||||
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
|
||||
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
|
||||
|
||||
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
|
||||
|
||||
####################################
|
||||
# WEBUI_VERSION
|
||||
|
|
@ -370,3 +406,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models"
|
|||
####################################
|
||||
|
||||
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
|
||||
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,13 @@ class MESSAGES(str, Enum):
|
|||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
||||
|
||||
|
||||
class WEBHOOK_MESSAGES(str, Enum):
|
||||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
||||
USER_SIGNUP = lambda username="": (
|
||||
f"New user signed up: {username}" if username else "New user signed up"
|
||||
)
|
||||
|
||||
|
||||
class ERROR_MESSAGES(str, Enum):
|
||||
def __str__(self) -> str:
|
||||
return super().__str__()
|
||||
|
|
@ -46,7 +53,7 @@ class ERROR_MESSAGES(str, Enum):
|
|||
|
||||
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
|
||||
INCORRECT_FORMAT = (
|
||||
lambda err="": f"Invalid format. Please use the correct format{err if err else ''}"
|
||||
lambda err="": f"Invalid format. Please use the correct format{err}"
|
||||
)
|
||||
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"version": "0.0.1",
|
||||
"version": 0,
|
||||
"ui": {
|
||||
"prompt_suggestions": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import markdown
|
|||
import time
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status
|
||||
|
|
@ -38,9 +39,15 @@ from config import (
|
|||
FRONTEND_BUILD_DIR,
|
||||
MODEL_FILTER_ENABLED,
|
||||
MODEL_FILTER_LIST,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
SRC_LOG_LEVELS,
|
||||
WEBHOOK_URL,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
class SPAStaticFiles(StaticFiles):
|
||||
async def get_response(self, path: str, scope):
|
||||
|
|
@ -58,6 +65,9 @@ app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
|
|||
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
app.state.WEBHOOK_URL = WEBHOOK_URL
|
||||
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
|
||||
|
|
@ -66,7 +76,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||
if request.method == "POST" and (
|
||||
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
||||
):
|
||||
print(request.url.path)
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
|
|
@ -89,7 +99,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||
)
|
||||
del data["docs"]
|
||||
|
||||
print(data["messages"])
|
||||
log.debug(f"data['messages']: {data['messages']}")
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
|
||||
|
|
@ -178,7 +188,7 @@ class ModelFilterConfigForm(BaseModel):
|
|||
|
||||
|
||||
@app.post("/api/config/model/filter")
|
||||
async def get_model_filter_config(
|
||||
async def update_model_filter_config(
|
||||
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
|
|
@ -197,6 +207,28 @@ async def get_model_filter_config(
|
|||
}
|
||||
|
||||
|
||||
@app.get("/api/webhook")
|
||||
async def get_webhook_url(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"url": app.state.WEBHOOK_URL,
|
||||
}
|
||||
|
||||
|
||||
class UrlForm(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@app.post("/api/webhook")
|
||||
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
|
||||
app.state.WEBHOOK_URL = form_data.url
|
||||
|
||||
webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL
|
||||
|
||||
return {
|
||||
"url": app.state.WEBHOOK_URL,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/version")
|
||||
async def get_app_config():
|
||||
|
||||
|
|
|
|||
|
|
@ -45,3 +45,4 @@ PyJWT
|
|||
pyjwt[crypto]
|
||||
|
||||
black
|
||||
langfuse
|
||||
|
|
|
|||
20
backend/utils/webhook.py
Normal file
20
backend/utils/webhook.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import requests
|
||||
|
||||
|
||||
def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
||||
try:
|
||||
payload = {}
|
||||
|
||||
if "https://hooks.slack.com" in url:
|
||||
payload["text"] = message
|
||||
elif "https://discord.com/api/webhooks" in url:
|
||||
payload["content"] = message
|
||||
else:
|
||||
payload = {**event_data}
|
||||
|
||||
r = requests.post(url, json=payload)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return False
|
||||
Loading…
Add table
Add a link
Reference in a new issue