feat: comfyui integration

This commit is contained in:
Timothy J. Baek 2024-03-23 17:01:13 -07:00
parent 862c96fcef
commit 98624a406f
4 changed files with 315 additions and 8 deletions

View file

@ -18,6 +18,8 @@ from utils.utils import (
get_current_user,
get_admin_user,
)
from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
@ -105,7 +107,12 @@ async def update_engine_url(
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
app.state.COMFYUI_BASE_URL = url
try:
r = requests.head(url)
app.state.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
@ -232,6 +239,8 @@ async def get_default_model(user=Depends(get_admin_user)):
try:
if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
elif app.state.ENGINE == "comfyui":
return {"model": app.state.MODEL if app.state.MODEL else ""}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
@ -246,10 +255,12 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str):
if app.state.ENGINE == "openai":
app.state.MODEL = model
return app.state.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
@ -297,12 +308,31 @@ def save_b64_image(b64_str):
return None
def save_url_image(url):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try:
r = requests.get(url)
r.raise_for_status()
with open(file_path, "wb") as image_file:
image_file.write(r.content)
return image_id
except Exception as e:
print(f"Error saving image: {e}")
return None
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
r = None
try:
if app.state.ENGINE == "openai":
@ -340,12 +370,47 @@ def generate_image(
return images
elif app.state.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
app.state.MODEL,
data,
user.id,
app.state.COMFYUI_BASE_URL,
)
print(res)
images = []
for image in res["data"]:
image_id = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
json.dump(data.model_dump(exclude_none=True), f)
print(images)
return images
else:
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,

View file

@ -0,0 +1,228 @@
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import random
from pydantic import BaseModel
from typing import Optional
COMFYUI_DEFAULT_PROMPT = """
{
"3": {
"inputs": {
"seed": 0,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "model.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "Negative Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
"""
def queue_prompt(prompt, client_id, base_url):
print("queue_prompt")
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(f"{base_url}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type, base_url):
print("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
return response.read()
def get_image_url(filename, subfolder, folder_type, base_url):
print("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
return f"{base_url}/view?{url_values}"
def get_history(prompt_id, base_url):
print("get_history")
with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
return json.loads(response.read())
def get_images(ws, prompt, client_id, base_url):
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
output_images = []
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = get_history(prompt_id, base_url)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
for image in node_output["images"]:
url = get_image_url(
image["filename"], image["subfolder"], image["type"], base_url
)
output_images.append({"url": url})
return {"data": output_images}
class ImageGenerationPayload(BaseModel):
prompt: str
negative_prompt: Optional[str] = ""
steps: Optional[int] = None
seed: Optional[int] = None
width: int
height: int
n: int = 1
def comfyui_generate_image(
model: str, payload: ImageGenerationPayload, client_id, base_url
):
host = base_url.replace("http://", "").replace("https://", "")
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
comfyui_prompt["5"]["inputs"]["width"] = payload.width
comfyui_prompt["5"]["inputs"]["height"] = payload.height
# set the text prompt for our positive CLIPTextEncode
comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
if payload.steps:
comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
comfyui_prompt["3"]["inputs"]["seed"] = (
payload.seed if payload.seed else random.randint(0, 18446744073709551614)
)
try:
ws = websocket.WebSocket()
ws.connect(f"ws://{host}/ws?clientId={client_id}")
print("WebSocket connection established.")
except Exception as e:
print(f"Failed to connect to WebSocket server: {e}")
return None
try:
images = get_images(ws, comfyui_prompt, client_id, base_url)
except Exception as e:
print(f"Error while receiving images: {e}")
images = None
ws.close()
return images

View file

@ -323,6 +323,7 @@
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedModel}
placeholder={$i18n.t('Select a model')}
required
>
{#if !selectedModel}
<option value="" disabled selected>{$i18n.t('Select a model')}</option>

View file

@ -2,6 +2,22 @@
export let show = false;
export let src = '';
export let alt = '';
const downloadImage = (url, filename) => {
fetch(url)
.then((response) => response.blob())
.then((blob) => {
const objectUrl = window.URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = objectUrl;
link.download = filename;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
window.URL.revokeObjectURL(objectUrl);
})
.catch((error) => console.error('Error downloading image:', error));
};
</script>
{#if show}
@ -35,10 +51,7 @@
<button
class=" p-5"
on:click={() => {
const a = document.createElement('a');
a.href = src;
a.download = 'Image.png';
a.click();
downloadImage(src, 'Image.png');
}}
>
<svg