forked from open-webui/open-webui
		
	Merge pull request 'main' (!3) from open-webui/open-webui:main into main
	
		
			
	
		
	
	
		
	
		
			Some checks failed
		
		
	
	
	
				
					
				
			
		
			Some checks failed
		
		
	
	
Reviewed-on: #3
This commit is contained in:
		
						commit
						921a1cf978
					
				
					 74 changed files with 7314 additions and 5536 deletions
				
			
		|  | @ -1,4 +1,5 @@ | |||
| import os | ||||
| import logging | ||||
| from fastapi import ( | ||||
|     FastAPI, | ||||
|     Request, | ||||
|  | @ -21,7 +22,10 @@ from utils.utils import ( | |||
| ) | ||||
| from utils.misc import calculate_sha256 | ||||
| 
 | ||||
| from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR | ||||
| from config import SRC_LOG_LEVELS, CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["AUDIO"]) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| app.add_middleware( | ||||
|  | @ -38,7 +42,7 @@ def transcribe( | |||
|     file: UploadFile = File(...), | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     print(file.content_type) | ||||
|     log.info(f"file.content_type: {file.content_type}") | ||||
| 
 | ||||
|     if file.content_type not in ["audio/mpeg", "audio/wav"]: | ||||
|         raise HTTPException( | ||||
|  | @ -62,7 +66,7 @@ def transcribe( | |||
|         ) | ||||
| 
 | ||||
|         segments, info = model.transcribe(file_path, beam_size=5) | ||||
|         print( | ||||
|         log.info( | ||||
|             "Detected language '%s' with probability %f" | ||||
|             % (info.language, info.language_probability) | ||||
|         ) | ||||
|  | @ -72,7 +76,7 @@ def transcribe( | |||
|         return {"text": transcript.strip()} | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
| 
 | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|  |  | |||
|  | @ -18,6 +18,8 @@ from utils.utils import ( | |||
|     get_current_user, | ||||
|     get_admin_user, | ||||
| ) | ||||
| 
 | ||||
| from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image | ||||
| from utils.misc import calculate_sha256 | ||||
| from typing import Optional | ||||
| from pydantic import BaseModel | ||||
|  | @ -25,10 +27,14 @@ from pathlib import Path | |||
| import uuid | ||||
| import base64 | ||||
| import json | ||||
| import logging | ||||
| 
 | ||||
| from config import CACHE_DIR, AUTOMATIC1111_BASE_URL | ||||
| from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL | ||||
| 
 | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["IMAGES"]) | ||||
| 
 | ||||
| IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") | ||||
| IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
|  | @ -49,6 +55,8 @@ app.state.MODEL = "" | |||
| 
 | ||||
| 
 | ||||
| app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL | ||||
| app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL | ||||
| 
 | ||||
| 
 | ||||
| app.state.IMAGE_SIZE = "512x512" | ||||
| app.state.IMAGE_STEPS = 50 | ||||
|  | @ -71,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user | |||
|     return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} | ||||
| 
 | ||||
| 
 | ||||
| class UrlUpdateForm(BaseModel): | ||||
|     url: str | ||||
| class EngineUrlUpdateForm(BaseModel): | ||||
|     AUTOMATIC1111_BASE_URL: Optional[str] = None | ||||
|     COMFYUI_BASE_URL: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/url") | ||||
| async def get_automatic1111_url(user=Depends(get_admin_user)): | ||||
|     return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} | ||||
| async def get_engine_url(user=Depends(get_admin_user)): | ||||
|     return { | ||||
|         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, | ||||
|         "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/url/update") | ||||
| async def update_automatic1111_url( | ||||
|     form_data: UrlUpdateForm, user=Depends(get_admin_user) | ||||
| async def update_engine_url( | ||||
|     form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|     if form_data.url == "": | ||||
|     if form_data.AUTOMATIC1111_BASE_URL == None: | ||||
|         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL | ||||
|     else: | ||||
|         url = form_data.url.strip("/") | ||||
|         url = form_data.AUTOMATIC1111_BASE_URL.strip("/") | ||||
|         try: | ||||
|             r = requests.head(url) | ||||
|             app.state.AUTOMATIC1111_BASE_URL = url | ||||
|         except Exception as e: | ||||
|             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
| 
 | ||||
|     if form_data.COMFYUI_BASE_URL == None: | ||||
|         app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL | ||||
|     else: | ||||
|         url = form_data.COMFYUI_BASE_URL.strip("/") | ||||
| 
 | ||||
|         try: | ||||
|             r = requests.head(url) | ||||
|             app.state.COMFYUI_BASE_URL = url | ||||
|         except Exception as e: | ||||
|             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
| 
 | ||||
|     return { | ||||
|         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, | ||||
|         "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, | ||||
|         "status": True, | ||||
|     } | ||||
| 
 | ||||
|  | @ -186,6 +210,18 @@ def get_models(user=Depends(get_current_user)): | |||
|                 {"id": "dall-e-2", "name": "DALL·E 2"}, | ||||
|                 {"id": "dall-e-3", "name": "DALL·E 3"}, | ||||
|             ] | ||||
|         elif app.state.ENGINE == "comfyui": | ||||
| 
 | ||||
|             r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") | ||||
|             info = r.json() | ||||
| 
 | ||||
|             return list( | ||||
|                 map( | ||||
|                     lambda model: {"id": model, "name": model}, | ||||
|                     info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0], | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         else: | ||||
|             r = requests.get( | ||||
|                 url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" | ||||
|  | @ -207,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)): | |||
|     try: | ||||
|         if app.state.ENGINE == "openai": | ||||
|             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} | ||||
|         elif app.state.ENGINE == "comfyui": | ||||
|             return {"model": app.state.MODEL if app.state.MODEL else ""} | ||||
|         else: | ||||
|             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|             options = r.json() | ||||
|  | @ -221,10 +259,12 @@ class UpdateModelForm(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| def set_model_handler(model: str): | ||||
| 
 | ||||
|     if app.state.ENGINE == "openai": | ||||
|         app.state.MODEL = model | ||||
|         return app.state.MODEL | ||||
|     if app.state.ENGINE == "comfyui": | ||||
|         app.state.MODEL = model | ||||
|         return app.state.MODEL | ||||
|     else: | ||||
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|         options = r.json() | ||||
|  | @ -266,6 +306,23 @@ def save_b64_image(b64_str): | |||
|         with open(file_path, "wb") as f: | ||||
|             f.write(img_data) | ||||
| 
 | ||||
|         return image_id | ||||
|     except Exception as e: | ||||
|         log.error(f"Error saving image: {e}") | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| def save_url_image(url): | ||||
|     image_id = str(uuid.uuid4()) | ||||
|     file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") | ||||
| 
 | ||||
|     try: | ||||
|         r = requests.get(url) | ||||
|         r.raise_for_status() | ||||
| 
 | ||||
|         with open(file_path, "wb") as image_file: | ||||
|             image_file.write(r.content) | ||||
| 
 | ||||
|         return image_id | ||||
|     except Exception as e: | ||||
|         print(f"Error saving image: {e}") | ||||
|  | @ -278,6 +335,8 @@ def generate_image( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
| 
 | ||||
|     width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||
| 
 | ||||
|     r = None | ||||
|     try: | ||||
|         if app.state.ENGINE == "openai": | ||||
|  | @ -315,12 +374,47 @@ def generate_image( | |||
| 
 | ||||
|             return images | ||||
| 
 | ||||
|         elif app.state.ENGINE == "comfyui": | ||||
| 
 | ||||
|             data = { | ||||
|                 "prompt": form_data.prompt, | ||||
|                 "width": width, | ||||
|                 "height": height, | ||||
|                 "n": form_data.n, | ||||
|             } | ||||
| 
 | ||||
|             if app.state.IMAGE_STEPS != None: | ||||
|                 data["steps"] = app.state.IMAGE_STEPS | ||||
| 
 | ||||
|             if form_data.negative_prompt != None: | ||||
|                 data["negative_prompt"] = form_data.negative_prompt | ||||
| 
 | ||||
|             data = ImageGenerationPayload(**data) | ||||
| 
 | ||||
|             res = comfyui_generate_image( | ||||
|                 app.state.MODEL, | ||||
|                 data, | ||||
|                 user.id, | ||||
|                 app.state.COMFYUI_BASE_URL, | ||||
|             ) | ||||
|             print(res) | ||||
| 
 | ||||
|             images = [] | ||||
| 
 | ||||
|             for image in res["data"]: | ||||
|                 image_id = save_url_image(image["url"]) | ||||
|                 images.append({"url": f"/cache/image/generations/{image_id}.png"}) | ||||
|                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") | ||||
| 
 | ||||
|                 with open(file_body_path, "w") as f: | ||||
|                     json.dump(data.model_dump(exclude_none=True), f) | ||||
| 
 | ||||
|             print(images) | ||||
|             return images | ||||
|         else: | ||||
|             if form_data.model: | ||||
|                 set_model_handler(form_data.model) | ||||
| 
 | ||||
|             width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||
| 
 | ||||
|             data = { | ||||
|                 "prompt": form_data.prompt, | ||||
|                 "batch_size": form_data.n, | ||||
|  | @ -341,7 +435,7 @@ def generate_image( | |||
| 
 | ||||
|             res = r.json() | ||||
| 
 | ||||
|             print(res) | ||||
|             log.debug(f"res: {res}") | ||||
| 
 | ||||
|             images = [] | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										228
									
								
								backend/apps/images/utils/comfyui.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								backend/apps/images/utils/comfyui.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,228 @@ | |||
| import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) | ||||
| import uuid | ||||
| import json | ||||
| import urllib.request | ||||
| import urllib.parse | ||||
| import random | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from typing import Optional | ||||
| 
 | ||||
| COMFYUI_DEFAULT_PROMPT = """ | ||||
| { | ||||
|   "3": { | ||||
|     "inputs": { | ||||
|       "seed": 0, | ||||
|       "steps": 20, | ||||
|       "cfg": 8, | ||||
|       "sampler_name": "euler", | ||||
|       "scheduler": "normal", | ||||
|       "denoise": 1, | ||||
|       "model": [ | ||||
|         "4", | ||||
|         0 | ||||
|       ], | ||||
|       "positive": [ | ||||
|         "6", | ||||
|         0 | ||||
|       ], | ||||
|       "negative": [ | ||||
|         "7", | ||||
|         0 | ||||
|       ], | ||||
|       "latent_image": [ | ||||
|         "5", | ||||
|         0 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "KSampler", | ||||
|     "_meta": { | ||||
|       "title": "KSampler" | ||||
|     } | ||||
|   }, | ||||
|   "4": { | ||||
|     "inputs": { | ||||
|       "ckpt_name": "model.safetensors" | ||||
|     }, | ||||
|     "class_type": "CheckpointLoaderSimple", | ||||
|     "_meta": { | ||||
|       "title": "Load Checkpoint" | ||||
|     } | ||||
|   }, | ||||
|   "5": { | ||||
|     "inputs": { | ||||
|       "width": 512, | ||||
|       "height": 512, | ||||
|       "batch_size": 1 | ||||
|     }, | ||||
|     "class_type": "EmptyLatentImage", | ||||
|     "_meta": { | ||||
|       "title": "Empty Latent Image" | ||||
|     } | ||||
|   }, | ||||
|   "6": { | ||||
|     "inputs": { | ||||
|       "text": "Prompt", | ||||
|       "clip": [ | ||||
|         "4", | ||||
|         1 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "CLIPTextEncode", | ||||
|     "_meta": { | ||||
|       "title": "CLIP Text Encode (Prompt)" | ||||
|     } | ||||
|   }, | ||||
|   "7": { | ||||
|     "inputs": { | ||||
|       "text": "Negative Prompt", | ||||
|       "clip": [ | ||||
|         "4", | ||||
|         1 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "CLIPTextEncode", | ||||
|     "_meta": { | ||||
|       "title": "CLIP Text Encode (Prompt)" | ||||
|     } | ||||
|   }, | ||||
|   "8": { | ||||
|     "inputs": { | ||||
|       "samples": [ | ||||
|         "3", | ||||
|         0 | ||||
|       ], | ||||
|       "vae": [ | ||||
|         "4", | ||||
|         2 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "VAEDecode", | ||||
|     "_meta": { | ||||
|       "title": "VAE Decode" | ||||
|     } | ||||
|   }, | ||||
|   "9": { | ||||
|     "inputs": { | ||||
|       "filename_prefix": "ComfyUI", | ||||
|       "images": [ | ||||
|         "8", | ||||
|         0 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "SaveImage", | ||||
|     "_meta": { | ||||
|       "title": "Save Image" | ||||
|     } | ||||
|   } | ||||
| } | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| def queue_prompt(prompt, client_id, base_url): | ||||
|     print("queue_prompt") | ||||
|     p = {"prompt": prompt, "client_id": client_id} | ||||
|     data = json.dumps(p).encode("utf-8") | ||||
|     req = urllib.request.Request(f"{base_url}/prompt", data=data) | ||||
|     return json.loads(urllib.request.urlopen(req).read()) | ||||
| 
 | ||||
| 
 | ||||
| def get_image(filename, subfolder, folder_type, base_url): | ||||
|     print("get_image") | ||||
|     data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||||
|     url_values = urllib.parse.urlencode(data) | ||||
|     with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response: | ||||
|         return response.read() | ||||
| 
 | ||||
| 
 | ||||
| def get_image_url(filename, subfolder, folder_type, base_url): | ||||
|     print("get_image") | ||||
|     data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||||
|     url_values = urllib.parse.urlencode(data) | ||||
|     return f"{base_url}/view?{url_values}" | ||||
| 
 | ||||
| 
 | ||||
| def get_history(prompt_id, base_url): | ||||
|     print("get_history") | ||||
|     with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response: | ||||
|         return json.loads(response.read()) | ||||
| 
 | ||||
| 
 | ||||
| def get_images(ws, prompt, client_id, base_url): | ||||
|     prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] | ||||
|     output_images = [] | ||||
|     while True: | ||||
|         out = ws.recv() | ||||
|         if isinstance(out, str): | ||||
|             message = json.loads(out) | ||||
|             if message["type"] == "executing": | ||||
|                 data = message["data"] | ||||
|                 if data["node"] is None and data["prompt_id"] == prompt_id: | ||||
|                     break  # Execution is done | ||||
|         else: | ||||
|             continue  # previews are binary data | ||||
| 
 | ||||
|     history = get_history(prompt_id, base_url)[prompt_id] | ||||
|     for o in history["outputs"]: | ||||
|         for node_id in history["outputs"]: | ||||
|             node_output = history["outputs"][node_id] | ||||
|             if "images" in node_output: | ||||
|                 for image in node_output["images"]: | ||||
|                     url = get_image_url( | ||||
|                         image["filename"], image["subfolder"], image["type"], base_url | ||||
|                     ) | ||||
|                     output_images.append({"url": url}) | ||||
|     return {"data": output_images} | ||||
| 
 | ||||
| 
 | ||||
| class ImageGenerationPayload(BaseModel): | ||||
|     prompt: str | ||||
|     negative_prompt: Optional[str] = "" | ||||
|     steps: Optional[int] = None | ||||
|     seed: Optional[int] = None | ||||
|     width: int | ||||
|     height: int | ||||
|     n: int = 1 | ||||
| 
 | ||||
| 
 | ||||
| def comfyui_generate_image( | ||||
|     model: str, payload: ImageGenerationPayload, client_id, base_url | ||||
| ): | ||||
|     host = base_url.replace("http://", "").replace("https://", "") | ||||
| 
 | ||||
|     comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) | ||||
| 
 | ||||
|     comfyui_prompt["4"]["inputs"]["ckpt_name"] = model | ||||
|     comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n | ||||
|     comfyui_prompt["5"]["inputs"]["width"] = payload.width | ||||
|     comfyui_prompt["5"]["inputs"]["height"] = payload.height | ||||
| 
 | ||||
|     # set the text prompt for our positive CLIPTextEncode | ||||
|     comfyui_prompt["6"]["inputs"]["text"] = payload.prompt | ||||
|     comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt | ||||
| 
 | ||||
|     if payload.steps: | ||||
|         comfyui_prompt["3"]["inputs"]["steps"] = payload.steps | ||||
| 
 | ||||
|     comfyui_prompt["3"]["inputs"]["seed"] = ( | ||||
|         payload.seed if payload.seed else random.randint(0, 18446744073709551614) | ||||
|     ) | ||||
| 
 | ||||
|     try: | ||||
|         ws = websocket.WebSocket() | ||||
|         ws.connect(f"ws://{host}/ws?clientId={client_id}") | ||||
|         print("WebSocket connection established.") | ||||
|     except Exception as e: | ||||
|         print(f"Failed to connect to WebSocket server: {e}") | ||||
|         return None | ||||
| 
 | ||||
|     try: | ||||
|         images = get_images(ws, comfyui_prompt, client_id, base_url) | ||||
|     except Exception as e: | ||||
|         print(f"Error while receiving images: {e}") | ||||
|         images = None | ||||
| 
 | ||||
|     ws.close() | ||||
| 
 | ||||
|     return images | ||||
|  | @ -1,10 +1,27 @@ | |||
| import logging | ||||
| 
 | ||||
| from litellm.proxy.proxy_server import ProxyConfig, initialize | ||||
| from litellm.proxy.proxy_server import app | ||||
| 
 | ||||
| from fastapi import FastAPI, Request, Depends, status | ||||
| from fastapi import FastAPI, Request, Depends, status, Response | ||||
| from fastapi.responses import JSONResponse | ||||
| 
 | ||||
| from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | ||||
| from starlette.responses import StreamingResponse | ||||
| import json | ||||
| 
 | ||||
| from utils.utils import get_http_authorization_cred, get_current_user | ||||
| from config import ENV | ||||
| from config import SRC_LOG_LEVELS, ENV | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["LITELLM"]) | ||||
| 
 | ||||
| 
 | ||||
| from config import ( | ||||
|     MODEL_FILTER_ENABLED, | ||||
|     MODEL_FILTER_LIST, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| proxy_config = ProxyConfig() | ||||
| 
 | ||||
|  | @ -26,16 +43,58 @@ async def on_startup(): | |||
|     await startup() | ||||
| 
 | ||||
| 
 | ||||
| app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED | ||||
| app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST | ||||
| 
 | ||||
| 
 | ||||
| @app.middleware("http") | ||||
| async def auth_middleware(request: Request, call_next): | ||||
|     auth_header = request.headers.get("Authorization", "") | ||||
|     request.state.user = None | ||||
| 
 | ||||
|     if ENV != "dev": | ||||
|         try: | ||||
|             user = get_current_user(get_http_authorization_cred(auth_header)) | ||||
|             print(user) | ||||
|         except Exception as e: | ||||
|             return JSONResponse(status_code=400, content={"detail": str(e)}) | ||||
|     try: | ||||
|         user = get_current_user(get_http_authorization_cred(auth_header)) | ||||
|         log.debug(f"user: {user}") | ||||
|         request.state.user = user | ||||
|     except Exception as e: | ||||
|         return JSONResponse(status_code=400, content={"detail": str(e)}) | ||||
| 
 | ||||
|     response = await call_next(request) | ||||
|     return response | ||||
| 
 | ||||
| 
 | ||||
| class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): | ||||
|     async def dispatch( | ||||
|         self, request: Request, call_next: RequestResponseEndpoint | ||||
|     ) -> Response: | ||||
| 
 | ||||
|         response = await call_next(request) | ||||
|         user = request.state.user | ||||
| 
 | ||||
|         if "/models" in request.url.path: | ||||
|             if isinstance(response, StreamingResponse): | ||||
|                 # Read the content of the streaming response | ||||
|                 body = b"" | ||||
|                 async for chunk in response.body_iterator: | ||||
|                     body += chunk | ||||
| 
 | ||||
|                 data = json.loads(body.decode("utf-8")) | ||||
| 
 | ||||
|                 if app.state.MODEL_FILTER_ENABLED: | ||||
|                     if user and user.role == "user": | ||||
|                         data["data"] = list( | ||||
|                             filter( | ||||
|                                 lambda model: model["id"] | ||||
|                                 in app.state.MODEL_FILTER_LIST, | ||||
|                                 data["data"], | ||||
|                             ) | ||||
|                         ) | ||||
| 
 | ||||
|                 # Modified Flag | ||||
|                 data["modified"] = True | ||||
|                 return JSONResponse(content=data) | ||||
| 
 | ||||
|         return response | ||||
| 
 | ||||
| 
 | ||||
| app.add_middleware(ModifyModelsResponseMiddleware) | ||||
|  |  | |||
|  | @ -1,24 +1,43 @@ | |||
| from fastapi import FastAPI, Request, Response, HTTPException, Depends, status | ||||
| from fastapi import ( | ||||
|     FastAPI, | ||||
|     Request, | ||||
|     Response, | ||||
|     HTTPException, | ||||
|     Depends, | ||||
|     status, | ||||
|     UploadFile, | ||||
|     File, | ||||
|     BackgroundTasks, | ||||
| ) | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from fastapi.responses import StreamingResponse | ||||
| from fastapi.concurrency import run_in_threadpool | ||||
| 
 | ||||
| from pydantic import BaseModel, ConfigDict | ||||
| 
 | ||||
| import os | ||||
| import copy | ||||
| import random | ||||
| import requests | ||||
| import json | ||||
| import uuid | ||||
| import aiohttp | ||||
| import asyncio | ||||
| import logging | ||||
| from urllib.parse import urlparse | ||||
| from typing import Optional, List, Union | ||||
| 
 | ||||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| from constants import ERROR_MESSAGES | ||||
| from utils.utils import decode_token, get_current_user, get_admin_user | ||||
| from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST | ||||
| 
 | ||||
| from typing import Optional, List, Union | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR | ||||
| from utils.misc import calculate_sha256 | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| app.add_middleware( | ||||
|  | @ -69,7 +88,7 @@ class UrlUpdateForm(BaseModel): | |||
| async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): | ||||
|     app.state.OLLAMA_BASE_URLS = form_data.urls | ||||
| 
 | ||||
|     print(app.state.OLLAMA_BASE_URLS) | ||||
|     log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") | ||||
|     return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} | ||||
| 
 | ||||
| 
 | ||||
|  | @ -90,7 +109,7 @@ async def fetch_url(url): | |||
|                 return await response.json() | ||||
|     except Exception as e: | ||||
|         # Handle connection error here | ||||
|         print(f"Connection error: {e}") | ||||
|         log.error(f"Connection error: {e}") | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
|  | @ -114,7 +133,7 @@ def merge_models_lists(model_lists): | |||
| 
 | ||||
| 
 | ||||
| async def get_all_models(): | ||||
|     print("get_all_models") | ||||
|     log.info("get_all_models()") | ||||
|     tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] | ||||
|     responses = await asyncio.gather(*tasks) | ||||
| 
 | ||||
|  | @ -155,7 +174,7 @@ async def get_ollama_tags( | |||
| 
 | ||||
|             return r.json() | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             error_detail = "Open WebUI: Server Connection Error" | ||||
|             if r is not None: | ||||
|                 try: | ||||
|  | @ -201,7 +220,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): | |||
| 
 | ||||
|             return r.json() | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             error_detail = "Open WebUI: Server Connection Error" | ||||
|             if r is not None: | ||||
|                 try: | ||||
|  | @ -227,18 +246,33 @@ async def pull_model( | |||
|     form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) | ||||
| ): | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     r = None | ||||
| 
 | ||||
|     def get_request(): | ||||
|         nonlocal url | ||||
|         nonlocal r | ||||
| 
 | ||||
|         request_id = str(uuid.uuid4()) | ||||
|         try: | ||||
|             REQUEST_POOL.append(request_id) | ||||
| 
 | ||||
|             def stream_content(): | ||||
|                 for chunk in r.iter_content(chunk_size=8192): | ||||
|                     yield chunk | ||||
|                 try: | ||||
|                     yield json.dumps({"id": request_id, "done": False}) + "\n" | ||||
| 
 | ||||
|                     for chunk in r.iter_content(chunk_size=8192): | ||||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|                         r.close() | ||||
|                         if request_id in REQUEST_POOL: | ||||
|                             REQUEST_POOL.remove(request_id) | ||||
| 
 | ||||
|             r = requests.request( | ||||
|                 method="POST", | ||||
|  | @ -259,8 +293,9 @@ async def pull_model( | |||
| 
 | ||||
|     try: | ||||
|         return await run_in_threadpool(get_request) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|  | @ -299,7 +334,7 @@ async def push_model( | |||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.debug(f"url: {url}") | ||||
| 
 | ||||
|     r = None | ||||
| 
 | ||||
|  | @ -331,7 +366,7 @@ async def push_model( | |||
|     try: | ||||
|         return await run_in_threadpool(get_request) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|  | @ -359,9 +394,9 @@ class CreateModelForm(BaseModel): | |||
| async def create_model( | ||||
|     form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) | ||||
| ): | ||||
|     print(form_data) | ||||
|     log.debug(f"form_data: {form_data}") | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     r = None | ||||
| 
 | ||||
|  | @ -383,7 +418,7 @@ async def create_model( | |||
| 
 | ||||
|             r.raise_for_status() | ||||
| 
 | ||||
|             print(r) | ||||
|             log.debug(f"r: {r}") | ||||
| 
 | ||||
|             return StreamingResponse( | ||||
|                 stream_content(), | ||||
|  | @ -396,7 +431,7 @@ async def create_model( | |||
|     try: | ||||
|         return await run_in_threadpool(get_request) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|  | @ -434,7 +469,7 @@ async def copy_model( | |||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     try: | ||||
|         r = requests.request( | ||||
|  | @ -444,11 +479,11 @@ async def copy_model( | |||
|         ) | ||||
|         r.raise_for_status() | ||||
| 
 | ||||
|         print(r.text) | ||||
|         log.debug(f"r.text: {r.text}") | ||||
| 
 | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|  | @ -481,7 +516,7 @@ async def delete_model( | |||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     try: | ||||
|         r = requests.request( | ||||
|  | @ -491,11 +526,11 @@ async def delete_model( | |||
|         ) | ||||
|         r.raise_for_status() | ||||
| 
 | ||||
|         print(r.text) | ||||
|         log.debug(f"r.text: {r.text}") | ||||
| 
 | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|  | @ -521,7 +556,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use | |||
| 
 | ||||
|     url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     try: | ||||
|         r = requests.request( | ||||
|  | @ -533,7 +568,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use | |||
| 
 | ||||
|         return r.json() | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|  | @ -573,7 +608,7 @@ async def generate_embeddings( | |||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     try: | ||||
|         r = requests.request( | ||||
|  | @ -585,7 +620,7 @@ async def generate_embeddings( | |||
| 
 | ||||
|         return r.json() | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|  | @ -633,7 +668,7 @@ async def generate_completion( | |||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     r = None | ||||
| 
 | ||||
|  | @ -654,7 +689,7 @@ async def generate_completion( | |||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             log.warning("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|  | @ -709,7 +744,7 @@ class GenerateChatCompletionForm(BaseModel): | |||
|     format: Optional[str] = None | ||||
|     options: Optional[dict] = None | ||||
|     template: Optional[str] = None | ||||
|     stream: Optional[bool] = True | ||||
|     stream: Optional[bool] = None | ||||
|     keep_alive: Optional[Union[int, str]] = None | ||||
| 
 | ||||
| 
 | ||||
|  | @ -731,11 +766,11 @@ async def generate_chat_completion( | |||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     r = None | ||||
| 
 | ||||
|     print(form_data.model_dump_json(exclude_none=True).encode()) | ||||
|     log.debug("form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(form_data.model_dump_json(exclude_none=True).encode())) | ||||
| 
 | ||||
|     def get_request(): | ||||
|         nonlocal form_data | ||||
|  | @ -754,7 +789,7 @@ async def generate_chat_completion( | |||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             log.warning("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|  | @ -777,7 +812,7 @@ async def generate_chat_completion( | |||
|                 headers=dict(r.headers), | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             raise e | ||||
| 
 | ||||
|     try: | ||||
|  | @ -831,7 +866,7 @@ async def generate_openai_chat_completion( | |||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     print(url) | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     r = None | ||||
| 
 | ||||
|  | @ -854,7 +889,7 @@ async def generate_openai_chat_completion( | |||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             log.warning("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|  | @ -897,6 +932,211 @@ async def generate_openai_chat_completion( | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class UrlForm(BaseModel): | ||||
|     url: str | ||||
| 
 | ||||
| 
 | ||||
| class UploadBlobForm(BaseModel): | ||||
|     filename: str | ||||
| 
 | ||||
| 
 | ||||
| def parse_huggingface_url(hf_url): | ||||
|     try: | ||||
|         # Parse the URL | ||||
|         parsed_url = urlparse(hf_url) | ||||
| 
 | ||||
|         # Get the path and split it into components | ||||
|         path_components = parsed_url.path.split("/") | ||||
| 
 | ||||
|         # Extract the desired output | ||||
|         user_repo = "/".join(path_components[1:3]) | ||||
|         model_file = path_components[-1] | ||||
| 
 | ||||
|         return model_file | ||||
|     except ValueError: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| async def download_file_stream( | ||||
|     ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 | ||||
| ): | ||||
|     done = False | ||||
| 
 | ||||
|     if os.path.exists(file_path): | ||||
|         current_size = os.path.getsize(file_path) | ||||
|     else: | ||||
|         current_size = 0 | ||||
| 
 | ||||
|     headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} | ||||
| 
 | ||||
|     timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout | ||||
| 
 | ||||
|     async with aiohttp.ClientSession(timeout=timeout) as session: | ||||
|         async with session.get(file_url, headers=headers) as response: | ||||
|             total_size = int(response.headers.get("content-length", 0)) + current_size | ||||
| 
 | ||||
|             with open(file_path, "ab+") as file: | ||||
|                 async for data in response.content.iter_chunked(chunk_size): | ||||
|                     current_size += len(data) | ||||
|                     file.write(data) | ||||
| 
 | ||||
|                     done = current_size == total_size | ||||
|                     progress = round((current_size / total_size) * 100, 2) | ||||
| 
 | ||||
|                     yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' | ||||
| 
 | ||||
|                 if done: | ||||
|                     file.seek(0) | ||||
|                     hashed = calculate_sha256(file) | ||||
|                     file.seek(0) | ||||
| 
 | ||||
|                     url = f"{ollama_url}/api/blobs/sha256:{hashed}" | ||||
|                     response = requests.post(url, data=file) | ||||
| 
 | ||||
|                     if response.ok: | ||||
|                         res = { | ||||
|                             "done": done, | ||||
|                             "blob": f"sha256:{hashed}", | ||||
|                             "name": file_name, | ||||
|                         } | ||||
|                         os.remove(file_path) | ||||
| 
 | ||||
|                         yield f"data: {json.dumps(res)}\n\n" | ||||
|                     else: | ||||
|                         raise "Ollama: Could not create blob, Please try again." | ||||
| 
 | ||||
| 
 | ||||
| # def number_generator(): | ||||
| #     for i in range(1, 101): | ||||
| #         yield f"data: {i}\n" | ||||
| 
 | ||||
| 
 | ||||
| # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" | ||||
| @app.post("/models/download") | ||||
| @app.post("/models/download/{url_idx}") | ||||
| async def download_model( | ||||
|     form_data: UrlForm, | ||||
|     url_idx: Optional[int] = None, | ||||
| ): | ||||
| 
 | ||||
|     if url_idx == None: | ||||
|         url_idx = 0 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
| 
 | ||||
|     file_name = parse_huggingface_url(form_data.url) | ||||
| 
 | ||||
|     if file_name: | ||||
|         file_path = f"{UPLOAD_DIR}/{file_name}" | ||||
|         return StreamingResponse( | ||||
|             download_file_stream(url, form_data.url, file_path, file_name), | ||||
|         ) | ||||
|     else: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/models/upload") | ||||
| @app.post("/models/upload/{url_idx}") | ||||
| def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): | ||||
|     if url_idx == None: | ||||
|         url_idx = 0 | ||||
|     ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
| 
 | ||||
|     file_path = f"{UPLOAD_DIR}/{file.filename}" | ||||
| 
 | ||||
|     # Save file in chunks | ||||
|     with open(file_path, "wb+") as f: | ||||
|         for chunk in file.file: | ||||
|             f.write(chunk) | ||||
| 
 | ||||
|     def file_process_stream(): | ||||
|         nonlocal ollama_url | ||||
|         total_size = os.path.getsize(file_path) | ||||
|         chunk_size = 1024 * 1024 | ||||
|         try: | ||||
|             with open(file_path, "rb") as f: | ||||
|                 total = 0 | ||||
|                 done = False | ||||
| 
 | ||||
|                 while not done: | ||||
|                     chunk = f.read(chunk_size) | ||||
|                     if not chunk: | ||||
|                         done = True | ||||
|                         continue | ||||
| 
 | ||||
|                     total += len(chunk) | ||||
|                     progress = round((total / total_size) * 100, 2) | ||||
| 
 | ||||
|                     res = { | ||||
|                         "progress": progress, | ||||
|                         "total": total_size, | ||||
|                         "completed": total, | ||||
|                     } | ||||
|                     yield f"data: {json.dumps(res)}\n\n" | ||||
| 
 | ||||
|                 if done: | ||||
|                     f.seek(0) | ||||
|                     hashed = calculate_sha256(f) | ||||
|                     f.seek(0) | ||||
| 
 | ||||
|                     url = f"{ollama_url}/api/blobs/sha256:{hashed}" | ||||
|                     response = requests.post(url, data=f) | ||||
| 
 | ||||
|                     if response.ok: | ||||
|                         res = { | ||||
|                             "done": done, | ||||
|                             "blob": f"sha256:{hashed}", | ||||
|                             "name": file.filename, | ||||
|                         } | ||||
|                         os.remove(file_path) | ||||
|                         yield f"data: {json.dumps(res)}\n\n" | ||||
|                     else: | ||||
|                         raise Exception( | ||||
|                             "Ollama: Could not create blob, Please try again." | ||||
|                         ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             res = {"error": str(e)} | ||||
|             yield f"data: {json.dumps(res)}\n\n" | ||||
| 
 | ||||
|     return StreamingResponse(file_process_stream(), media_type="text/event-stream") | ||||
| 
 | ||||
| 
 | ||||
| # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): | ||||
| #     if url_idx == None: | ||||
| #         url_idx = 0 | ||||
| #     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
| 
 | ||||
| #     file_location = os.path.join(UPLOAD_DIR, file.filename) | ||||
| #     total_size = file.size | ||||
| 
 | ||||
| #     async def file_upload_generator(file): | ||||
| #         print(file) | ||||
| #         try: | ||||
| #             async with aiofiles.open(file_location, "wb") as f: | ||||
| #                 completed_size = 0 | ||||
| #                 while True: | ||||
| #                     chunk = await file.read(1024*1024) | ||||
| #                     if not chunk: | ||||
| #                         break | ||||
| #                     await f.write(chunk) | ||||
| #                     completed_size += len(chunk) | ||||
| #                     progress = (completed_size / total_size) * 100 | ||||
| 
 | ||||
| #                     print(progress) | ||||
| #                     yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n' | ||||
| #         except Exception as e: | ||||
| #             print(e) | ||||
| #             yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n" | ||||
| #         finally: | ||||
| #             await file.close() | ||||
| #             print("done") | ||||
| #             yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n' | ||||
| 
 | ||||
| #     return StreamingResponse( | ||||
| #         file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream" | ||||
| #     ) | ||||
| 
 | ||||
| 
 | ||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||
| async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||
|     url = app.state.OLLAMA_BASE_URLS[0] | ||||
|  | @ -947,7 +1187,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current | |||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             log.warning("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import requests | |||
| import aiohttp | ||||
| import asyncio | ||||
| import json | ||||
| import logging | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
|  | @ -19,6 +20,7 @@ from utils.utils import ( | |||
|     get_admin_user, | ||||
| ) | ||||
| from config import ( | ||||
|     SRC_LOG_LEVELS, | ||||
|     OPENAI_API_BASE_URLS, | ||||
|     OPENAI_API_KEYS, | ||||
|     CACHE_DIR, | ||||
|  | @ -31,6 +33,9 @@ from typing import List, Optional | |||
| import hashlib | ||||
| from pathlib import Path | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["OPENAI"]) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|  | @ -134,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): | |||
|             return FileResponse(file_path) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             error_detail = "Open WebUI: Server Connection Error" | ||||
|             if r is not None: | ||||
|                 try: | ||||
|  | @ -160,7 +165,7 @@ async def fetch_url(url, key): | |||
|                 return await response.json() | ||||
|     except Exception as e: | ||||
|         # Handle connection error here | ||||
|         print(f"Connection error: {e}") | ||||
|         log.error(f"Connection error: {e}") | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
|  | @ -182,7 +187,7 @@ def merge_models_lists(model_lists): | |||
| 
 | ||||
| 
 | ||||
| async def get_all_models(): | ||||
|     print("get_all_models") | ||||
|     log.info("get_all_models()") | ||||
| 
 | ||||
|     if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": | ||||
|         models = {"data": []} | ||||
|  | @ -208,7 +213,7 @@ async def get_all_models(): | |||
|             ) | ||||
|         } | ||||
| 
 | ||||
|         print(models) | ||||
|         log.info(f"models: {models}") | ||||
|         app.state.MODELS = {model["id"]: model for model in models["data"]} | ||||
| 
 | ||||
|         return models | ||||
|  | @ -246,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use | |||
| 
 | ||||
|             return response_data | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             error_detail = "Open WebUI: Server Connection Error" | ||||
|             if r is not None: | ||||
|                 try: | ||||
|  | @ -280,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): | |||
|         if body.get("model") == "gpt-4-vision-preview": | ||||
|             if "max_tokens" not in body: | ||||
|                 body["max_tokens"] = 4000 | ||||
|             print("Modified body_dict:", body) | ||||
|             log.debug("Modified body_dict:", body) | ||||
| 
 | ||||
|         # Fix for ChatGPT calls failing because the num_ctx key is in body | ||||
|         if "num_ctx" in body: | ||||
|  | @ -292,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): | |||
|         # Convert the modified body back to JSON | ||||
|         body = json.dumps(body) | ||||
|     except json.JSONDecodeError as e: | ||||
|         print("Error loading request body into a dictionary:", e) | ||||
|         log.error("Error loading request body into a dictionary:", e) | ||||
| 
 | ||||
|     url = app.state.OPENAI_API_BASE_URLS[idx] | ||||
|     key = app.state.OPENAI_API_KEYS[idx] | ||||
|  | @ -330,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): | |||
|             response_data = r.json() | ||||
|             return response_data | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|  |  | |||
|  | @ -8,7 +8,7 @@ from fastapi import ( | |||
|     Form, | ||||
| ) | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| import os, shutil | ||||
| import os, shutil, logging | ||||
| 
 | ||||
| from pathlib import Path | ||||
| from typing import List | ||||
|  | @ -54,6 +54,7 @@ from utils.misc import ( | |||
| ) | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from config import ( | ||||
|     SRC_LOG_LEVELS, | ||||
|     UPLOAD_DIR, | ||||
|     DOCS_DIR, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|  | @ -66,6 +67,9 @@ from config import ( | |||
| 
 | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
| # | ||||
| # if RAG_EMBEDDING_MODEL: | ||||
| #    sentence_transformer_ef = SentenceTransformer( | ||||
|  | @ -110,40 +114,6 @@ class CollectionNameForm(BaseModel): | |||
| class StoreWebForm(CollectionNameForm): | ||||
|     url: str | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP | ||||
|     ) | ||||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         if overwrite: | ||||
|             for collection in CHROMA_CLIENT.list_collections(): | ||||
|                 if collection_name == collection.name: | ||||
|                     print(f"deleting existing collection {collection_name}") | ||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||
| 
 | ||||
|         collection = CHROMA_CLIENT.create_collection( | ||||
|             name=collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|         ) | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         if e.__class__.__name__ == "UniqueConstraintError": | ||||
|             return True | ||||
| 
 | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/") | ||||
| async def get_status(): | ||||
|     return { | ||||
|  | @ -274,7 +244,7 @@ def query_doc_handler( | |||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(e), | ||||
|  | @ -318,13 +288,63 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | |||
|             "filename": form_data.url, | ||||
|         } | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(e), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, | ||||
|         chunk_overlap=app.state.CHUNK_OVERLAP, | ||||
|         add_start_index=True, | ||||
|     ) | ||||
|     docs = text_splitter.split_documents(data) | ||||
|     return store_docs_in_vector_db(docs, collection_name, overwrite) | ||||
| 
 | ||||
| 
 | ||||
| def store_text_in_vector_db( | ||||
|     text, metadata, collection_name, overwrite: bool = False | ||||
| ) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, | ||||
|         chunk_overlap=app.state.CHUNK_OVERLAP, | ||||
|         add_start_index=True, | ||||
|     ) | ||||
|     docs = text_splitter.create_documents([text], metadatas=[metadata]) | ||||
|     return store_docs_in_vector_db(docs, collection_name, overwrite) | ||||
| 
 | ||||
| 
 | ||||
| def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         if overwrite: | ||||
|             for collection in CHROMA_CLIENT.list_collections(): | ||||
|                 if collection_name == collection.name: | ||||
|                     print(f"deleting existing collection {collection_name}") | ||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||
| 
 | ||||
|         collection = CHROMA_CLIENT.create_collection( | ||||
|             name=collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|         ) | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         if e.__class__.__name__ == "UniqueConstraintError": | ||||
|             return True | ||||
| 
 | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def get_loader(filename: str, file_content_type: str, file_path: str): | ||||
|     file_ext = filename.split(".")[-1].lower() | ||||
|     known_type = True | ||||
|  | @ -416,7 +436,7 @@ def store_doc( | |||
| ): | ||||
|     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" | ||||
| 
 | ||||
|     print(file.content_type) | ||||
|     log.info(f"file.content_type: {file.content_type}") | ||||
|     try: | ||||
|         filename = file.filename | ||||
|         file_path = f"{UPLOAD_DIR}/{filename}" | ||||
|  | @ -447,7 +467,7 @@ def store_doc( | |||
|                 detail=ERROR_MESSAGES.DEFAULT(), | ||||
|             ) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         if "No pandoc was found" in str(e): | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_400_BAD_REQUEST, | ||||
|  | @ -460,6 +480,37 @@ def store_doc( | |||
|             ) | ||||
| 
 | ||||
| 
 | ||||
| class TextRAGForm(BaseModel): | ||||
|     name: str | ||||
|     content: str | ||||
|     collection_name: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/text") | ||||
| def store_text( | ||||
|     form_data: TextRAGForm, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
| 
 | ||||
|     collection_name = form_data.collection_name | ||||
|     if collection_name == None: | ||||
|         collection_name = calculate_sha256_string(form_data.content) | ||||
| 
 | ||||
|     result = store_text_in_vector_db( | ||||
|         form_data.content, | ||||
|         metadata={"name": form_data.name, "created_by": user.id}, | ||||
|         collection_name=collection_name, | ||||
|     ) | ||||
| 
 | ||||
|     if result: | ||||
|         return {"status": True, "collection_name": collection_name} | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/scan") | ||||
| def scan_docs_dir(user=Depends(get_admin_user)): | ||||
|     for path in Path(DOCS_DIR).rglob("./**/*"): | ||||
|  | @ -512,7 +563,7 @@ def scan_docs_dir(user=Depends(get_admin_user)): | |||
|                         ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
| 
 | ||||
|     return True | ||||
| 
 | ||||
|  | @ -533,11 +584,11 @@ def reset(user=Depends(get_admin_user)) -> bool: | |||
|             elif os.path.isdir(file_path): | ||||
|                 shutil.rmtree(file_path) | ||||
|         except Exception as e: | ||||
|             print("Failed to delete %s. Reason: %s" % (file_path, e)) | ||||
|             log.error("Failed to delete %s. Reason: %s" % (file_path, e)) | ||||
| 
 | ||||
|     try: | ||||
|         CHROMA_CLIENT.reset() | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
| 
 | ||||
|     return True | ||||
|  |  | |||
|  | @ -1,7 +1,11 @@ | |||
| import re | ||||
| import logging | ||||
| from typing import List | ||||
| 
 | ||||
| from config import CHROMA_CLIENT | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
| 
 | ||||
| def query_doc(collection_name: str, query: str, k: int, embedding_function): | ||||
|  | @ -97,7 +101,7 @@ def rag_template(template: str, context: str, query: str): | |||
| 
 | ||||
| 
 | ||||
| def rag_messages(docs, messages, template, k, embedding_function): | ||||
|     print(docs) | ||||
|     log.debug(f"docs: {docs}") | ||||
| 
 | ||||
|     last_user_message_idx = None | ||||
|     for i in range(len(messages) - 1, -1, -1): | ||||
|  | @ -137,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|                     k=k, | ||||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|             elif doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|             else: | ||||
|                 context = query_doc( | ||||
|                     collection_name=doc["collection_name"], | ||||
|  | @ -145,7 +151,7 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             context = None | ||||
| 
 | ||||
|         relevant_contexts.append(context) | ||||
|  |  | |||
|  | @ -1,13 +1,16 @@ | |||
| from peewee import * | ||||
| from config import DATA_DIR | ||||
| from config import SRC_LOG_LEVELS, DATA_DIR | ||||
| import os | ||||
| import logging | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["DB"]) | ||||
| 
 | ||||
| # Check if the file exists | ||||
| if os.path.exists(f"{DATA_DIR}/ollama.db"): | ||||
|     # Rename the file | ||||
|     os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") | ||||
|     print("File renamed successfully.") | ||||
|     log.info("File renamed successfully.") | ||||
| else: | ||||
|     pass | ||||
| 
 | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ from config import ( | |||
|     DEFAULT_USER_ROLE, | ||||
|     ENABLE_SIGNUP, | ||||
|     USER_PERMISSIONS, | ||||
|     WEBHOOK_URL, | ||||
| ) | ||||
| 
 | ||||
| app = FastAPI() | ||||
|  | @ -32,6 +33,7 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS | |||
| app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS | ||||
| app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE | ||||
| app.state.USER_PERMISSIONS = USER_PERMISSIONS | ||||
| app.state.WEBHOOK_URL = WEBHOOK_URL | ||||
| 
 | ||||
| 
 | ||||
| app.add_middleware( | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ from pydantic import BaseModel | |||
| from typing import List, Union, Optional | ||||
| import time | ||||
| import uuid | ||||
| import logging | ||||
| from peewee import * | ||||
| 
 | ||||
| from apps.web.models.users import UserModel, Users | ||||
|  | @ -9,6 +10,10 @@ from utils.utils import verify_password | |||
| 
 | ||||
| from apps.web.internal.db import DB | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
| #################### | ||||
| # DB MODEL | ||||
| #################### | ||||
|  | @ -86,7 +91,7 @@ class AuthsTable: | |||
|     def insert_new_auth( | ||||
|         self, email: str, password: str, name: str, role: str = "pending" | ||||
|     ) -> Optional[UserModel]: | ||||
|         print("insert_new_auth") | ||||
|         log.info("insert_new_auth") | ||||
| 
 | ||||
|         id = str(uuid.uuid4()) | ||||
| 
 | ||||
|  | @ -103,7 +108,7 @@ class AuthsTable: | |||
|             return None | ||||
| 
 | ||||
|     def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: | ||||
|         print("authenticate_user", email) | ||||
|         log.info(f"authenticate_user: {email}") | ||||
|         try: | ||||
|             auth = Auth.get(Auth.email == email, Auth.active == True) | ||||
|             if auth: | ||||
|  |  | |||
|  | @ -95,20 +95,6 @@ class ChatTable: | |||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: | ||||
|         try: | ||||
|             query = Chat.update( | ||||
|                 chat=json.dumps(chat), | ||||
|                 title=chat["title"] if "title" in chat else "New Chat", | ||||
|                 timestamp=int(time.time()), | ||||
|             ).where(Chat.id == id) | ||||
|             query.execute() | ||||
| 
 | ||||
|             chat = Chat.get(Chat.id == id) | ||||
|             return ChatModel(**model_to_dict(chat)) | ||||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def get_chat_lists_by_user_id( | ||||
|         self, user_id: str, skip: int = 0, limit: int = 50 | ||||
|     ) -> List[ChatModel]: | ||||
|  |  | |||
|  | @ -3,6 +3,7 @@ from peewee import * | |||
| from playhouse.shortcuts import model_to_dict | ||||
| from typing import List, Union, Optional | ||||
| import time | ||||
| import logging | ||||
| 
 | ||||
| from utils.utils import decode_token | ||||
| from utils.misc import get_gravatar_url | ||||
|  | @ -11,6 +12,10 @@ from apps.web.internal.db import DB | |||
| 
 | ||||
| import json | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
| #################### | ||||
| # Documents DB Schema | ||||
| #################### | ||||
|  | @ -118,7 +123,7 @@ class DocumentsTable: | |||
|             doc = Document.get(Document.name == form_data.name) | ||||
|             return DocumentModel(**model_to_dict(doc)) | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             return None | ||||
| 
 | ||||
|     def update_doc_content_by_name( | ||||
|  | @ -138,7 +143,7 @@ class DocumentsTable: | |||
|             doc = Document.get(Document.name == name) | ||||
|             return DocumentModel(**model_to_dict(doc)) | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             return None | ||||
| 
 | ||||
|     def delete_doc_by_name(self, name: str) -> bool: | ||||
|  |  | |||
|  | @ -6,9 +6,14 @@ from playhouse.shortcuts import model_to_dict | |||
| import json | ||||
| import uuid | ||||
| import time | ||||
| import logging | ||||
| 
 | ||||
| from apps.web.internal.db import DB | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
| #################### | ||||
| # Tag DB Schema | ||||
| #################### | ||||
|  | @ -173,7 +178,7 @@ class TagTable: | |||
|                 (ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id) | ||||
|             ) | ||||
|             res = query.execute()  # Remove the rows, return number of rows removed. | ||||
|             print(res) | ||||
|             log.debug(f"res: {res}") | ||||
| 
 | ||||
|             tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) | ||||
|             if tag_count == 0: | ||||
|  | @ -185,7 +190,7 @@ class TagTable: | |||
| 
 | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             print("delete_tag", e) | ||||
|             log.error(f"delete_tag: {e}") | ||||
|             return False | ||||
| 
 | ||||
|     def delete_tag_by_tag_name_and_chat_id_and_user_id( | ||||
|  | @ -198,7 +203,7 @@ class TagTable: | |||
|                 & (ChatIdTag.user_id == user_id) | ||||
|             ) | ||||
|             res = query.execute()  # Remove the rows, return number of rows removed. | ||||
|             print(res) | ||||
|             log.debug(f"res: {res}") | ||||
| 
 | ||||
|             tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) | ||||
|             if tag_count == 0: | ||||
|  | @ -210,7 +215,7 @@ class TagTable: | |||
| 
 | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             print("delete_tag", e) | ||||
|             log.error(f"delete_tag: {e}") | ||||
|             return False | ||||
| 
 | ||||
|     def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: | ||||
|  |  | |||
|  | @ -27,7 +27,8 @@ from utils.utils import ( | |||
|     create_token, | ||||
| ) | ||||
| from utils.misc import parse_duration, validate_email_format | ||||
| from constants import ERROR_MESSAGES | ||||
| from utils.webhook import post_webhook | ||||
| from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
|  | @ -155,6 +156,17 @@ async def signup(request: Request, form_data: SignupForm): | |||
|             ) | ||||
|             # response.set_cookie(key='token', value=token, httponly=True) | ||||
| 
 | ||||
|             if request.app.state.WEBHOOK_URL: | ||||
|                 post_webhook( | ||||
|                     request.app.state.WEBHOOK_URL, | ||||
|                     WEBHOOK_MESSAGES.USER_SIGNUP(user.name), | ||||
|                     { | ||||
|                         "action": "signup", | ||||
|                         "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), | ||||
|                         "user": user.model_dump_json(exclude_none=True), | ||||
|                     }, | ||||
|                 ) | ||||
| 
 | ||||
|             return { | ||||
|                 "token": token, | ||||
|                 "token_type": "Bearer", | ||||
|  |  | |||
|  | @ -5,6 +5,7 @@ from utils.utils import get_current_user, get_admin_user | |||
| from fastapi import APIRouter | ||||
| from pydantic import BaseModel | ||||
| import json | ||||
| import logging | ||||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| from apps.web.models.chats import ( | ||||
|  | @ -27,6 +28,10 @@ from apps.web.models.tags import ( | |||
| 
 | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
| ############################ | ||||
|  | @ -78,7 +83,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): | |||
|         chat = Chats.insert_new_chat(user.id, form_data) | ||||
|         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() | ||||
|         ) | ||||
|  | @ -95,7 +100,7 @@ async def get_all_tags(user=Depends(get_current_user)): | |||
|         tags = Tags.get_tags_by_user_id(user.id) | ||||
|         return tags | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() | ||||
|         ) | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ from fastapi import APIRouter | |||
| from pydantic import BaseModel | ||||
| import time | ||||
| import uuid | ||||
| import logging | ||||
| 
 | ||||
| from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users | ||||
| from apps.web.models.auths import Auths | ||||
|  | @ -14,6 +15,10 @@ from apps.web.models.auths import Auths | |||
| from utils.utils import get_current_user, get_password_hash, get_admin_user | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
| ############################ | ||||
|  | @ -83,7 +88,7 @@ async def update_user_by_id( | |||
| 
 | ||||
|         if form_data.password: | ||||
|             hashed = get_password_hash(form_data.password) | ||||
|             print(hashed) | ||||
|             log.debug(f"hashed: {hashed}") | ||||
|             Auths.update_user_password_by_id(user_id, hashed) | ||||
| 
 | ||||
|         Auths.update_email_by_id(user_id, form_data.email.lower()) | ||||
|  |  | |||
|  | @ -21,155 +21,6 @@ from constants import ERROR_MESSAGES | |||
| router = APIRouter() | ||||
| 
 | ||||
| 
 | ||||
| class UploadBlobForm(BaseModel): | ||||
|     filename: str | ||||
| 
 | ||||
| 
 | ||||
| from urllib.parse import urlparse | ||||
| 
 | ||||
| 
 | ||||
| def parse_huggingface_url(hf_url): | ||||
|     try: | ||||
|         # Parse the URL | ||||
|         parsed_url = urlparse(hf_url) | ||||
| 
 | ||||
|         # Get the path and split it into components | ||||
|         path_components = parsed_url.path.split("/") | ||||
| 
 | ||||
|         # Extract the desired output | ||||
|         user_repo = "/".join(path_components[1:3]) | ||||
|         model_file = path_components[-1] | ||||
| 
 | ||||
|         return model_file | ||||
|     except ValueError: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024): | ||||
|     done = False | ||||
| 
 | ||||
|     if os.path.exists(file_path): | ||||
|         current_size = os.path.getsize(file_path) | ||||
|     else: | ||||
|         current_size = 0 | ||||
| 
 | ||||
|     headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} | ||||
| 
 | ||||
|     timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout | ||||
| 
 | ||||
|     async with aiohttp.ClientSession(timeout=timeout) as session: | ||||
|         async with session.get(url, headers=headers) as response: | ||||
|             total_size = int(response.headers.get("content-length", 0)) + current_size | ||||
| 
 | ||||
|             with open(file_path, "ab+") as file: | ||||
|                 async for data in response.content.iter_chunked(chunk_size): | ||||
|                     current_size += len(data) | ||||
|                     file.write(data) | ||||
| 
 | ||||
|                     done = current_size == total_size | ||||
|                     progress = round((current_size / total_size) * 100, 2) | ||||
|                     yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' | ||||
| 
 | ||||
|                 if done: | ||||
|                     file.seek(0) | ||||
|                     hashed = calculate_sha256(file) | ||||
|                     file.seek(0) | ||||
| 
 | ||||
|                     url = f"{OLLAMA_BASE_URLS[0]}/api/blobs/sha256:{hashed}" | ||||
|                     response = requests.post(url, data=file) | ||||
| 
 | ||||
|                     if response.ok: | ||||
|                         res = { | ||||
|                             "done": done, | ||||
|                             "blob": f"sha256:{hashed}", | ||||
|                             "name": file_name, | ||||
|                         } | ||||
|                         os.remove(file_path) | ||||
| 
 | ||||
|                         yield f"data: {json.dumps(res)}\n\n" | ||||
|                     else: | ||||
|                         raise "Ollama: Could not create blob, Please try again." | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/download") | ||||
| async def download( | ||||
|     url: str, | ||||
| ): | ||||
|     # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" | ||||
|     file_name = parse_huggingface_url(url) | ||||
| 
 | ||||
|     if file_name: | ||||
|         file_path = f"{UPLOAD_DIR}/{file_name}" | ||||
| 
 | ||||
|         return StreamingResponse( | ||||
|             download_file_stream(url, file_path, file_name), | ||||
|             media_type="text/event-stream", | ||||
|         ) | ||||
|     else: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/upload") | ||||
| def upload(file: UploadFile = File(...)): | ||||
|     file_path = f"{UPLOAD_DIR}/{file.filename}" | ||||
| 
 | ||||
|     # Save file in chunks | ||||
|     with open(file_path, "wb+") as f: | ||||
|         for chunk in file.file: | ||||
|             f.write(chunk) | ||||
| 
 | ||||
|     def file_process_stream(): | ||||
|         total_size = os.path.getsize(file_path) | ||||
|         chunk_size = 1024 * 1024 | ||||
|         try: | ||||
|             with open(file_path, "rb") as f: | ||||
|                 total = 0 | ||||
|                 done = False | ||||
| 
 | ||||
|                 while not done: | ||||
|                     chunk = f.read(chunk_size) | ||||
|                     if not chunk: | ||||
|                         done = True | ||||
|                         continue | ||||
| 
 | ||||
|                     total += len(chunk) | ||||
|                     progress = round((total / total_size) * 100, 2) | ||||
| 
 | ||||
|                     res = { | ||||
|                         "progress": progress, | ||||
|                         "total": total_size, | ||||
|                         "completed": total, | ||||
|                     } | ||||
|                     yield f"data: {json.dumps(res)}\n\n" | ||||
| 
 | ||||
|                 if done: | ||||
|                     f.seek(0) | ||||
|                     hashed = calculate_sha256(f) | ||||
|                     f.seek(0) | ||||
| 
 | ||||
|                     url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}" | ||||
|                     response = requests.post(url, data=f) | ||||
| 
 | ||||
|                     if response.ok: | ||||
|                         res = { | ||||
|                             "done": done, | ||||
|                             "blob": f"sha256:{hashed}", | ||||
|                             "name": file.filename, | ||||
|                         } | ||||
|                         os.remove(file_path) | ||||
|                         yield f"data: {json.dumps(res)}\n\n" | ||||
|                     else: | ||||
|                         raise Exception( | ||||
|                             "Ollama: Could not create blob, Please try again." | ||||
|                         ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             res = {"error": str(e)} | ||||
|             yield f"data: {json.dumps(res)}\n\n" | ||||
| 
 | ||||
|     return StreamingResponse(file_process_stream(), media_type="text/event-stream") | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/gravatar") | ||||
| async def get_gravatar( | ||||
|     email: str, | ||||
|  |  | |||
|  | @ -1,4 +1,6 @@ | |||
| import os | ||||
| import sys | ||||
| import logging | ||||
| import chromadb | ||||
| from chromadb import Settings | ||||
| from base64 import b64encode | ||||
|  | @ -21,7 +23,7 @@ try: | |||
| 
 | ||||
|     load_dotenv(find_dotenv("../.env")) | ||||
| except ImportError: | ||||
|     print("dotenv not installed, skipping...") | ||||
|     log.warning("dotenv not installed, skipping...") | ||||
| 
 | ||||
| WEBUI_NAME = "Aura" | ||||
| shutil.copyfile("../build/favicon.png", "./static/favicon.png") | ||||
|  | @ -100,6 +102,34 @@ for version in soup.find_all("h2"): | |||
| CHANGELOG = changelog_json | ||||
| 
 | ||||
| 
 | ||||
| #################################### | ||||
| # LOGGING | ||||
| #################################### | ||||
| log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] | ||||
| 
 | ||||
| GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper() | ||||
| if GLOBAL_LOG_LEVEL in log_levels: | ||||
|     logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True) | ||||
| else: | ||||
|     GLOBAL_LOG_LEVEL = "INFO" | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}") | ||||
| 
 | ||||
| log_sources = ["AUDIO", "CONFIG", "DB", "IMAGES", "LITELLM", "MAIN", "MODELS", "OLLAMA", "OPENAI", "RAG"] | ||||
| 
 | ||||
| SRC_LOG_LEVELS = {} | ||||
| 
 | ||||
| for source in log_sources: | ||||
|     log_env_var = source + "_LOG_LEVEL" | ||||
|     SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper() | ||||
|     if SRC_LOG_LEVELS[source] not in log_levels: | ||||
|         SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL | ||||
|     log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}") | ||||
| 
 | ||||
| log.setLevel(SRC_LOG_LEVELS["CONFIG"]) | ||||
| 
 | ||||
| 
 | ||||
| #################################### | ||||
| # CUSTOM_NAME | ||||
| #################################### | ||||
|  | @ -125,7 +155,7 @@ if CUSTOM_NAME: | |||
| 
 | ||||
|             WEBUI_NAME = data["name"] | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
|  | @ -194,9 +224,9 @@ def create_config_file(file_path): | |||
| LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml" | ||||
| 
 | ||||
| if not os.path.exists(LITELLM_CONFIG_PATH): | ||||
|     print("Config file doesn't exist. Creating...") | ||||
|     log.info("Config file doesn't exist. Creating...") | ||||
|     create_config_file(LITELLM_CONFIG_PATH) | ||||
|     print("Config file created successfully.") | ||||
|     log.info("Config file created successfully.") | ||||
| 
 | ||||
| 
 | ||||
| #################################### | ||||
|  | @ -290,13 +320,19 @@ DEFAULT_PROMPT_SUGGESTIONS = ( | |||
| 
 | ||||
| 
 | ||||
| DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") | ||||
| USER_PERMISSIONS = {"chat": {"deletion": True}} | ||||
| 
 | ||||
| USER_PERMISSIONS_CHAT_DELETION = ( | ||||
|     os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" | ||||
| ) | ||||
| 
 | ||||
| USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} | ||||
| 
 | ||||
| 
 | ||||
| MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False) | ||||
| MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true" | ||||
| MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") | ||||
| MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] | ||||
| 
 | ||||
| WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") | ||||
| 
 | ||||
| #################################### | ||||
| # WEBUI_VERSION | ||||
|  | @ -370,3 +406,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models" | |||
| #################################### | ||||
| 
 | ||||
| AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") | ||||
| COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") | ||||
|  |  | |||
|  | @ -5,6 +5,13 @@ class MESSAGES(str, Enum): | |||
|     DEFAULT = lambda msg="": f"{msg if msg else ''}" | ||||
| 
 | ||||
| 
 | ||||
| class WEBHOOK_MESSAGES(str, Enum): | ||||
|     DEFAULT = lambda msg="": f"{msg if msg else ''}" | ||||
|     USER_SIGNUP = lambda username="": ( | ||||
|         f"New user signed up: {username}" if username else "New user signed up" | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class ERROR_MESSAGES(str, Enum): | ||||
|     def __str__(self) -> str: | ||||
|         return super().__str__() | ||||
|  | @ -46,7 +53,7 @@ class ERROR_MESSAGES(str, Enum): | |||
| 
 | ||||
|     PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." | ||||
|     INCORRECT_FORMAT = ( | ||||
|         lambda err="": f"Invalid format. Please use the correct format{err if err else ''}" | ||||
|         lambda err="": f"Invalid format. Please use the correct format{err}" | ||||
|     ) | ||||
|     RATE_LIMIT_EXCEEDED = "API rate limit exceeded" | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| { | ||||
|     "version": "0.0.1", | ||||
|     "version": 0, | ||||
|     "ui": { | ||||
|         "prompt_suggestions": [ | ||||
|             { | ||||
|  |  | |||
|  | @ -4,6 +4,7 @@ import markdown | |||
| import time | ||||
| import os | ||||
| import sys | ||||
| import logging | ||||
| import requests | ||||
| 
 | ||||
| from fastapi import FastAPI, Request, Depends, status | ||||
|  | @ -38,9 +39,15 @@ from config import ( | |||
|     FRONTEND_BUILD_DIR, | ||||
|     MODEL_FILTER_ENABLED, | ||||
|     MODEL_FILTER_LIST, | ||||
|     GLOBAL_LOG_LEVEL, | ||||
|     SRC_LOG_LEVELS, | ||||
|     WEBHOOK_URL, | ||||
| ) | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MAIN"]) | ||||
| 
 | ||||
| class SPAStaticFiles(StaticFiles): | ||||
|     async def get_response(self, path: str, scope): | ||||
|  | @ -58,6 +65,9 @@ app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) | |||
| app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED | ||||
| app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST | ||||
| 
 | ||||
| app.state.WEBHOOK_URL = WEBHOOK_URL | ||||
| 
 | ||||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -66,7 +76,7 @@ class RAGMiddleware(BaseHTTPMiddleware): | |||
|         if request.method == "POST" and ( | ||||
|             "/api/chat" in request.url.path or "/chat/completions" in request.url.path | ||||
|         ): | ||||
|             print(request.url.path) | ||||
|             log.debug(f"request.url.path: {request.url.path}") | ||||
| 
 | ||||
|             # Read the original request body | ||||
|             body = await request.body() | ||||
|  | @ -89,7 +99,7 @@ class RAGMiddleware(BaseHTTPMiddleware): | |||
|                 ) | ||||
|                 del data["docs"] | ||||
| 
 | ||||
|                 print(data["messages"]) | ||||
|                 log.debug(f"data['messages']: {data['messages']}") | ||||
| 
 | ||||
|             modified_body_bytes = json.dumps(data).encode("utf-8") | ||||
| 
 | ||||
|  | @ -178,7 +188,7 @@ class ModelFilterConfigForm(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| @app.post("/api/config/model/filter") | ||||
| async def get_model_filter_config( | ||||
| async def update_model_filter_config( | ||||
|     form_data: ModelFilterConfigForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|  | @ -197,6 +207,28 @@ async def get_model_filter_config( | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/api/webhook") | ||||
| async def get_webhook_url(user=Depends(get_admin_user)): | ||||
|     return { | ||||
|         "url": app.state.WEBHOOK_URL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class UrlForm(BaseModel): | ||||
|     url: str | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/api/webhook") | ||||
| async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): | ||||
|     app.state.WEBHOOK_URL = form_data.url | ||||
| 
 | ||||
|     webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL | ||||
| 
 | ||||
|     return { | ||||
|         "url": app.state.WEBHOOK_URL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/api/version") | ||||
| async def get_app_config(): | ||||
| 
 | ||||
|  |  | |||
|  | @ -45,3 +45,4 @@ PyJWT | |||
| pyjwt[crypto] | ||||
| 
 | ||||
| black | ||||
| langfuse | ||||
|  |  | |||
							
								
								
									
										20
									
								
								backend/utils/webhook.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								backend/utils/webhook.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,20 @@ | |||
| import requests | ||||
| 
 | ||||
| 
 | ||||
| def post_webhook(url: str, message: str, event_data: dict) -> bool: | ||||
|     try: | ||||
|         payload = {} | ||||
| 
 | ||||
|         if "https://hooks.slack.com" in url: | ||||
|             payload["text"] = message | ||||
|         elif "https://discord.com/api/webhooks" in url: | ||||
|             payload["content"] = message | ||||
|         else: | ||||
|             payload = {**event_data} | ||||
| 
 | ||||
|         r = requests.post(url, json=payload) | ||||
|         r.raise_for_status() | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         return False | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue