From dd3a4b38895b2ac807c9f871dc0a64903d82941e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 16:22:19 -0800 Subject: [PATCH 1/3] refac: image generation --- src/lib/apis/images/index.ts | 26 ++- .../components/chat/Settings/Images.svelte | 180 +++++++++++------- 2 files changed, 125 insertions(+), 81 deletions(-) diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index f05ce0b7..220fb7d9 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -1,9 +1,9 @@ import { IMAGES_API_BASE_URL } from '$lib/constants'; -export const getImageGenerationEnabledStatus = async (token: string = '') => { +export const getImageGenerationConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -32,16 +32,24 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => { return res; }; -export const toggleImageGenerationEnabledStatus = async (token: string = '') => { +export const updateImageGenerationConfig = async ( + token: string = '', + engine: string, + enabled: boolean +) => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, { - method: 'GET', + const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, { + method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', ...(token && { authorization: `Bearer ${token}` }) - } + }, + body: JSON.stringify({ + engine, + enabled + }) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -263,7 +271,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => { return res.IMAGE_STEPS; }; -export const getDiffusionModels = async (token: string = '') => { +export const getImageGenerationModels = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models`, { @@ -295,7 +303,7 @@ export const getDiffusionModels = async (token: string = '') => { return res; }; -export const getDefaultDiffusionModel = async (token: string = '') => { +export const getDefaultImageGenerationModel = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, { @@ -327,7 +335,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => { return res.model; }; -export const updateDefaultDiffusionModel = async (token: string = '', model: string) => { +export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, { diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index b36eae12..cced3c43 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -5,13 +5,13 @@ import { config, user } from '$lib/stores'; import { getAUTOMATIC1111Url, - getDefaultDiffusionModel, - getDiffusionModels, - getImageGenerationEnabledStatus, + getImageGenerationModels, + getDefaultImageGenerationModel, + updateDefaultImageGenerationModel, getImageSize, - toggleImageGenerationEnabledStatus, + getImageGenerationConfig, + updateImageGenerationConfig, updateAUTOMATIC1111Url, - updateDefaultDiffusionModel, updateImageSize, getImageSteps, updateImageSteps @@ -23,7 +23,9 @@ let loading = false; + let imageGenerationEngine = ''; let enableImageGeneration = false; + let AUTOMATIC1111_BASE_URL = ''; let selectedModel = ''; @@ -33,11 +35,11 @@ let steps = 50; const getModels = async () => { - models = await getDiffusionModels(localStorage.token).catch((error) => { + models = await getImageGenerationModels(localStorage.token).catch((error) => { toast.error(error); return null; }); - selectedModel = await getDefaultDiffusionModel(localStorage.token).catch((error) => { + selectedModel = await getDefaultImageGenerationModel(localStorage.token).catch((error) => { return ''; }); }; @@ -62,33 +64,44 @@ AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); } }; - const toggleImageGeneration = async () => { - if (AUTOMATIC1111_BASE_URL) { - enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token).catch( - (error) => { - toast.error(error); - return false; - } - ); + const updateImageGeneration = async () => { + const res = await updateImageGenerationConfig( + localStorage.token, + imageGenerationEngine, + enableImageGeneration + ).catch((error) => { + toast.error(error); + return null; + }); - if (enableImageGeneration) { - config.set(await getBackendConfig(localStorage.token)); - getModels(); - } - } else { - enableImageGeneration = false; - toast.error('AUTOMATIC1111_BASE_URL not provided'); + if (res) { + imageGenerationEngine = res.engine; + enableImageGeneration = res.enabled; + } + + if (enableImageGeneration) { + config.set(await getBackendConfig(localStorage.token)); + getModels(); } }; onMount(async () => { if ($user.role === 'admin') { - enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token); + const res = await getImageGenerationConfig(localStorage.token).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + imageGenerationEngine = res.engine; + enableImageGeneration = res.enabled; + } AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); - if (enableImageGeneration && AUTOMATIC1111_BASE_URL) { - imageSize = await getImageSize(localStorage.token); - steps = await getImageSteps(localStorage.token); + imageSize = await getImageSize(localStorage.token); + steps = await getImageSteps(localStorage.token); + + if (enableImageGeneration) { getModels(); } } @@ -99,7 +112,7 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={async () => { loading = true; - await updateDefaultDiffusionModel(localStorage.token, selectedModel); + await updateDefaultImageGenerationModel(localStorage.token, selectedModel); await updateImageSize(localStorage.token, imageSize).catch((error) => { toast.error(error); return null; @@ -117,6 +130,23 @@
Image Settings
+
+
Image Generation Engine
+
+ +
+
+
Image Generation (Experimental)
@@ -124,7 +154,12 @@
-
-
AUTOMATIC1111 Base URL
-
-
- -
-
+
+
+ - - -
+
+ + + +
+ Include `--api` flag when running stable-diffusion-webui + + (e.g. `sh webui.sh --api`) + +
+ {/if} {#if enableImageGeneration}
@@ -199,9 +237,7 @@ {/if} {#each models ?? [] as model} - + {/each} From 0221acd163e99fafab17313f693cd80679387c96 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 17:38:10 -0800 Subject: [PATCH 2/3] feat: dall-e integration --- backend/apps/images/main.py | 237 ++++++++++++++---- backend/main.py | 1 + src/lib/apis/images/index.ts | 67 +++++ .../chat/Messages/ResponseMessage.svelte | 8 +- .../components/chat/Settings/Images.svelte | 38 ++- src/lib/components/common/Image.svelte | 9 +- 6 files changed, 296 insertions(+), 64 deletions(-) diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index dfa1f187..87ecc292 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -21,7 +21,16 @@ from utils.utils import ( from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel -from config import AUTOMATIC1111_BASE_URL +from pathlib import Path +import uuid +import base64 +import json + +from config import CACHE_DIR, AUTOMATIC1111_BASE_URL + + +IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") +IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI() app.add_middleware( @@ -32,25 +41,34 @@ app.add_middleware( allow_headers=["*"], ) +app.state.ENGINE = "" +app.state.ENABLED = False + +app.state.OPENAI_API_KEY = "" +app.state.MODEL = "" + + app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != "" + app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_STEPS = 50 -@app.get("/enabled", response_model=bool) -async def get_enable_status(request: Request, user=Depends(get_admin_user)): - return app.state.ENABLED +@app.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} -@app.get("/enabled/toggle", response_model=bool) -async def toggle_enabled(request: Request, user=Depends(get_admin_user)): - try: - r = requests.head(app.state.AUTOMATIC1111_BASE_URL) - app.state.ENABLED = not app.state.ENABLED - return app.state.ENABLED - except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) +class ConfigUpdateForm(BaseModel): + engine: str + enabled: bool + + +@app.post("/config/update") +async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): + app.state.ENGINE = form_data.engine + app.state.ENABLED = form_data.enabled + return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} class UrlUpdateForm(BaseModel): @@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel): @app.get("/url") -async def get_openai_url(user=Depends(get_admin_user)): +async def get_automatic1111_url(user=Depends(get_admin_user)): return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} @app.post("/url/update") -async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): +async def update_automatic1111_url( + form_data: UrlUpdateForm, user=Depends(get_admin_user) +): if form_data.url == "": app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: - app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/") + url = form_data.url.strip("/") + try: + r = requests.head(url) + app.state.AUTOMATIC1111_BASE_URL = url + except Exception as e: + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, @@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use } +class OpenAIKeyUpdateForm(BaseModel): + key: str + + +@app.get("/key") +async def get_openai_key(user=Depends(get_admin_user)): + return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} + + +@app.post("/key/update") +async def update_openai_key( + form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user) +): + + if form_data.key == "": + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + app.state.OPENAI_API_KEY = form_data.key + return { + "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "status": True, + } + + class ImageSizeUpdateForm(BaseModel): size: str @@ -132,9 +181,22 @@ async def update_image_size( @app.get("/models") def get_models(user=Depends(get_current_user)): try: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models") - models = r.json() - return models + if app.state.ENGINE == "openai": + return [ + {"id": "dall-e-2", "name": "DALL·E 2"}, + {"id": "dall-e-3", "name": "DALL·E 3"}, + ] + else: + r = requests.get( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" + ) + models = r.json() + return list( + map( + lambda model: {"id": model["title"], "name": model["model_name"]}, + models, + ) + ) except Exception as e: app.state.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)): @app.get("/models/default") async def get_default_model(user=Depends(get_admin_user)): try: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") - options = r.json() - - return {"model": options["sd_model_checkpoint"]} + if app.state.ENGINE == "openai": + return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} + else: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() + return {"model": options["sd_model_checkpoint"]} except Exception as e: app.state.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") - options = r.json() - if model != options["sd_model_checkpoint"]: - options["sd_model_checkpoint"] = model - r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options - ) + if app.state.ENGINE == "openai": + app.state.MODEL = model + return app.state.MODEL + else: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() - return options + if model != options["sd_model_checkpoint"]: + options["sd_model_checkpoint"] = model + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options + ) + + return options @app.post("/models/default/update") @@ -185,6 +254,24 @@ class GenerateImageForm(BaseModel): negative_prompt: Optional[str] = None +def save_b64_image(b64_str): + image_id = str(uuid.uuid4()) + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") + + try: + # Split the base64 string to get the actual image data + img_data = base64.b64decode(b64_str) + + # Write the image data to a file + with open(file_path, "wb") as f: + f.write(img_data) + + return image_id + except Exception as e: + print(f"Error saving image: {e}") + return None + + @app.post("/generations") def generate_image( form_data: GenerateImageForm, @@ -194,32 +281,82 @@ def generate_image( print(form_data) try: - if form_data.model: - set_model_handler(form_data.model) + if app.state.ENGINE == "openai": - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + headers = {} + headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" - data = { - "prompt": form_data.prompt, - "batch_size": form_data.n, - "width": width, - "height": height, - } + data = { + "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", + "prompt": form_data.prompt, + "n": form_data.n, + "size": form_data.size, + "response_format": "b64_json", + } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + r = requests.post( + url=f"https://api.openai.com/v1/images/generations", + json=data, + headers=headers, + ) - if form_data.negative_prompt != None: - data["negative_prompt"] = form_data.negative_prompt + r.raise_for_status() - print(data) + res = r.json() - r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", - json=data, - ) + images = [] + + for image in res["data"]: + image_id = save_b64_image(image["b64_json"]) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump(data, f) + + return images + + else: + if form_data.model: + set_model_handler(form_data.model) + + width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + + data = { + "prompt": form_data.prompt, + "batch_size": form_data.n, + "width": width, + "height": height, + } + + if app.state.IMAGE_STEPS != None: + data["steps"] = app.state.IMAGE_STEPS + + if form_data.negative_prompt != None: + data["negative_prompt"] = form_data.negative_prompt + + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + json=data, + ) + + res = r.json() + + print(res) + + images = [] + + for image in res["images"]: + image_id = save_b64_image(image) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump({**data, "info": res["info"]}, f) + + return images - return r.json() except Exception as e: print(e) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) diff --git a/backend/main.py b/backend/main.py index 9e04ee48..afa974ca 100644 --- a/backend/main.py +++ b/backend/main.py @@ -121,6 +121,7 @@ async def get_app_latest_release_version(): app.mount("/static", StaticFiles(directory="static"), name="static") +app.mount("/cache", StaticFiles(directory="data/cache"), name="cache") app.mount( diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index 220fb7d9..1fb004a3 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -72,6 +72,73 @@ export const updateImageGenerationConfig = async ( return res; }; +export const getOpenAIKey = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/key`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; +}; + +export const updateOpenAIKey = async (token: string = '', key: string) => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + key: key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; +}; + export const getAUTOMATIC1111Url = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 3d94609b..3f6c7739 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -277,13 +277,15 @@ const generateImage = async (message) => { generatingImage = true; - const res = await imageGenerations(localStorage.token, message.content); + const res = await imageGenerations(localStorage.token, message.content).catch((error) => { + toast.error(error); + }); console.log(res); if (res) { - message.files = res.images.map((image) => ({ + message.files = res.map((image) => ({ type: 'image', - url: `data:image/png;base64,${image}` + url: `${image.url}` })); dispatch('save', message); diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index cced3c43..44430f11 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -14,7 +14,9 @@ updateAUTOMATIC1111Url, updateImageSize, getImageSteps, - updateImageSteps + updateImageSteps, + getOpenAIKey, + updateOpenAIKey } from '$lib/apis/images'; import { getBackendConfig } from '$lib/apis'; const dispatch = createEventDispatcher(); @@ -27,6 +29,7 @@ let enableImageGeneration = false; let AUTOMATIC1111_BASE_URL = ''; + let OPENAI_API_KEY = ''; let selectedModel = ''; let models = null; @@ -97,6 +100,7 @@ enableImageGeneration = res.enabled; } AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); + OPENAI_API_KEY = await getOpenAIKey(localStorage.token); imageSize = await getImageSize(localStorage.token); steps = await getImageSteps(localStorage.token); @@ -112,6 +116,10 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={async () => { loading = true; + await updateOpenAIKey(localStorage.token, OPENAI_API_KEY); + + await updateDefaultImageGenerationModel(localStorage.token, selectedModel); + await updateDefaultImageGenerationModel(localStorage.token, selectedModel); await updateImageSize(localStorage.token, imageSize).catch((error) => { toast.error(error); @@ -156,10 +164,12 @@ on:click={() => { if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') { toast.error('AUTOMATIC1111 Base URL is required.'); + enableImageGeneration = false; } else { enableImageGeneration = !enableImageGeneration; - updateImageGeneration(); } + + updateImageGeneration(); }} type="button" > @@ -172,21 +182,20 @@ +
{#if imageGenerationEngine === ''} -
-
AUTOMATIC1111 Base URL
+ {:else if imageGenerationEngine === 'openai'} +
OpenAI API Key
+
+
+ +
+
{/if} {#if enableImageGeneration} @@ -229,7 +249,7 @@
@@ -262,7 +282,7 @@
diff --git a/src/lib/components/common/Image.svelte b/src/lib/components/common/Image.svelte index 566ebb5b..e69f0e29 100644 --- a/src/lib/components/common/Image.svelte +++ b/src/lib/components/common/Image.svelte @@ -1,18 +1,23 @@ - + From fe7610d380a38f5bc4d660b6ec6ec05d05c04bb9 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 17:40:30 -0800 Subject: [PATCH 3/3] fix: disable dall-e image generation w/o key --- src/lib/components/chat/Settings/Images.svelte | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index 44430f11..38d365a4 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -165,6 +165,9 @@ if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') { toast.error('AUTOMATIC1111 Base URL is required.'); enableImageGeneration = false; + } else if (imageGenerationEngine === 'openai' && OPENAI_API_KEY === '') { + toast.error('OpenAI API Key is required.'); + enableImageGeneration = false; } else { enableImageGeneration = !enableImageGeneration; }