forked from open-webui/open-webui
		
	main #3
					 4 changed files with 315 additions and 8 deletions
				
			
		|  | @ -18,6 +18,8 @@ from utils.utils import ( | ||||||
|     get_current_user, |     get_current_user, | ||||||
|     get_admin_user, |     get_admin_user, | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image | ||||||
| from utils.misc import calculate_sha256 | from utils.misc import calculate_sha256 | ||||||
| from typing import Optional | from typing import Optional | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
|  | @ -105,7 +107,12 @@ async def update_engine_url( | ||||||
|         app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL |         app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL | ||||||
|     else: |     else: | ||||||
|         url = form_data.COMFYUI_BASE_URL.strip("/") |         url = form_data.COMFYUI_BASE_URL.strip("/") | ||||||
|         app.state.COMFYUI_BASE_URL = url | 
 | ||||||
|  |         try: | ||||||
|  |             r = requests.head(url) | ||||||
|  |             app.state.COMFYUI_BASE_URL = url | ||||||
|  |         except Exception as e: | ||||||
|  |             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||||
| 
 | 
 | ||||||
|     return { |     return { | ||||||
|         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, |         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, | ||||||
|  | @ -232,6 +239,8 @@ async def get_default_model(user=Depends(get_admin_user)): | ||||||
|     try: |     try: | ||||||
|         if app.state.ENGINE == "openai": |         if app.state.ENGINE == "openai": | ||||||
|             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} |             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} | ||||||
|  |         elif app.state.ENGINE == "comfyui": | ||||||
|  |             return {"model": app.state.MODEL if app.state.MODEL else ""} | ||||||
|         else: |         else: | ||||||
|             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") |             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||||
|             options = r.json() |             options = r.json() | ||||||
|  | @ -246,10 +255,12 @@ class UpdateModelForm(BaseModel): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def set_model_handler(model: str): | def set_model_handler(model: str): | ||||||
| 
 |  | ||||||
|     if app.state.ENGINE == "openai": |     if app.state.ENGINE == "openai": | ||||||
|         app.state.MODEL = model |         app.state.MODEL = model | ||||||
|         return app.state.MODEL |         return app.state.MODEL | ||||||
|  |     if app.state.ENGINE == "comfyui": | ||||||
|  |         app.state.MODEL = model | ||||||
|  |         return app.state.MODEL | ||||||
|     else: |     else: | ||||||
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") |         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||||
|         options = r.json() |         options = r.json() | ||||||
|  | @ -297,12 +308,31 @@ def save_b64_image(b64_str): | ||||||
|         return None |         return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def save_url_image(url): | ||||||
|  |     image_id = str(uuid.uuid4()) | ||||||
|  |     file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         r = requests.get(url) | ||||||
|  |         r.raise_for_status() | ||||||
|  | 
 | ||||||
|  |         with open(file_path, "wb") as image_file: | ||||||
|  |             image_file.write(r.content) | ||||||
|  | 
 | ||||||
|  |         return image_id | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"Error saving image: {e}") | ||||||
|  |         return None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @app.post("/generations") | @app.post("/generations") | ||||||
| def generate_image( | def generate_image( | ||||||
|     form_data: GenerateImageForm, |     form_data: GenerateImageForm, | ||||||
|     user=Depends(get_current_user), |     user=Depends(get_current_user), | ||||||
| ): | ): | ||||||
| 
 | 
 | ||||||
|  |     width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||||
|  | 
 | ||||||
|     r = None |     r = None | ||||||
|     try: |     try: | ||||||
|         if app.state.ENGINE == "openai": |         if app.state.ENGINE == "openai": | ||||||
|  | @ -340,12 +370,47 @@ def generate_image( | ||||||
| 
 | 
 | ||||||
|             return images |             return images | ||||||
| 
 | 
 | ||||||
|  |         elif app.state.ENGINE == "comfyui": | ||||||
|  | 
 | ||||||
|  |             data = { | ||||||
|  |                 "prompt": form_data.prompt, | ||||||
|  |                 "width": width, | ||||||
|  |                 "height": height, | ||||||
|  |                 "n": form_data.n, | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             if app.state.IMAGE_STEPS != None: | ||||||
|  |                 data["steps"] = app.state.IMAGE_STEPS | ||||||
|  | 
 | ||||||
|  |             if form_data.negative_prompt != None: | ||||||
|  |                 data["negative_prompt"] = form_data.negative_prompt | ||||||
|  | 
 | ||||||
|  |             data = ImageGenerationPayload(**data) | ||||||
|  | 
 | ||||||
|  |             res = comfyui_generate_image( | ||||||
|  |                 app.state.MODEL, | ||||||
|  |                 data, | ||||||
|  |                 user.id, | ||||||
|  |                 app.state.COMFYUI_BASE_URL, | ||||||
|  |             ) | ||||||
|  |             print(res) | ||||||
|  | 
 | ||||||
|  |             images = [] | ||||||
|  | 
 | ||||||
|  |             for image in res["data"]: | ||||||
|  |                 image_id = save_url_image(image["url"]) | ||||||
|  |                 images.append({"url": f"/cache/image/generations/{image_id}.png"}) | ||||||
|  |                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") | ||||||
|  | 
 | ||||||
|  |                 with open(file_body_path, "w") as f: | ||||||
|  |                     json.dump(data.model_dump(exclude_none=True), f) | ||||||
|  | 
 | ||||||
|  |             print(images) | ||||||
|  |             return images | ||||||
|         else: |         else: | ||||||
|             if form_data.model: |             if form_data.model: | ||||||
|                 set_model_handler(form_data.model) |                 set_model_handler(form_data.model) | ||||||
| 
 | 
 | ||||||
|             width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) |  | ||||||
| 
 |  | ||||||
|             data = { |             data = { | ||||||
|                 "prompt": form_data.prompt, |                 "prompt": form_data.prompt, | ||||||
|                 "batch_size": form_data.n, |                 "batch_size": form_data.n, | ||||||
|  |  | ||||||
							
								
								
									
										228
									
								
								backend/apps/images/utils/comfyui.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								backend/apps/images/utils/comfyui.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,228 @@ | ||||||
|  | import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) | ||||||
|  | import uuid | ||||||
|  | import json | ||||||
|  | import urllib.request | ||||||
|  | import urllib.parse | ||||||
|  | import random | ||||||
|  | 
 | ||||||
|  | from pydantic import BaseModel | ||||||
|  | 
 | ||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | COMFYUI_DEFAULT_PROMPT = """ | ||||||
|  | { | ||||||
|  |   "3": { | ||||||
|  |     "inputs": { | ||||||
|  |       "seed": 0, | ||||||
|  |       "steps": 20, | ||||||
|  |       "cfg": 8, | ||||||
|  |       "sampler_name": "euler", | ||||||
|  |       "scheduler": "normal", | ||||||
|  |       "denoise": 1, | ||||||
|  |       "model": [ | ||||||
|  |         "4", | ||||||
|  |         0 | ||||||
|  |       ], | ||||||
|  |       "positive": [ | ||||||
|  |         "6", | ||||||
|  |         0 | ||||||
|  |       ], | ||||||
|  |       "negative": [ | ||||||
|  |         "7", | ||||||
|  |         0 | ||||||
|  |       ], | ||||||
|  |       "latent_image": [ | ||||||
|  |         "5", | ||||||
|  |         0 | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     "class_type": "KSampler", | ||||||
|  |     "_meta": { | ||||||
|  |       "title": "KSampler" | ||||||
|  |     } | ||||||
|  |   }, | ||||||
|  |   "4": { | ||||||
|  |     "inputs": { | ||||||
|  |       "ckpt_name": "model.safetensors" | ||||||
|  |     }, | ||||||
|  |     "class_type": "CheckpointLoaderSimple", | ||||||
|  |     "_meta": { | ||||||
|  |       "title": "Load Checkpoint" | ||||||
|  |     } | ||||||
|  |   }, | ||||||
|  |   "5": { | ||||||
|  |     "inputs": { | ||||||
|  |       "width": 512, | ||||||
|  |       "height": 512, | ||||||
|  |       "batch_size": 1 | ||||||
|  |     }, | ||||||
|  |     "class_type": "EmptyLatentImage", | ||||||
|  |     "_meta": { | ||||||
|  |       "title": "Empty Latent Image" | ||||||
|  |     } | ||||||
|  |   }, | ||||||
|  |   "6": { | ||||||
|  |     "inputs": { | ||||||
|  |       "text": "Prompt", | ||||||
|  |       "clip": [ | ||||||
|  |         "4", | ||||||
|  |         1 | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     "class_type": "CLIPTextEncode", | ||||||
|  |     "_meta": { | ||||||
|  |       "title": "CLIP Text Encode (Prompt)" | ||||||
|  |     } | ||||||
|  |   }, | ||||||
|  |   "7": { | ||||||
|  |     "inputs": { | ||||||
|  |       "text": "Negative Prompt", | ||||||
|  |       "clip": [ | ||||||
|  |         "4", | ||||||
|  |         1 | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     "class_type": "CLIPTextEncode", | ||||||
|  |     "_meta": { | ||||||
|  |       "title": "CLIP Text Encode (Prompt)" | ||||||
|  |     } | ||||||
|  |   }, | ||||||
|  |   "8": { | ||||||
|  |     "inputs": { | ||||||
|  |       "samples": [ | ||||||
|  |         "3", | ||||||
|  |         0 | ||||||
|  |       ], | ||||||
|  |       "vae": [ | ||||||
|  |         "4", | ||||||
|  |         2 | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     "class_type": "VAEDecode", | ||||||
|  |     "_meta": { | ||||||
|  |       "title": "VAE Decode" | ||||||
|  |     } | ||||||
|  |   }, | ||||||
|  |   "9": { | ||||||
|  |     "inputs": { | ||||||
|  |       "filename_prefix": "ComfyUI", | ||||||
|  |       "images": [ | ||||||
|  |         "8", | ||||||
|  |         0 | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     "class_type": "SaveImage", | ||||||
|  |     "_meta": { | ||||||
|  |       "title": "Save Image" | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def queue_prompt(prompt, client_id, base_url): | ||||||
|  |     print("queue_prompt") | ||||||
|  |     p = {"prompt": prompt, "client_id": client_id} | ||||||
|  |     data = json.dumps(p).encode("utf-8") | ||||||
|  |     req = urllib.request.Request(f"{base_url}/prompt", data=data) | ||||||
|  |     return json.loads(urllib.request.urlopen(req).read()) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_image(filename, subfolder, folder_type, base_url): | ||||||
|  |     print("get_image") | ||||||
|  |     data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||||||
|  |     url_values = urllib.parse.urlencode(data) | ||||||
|  |     with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response: | ||||||
|  |         return response.read() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_image_url(filename, subfolder, folder_type, base_url): | ||||||
|  |     print("get_image") | ||||||
|  |     data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||||||
|  |     url_values = urllib.parse.urlencode(data) | ||||||
|  |     return f"{base_url}/view?{url_values}" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_history(prompt_id, base_url): | ||||||
|  |     print("get_history") | ||||||
|  |     with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response: | ||||||
|  |         return json.loads(response.read()) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_images(ws, prompt, client_id, base_url): | ||||||
|  |     prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] | ||||||
|  |     output_images = [] | ||||||
|  |     while True: | ||||||
|  |         out = ws.recv() | ||||||
|  |         if isinstance(out, str): | ||||||
|  |             message = json.loads(out) | ||||||
|  |             if message["type"] == "executing": | ||||||
|  |                 data = message["data"] | ||||||
|  |                 if data["node"] is None and data["prompt_id"] == prompt_id: | ||||||
|  |                     break  # Execution is done | ||||||
|  |         else: | ||||||
|  |             continue  # previews are binary data | ||||||
|  | 
 | ||||||
|  |     history = get_history(prompt_id, base_url)[prompt_id] | ||||||
|  |     for o in history["outputs"]: | ||||||
|  |         for node_id in history["outputs"]: | ||||||
|  |             node_output = history["outputs"][node_id] | ||||||
|  |             if "images" in node_output: | ||||||
|  |                 for image in node_output["images"]: | ||||||
|  |                     url = get_image_url( | ||||||
|  |                         image["filename"], image["subfolder"], image["type"], base_url | ||||||
|  |                     ) | ||||||
|  |                     output_images.append({"url": url}) | ||||||
|  |     return {"data": output_images} | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ImageGenerationPayload(BaseModel): | ||||||
|  |     prompt: str | ||||||
|  |     negative_prompt: Optional[str] = "" | ||||||
|  |     steps: Optional[int] = None | ||||||
|  |     seed: Optional[int] = None | ||||||
|  |     width: int | ||||||
|  |     height: int | ||||||
|  |     n: int = 1 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def comfyui_generate_image( | ||||||
|  |     model: str, payload: ImageGenerationPayload, client_id, base_url | ||||||
|  | ): | ||||||
|  |     host = base_url.replace("http://", "").replace("https://", "") | ||||||
|  | 
 | ||||||
|  |     comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) | ||||||
|  | 
 | ||||||
|  |     comfyui_prompt["4"]["inputs"]["ckpt_name"] = model | ||||||
|  |     comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n | ||||||
|  |     comfyui_prompt["5"]["inputs"]["width"] = payload.width | ||||||
|  |     comfyui_prompt["5"]["inputs"]["height"] = payload.height | ||||||
|  | 
 | ||||||
|  |     # set the text prompt for our positive CLIPTextEncode | ||||||
|  |     comfyui_prompt["6"]["inputs"]["text"] = payload.prompt | ||||||
|  |     comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt | ||||||
|  | 
 | ||||||
|  |     if payload.steps: | ||||||
|  |         comfyui_prompt["3"]["inputs"]["steps"] = payload.steps | ||||||
|  | 
 | ||||||
|  |     comfyui_prompt["3"]["inputs"]["seed"] = ( | ||||||
|  |         payload.seed if payload.seed else random.randint(0, 18446744073709551614) | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         ws = websocket.WebSocket() | ||||||
|  |         ws.connect(f"ws://{host}/ws?clientId={client_id}") | ||||||
|  |         print("WebSocket connection established.") | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"Failed to connect to WebSocket server: {e}") | ||||||
|  |         return None | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         images = get_images(ws, comfyui_prompt, client_id, base_url) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"Error while receiving images: {e}") | ||||||
|  |         images = None | ||||||
|  | 
 | ||||||
|  |     ws.close() | ||||||
|  | 
 | ||||||
|  |     return images | ||||||
|  | @ -323,6 +323,7 @@ | ||||||
| 							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | 							class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||||
| 							bind:value={selectedModel} | 							bind:value={selectedModel} | ||||||
| 							placeholder={$i18n.t('Select a model')} | 							placeholder={$i18n.t('Select a model')} | ||||||
|  | 							required | ||||||
| 						> | 						> | ||||||
| 							{#if !selectedModel} | 							{#if !selectedModel} | ||||||
| 								<option value="" disabled selected>{$i18n.t('Select a model')}</option> | 								<option value="" disabled selected>{$i18n.t('Select a model')}</option> | ||||||
|  |  | ||||||
|  | @ -2,6 +2,22 @@ | ||||||
| 	export let show = false; | 	export let show = false; | ||||||
| 	export let src = ''; | 	export let src = ''; | ||||||
| 	export let alt = ''; | 	export let alt = ''; | ||||||
|  | 
 | ||||||
|  | 	const downloadImage = (url, filename) => { | ||||||
|  | 		fetch(url) | ||||||
|  | 			.then((response) => response.blob()) | ||||||
|  | 			.then((blob) => { | ||||||
|  | 				const objectUrl = window.URL.createObjectURL(blob); | ||||||
|  | 				const link = document.createElement('a'); | ||||||
|  | 				link.href = objectUrl; | ||||||
|  | 				link.download = filename; | ||||||
|  | 				document.body.appendChild(link); | ||||||
|  | 				link.click(); | ||||||
|  | 				document.body.removeChild(link); | ||||||
|  | 				window.URL.revokeObjectURL(objectUrl); | ||||||
|  | 			}) | ||||||
|  | 			.catch((error) => console.error('Error downloading image:', error)); | ||||||
|  | 	}; | ||||||
| </script> | </script> | ||||||
| 
 | 
 | ||||||
| {#if show} | {#if show} | ||||||
|  | @ -35,10 +51,7 @@ | ||||||
| 				<button | 				<button | ||||||
| 					class=" p-5" | 					class=" p-5" | ||||||
| 					on:click={() => { | 					on:click={() => { | ||||||
| 						const a = document.createElement('a'); | 						downloadImage(src, 'Image.png'); | ||||||
| 						a.href = src; |  | ||||||
| 						a.download = 'Image.png'; |  | ||||||
| 						a.click(); |  | ||||||
| 					}} | 					}} | ||||||
| 				> | 				> | ||||||
| 					<svg | 					<svg | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue