diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index dfa1f187..87ecc292 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -21,7 +21,16 @@ from utils.utils import ( from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel -from config import AUTOMATIC1111_BASE_URL +from pathlib import Path +import uuid +import base64 +import json + +from config import CACHE_DIR, AUTOMATIC1111_BASE_URL + + +IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") +IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI() app.add_middleware( @@ -32,25 +41,34 @@ app.add_middleware( allow_headers=["*"], ) +app.state.ENGINE = "" +app.state.ENABLED = False + +app.state.OPENAI_API_KEY = "" +app.state.MODEL = "" + + app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != "" + app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_STEPS = 50 -@app.get("/enabled", response_model=bool) -async def get_enable_status(request: Request, user=Depends(get_admin_user)): - return app.state.ENABLED +@app.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} -@app.get("/enabled/toggle", response_model=bool) -async def toggle_enabled(request: Request, user=Depends(get_admin_user)): - try: - r = requests.head(app.state.AUTOMATIC1111_BASE_URL) - app.state.ENABLED = not app.state.ENABLED - return app.state.ENABLED - except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) +class ConfigUpdateForm(BaseModel): + engine: str + enabled: bool + + +@app.post("/config/update") +async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): + app.state.ENGINE = form_data.engine + app.state.ENABLED = form_data.enabled + return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} class UrlUpdateForm(BaseModel): @@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel): @app.get("/url") -async def get_openai_url(user=Depends(get_admin_user)): +async def get_automatic1111_url(user=Depends(get_admin_user)): return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} @app.post("/url/update") -async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): +async def update_automatic1111_url( + form_data: UrlUpdateForm, user=Depends(get_admin_user) +): if form_data.url == "": app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: - app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/") + url = form_data.url.strip("/") + try: + r = requests.head(url) + app.state.AUTOMATIC1111_BASE_URL = url + except Exception as e: + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, @@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use } +class OpenAIKeyUpdateForm(BaseModel): + key: str + + +@app.get("/key") +async def get_openai_key(user=Depends(get_admin_user)): + return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} + + +@app.post("/key/update") +async def update_openai_key( + form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user) +): + + if form_data.key == "": + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + app.state.OPENAI_API_KEY = form_data.key + return { + "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "status": True, + } + + class ImageSizeUpdateForm(BaseModel): size: str @@ -132,9 +181,22 @@ async def update_image_size( @app.get("/models") def get_models(user=Depends(get_current_user)): try: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models") - models = r.json() - return models + if app.state.ENGINE == "openai": + return [ + {"id": "dall-e-2", "name": "DALL·E 2"}, + {"id": "dall-e-3", "name": "DALL·E 3"}, + ] + else: + r = requests.get( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" + ) + models = r.json() + return list( + map( + lambda model: {"id": model["title"], "name": model["model_name"]}, + models, + ) + ) except Exception as e: app.state.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)): @app.get("/models/default") async def get_default_model(user=Depends(get_admin_user)): try: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") - options = r.json() - - return {"model": options["sd_model_checkpoint"]} + if app.state.ENGINE == "openai": + return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} + else: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() + return {"model": options["sd_model_checkpoint"]} except Exception as e: app.state.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") - options = r.json() - if model != options["sd_model_checkpoint"]: - options["sd_model_checkpoint"] = model - r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options - ) + if app.state.ENGINE == "openai": + app.state.MODEL = model + return app.state.MODEL + else: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() - return options + if model != options["sd_model_checkpoint"]: + options["sd_model_checkpoint"] = model + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options + ) + + return options @app.post("/models/default/update") @@ -185,6 +254,24 @@ class GenerateImageForm(BaseModel): negative_prompt: Optional[str] = None +def save_b64_image(b64_str): + image_id = str(uuid.uuid4()) + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") + + try: + # Split the base64 string to get the actual image data + img_data = base64.b64decode(b64_str) + + # Write the image data to a file + with open(file_path, "wb") as f: + f.write(img_data) + + return image_id + except Exception as e: + print(f"Error saving image: {e}") + return None + + @app.post("/generations") def generate_image( form_data: GenerateImageForm, @@ -194,32 +281,82 @@ def generate_image( print(form_data) try: - if form_data.model: - set_model_handler(form_data.model) + if app.state.ENGINE == "openai": - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + headers = {} + headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" - data = { - "prompt": form_data.prompt, - "batch_size": form_data.n, - "width": width, - "height": height, - } + data = { + "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", + "prompt": form_data.prompt, + "n": form_data.n, + "size": form_data.size, + "response_format": "b64_json", + } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + r = requests.post( + url=f"https://api.openai.com/v1/images/generations", + json=data, + headers=headers, + ) - if form_data.negative_prompt != None: - data["negative_prompt"] = form_data.negative_prompt + r.raise_for_status() - print(data) + res = r.json() - r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", - json=data, - ) + images = [] + + for image in res["data"]: + image_id = save_b64_image(image["b64_json"]) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump(data, f) + + return images + + else: + if form_data.model: + set_model_handler(form_data.model) + + width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + + data = { + "prompt": form_data.prompt, + "batch_size": form_data.n, + "width": width, + "height": height, + } + + if app.state.IMAGE_STEPS != None: + data["steps"] = app.state.IMAGE_STEPS + + if form_data.negative_prompt != None: + data["negative_prompt"] = form_data.negative_prompt + + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + json=data, + ) + + res = r.json() + + print(res) + + images = [] + + for image in res["images"]: + image_id = save_b64_image(image) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump({**data, "info": res["info"]}, f) + + return images - return r.json() except Exception as e: print(e) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) diff --git a/backend/main.py b/backend/main.py index 9e04ee48..afa974ca 100644 --- a/backend/main.py +++ b/backend/main.py @@ -121,6 +121,7 @@ async def get_app_latest_release_version(): app.mount("/static", StaticFiles(directory="static"), name="static") +app.mount("/cache", StaticFiles(directory="data/cache"), name="cache") app.mount( diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index 220fb7d9..1fb004a3 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -72,6 +72,73 @@ export const updateImageGenerationConfig = async ( return res; }; +export const getOpenAIKey = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/key`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; +}; + +export const updateOpenAIKey = async (token: string = '', key: string) => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + key: key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; +}; + export const getAUTOMATIC1111Url = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 3d94609b..3f6c7739 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -277,13 +277,15 @@ const generateImage = async (message) => { generatingImage = true; - const res = await imageGenerations(localStorage.token, message.content); + const res = await imageGenerations(localStorage.token, message.content).catch((error) => { + toast.error(error); + }); console.log(res); if (res) { - message.files = res.images.map((image) => ({ + message.files = res.map((image) => ({ type: 'image', - url: `data:image/png;base64,${image}` + url: `${image.url}` })); dispatch('save', message); diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index cced3c43..44430f11 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -14,7 +14,9 @@ updateAUTOMATIC1111Url, updateImageSize, getImageSteps, - updateImageSteps + updateImageSteps, + getOpenAIKey, + updateOpenAIKey } from '$lib/apis/images'; import { getBackendConfig } from '$lib/apis'; const dispatch = createEventDispatcher(); @@ -27,6 +29,7 @@ let enableImageGeneration = false; let AUTOMATIC1111_BASE_URL = ''; + let OPENAI_API_KEY = ''; let selectedModel = ''; let models = null; @@ -97,6 +100,7 @@ enableImageGeneration = res.enabled; } AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); + OPENAI_API_KEY = await getOpenAIKey(localStorage.token); imageSize = await getImageSize(localStorage.token); steps = await getImageSteps(localStorage.token); @@ -112,6 +116,10 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={async () => { loading = true; + await updateOpenAIKey(localStorage.token, OPENAI_API_KEY); + + await updateDefaultImageGenerationModel(localStorage.token, selectedModel); + await updateDefaultImageGenerationModel(localStorage.token, selectedModel); await updateImageSize(localStorage.token, imageSize).catch((error) => { toast.error(error); @@ -156,10 +164,12 @@ on:click={() => { if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') { toast.error('AUTOMATIC1111 Base URL is required.'); + enableImageGeneration = false; } else { enableImageGeneration = !enableImageGeneration; - updateImageGeneration(); } + + updateImageGeneration(); }} type="button" > @@ -172,21 +182,20 @@ +
{#if imageGenerationEngine === ''} -
-
AUTOMATIC1111 Base URL
+ {:else if imageGenerationEngine === 'openai'} +
OpenAI API Key
+
+
+ +
+
{/if} {#if enableImageGeneration} @@ -229,7 +249,7 @@
@@ -262,7 +282,7 @@
diff --git a/src/lib/components/common/Image.svelte b/src/lib/components/common/Image.svelte index 566ebb5b..e69f0e29 100644 --- a/src/lib/components/common/Image.svelte +++ b/src/lib/components/common/Image.svelte @@ -1,18 +1,23 @@ - +