diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index e14b0f6a..b829b049 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -18,6 +18,8 @@ from utils.utils import ( get_current_user, get_admin_user, ) + +from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel @@ -26,7 +28,7 @@ import uuid import base64 import json -from config import CACHE_DIR, AUTOMATIC1111_BASE_URL +from config import CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") @@ -49,6 +51,8 @@ app.state.MODEL = "" app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL + app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_STEPS = 50 @@ -71,32 +75,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} -class UrlUpdateForm(BaseModel): - url: str +class EngineUrlUpdateForm(BaseModel): + AUTOMATIC1111_BASE_URL: Optional[str] = None + COMFYUI_BASE_URL: Optional[str] = None @app.get("/url") -async def get_automatic1111_url(user=Depends(get_admin_user)): - return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} +async def get_engine_url(user=Depends(get_admin_user)): + return { + "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, + } @app.post("/url/update") -async def update_automatic1111_url( - form_data: UrlUpdateForm, user=Depends(get_admin_user) +async def update_engine_url( + form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) ): - if form_data.url == "": + if form_data.AUTOMATIC1111_BASE_URL == None: app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: - url = form_data.url.strip("/") + url = form_data.AUTOMATIC1111_BASE_URL.strip("/") try: r = requests.head(url) app.state.AUTOMATIC1111_BASE_URL = url except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) + if form_data.COMFYUI_BASE_URL == None: + app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL + else: + url = form_data.COMFYUI_BASE_URL.strip("/") + + try: + r = requests.head(url) + app.state.COMFYUI_BASE_URL = url + except Exception as e: + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) + return { "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, + "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, "status": True, } @@ -186,6 +206,18 @@ def get_models(user=Depends(get_current_user)): {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, ] + elif app.state.ENGINE == "comfyui": + + r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") + info = r.json() + + return list( + map( + lambda model: {"id": model, "name": model}, + info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0], + ) + ) + else: r = requests.get( url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" @@ -207,6 +239,8 @@ async def get_default_model(user=Depends(get_admin_user)): try: if app.state.ENGINE == "openai": return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} + elif app.state.ENGINE == "comfyui": + return {"model": app.state.MODEL if app.state.MODEL else ""} else: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() @@ -221,10 +255,12 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - if app.state.ENGINE == "openai": app.state.MODEL = model return app.state.MODEL + if app.state.ENGINE == "comfyui": + app.state.MODEL = model + return app.state.MODEL else: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() @@ -272,12 +308,31 @@ def save_b64_image(b64_str): return None +def save_url_image(url): + image_id = str(uuid.uuid4()) + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") + + try: + r = requests.get(url) + r.raise_for_status() + + with open(file_path, "wb") as image_file: + image_file.write(r.content) + + return image_id + except Exception as e: + print(f"Error saving image: {e}") + return None + + @app.post("/generations") def generate_image( form_data: GenerateImageForm, user=Depends(get_current_user), ): + width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + r = None try: if app.state.ENGINE == "openai": @@ -315,12 +370,47 @@ def generate_image( return images + elif app.state.ENGINE == "comfyui": + + data = { + "prompt": form_data.prompt, + "width": width, + "height": height, + "n": form_data.n, + } + + if app.state.IMAGE_STEPS != None: + data["steps"] = app.state.IMAGE_STEPS + + if form_data.negative_prompt != None: + data["negative_prompt"] = form_data.negative_prompt + + data = ImageGenerationPayload(**data) + + res = comfyui_generate_image( + app.state.MODEL, + data, + user.id, + app.state.COMFYUI_BASE_URL, + ) + print(res) + + images = [] + + for image in res["data"]: + image_id = save_url_image(image["url"]) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump(data.model_dump(exclude_none=True), f) + + print(images) + return images else: if form_data.model: set_model_handler(form_data.model) - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) - data = { "prompt": form_data.prompt, "batch_size": form_data.n, diff --git a/backend/apps/images/utils/comfyui.py b/backend/apps/images/utils/comfyui.py new file mode 100644 index 00000000..6a9fef35 --- /dev/null +++ b/backend/apps/images/utils/comfyui.py @@ -0,0 +1,228 @@ +import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import json +import urllib.request +import urllib.parse +import random + +from pydantic import BaseModel + +from typing import Optional + +COMFYUI_DEFAULT_PROMPT = """ +{ + "3": { + "inputs": { + "seed": 0, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "model.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "Prompt", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "Negative Prompt", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } +} +""" + + +def queue_prompt(prompt, client_id, base_url): + print("queue_prompt") + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode("utf-8") + req = urllib.request.Request(f"{base_url}/prompt", data=data) + return json.loads(urllib.request.urlopen(req).read()) + + +def get_image(filename, subfolder, folder_type, base_url): + print("get_image") + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response: + return response.read() + + +def get_image_url(filename, subfolder, folder_type, base_url): + print("get_image") + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + return f"{base_url}/view?{url_values}" + + +def get_history(prompt_id, base_url): + print("get_history") + with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response: + return json.loads(response.read()) + + +def get_images(ws, prompt, client_id, base_url): + prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] + output_images = [] + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "executing": + data = message["data"] + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + else: + continue # previews are binary data + + history = get_history(prompt_id, base_url)[prompt_id] + for o in history["outputs"]: + for node_id in history["outputs"]: + node_output = history["outputs"][node_id] + if "images" in node_output: + for image in node_output["images"]: + url = get_image_url( + image["filename"], image["subfolder"], image["type"], base_url + ) + output_images.append({"url": url}) + return {"data": output_images} + + +class ImageGenerationPayload(BaseModel): + prompt: str + negative_prompt: Optional[str] = "" + steps: Optional[int] = None + seed: Optional[int] = None + width: int + height: int + n: int = 1 + + +def comfyui_generate_image( + model: str, payload: ImageGenerationPayload, client_id, base_url +): + host = base_url.replace("http://", "").replace("https://", "") + + comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) + + comfyui_prompt["4"]["inputs"]["ckpt_name"] = model + comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n + comfyui_prompt["5"]["inputs"]["width"] = payload.width + comfyui_prompt["5"]["inputs"]["height"] = payload.height + + # set the text prompt for our positive CLIPTextEncode + comfyui_prompt["6"]["inputs"]["text"] = payload.prompt + comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt + + if payload.steps: + comfyui_prompt["3"]["inputs"]["steps"] = payload.steps + + comfyui_prompt["3"]["inputs"]["seed"] = ( + payload.seed if payload.seed else random.randint(0, 18446744073709551614) + ) + + try: + ws = websocket.WebSocket() + ws.connect(f"ws://{host}/ws?clientId={client_id}") + print("WebSocket connection established.") + except Exception as e: + print(f"Failed to connect to WebSocket server: {e}") + return None + + try: + images = get_images(ws, comfyui_prompt, client_id, base_url) + except Exception as e: + print(f"Error while receiving images: {e}") + images = None + + ws.close() + + return images diff --git a/backend/config.py b/backend/config.py index 9236e8a8..67edd3f4 100644 --- a/backend/config.py +++ b/backend/config.py @@ -376,3 +376,4 @@ WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models" #################################### AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") +COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "") diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index 1fb004a3..aadfafd1 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -139,7 +139,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => { return res.OPENAI_API_KEY; }; -export const getAUTOMATIC1111Url = async (token: string = '') => { +export const getImageGenerationEngineUrls = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/url`, { @@ -168,10 +168,10 @@ export const getAUTOMATIC1111Url = async (token: string = '') => { throw error; } - return res.AUTOMATIC1111_BASE_URL; + return res; }; -export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => { +export const updateImageGenerationEngineUrls = async (token: string = '', urls: object = {}) => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, { @@ -182,7 +182,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - url: url + ...urls }) }) .then(async (res) => { @@ -203,7 +203,7 @@ export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => throw error; } - return res.AUTOMATIC1111_BASE_URL; + return res; }; export const getImageSize = async (token: string = '') => { diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index 5ba046f1..7282c184 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -4,14 +4,14 @@ import { createEventDispatcher, onMount, getContext } from 'svelte'; import { config, user } from '$lib/stores'; import { - getAUTOMATIC1111Url, getImageGenerationModels, getDefaultImageGenerationModel, updateDefaultImageGenerationModel, getImageSize, getImageGenerationConfig, updateImageGenerationConfig, - updateAUTOMATIC1111Url, + getImageGenerationEngineUrls, + updateImageGenerationEngineUrls, updateImageSize, getImageSteps, updateImageSteps, @@ -31,6 +31,8 @@ let enableImageGeneration = false; let AUTOMATIC1111_BASE_URL = ''; + let COMFYUI_BASE_URL = ''; + let OPENAI_API_KEY = ''; let selectedModel = ''; @@ -49,24 +51,47 @@ }); }; - const updateAUTOMATIC1111UrlHandler = async () => { - const res = await updateAUTOMATIC1111Url(localStorage.token, AUTOMATIC1111_BASE_URL).catch( - (error) => { + const updateUrlHandler = async () => { + if (imageGenerationEngine === 'comfyui') { + const res = await updateImageGenerationEngineUrls(localStorage.token, { + COMFYUI_BASE_URL: COMFYUI_BASE_URL + }).catch((error) => { toast.error(error); + + console.log(error); return null; - } - ); + }); - if (res) { - AUTOMATIC1111_BASE_URL = res; + if (res) { + COMFYUI_BASE_URL = res.COMFYUI_BASE_URL; - await getModels(); + await getModels(); - if (models) { - toast.success($i18n.t('Server connection verified')); + if (models) { + toast.success($i18n.t('Server connection verified')); + } + } else { + ({ COMFYUI_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token)); } } else { - AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); + const res = await updateImageGenerationEngineUrls(localStorage.token, { + AUTOMATIC1111_BASE_URL: AUTOMATIC1111_BASE_URL + }).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + AUTOMATIC1111_BASE_URL = res.AUTOMATIC1111_BASE_URL; + + await getModels(); + + if (models) { + toast.success($i18n.t('Server connection verified')); + } + } else { + ({ AUTOMATIC1111_BASE_URL } = await getImageGenerationEngineUrls(localStorage.token)); + } } }; const updateImageGeneration = async () => { @@ -101,7 +126,11 @@ imageGenerationEngine = res.engine; enableImageGeneration = res.enabled; } - AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); + const URLS = await getImageGenerationEngineUrls(localStorage.token); + + AUTOMATIC1111_BASE_URL = URLS.AUTOMATIC1111_BASE_URL; + COMFYUI_BASE_URL = URLS.COMFYUI_BASE_URL; + OPENAI_API_KEY = await getOpenAIKey(localStorage.token); imageSize = await getImageSize(localStorage.token); @@ -154,6 +183,7 @@ }} > + @@ -171,6 +201,9 @@ if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') { toast.error($i18n.t('AUTOMATIC1111 Base URL is required.')); enableImageGeneration = false; + } else if (imageGenerationEngine === 'comfyui' && COMFYUI_BASE_URL === '') { + toast.error($i18n.t('ComfyUI Base URL is required.')); + enableImageGeneration = false; } else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') { toast.error($i18n.t('OpenAI API Key is required.')); enableImageGeneration = false; @@ -204,12 +237,10 @@ /> + {:else if imageGenerationEngine === 'openai'}
{$i18n.t('OpenAI API Key')}
@@ -261,6 +323,7 @@ class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" bind:value={selectedModel} placeholder={$i18n.t('Select a model')} + required > {#if !selectedModel} diff --git a/src/lib/components/common/ImagePreview.svelte b/src/lib/components/common/ImagePreview.svelte index cf69327f..badabebd 100644 --- a/src/lib/components/common/ImagePreview.svelte +++ b/src/lib/components/common/ImagePreview.svelte @@ -2,6 +2,22 @@ export let show = false; export let src = ''; export let alt = ''; + + const downloadImage = (url, filename) => { + fetch(url) + .then((response) => response.blob()) + .then((blob) => { + const objectUrl = window.URL.createObjectURL(blob); + const link = document.createElement('a'); + link.href = objectUrl; + link.download = filename; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + window.URL.revokeObjectURL(objectUrl); + }) + .catch((error) => console.error('Error downloading image:', error)); + }; {#if show} @@ -35,10 +51,7 @@