diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index dfa1f187..87ecc292 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -21,7 +21,16 @@ from utils.utils import ( from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel -from config import AUTOMATIC1111_BASE_URL +from pathlib import Path +import uuid +import base64 +import json + +from config import CACHE_DIR, AUTOMATIC1111_BASE_URL + + +IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") +IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI() app.add_middleware( @@ -32,25 +41,34 @@ app.add_middleware( allow_headers=["*"], ) +app.state.ENGINE = "" +app.state.ENABLED = False + +app.state.OPENAI_API_KEY = "" +app.state.MODEL = "" + + app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != "" + app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_STEPS = 50 -@app.get("/enabled", response_model=bool) -async def get_enable_status(request: Request, user=Depends(get_admin_user)): - return app.state.ENABLED +@app.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): + return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} -@app.get("/enabled/toggle", response_model=bool) -async def toggle_enabled(request: Request, user=Depends(get_admin_user)): - try: - r = requests.head(app.state.AUTOMATIC1111_BASE_URL) - app.state.ENABLED = not app.state.ENABLED - return app.state.ENABLED - except Exception as e: - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) +class ConfigUpdateForm(BaseModel): + engine: str + enabled: bool + + +@app.post("/config/update") +async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): + app.state.ENGINE = form_data.engine + app.state.ENABLED = form_data.enabled + return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} class UrlUpdateForm(BaseModel): @@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel): @app.get("/url") -async def get_openai_url(user=Depends(get_admin_user)): +async def get_automatic1111_url(user=Depends(get_admin_user)): return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} @app.post("/url/update") -async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): +async def update_automatic1111_url( + form_data: UrlUpdateForm, user=Depends(get_admin_user) +): if form_data.url == "": app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: - app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/") + url = form_data.url.strip("/") + try: + r = requests.head(url) + app.state.AUTOMATIC1111_BASE_URL = url + except Exception as e: + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) return { "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, @@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use } +class OpenAIKeyUpdateForm(BaseModel): + key: str + + +@app.get("/key") +async def get_openai_key(user=Depends(get_admin_user)): + return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} + + +@app.post("/key/update") +async def update_openai_key( + form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user) +): + + if form_data.key == "": + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + app.state.OPENAI_API_KEY = form_data.key + return { + "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + "status": True, + } + + class ImageSizeUpdateForm(BaseModel): size: str @@ -132,9 +181,22 @@ async def update_image_size( @app.get("/models") def get_models(user=Depends(get_current_user)): try: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models") - models = r.json() - return models + if app.state.ENGINE == "openai": + return [ + {"id": "dall-e-2", "name": "DALL·E 2"}, + {"id": "dall-e-3", "name": "DALL·E 3"}, + ] + else: + r = requests.get( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" + ) + models = r.json() + return list( + map( + lambda model: {"id": model["title"], "name": model["model_name"]}, + models, + ) + ) except Exception as e: app.state.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)): @app.get("/models/default") async def get_default_model(user=Depends(get_admin_user)): try: - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") - options = r.json() - - return {"model": options["sd_model_checkpoint"]} + if app.state.ENGINE == "openai": + return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} + else: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() + return {"model": options["sd_model_checkpoint"]} except Exception as e: app.state.ENABLED = False raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel): def set_model_handler(model: str): - r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") - options = r.json() - if model != options["sd_model_checkpoint"]: - options["sd_model_checkpoint"] = model - r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options - ) + if app.state.ENGINE == "openai": + app.state.MODEL = model + return app.state.MODEL + else: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() - return options + if model != options["sd_model_checkpoint"]: + options["sd_model_checkpoint"] = model + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options + ) + + return options @app.post("/models/default/update") @@ -185,6 +254,24 @@ class GenerateImageForm(BaseModel): negative_prompt: Optional[str] = None +def save_b64_image(b64_str): + image_id = str(uuid.uuid4()) + file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") + + try: + # Split the base64 string to get the actual image data + img_data = base64.b64decode(b64_str) + + # Write the image data to a file + with open(file_path, "wb") as f: + f.write(img_data) + + return image_id + except Exception as e: + print(f"Error saving image: {e}") + return None + + @app.post("/generations") def generate_image( form_data: GenerateImageForm, @@ -194,32 +281,82 @@ def generate_image( print(form_data) try: - if form_data.model: - set_model_handler(form_data.model) + if app.state.ENGINE == "openai": - width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + headers = {} + headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" - data = { - "prompt": form_data.prompt, - "batch_size": form_data.n, - "width": width, - "height": height, - } + data = { + "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", + "prompt": form_data.prompt, + "n": form_data.n, + "size": form_data.size, + "response_format": "b64_json", + } - if app.state.IMAGE_STEPS != None: - data["steps"] = app.state.IMAGE_STEPS + r = requests.post( + url=f"https://api.openai.com/v1/images/generations", + json=data, + headers=headers, + ) - if form_data.negative_prompt != None: - data["negative_prompt"] = form_data.negative_prompt + r.raise_for_status() - print(data) + res = r.json() - r = requests.post( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", - json=data, - ) + images = [] + + for image in res["data"]: + image_id = save_b64_image(image["b64_json"]) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump(data, f) + + return images + + else: + if form_data.model: + set_model_handler(form_data.model) + + width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) + + data = { + "prompt": form_data.prompt, + "batch_size": form_data.n, + "width": width, + "height": height, + } + + if app.state.IMAGE_STEPS != None: + data["steps"] = app.state.IMAGE_STEPS + + if form_data.negative_prompt != None: + data["negative_prompt"] = form_data.negative_prompt + + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + json=data, + ) + + res = r.json() + + print(res) + + images = [] + + for image in res["images"]: + image_id = save_b64_image(image) + images.append({"url": f"/cache/image/generations/{image_id}.png"}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") + + with open(file_body_path, "w") as f: + json.dump({**data, "info": res["info"]}, f) + + return images - return r.json() except Exception as e: print(e) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) diff --git a/backend/main.py b/backend/main.py index 9e04ee48..afa974ca 100644 --- a/backend/main.py +++ b/backend/main.py @@ -121,6 +121,7 @@ async def get_app_latest_release_version(): app.mount("/static", StaticFiles(directory="static"), name="static") +app.mount("/cache", StaticFiles(directory="data/cache"), name="cache") app.mount( diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index f05ce0b7..1fb004a3 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -1,9 +1,9 @@ import { IMAGES_API_BASE_URL } from '$lib/constants'; -export const getImageGenerationEnabledStatus = async (token: string = '') => { +export const getImageGenerationConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -32,10 +32,50 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => { return res; }; -export const toggleImageGenerationEnabledStatus = async (token: string = '') => { +export const updateImageGenerationConfig = async ( + token: string = '', + engine: string, + enabled: boolean +) => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + engine, + enabled + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getOpenAIKey = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/key`, { method: 'GET', headers: { Accept: 'application/json', @@ -61,7 +101,42 @@ export const toggleImageGenerationEnabledStatus = async (token: string = '') => throw error; } - return res; + return res.OPENAI_API_KEY; +}; + +export const updateOpenAIKey = async (token: string = '', key: string) => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + key: key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; }; export const getAUTOMATIC1111Url = async (token: string = '') => { @@ -263,7 +338,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => { return res.IMAGE_STEPS; }; -export const getDiffusionModels = async (token: string = '') => { +export const getImageGenerationModels = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models`, { @@ -295,7 +370,7 @@ export const getDiffusionModels = async (token: string = '') => { return res; }; -export const getDefaultDiffusionModel = async (token: string = '') => { +export const getDefaultImageGenerationModel = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, { @@ -327,7 +402,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => { return res.model; }; -export const updateDefaultDiffusionModel = async (token: string = '', model: string) => { +export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, { diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 3d94609b..3f6c7739 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -277,13 +277,15 @@ const generateImage = async (message) => { generatingImage = true; - const res = await imageGenerations(localStorage.token, message.content); + const res = await imageGenerations(localStorage.token, message.content).catch((error) => { + toast.error(error); + }); console.log(res); if (res) { - message.files = res.images.map((image) => ({ + message.files = res.map((image) => ({ type: 'image', - url: `data:image/png;base64,${image}` + url: `${image.url}` })); dispatch('save', message); diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index b36eae12..38d365a4 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -5,16 +5,18 @@ import { config, user } from '$lib/stores'; import { getAUTOMATIC1111Url, - getDefaultDiffusionModel, - getDiffusionModels, - getImageGenerationEnabledStatus, + getImageGenerationModels, + getDefaultImageGenerationModel, + updateDefaultImageGenerationModel, getImageSize, - toggleImageGenerationEnabledStatus, + getImageGenerationConfig, + updateImageGenerationConfig, updateAUTOMATIC1111Url, - updateDefaultDiffusionModel, updateImageSize, getImageSteps, - updateImageSteps + updateImageSteps, + getOpenAIKey, + updateOpenAIKey } from '$lib/apis/images'; import { getBackendConfig } from '$lib/apis'; const dispatch = createEventDispatcher(); @@ -23,8 +25,11 @@ let loading = false; + let imageGenerationEngine = ''; let enableImageGeneration = false; + let AUTOMATIC1111_BASE_URL = ''; + let OPENAI_API_KEY = ''; let selectedModel = ''; let models = null; @@ -33,11 +38,11 @@ let steps = 50; const getModels = async () => { - models = await getDiffusionModels(localStorage.token).catch((error) => { + models = await getImageGenerationModels(localStorage.token).catch((error) => { toast.error(error); return null; }); - selectedModel = await getDefaultDiffusionModel(localStorage.token).catch((error) => { + selectedModel = await getDefaultImageGenerationModel(localStorage.token).catch((error) => { return ''; }); }; @@ -62,33 +67,45 @@ AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); } }; - const toggleImageGeneration = async () => { - if (AUTOMATIC1111_BASE_URL) { - enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token).catch( - (error) => { - toast.error(error); - return false; - } - ); + const updateImageGeneration = async () => { + const res = await updateImageGenerationConfig( + localStorage.token, + imageGenerationEngine, + enableImageGeneration + ).catch((error) => { + toast.error(error); + return null; + }); - if (enableImageGeneration) { - config.set(await getBackendConfig(localStorage.token)); - getModels(); - } - } else { - enableImageGeneration = false; - toast.error('AUTOMATIC1111_BASE_URL not provided'); + if (res) { + imageGenerationEngine = res.engine; + enableImageGeneration = res.enabled; + } + + if (enableImageGeneration) { + config.set(await getBackendConfig(localStorage.token)); + getModels(); } }; onMount(async () => { if ($user.role === 'admin') { - enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token); - AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); + const res = await getImageGenerationConfig(localStorage.token).catch((error) => { + toast.error(error); + return null; + }); - if (enableImageGeneration && AUTOMATIC1111_BASE_URL) { - imageSize = await getImageSize(localStorage.token); - steps = await getImageSteps(localStorage.token); + if (res) { + imageGenerationEngine = res.engine; + enableImageGeneration = res.enabled; + } + AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); + OPENAI_API_KEY = await getOpenAIKey(localStorage.token); + + imageSize = await getImageSize(localStorage.token); + steps = await getImageSteps(localStorage.token); + + if (enableImageGeneration) { getModels(); } } @@ -99,7 +116,11 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={async () => { loading = true; - await updateDefaultDiffusionModel(localStorage.token, selectedModel); + await updateOpenAIKey(localStorage.token, OPENAI_API_KEY); + + await updateDefaultImageGenerationModel(localStorage.token, selectedModel); + + await updateDefaultImageGenerationModel(localStorage.token, selectedModel); await updateImageSize(localStorage.token, imageSize).catch((error) => { toast.error(error); return null; @@ -117,6 +138,23 @@