forked from open-webui/open-webui
		
	Merge branch 'dev' into debug_print
This commit is contained in:
		
						commit
						371dfc1143
					
				
					 42 changed files with 4942 additions and 5335 deletions
				
			
		|  | @ -18,6 +18,8 @@ from utils.utils import ( | |||
|     get_current_user, | ||||
|     get_admin_user, | ||||
| ) | ||||
| 
 | ||||
| from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image | ||||
| from utils.misc import calculate_sha256 | ||||
| from typing import Optional | ||||
| from pydantic import BaseModel | ||||
|  | @ -27,7 +29,8 @@ import base64 | |||
| import json | ||||
| import logging | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL | ||||
| from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL | ||||
| 
 | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["IMAGES"]) | ||||
|  | @ -52,6 +55,8 @@ app.state.MODEL = "" | |||
| 
 | ||||
| 
 | ||||
| app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL | ||||
| app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL | ||||
| 
 | ||||
| 
 | ||||
| app.state.IMAGE_SIZE = "512x512" | ||||
| app.state.IMAGE_STEPS = 50 | ||||
|  | @ -74,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user | |||
|     return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} | ||||
| 
 | ||||
| 
 | ||||
| class UrlUpdateForm(BaseModel): | ||||
|     url: str | ||||
| class EngineUrlUpdateForm(BaseModel): | ||||
|     AUTOMATIC1111_BASE_URL: Optional[str] = None | ||||
|     COMFYUI_BASE_URL: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/url") | ||||
| async def get_automatic1111_url(user=Depends(get_admin_user)): | ||||
|     return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} | ||||
| async def get_engine_url(user=Depends(get_admin_user)): | ||||
|     return { | ||||
|         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, | ||||
|         "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/url/update") | ||||
| async def update_automatic1111_url( | ||||
|     form_data: UrlUpdateForm, user=Depends(get_admin_user) | ||||
| async def update_engine_url( | ||||
|     form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|     if form_data.url == "": | ||||
|     if form_data.AUTOMATIC1111_BASE_URL == None: | ||||
|         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL | ||||
|     else: | ||||
|         url = form_data.url.strip("/") | ||||
|         url = form_data.AUTOMATIC1111_BASE_URL.strip("/") | ||||
|         try: | ||||
|             r = requests.head(url) | ||||
|             app.state.AUTOMATIC1111_BASE_URL = url | ||||
|         except Exception as e: | ||||
|             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
| 
 | ||||
|     if form_data.COMFYUI_BASE_URL == None: | ||||
|         app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL | ||||
|     else: | ||||
|         url = form_data.COMFYUI_BASE_URL.strip("/") | ||||
| 
 | ||||
|         try: | ||||
|             r = requests.head(url) | ||||
|             app.state.COMFYUI_BASE_URL = url | ||||
|         except Exception as e: | ||||
|             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
| 
 | ||||
|     return { | ||||
|         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, | ||||
|         "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, | ||||
|         "status": True, | ||||
|     } | ||||
| 
 | ||||
|  | @ -189,6 +210,18 @@ def get_models(user=Depends(get_current_user)): | |||
|                 {"id": "dall-e-2", "name": "DALL·E 2"}, | ||||
|                 {"id": "dall-e-3", "name": "DALL·E 3"}, | ||||
|             ] | ||||
|         elif app.state.ENGINE == "comfyui": | ||||
| 
 | ||||
|             r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") | ||||
|             info = r.json() | ||||
| 
 | ||||
|             return list( | ||||
|                 map( | ||||
|                     lambda model: {"id": model, "name": model}, | ||||
|                     info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0], | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         else: | ||||
|             r = requests.get( | ||||
|                 url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" | ||||
|  | @ -210,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)): | |||
|     try: | ||||
|         if app.state.ENGINE == "openai": | ||||
|             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} | ||||
|         elif app.state.ENGINE == "comfyui": | ||||
|             return {"model": app.state.MODEL if app.state.MODEL else ""} | ||||
|         else: | ||||
|             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|             options = r.json() | ||||
|  | @ -224,10 +259,12 @@ class UpdateModelForm(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| def set_model_handler(model: str): | ||||
| 
 | ||||
|     if app.state.ENGINE == "openai": | ||||
|         app.state.MODEL = model | ||||
|         return app.state.MODEL | ||||
|     if app.state.ENGINE == "comfyui": | ||||
|         app.state.MODEL = model | ||||
|         return app.state.MODEL | ||||
|     else: | ||||
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|         options = r.json() | ||||
|  | @ -275,12 +312,31 @@ def save_b64_image(b64_str): | |||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| def save_url_image(url): | ||||
|     image_id = str(uuid.uuid4()) | ||||
|     file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") | ||||
| 
 | ||||
|     try: | ||||
|         r = requests.get(url) | ||||
|         r.raise_for_status() | ||||
| 
 | ||||
|         with open(file_path, "wb") as image_file: | ||||
|             image_file.write(r.content) | ||||
| 
 | ||||
|         return image_id | ||||
|     except Exception as e: | ||||
|         print(f"Error saving image: {e}") | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/generations") | ||||
| def generate_image( | ||||
|     form_data: GenerateImageForm, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
| 
 | ||||
|     width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||
| 
 | ||||
|     r = None | ||||
|     try: | ||||
|         if app.state.ENGINE == "openai": | ||||
|  | @ -318,12 +374,47 @@ def generate_image( | |||
| 
 | ||||
|             return images | ||||
| 
 | ||||
|         elif app.state.ENGINE == "comfyui": | ||||
| 
 | ||||
|             data = { | ||||
|                 "prompt": form_data.prompt, | ||||
|                 "width": width, | ||||
|                 "height": height, | ||||
|                 "n": form_data.n, | ||||
|             } | ||||
| 
 | ||||
|             if app.state.IMAGE_STEPS != None: | ||||
|                 data["steps"] = app.state.IMAGE_STEPS | ||||
| 
 | ||||
|             if form_data.negative_prompt != None: | ||||
|                 data["negative_prompt"] = form_data.negative_prompt | ||||
| 
 | ||||
|             data = ImageGenerationPayload(**data) | ||||
| 
 | ||||
|             res = comfyui_generate_image( | ||||
|                 app.state.MODEL, | ||||
|                 data, | ||||
|                 user.id, | ||||
|                 app.state.COMFYUI_BASE_URL, | ||||
|             ) | ||||
|             print(res) | ||||
| 
 | ||||
|             images = [] | ||||
| 
 | ||||
|             for image in res["data"]: | ||||
|                 image_id = save_url_image(image["url"]) | ||||
|                 images.append({"url": f"/cache/image/generations/{image_id}.png"}) | ||||
|                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") | ||||
| 
 | ||||
|                 with open(file_body_path, "w") as f: | ||||
|                     json.dump(data.model_dump(exclude_none=True), f) | ||||
| 
 | ||||
|             print(images) | ||||
|             return images | ||||
|         else: | ||||
|             if form_data.model: | ||||
|                 set_model_handler(form_data.model) | ||||
| 
 | ||||
|             width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||
| 
 | ||||
|             data = { | ||||
|                 "prompt": form_data.prompt, | ||||
|                 "batch_size": form_data.n, | ||||
|  |  | |||
							
								
								
									
										228
									
								
								backend/apps/images/utils/comfyui.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								backend/apps/images/utils/comfyui.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,228 @@ | |||
| import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) | ||||
| import uuid | ||||
| import json | ||||
| import urllib.request | ||||
| import urllib.parse | ||||
| import random | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from typing import Optional | ||||
| 
 | ||||
| COMFYUI_DEFAULT_PROMPT = """ | ||||
| { | ||||
|   "3": { | ||||
|     "inputs": { | ||||
|       "seed": 0, | ||||
|       "steps": 20, | ||||
|       "cfg": 8, | ||||
|       "sampler_name": "euler", | ||||
|       "scheduler": "normal", | ||||
|       "denoise": 1, | ||||
|       "model": [ | ||||
|         "4", | ||||
|         0 | ||||
|       ], | ||||
|       "positive": [ | ||||
|         "6", | ||||
|         0 | ||||
|       ], | ||||
|       "negative": [ | ||||
|         "7", | ||||
|         0 | ||||
|       ], | ||||
|       "latent_image": [ | ||||
|         "5", | ||||
|         0 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "KSampler", | ||||
|     "_meta": { | ||||
|       "title": "KSampler" | ||||
|     } | ||||
|   }, | ||||
|   "4": { | ||||
|     "inputs": { | ||||
|       "ckpt_name": "model.safetensors" | ||||
|     }, | ||||
|     "class_type": "CheckpointLoaderSimple", | ||||
|     "_meta": { | ||||
|       "title": "Load Checkpoint" | ||||
|     } | ||||
|   }, | ||||
|   "5": { | ||||
|     "inputs": { | ||||
|       "width": 512, | ||||
|       "height": 512, | ||||
|       "batch_size": 1 | ||||
|     }, | ||||
|     "class_type": "EmptyLatentImage", | ||||
|     "_meta": { | ||||
|       "title": "Empty Latent Image" | ||||
|     } | ||||
|   }, | ||||
|   "6": { | ||||
|     "inputs": { | ||||
|       "text": "Prompt", | ||||
|       "clip": [ | ||||
|         "4", | ||||
|         1 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "CLIPTextEncode", | ||||
|     "_meta": { | ||||
|       "title": "CLIP Text Encode (Prompt)" | ||||
|     } | ||||
|   }, | ||||
|   "7": { | ||||
|     "inputs": { | ||||
|       "text": "Negative Prompt", | ||||
|       "clip": [ | ||||
|         "4", | ||||
|         1 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "CLIPTextEncode", | ||||
|     "_meta": { | ||||
|       "title": "CLIP Text Encode (Prompt)" | ||||
|     } | ||||
|   }, | ||||
|   "8": { | ||||
|     "inputs": { | ||||
|       "samples": [ | ||||
|         "3", | ||||
|         0 | ||||
|       ], | ||||
|       "vae": [ | ||||
|         "4", | ||||
|         2 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "VAEDecode", | ||||
|     "_meta": { | ||||
|       "title": "VAE Decode" | ||||
|     } | ||||
|   }, | ||||
|   "9": { | ||||
|     "inputs": { | ||||
|       "filename_prefix": "ComfyUI", | ||||
|       "images": [ | ||||
|         "8", | ||||
|         0 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "SaveImage", | ||||
|     "_meta": { | ||||
|       "title": "Save Image" | ||||
|     } | ||||
|   } | ||||
| } | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| def queue_prompt(prompt, client_id, base_url): | ||||
|     print("queue_prompt") | ||||
|     p = {"prompt": prompt, "client_id": client_id} | ||||
|     data = json.dumps(p).encode("utf-8") | ||||
|     req = urllib.request.Request(f"{base_url}/prompt", data=data) | ||||
|     return json.loads(urllib.request.urlopen(req).read()) | ||||
| 
 | ||||
| 
 | ||||
| def get_image(filename, subfolder, folder_type, base_url): | ||||
|     print("get_image") | ||||
|     data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||||
|     url_values = urllib.parse.urlencode(data) | ||||
|     with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response: | ||||
|         return response.read() | ||||
| 
 | ||||
| 
 | ||||
| def get_image_url(filename, subfolder, folder_type, base_url): | ||||
|     print("get_image") | ||||
|     data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||||
|     url_values = urllib.parse.urlencode(data) | ||||
|     return f"{base_url}/view?{url_values}" | ||||
| 
 | ||||
| 
 | ||||
| def get_history(prompt_id, base_url): | ||||
|     print("get_history") | ||||
|     with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response: | ||||
|         return json.loads(response.read()) | ||||
| 
 | ||||
| 
 | ||||
| def get_images(ws, prompt, client_id, base_url): | ||||
|     prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] | ||||
|     output_images = [] | ||||
|     while True: | ||||
|         out = ws.recv() | ||||
|         if isinstance(out, str): | ||||
|             message = json.loads(out) | ||||
|             if message["type"] == "executing": | ||||
|                 data = message["data"] | ||||
|                 if data["node"] is None and data["prompt_id"] == prompt_id: | ||||
|                     break  # Execution is done | ||||
|         else: | ||||
|             continue  # previews are binary data | ||||
| 
 | ||||
|     history = get_history(prompt_id, base_url)[prompt_id] | ||||
|     for o in history["outputs"]: | ||||
|         for node_id in history["outputs"]: | ||||
|             node_output = history["outputs"][node_id] | ||||
|             if "images" in node_output: | ||||
|                 for image in node_output["images"]: | ||||
|                     url = get_image_url( | ||||
|                         image["filename"], image["subfolder"], image["type"], base_url | ||||
|                     ) | ||||
|                     output_images.append({"url": url}) | ||||
|     return {"data": output_images} | ||||
| 
 | ||||
| 
 | ||||
| class ImageGenerationPayload(BaseModel): | ||||
|     prompt: str | ||||
|     negative_prompt: Optional[str] = "" | ||||
|     steps: Optional[int] = None | ||||
|     seed: Optional[int] = None | ||||
|     width: int | ||||
|     height: int | ||||
|     n: int = 1 | ||||
| 
 | ||||
| 
 | ||||
| def comfyui_generate_image( | ||||
|     model: str, payload: ImageGenerationPayload, client_id, base_url | ||||
| ): | ||||
|     host = base_url.replace("http://", "").replace("https://", "") | ||||
| 
 | ||||
|     comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) | ||||
| 
 | ||||
|     comfyui_prompt["4"]["inputs"]["ckpt_name"] = model | ||||
|     comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n | ||||
|     comfyui_prompt["5"]["inputs"]["width"] = payload.width | ||||
|     comfyui_prompt["5"]["inputs"]["height"] = payload.height | ||||
| 
 | ||||
|     # set the text prompt for our positive CLIPTextEncode | ||||
|     comfyui_prompt["6"]["inputs"]["text"] = payload.prompt | ||||
|     comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt | ||||
| 
 | ||||
|     if payload.steps: | ||||
|         comfyui_prompt["3"]["inputs"]["steps"] = payload.steps | ||||
| 
 | ||||
|     comfyui_prompt["3"]["inputs"]["seed"] = ( | ||||
|         payload.seed if payload.seed else random.randint(0, 18446744073709551614) | ||||
|     ) | ||||
| 
 | ||||
|     try: | ||||
|         ws = websocket.WebSocket() | ||||
|         ws.connect(f"ws://{host}/ws?clientId={client_id}") | ||||
|         print("WebSocket connection established.") | ||||
|     except Exception as e: | ||||
|         print(f"Failed to connect to WebSocket server: {e}") | ||||
|         return None | ||||
| 
 | ||||
|     try: | ||||
|         images = get_images(ws, comfyui_prompt, client_id, base_url) | ||||
|     except Exception as e: | ||||
|         print(f"Error while receiving images: {e}") | ||||
|         images = None | ||||
| 
 | ||||
|     ws.close() | ||||
| 
 | ||||
|     return images | ||||
|  | @ -1,10 +1,22 @@ | |||
| from fastapi import FastAPI, Request, Response, HTTPException, Depends, status | ||||
| from fastapi import ( | ||||
|     FastAPI, | ||||
|     Request, | ||||
|     Response, | ||||
|     HTTPException, | ||||
|     Depends, | ||||
|     status, | ||||
|     UploadFile, | ||||
|     File, | ||||
|     BackgroundTasks, | ||||
| ) | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from fastapi.responses import StreamingResponse | ||||
| from fastapi.concurrency import run_in_threadpool | ||||
| 
 | ||||
| from pydantic import BaseModel, ConfigDict | ||||
| 
 | ||||
| import os | ||||
| import copy | ||||
| import random | ||||
| import requests | ||||
| import json | ||||
|  | @ -12,13 +24,17 @@ import uuid | |||
| import aiohttp | ||||
| import asyncio | ||||
| import logging | ||||
| from urllib.parse import urlparse | ||||
| from typing import Optional, List, Union | ||||
| 
 | ||||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| from constants import ERROR_MESSAGES | ||||
| from utils.utils import decode_token, get_current_user, get_admin_user | ||||
| from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST | ||||
| 
 | ||||
| from typing import Optional, List, Union | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR | ||||
| from utils.misc import calculate_sha256 | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) | ||||
|  | @ -237,11 +253,26 @@ async def pull_model( | |||
|     def get_request(): | ||||
|         nonlocal url | ||||
|         nonlocal r | ||||
| 
 | ||||
|         request_id = str(uuid.uuid4()) | ||||
|         try: | ||||
|             REQUEST_POOL.append(request_id) | ||||
| 
 | ||||
|             def stream_content(): | ||||
|                 for chunk in r.iter_content(chunk_size=8192): | ||||
|                     yield chunk | ||||
|                 try: | ||||
|                     yield json.dumps({"id": request_id, "done": False}) + "\n" | ||||
| 
 | ||||
|                     for chunk in r.iter_content(chunk_size=8192): | ||||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|                         r.close() | ||||
|                         if request_id in REQUEST_POOL: | ||||
|                             REQUEST_POOL.remove(request_id) | ||||
| 
 | ||||
|             r = requests.request( | ||||
|                 method="POST", | ||||
|  | @ -262,6 +293,7 @@ async def pull_model( | |||
| 
 | ||||
|     try: | ||||
|         return await run_in_threadpool(get_request) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|  | @ -900,6 +932,211 @@ async def generate_openai_chat_completion( | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class UrlForm(BaseModel): | ||||
|     url: str | ||||
| 
 | ||||
| 
 | ||||
| class UploadBlobForm(BaseModel): | ||||
|     filename: str | ||||
| 
 | ||||
| 
 | ||||
| def parse_huggingface_url(hf_url): | ||||
|     try: | ||||
|         # Parse the URL | ||||
|         parsed_url = urlparse(hf_url) | ||||
| 
 | ||||
|         # Get the path and split it into components | ||||
|         path_components = parsed_url.path.split("/") | ||||
| 
 | ||||
|         # Extract the desired output | ||||
|         user_repo = "/".join(path_components[1:3]) | ||||
|         model_file = path_components[-1] | ||||
| 
 | ||||
|         return model_file | ||||
|     except ValueError: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| async def download_file_stream( | ||||
|     ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 | ||||
| ): | ||||
|     done = False | ||||
| 
 | ||||
|     if os.path.exists(file_path): | ||||
|         current_size = os.path.getsize(file_path) | ||||
|     else: | ||||
|         current_size = 0 | ||||
| 
 | ||||
|     headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} | ||||
| 
 | ||||
|     timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout | ||||
| 
 | ||||
|     async with aiohttp.ClientSession(timeout=timeout) as session: | ||||
|         async with session.get(file_url, headers=headers) as response: | ||||
|             total_size = int(response.headers.get("content-length", 0)) + current_size | ||||
| 
 | ||||
|             with open(file_path, "ab+") as file: | ||||
|                 async for data in response.content.iter_chunked(chunk_size): | ||||
|                     current_size += len(data) | ||||
|                     file.write(data) | ||||
| 
 | ||||
|                     done = current_size == total_size | ||||
|                     progress = round((current_size / total_size) * 100, 2) | ||||
| 
 | ||||
|                     yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' | ||||
| 
 | ||||
|                 if done: | ||||
|                     file.seek(0) | ||||
|                     hashed = calculate_sha256(file) | ||||
|                     file.seek(0) | ||||
| 
 | ||||
|                     url = f"{ollama_url}/api/blobs/sha256:{hashed}" | ||||
|                     response = requests.post(url, data=file) | ||||
| 
 | ||||
|                     if response.ok: | ||||
|                         res = { | ||||
|                             "done": done, | ||||
|                             "blob": f"sha256:{hashed}", | ||||
|                             "name": file_name, | ||||
|                         } | ||||
|                         os.remove(file_path) | ||||
| 
 | ||||
|                         yield f"data: {json.dumps(res)}\n\n" | ||||
|                     else: | ||||
|                         raise "Ollama: Could not create blob, Please try again." | ||||
| 
 | ||||
| 
 | ||||
| # def number_generator(): | ||||
| #     for i in range(1, 101): | ||||
| #         yield f"data: {i}\n" | ||||
| 
 | ||||
| 
 | ||||
| # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" | ||||
| @app.post("/models/download") | ||||
| @app.post("/models/download/{url_idx}") | ||||
| async def download_model( | ||||
|     form_data: UrlForm, | ||||
|     url_idx: Optional[int] = None, | ||||
| ): | ||||
| 
 | ||||
|     if url_idx == None: | ||||
|         url_idx = 0 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
| 
 | ||||
|     file_name = parse_huggingface_url(form_data.url) | ||||
| 
 | ||||
|     if file_name: | ||||
|         file_path = f"{UPLOAD_DIR}/{file_name}" | ||||
|         return StreamingResponse( | ||||
|             download_file_stream(url, form_data.url, file_path, file_name), | ||||
|         ) | ||||
|     else: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/models/upload") | ||||
| @app.post("/models/upload/{url_idx}") | ||||
| def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): | ||||
|     if url_idx == None: | ||||
|         url_idx = 0 | ||||
|     ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
| 
 | ||||
|     file_path = f"{UPLOAD_DIR}/{file.filename}" | ||||
| 
 | ||||
|     # Save file in chunks | ||||
|     with open(file_path, "wb+") as f: | ||||
|         for chunk in file.file: | ||||
|             f.write(chunk) | ||||
| 
 | ||||
|     def file_process_stream(): | ||||
|         nonlocal ollama_url | ||||
|         total_size = os.path.getsize(file_path) | ||||
|         chunk_size = 1024 * 1024 | ||||
|         try: | ||||
|             with open(file_path, "rb") as f: | ||||
|                 total = 0 | ||||
|                 done = False | ||||
| 
 | ||||
|                 while not done: | ||||
|                     chunk = f.read(chunk_size) | ||||
|                     if not chunk: | ||||
|                         done = True | ||||
|                         continue | ||||
| 
 | ||||
|                     total += len(chunk) | ||||
|                     progress = round((total / total_size) * 100, 2) | ||||
| 
 | ||||
|                     res = { | ||||
|                         "progress": progress, | ||||
|                         "total": total_size, | ||||
|                         "completed": total, | ||||
|                     } | ||||
|                     yield f"data: {json.dumps(res)}\n\n" | ||||
| 
 | ||||
|                 if done: | ||||
|                     f.seek(0) | ||||
|                     hashed = calculate_sha256(f) | ||||
|                     f.seek(0) | ||||
| 
 | ||||
|                     url = f"{ollama_url}/api/blobs/sha256:{hashed}" | ||||
|                     response = requests.post(url, data=f) | ||||
| 
 | ||||
|                     if response.ok: | ||||
|                         res = { | ||||
|                             "done": done, | ||||
|                             "blob": f"sha256:{hashed}", | ||||
|                             "name": file.filename, | ||||
|                         } | ||||
|                         os.remove(file_path) | ||||
|                         yield f"data: {json.dumps(res)}\n\n" | ||||
|                     else: | ||||
|                         raise Exception( | ||||
|                             "Ollama: Could not create blob, Please try again." | ||||
|                         ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             res = {"error": str(e)} | ||||
|             yield f"data: {json.dumps(res)}\n\n" | ||||
| 
 | ||||
|     return StreamingResponse(file_process_stream(), media_type="text/event-stream") | ||||
| 
 | ||||
| 
 | ||||
| # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): | ||||
| #     if url_idx == None: | ||||
| #         url_idx = 0 | ||||
| #     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
| 
 | ||||
| #     file_location = os.path.join(UPLOAD_DIR, file.filename) | ||||
| #     total_size = file.size | ||||
| 
 | ||||
| #     async def file_upload_generator(file): | ||||
| #         print(file) | ||||
| #         try: | ||||
| #             async with aiofiles.open(file_location, "wb") as f: | ||||
| #                 completed_size = 0 | ||||
| #                 while True: | ||||
| #                     chunk = await file.read(1024*1024) | ||||
| #                     if not chunk: | ||||
| #                         break | ||||
| #                     await f.write(chunk) | ||||
| #                     completed_size += len(chunk) | ||||
| #                     progress = (completed_size / total_size) * 100 | ||||
| 
 | ||||
| #                     print(progress) | ||||
| #                     yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n' | ||||
| #         except Exception as e: | ||||
| #             print(e) | ||||
| #             yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n" | ||||
| #         finally: | ||||
| #             await file.close() | ||||
| #             print("done") | ||||
| #             yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n' | ||||
| 
 | ||||
| #     return StreamingResponse( | ||||
| #         file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream" | ||||
| #     ) | ||||
| 
 | ||||
| 
 | ||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||
| async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||
|     url = app.state.OLLAMA_BASE_URLS[0] | ||||
|  |  | |||
|  | @ -114,40 +114,6 @@ class CollectionNameForm(BaseModel): | |||
| class StoreWebForm(CollectionNameForm): | ||||
|     url: str | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP | ||||
|     ) | ||||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         if overwrite: | ||||
|             for collection in CHROMA_CLIENT.list_collections(): | ||||
|                 if collection_name == collection.name: | ||||
|                     log.info(f"deleting existing collection {collection_name}") | ||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||
| 
 | ||||
|         collection = CHROMA_CLIENT.create_collection( | ||||
|             name=collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|         ) | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         if e.__class__.__name__ == "UniqueConstraintError": | ||||
|             return True | ||||
| 
 | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/") | ||||
| async def get_status(): | ||||
|     return { | ||||
|  | @ -329,6 +295,56 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, | ||||
|         chunk_overlap=app.state.CHUNK_OVERLAP, | ||||
|         add_start_index=True, | ||||
|     ) | ||||
|     docs = text_splitter.split_documents(data) | ||||
|     return store_docs_in_vector_db(docs, collection_name, overwrite) | ||||
| 
 | ||||
| 
 | ||||
| def store_text_in_vector_db( | ||||
|     text, metadata, collection_name, overwrite: bool = False | ||||
| ) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, | ||||
|         chunk_overlap=app.state.CHUNK_OVERLAP, | ||||
|         add_start_index=True, | ||||
|     ) | ||||
|     docs = text_splitter.create_documents([text], metadatas=[metadata]) | ||||
|     return store_docs_in_vector_db(docs, collection_name, overwrite) | ||||
| 
 | ||||
| 
 | ||||
| def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         if overwrite: | ||||
|             for collection in CHROMA_CLIENT.list_collections(): | ||||
|                 if collection_name == collection.name: | ||||
|                     print(f"deleting existing collection {collection_name}") | ||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||
| 
 | ||||
|         collection = CHROMA_CLIENT.create_collection( | ||||
|             name=collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|         ) | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         if e.__class__.__name__ == "UniqueConstraintError": | ||||
|             return True | ||||
| 
 | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def get_loader(filename: str, file_content_type: str, file_path: str): | ||||
|     file_ext = filename.split(".")[-1].lower() | ||||
|     known_type = True | ||||
|  | @ -464,6 +480,37 @@ def store_doc( | |||
|             ) | ||||
| 
 | ||||
| 
 | ||||
| class TextRAGForm(BaseModel): | ||||
|     name: str | ||||
|     content: str | ||||
|     collection_name: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/text") | ||||
| def store_text( | ||||
|     form_data: TextRAGForm, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
| 
 | ||||
|     collection_name = form_data.collection_name | ||||
|     if collection_name == None: | ||||
|         collection_name = calculate_sha256_string(form_data.content) | ||||
| 
 | ||||
|     result = store_text_in_vector_db( | ||||
|         form_data.content, | ||||
|         metadata={"name": form_data.name, "created_by": user.id}, | ||||
|         collection_name=collection_name, | ||||
|     ) | ||||
| 
 | ||||
|     if result: | ||||
|         return {"status": True, "collection_name": collection_name} | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/scan") | ||||
| def scan_docs_dir(user=Depends(get_admin_user)): | ||||
|     for path in Path(DOCS_DIR).rglob("./**/*"): | ||||
|  |  | |||
|  | @ -141,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|                     k=k, | ||||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|             elif doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|             else: | ||||
|                 context = query_doc( | ||||
|                     collection_name=doc["collection_name"], | ||||
|  |  | |||
|  | @ -95,20 +95,6 @@ class ChatTable: | |||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: | ||||
|         try: | ||||
|             query = Chat.update( | ||||
|                 chat=json.dumps(chat), | ||||
|                 title=chat["title"] if "title" in chat else "New Chat", | ||||
|                 timestamp=int(time.time()), | ||||
|             ).where(Chat.id == id) | ||||
|             query.execute() | ||||
| 
 | ||||
|             chat = Chat.get(Chat.id == id) | ||||
|             return ChatModel(**model_to_dict(chat)) | ||||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def get_chat_lists_by_user_id( | ||||
|         self, user_id: str, skip: int = 0, limit: int = 50 | ||||
|     ) -> List[ChatModel]: | ||||
|  |  | |||
|  | @ -21,155 +21,6 @@ from constants import ERROR_MESSAGES | |||
| router = APIRouter() | ||||
| 
 | ||||
| 
 | ||||
| class UploadBlobForm(BaseModel): | ||||
|     filename: str | ||||
| 
 | ||||
| 
 | ||||
| from urllib.parse import urlparse | ||||
| 
 | ||||
| 
 | ||||
| def parse_huggingface_url(hf_url): | ||||
|     try: | ||||
|         # Parse the URL | ||||
|         parsed_url = urlparse(hf_url) | ||||
| 
 | ||||
|         # Get the path and split it into components | ||||
|         path_components = parsed_url.path.split("/") | ||||
| 
 | ||||
|         # Extract the desired output | ||||
|         user_repo = "/".join(path_components[1:3]) | ||||
|         model_file = path_components[-1] | ||||
| 
 | ||||
|         return model_file | ||||
|     except ValueError: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024): | ||||
|     done = False | ||||
| 
 | ||||
|     if os.path.exists(file_path): | ||||
|         current_size = os.path.getsize(file_path) | ||||
|     else: | ||||
|         current_size = 0 | ||||
| 
 | ||||
|     headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} | ||||
| 
 | ||||
|     timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout | ||||
| 
 | ||||
|     async with aiohttp.ClientSession(timeout=timeout) as session: | ||||
|         async with session.get(url, headers=headers) as response: | ||||
|             total_size = int(response.headers.get("content-length", 0)) + current_size | ||||
| 
 | ||||
|             with open(file_path, "ab+") as file: | ||||
|                 async for data in response.content.iter_chunked(chunk_size): | ||||
|                     current_size += len(data) | ||||
|                     file.write(data) | ||||
| 
 | ||||
|                     done = current_size == total_size | ||||
|                     progress = round((current_size / total_size) * 100, 2) | ||||
|                     yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' | ||||
| 
 | ||||
|                 if done: | ||||
|                     file.seek(0) | ||||
|                     hashed = calculate_sha256(file) | ||||
|                     file.seek(0) | ||||
| 
 | ||||
|                     url = f"{OLLAMA_BASE_URLS[0]}/api/blobs/sha256:{hashed}" | ||||
|                     response = requests.post(url, data=file) | ||||
| 
 | ||||
|                     if response.ok: | ||||
|                         res = { | ||||
|                             "done": done, | ||||
|                             "blob": f"sha256:{hashed}", | ||||
|                             "name": file_name, | ||||
|                         } | ||||
|                         os.remove(file_path) | ||||
| 
 | ||||
|                         yield f"data: {json.dumps(res)}\n\n" | ||||
|                     else: | ||||
|                         raise "Ollama: Could not create blob, Please try again." | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/download") | ||||
| async def download( | ||||
|     url: str, | ||||
| ): | ||||
|     # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" | ||||
|     file_name = parse_huggingface_url(url) | ||||
| 
 | ||||
|     if file_name: | ||||
|         file_path = f"{UPLOAD_DIR}/{file_name}" | ||||
| 
 | ||||
|         return StreamingResponse( | ||||
|             download_file_stream(url, file_path, file_name), | ||||
|             media_type="text/event-stream", | ||||
|         ) | ||||
|     else: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/upload") | ||||
| def upload(file: UploadFile = File(...)): | ||||
|     file_path = f"{UPLOAD_DIR}/{file.filename}" | ||||
| 
 | ||||
|     # Save file in chunks | ||||
|     with open(file_path, "wb+") as f: | ||||
|         for chunk in file.file: | ||||
|             f.write(chunk) | ||||
| 
 | ||||
|     def file_process_stream(): | ||||
|         total_size = os.path.getsize(file_path) | ||||
|         chunk_size = 1024 * 1024 | ||||
|         try: | ||||
|             with open(file_path, "rb") as f: | ||||
|                 total = 0 | ||||
|                 done = False | ||||
| 
 | ||||
|                 while not done: | ||||
|                     chunk = f.read(chunk_size) | ||||
|                     if not chunk: | ||||
|                         done = True | ||||
|                         continue | ||||
| 
 | ||||
|                     total += len(chunk) | ||||
|                     progress = round((total / total_size) * 100, 2) | ||||
| 
 | ||||
|                     res = { | ||||
|                         "progress": progress, | ||||
|                         "total": total_size, | ||||
|                         "completed": total, | ||||
|                     } | ||||
|                     yield f"data: {json.dumps(res)}\n\n" | ||||
| 
 | ||||
|                 if done: | ||||
|                     f.seek(0) | ||||
|                     hashed = calculate_sha256(f) | ||||
|                     f.seek(0) | ||||
| 
 | ||||
|                     url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}" | ||||
|                     response = requests.post(url, data=f) | ||||
| 
 | ||||
|                     if response.ok: | ||||
|                         res = { | ||||
|                             "done": done, | ||||
|                             "blob": f"sha256:{hashed}", | ||||
|                             "name": file.filename, | ||||
|                         } | ||||
|                         os.remove(file_path) | ||||
|                         yield f"data: {json.dumps(res)}\n\n" | ||||
|                     else: | ||||
|                         raise Exception( | ||||
|                             "Ollama: Could not create blob, Please try again." | ||||
|                         ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             res = {"error": str(e)} | ||||
|             yield f"data: {json.dumps(res)}\n\n" | ||||
| 
 | ||||
|     return StreamingResponse(file_process_stream(), media_type="text/event-stream") | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/gravatar") | ||||
| async def get_gravatar( | ||||
|     email: str, | ||||
|  |  | |||
|  | @ -406,3 +406,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models" | |||
| #################################### | ||||
| 
 | ||||
| AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") | ||||
| COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") | ||||
|  |  | |||
|  | @ -45,3 +45,4 @@ PyJWT | |||
| pyjwt[crypto] | ||||
| 
 | ||||
| black | ||||
| langfuse | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek