forked from open-webui/open-webui
		
	main #3
					 4 changed files with 315 additions and 8 deletions
				
			
		|  | @ -18,6 +18,8 @@ from utils.utils import ( | |||
|     get_current_user, | ||||
|     get_admin_user, | ||||
| ) | ||||
| 
 | ||||
| from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image | ||||
| from utils.misc import calculate_sha256 | ||||
| from typing import Optional | ||||
| from pydantic import BaseModel | ||||
|  | @ -105,7 +107,12 @@ async def update_engine_url( | |||
|         app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL | ||||
|     else: | ||||
|         url = form_data.COMFYUI_BASE_URL.strip("/") | ||||
|         app.state.COMFYUI_BASE_URL = url | ||||
| 
 | ||||
|         try: | ||||
|             r = requests.head(url) | ||||
|             app.state.COMFYUI_BASE_URL = url | ||||
|         except Exception as e: | ||||
|             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
| 
 | ||||
|     return { | ||||
|         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, | ||||
|  | @ -232,6 +239,8 @@ async def get_default_model(user=Depends(get_admin_user)): | |||
|     try: | ||||
|         if app.state.ENGINE == "openai": | ||||
|             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} | ||||
|         elif app.state.ENGINE == "comfyui": | ||||
|             return {"model": app.state.MODEL if app.state.MODEL else ""} | ||||
|         else: | ||||
|             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|             options = r.json() | ||||
|  | @ -246,10 +255,12 @@ class UpdateModelForm(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| def set_model_handler(model: str): | ||||
| 
 | ||||
|     if app.state.ENGINE == "openai": | ||||
|         app.state.MODEL = model | ||||
|         return app.state.MODEL | ||||
|     if app.state.ENGINE == "comfyui": | ||||
|         app.state.MODEL = model | ||||
|         return app.state.MODEL | ||||
|     else: | ||||
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|         options = r.json() | ||||
|  | @ -297,12 +308,31 @@ def save_b64_image(b64_str): | |||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| def save_url_image(url): | ||||
|     image_id = str(uuid.uuid4()) | ||||
|     file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") | ||||
| 
 | ||||
|     try: | ||||
|         r = requests.get(url) | ||||
|         r.raise_for_status() | ||||
| 
 | ||||
|         with open(file_path, "wb") as image_file: | ||||
|             image_file.write(r.content) | ||||
| 
 | ||||
|         return image_id | ||||
|     except Exception as e: | ||||
|         print(f"Error saving image: {e}") | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/generations") | ||||
| def generate_image( | ||||
|     form_data: GenerateImageForm, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
| 
 | ||||
|     width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||
| 
 | ||||
|     r = None | ||||
|     try: | ||||
|         if app.state.ENGINE == "openai": | ||||
|  | @ -340,12 +370,47 @@ def generate_image( | |||
| 
 | ||||
|             return images | ||||
| 
 | ||||
|         elif app.state.ENGINE == "comfyui": | ||||
| 
 | ||||
|             data = { | ||||
|                 "prompt": form_data.prompt, | ||||
|                 "width": width, | ||||
|                 "height": height, | ||||
|                 "n": form_data.n, | ||||
|             } | ||||
| 
 | ||||
|             if app.state.IMAGE_STEPS != None: | ||||
|                 data["steps"] = app.state.IMAGE_STEPS | ||||
| 
 | ||||
|             if form_data.negative_prompt != None: | ||||
|                 data["negative_prompt"] = form_data.negative_prompt | ||||
| 
 | ||||
|             data = ImageGenerationPayload(**data) | ||||
| 
 | ||||
|             res = comfyui_generate_image( | ||||
|                 app.state.MODEL, | ||||
|                 data, | ||||
|                 user.id, | ||||
|                 app.state.COMFYUI_BASE_URL, | ||||
|             ) | ||||
|             print(res) | ||||
| 
 | ||||
|             images = [] | ||||
| 
 | ||||
|             for image in res["data"]: | ||||
|                 image_id = save_url_image(image["url"]) | ||||
|                 images.append({"url": f"/cache/image/generations/{image_id}.png"}) | ||||
|                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") | ||||
| 
 | ||||
|                 with open(file_body_path, "w") as f: | ||||
|                     json.dump(data.model_dump(exclude_none=True), f) | ||||
| 
 | ||||
|             print(images) | ||||
|             return images | ||||
|         else: | ||||
|             if form_data.model: | ||||
|                 set_model_handler(form_data.model) | ||||
| 
 | ||||
|             width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||
| 
 | ||||
|             data = { | ||||
|                 "prompt": form_data.prompt, | ||||
|                 "batch_size": form_data.n, | ||||
|  |  | |||
							
								
								
									
										228
									
								
								backend/apps/images/utils/comfyui.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								backend/apps/images/utils/comfyui.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,228 @@ | |||
| import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) | ||||
| import uuid | ||||
| import json | ||||
| import urllib.request | ||||
| import urllib.parse | ||||
| import random | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from typing import Optional | ||||
| 
 | ||||
| COMFYUI_DEFAULT_PROMPT = """ | ||||
| { | ||||
|   "3": { | ||||
|     "inputs": { | ||||
|       "seed": 0, | ||||
|       "steps": 20, | ||||
|       "cfg": 8, | ||||
|       "sampler_name": "euler", | ||||
|       "scheduler": "normal", | ||||
|       "denoise": 1, | ||||
|       "model": [ | ||||
|         "4", | ||||
|         0 | ||||
|       ], | ||||
|       "positive": [ | ||||
|         "6", | ||||
|         0 | ||||
|       ], | ||||
|       "negative": [ | ||||
|         "7", | ||||
|         0 | ||||
|       ], | ||||
|       "latent_image": [ | ||||
|         "5", | ||||
|         0 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "KSampler", | ||||
|     "_meta": { | ||||
|       "title": "KSampler" | ||||
|     } | ||||
|   }, | ||||
|   "4": { | ||||
|     "inputs": { | ||||
|       "ckpt_name": "model.safetensors" | ||||
|     }, | ||||
|     "class_type": "CheckpointLoaderSimple", | ||||
|     "_meta": { | ||||
|       "title": "Load Checkpoint" | ||||
|     } | ||||
|   }, | ||||
|   "5": { | ||||
|     "inputs": { | ||||
|       "width": 512, | ||||
|       "height": 512, | ||||
|       "batch_size": 1 | ||||
|     }, | ||||
|     "class_type": "EmptyLatentImage", | ||||
|     "_meta": { | ||||
|       "title": "Empty Latent Image" | ||||
|     } | ||||
|   }, | ||||
|   "6": { | ||||
|     "inputs": { | ||||
|       "text": "Prompt", | ||||
|       "clip": [ | ||||
|         "4", | ||||
|         1 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "CLIPTextEncode", | ||||
|     "_meta": { | ||||
|       "title": "CLIP Text Encode (Prompt)" | ||||
|     } | ||||
|   }, | ||||
|   "7": { | ||||
|     "inputs": { | ||||
|       "text": "Negative Prompt", | ||||
|       "clip": [ | ||||
|         "4", | ||||
|         1 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "CLIPTextEncode", | ||||
|     "_meta": { | ||||
|       "title": "CLIP Text Encode (Prompt)" | ||||
|     } | ||||
|   }, | ||||
|   "8": { | ||||
|     "inputs": { | ||||
|       "samples": [ | ||||
|         "3", | ||||
|         0 | ||||
|       ], | ||||
|       "vae": [ | ||||
|         "4", | ||||
|         2 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "VAEDecode", | ||||
|     "_meta": { | ||||
|       "title": "VAE Decode" | ||||
|     } | ||||
|   }, | ||||
|   "9": { | ||||
|     "inputs": { | ||||
|       "filename_prefix": "ComfyUI", | ||||
|       "images": [ | ||||
|         "8", | ||||
|         0 | ||||
|       ] | ||||
|     }, | ||||
|     "class_type": "SaveImage", | ||||
|     "_meta": { | ||||
|       "title": "Save Image" | ||||
|     } | ||||
|   } | ||||
| } | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| def queue_prompt(prompt, client_id, base_url): | ||||
|     print("queue_prompt") | ||||
|     p = {"prompt": prompt, "client_id": client_id} | ||||
|     data = json.dumps(p).encode("utf-8") | ||||
|     req = urllib.request.Request(f"{base_url}/prompt", data=data) | ||||
|     return json.loads(urllib.request.urlopen(req).read()) | ||||
| 
 | ||||
| 
 | ||||
| def get_image(filename, subfolder, folder_type, base_url): | ||||
|     print("get_image") | ||||
|     data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||||
|     url_values = urllib.parse.urlencode(data) | ||||
|     with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response: | ||||
|         return response.read() | ||||
| 
 | ||||
| 
 | ||||
| def get_image_url(filename, subfolder, folder_type, base_url): | ||||
|     print("get_image") | ||||
|     data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||||
|     url_values = urllib.parse.urlencode(data) | ||||
|     return f"{base_url}/view?{url_values}" | ||||
| 
 | ||||
| 
 | ||||
| def get_history(prompt_id, base_url): | ||||
|     print("get_history") | ||||
|     with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response: | ||||
|         return json.loads(response.read()) | ||||
| 
 | ||||
| 
 | ||||
| def get_images(ws, prompt, client_id, base_url): | ||||
|     prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] | ||||
|     output_images = [] | ||||
|     while True: | ||||
|         out = ws.recv() | ||||
|         if isinstance(out, str): | ||||
|             message = json.loads(out) | ||||
|             if message["type"] == "executing": | ||||
|                 data = message["data"] | ||||
|                 if data["node"] is None and data["prompt_id"] == prompt_id: | ||||
|                     break  # Execution is done | ||||
|         else: | ||||
|             continue  # previews are binary data | ||||
| 
 | ||||
|     history = get_history(prompt_id, base_url)[prompt_id] | ||||
|     for o in history["outputs"]: | ||||
|         for node_id in history["outputs"]: | ||||
|             node_output = history["outputs"][node_id] | ||||
|             if "images" in node_output: | ||||
|                 for image in node_output["images"]: | ||||
|                     url = get_image_url( | ||||
|                         image["filename"], image["subfolder"], image["type"], base_url | ||||
|                     ) | ||||
|                     output_images.append({"url": url}) | ||||
|     return {"data": output_images} | ||||
| 
 | ||||
| 
 | ||||
| class ImageGenerationPayload(BaseModel): | ||||
|     prompt: str | ||||
|     negative_prompt: Optional[str] = "" | ||||
|     steps: Optional[int] = None | ||||
|     seed: Optional[int] = None | ||||
|     width: int | ||||
|     height: int | ||||
|     n: int = 1 | ||||
| 
 | ||||
| 
 | ||||
| def comfyui_generate_image( | ||||
|     model: str, payload: ImageGenerationPayload, client_id, base_url | ||||
| ): | ||||
|     host = base_url.replace("http://", "").replace("https://", "") | ||||
| 
 | ||||
|     comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) | ||||
| 
 | ||||
|     comfyui_prompt["4"]["inputs"]["ckpt_name"] = model | ||||
|     comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n | ||||
|     comfyui_prompt["5"]["inputs"]["width"] = payload.width | ||||
|     comfyui_prompt["5"]["inputs"]["height"] = payload.height | ||||
| 
 | ||||
|     # set the text prompt for our positive CLIPTextEncode | ||||
|     comfyui_prompt["6"]["inputs"]["text"] = payload.prompt | ||||
|     comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt | ||||
| 
 | ||||
|     if payload.steps: | ||||
|         comfyui_prompt["3"]["inputs"]["steps"] = payload.steps | ||||
| 
 | ||||
|     comfyui_prompt["3"]["inputs"]["seed"] = ( | ||||
|         payload.seed if payload.seed else random.randint(0, 18446744073709551614) | ||||
|     ) | ||||
| 
 | ||||
|     try: | ||||
|         ws = websocket.WebSocket() | ||||
|         ws.connect(f"ws://{host}/ws?clientId={client_id}") | ||||
|         print("WebSocket connection established.") | ||||
|     except Exception as e: | ||||
|         print(f"Failed to connect to WebSocket server: {e}") | ||||
|         return None | ||||
| 
 | ||||
|     try: | ||||
|         images = get_images(ws, comfyui_prompt, client_id, base_url) | ||||
|     except Exception as e: | ||||
|         print(f"Error while receiving images: {e}") | ||||
|         images = None | ||||
| 
 | ||||
|     ws.close() | ||||
| 
 | ||||
|     return images | ||||
|  | @ -323,6 +323,7 @@ | |||
| 							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||
| 							bind:value={selectedModel} | ||||
| 							placeholder={$i18n.t('Select a model')} | ||||
| 							required | ||||
| 						> | ||||
| 							{#if !selectedModel} | ||||
| 								<option value="" disabled selected>{$i18n.t('Select a model')}</option> | ||||
|  |  | |||
|  | @ -2,6 +2,22 @@ | |||
| 	export let show = false; | ||||
| 	export let src = ''; | ||||
| 	export let alt = ''; | ||||
| 
 | ||||
| 	const downloadImage = (url, filename) => { | ||||
| 		fetch(url) | ||||
| 			.then((response) => response.blob()) | ||||
| 			.then((blob) => { | ||||
| 				const objectUrl = window.URL.createObjectURL(blob); | ||||
| 				const link = document.createElement('a'); | ||||
| 				link.href = objectUrl; | ||||
| 				link.download = filename; | ||||
| 				document.body.appendChild(link); | ||||
| 				link.click(); | ||||
| 				document.body.removeChild(link); | ||||
| 				window.URL.revokeObjectURL(objectUrl); | ||||
| 			}) | ||||
| 			.catch((error) => console.error('Error downloading image:', error)); | ||||
| 	}; | ||||
| </script> | ||||
| 
 | ||||
| {#if show} | ||||
|  | @ -35,10 +51,7 @@ | |||
| 				<button | ||||
| 					class=" p-5" | ||||
| 					on:click={() => { | ||||
| 						const a = document.createElement('a'); | ||||
| 						a.href = src; | ||||
| 						a.download = 'Image.png'; | ||||
| 						a.click(); | ||||
| 						downloadImage(src, 'Image.png'); | ||||
| 					}} | ||||
| 				> | ||||
| 					<svg | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue