diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index f05ce0b7..220fb7d9 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -1,9 +1,9 @@ import { IMAGES_API_BASE_URL } from '$lib/constants'; -export const getImageGenerationEnabledStatus = async (token: string = '') => { +export const getImageGenerationConfig = async (token: string = '') => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, { + const res = await fetch(`${IMAGES_API_BASE_URL}/config`, { method: 'GET', headers: { Accept: 'application/json', @@ -32,16 +32,24 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => { return res; }; -export const toggleImageGenerationEnabledStatus = async (token: string = '') => { +export const updateImageGenerationConfig = async ( + token: string = '', + engine: string, + enabled: boolean +) => { let error = null; - const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, { - method: 'GET', + const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, { + method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', ...(token && { authorization: `Bearer ${token}` }) - } + }, + body: JSON.stringify({ + engine, + enabled + }) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -263,7 +271,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => { return res.IMAGE_STEPS; }; -export const getDiffusionModels = async (token: string = '') => { +export const getImageGenerationModels = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models`, { @@ -295,7 +303,7 @@ export const getDiffusionModels = async (token: string = '') => { return res; }; -export const getDefaultDiffusionModel = async (token: string = '') => { +export const getDefaultImageGenerationModel = async (token: string = '') => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, { @@ -327,7 +335,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => { return res.model; }; -export const updateDefaultDiffusionModel = async (token: string = '', model: string) => { +export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => { let error = null; const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, { diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index b36eae12..cced3c43 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -5,13 +5,13 @@ import { config, user } from '$lib/stores'; import { getAUTOMATIC1111Url, - getDefaultDiffusionModel, - getDiffusionModels, - getImageGenerationEnabledStatus, + getImageGenerationModels, + getDefaultImageGenerationModel, + updateDefaultImageGenerationModel, getImageSize, - toggleImageGenerationEnabledStatus, + getImageGenerationConfig, + updateImageGenerationConfig, updateAUTOMATIC1111Url, - updateDefaultDiffusionModel, updateImageSize, getImageSteps, updateImageSteps @@ -23,7 +23,9 @@ let loading = false; + let imageGenerationEngine = ''; let enableImageGeneration = false; + let AUTOMATIC1111_BASE_URL = ''; let selectedModel = ''; @@ -33,11 +35,11 @@ let steps = 50; const getModels = async () => { - models = await getDiffusionModels(localStorage.token).catch((error) => { + models = await getImageGenerationModels(localStorage.token).catch((error) => { toast.error(error); return null; }); - selectedModel = await getDefaultDiffusionModel(localStorage.token).catch((error) => { + selectedModel = await getDefaultImageGenerationModel(localStorage.token).catch((error) => { return ''; }); }; @@ -62,33 +64,44 @@ AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); } }; - const toggleImageGeneration = async () => { - if (AUTOMATIC1111_BASE_URL) { - enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token).catch( - (error) => { - toast.error(error); - return false; - } - ); + const updateImageGeneration = async () => { + const res = await updateImageGenerationConfig( + localStorage.token, + imageGenerationEngine, + enableImageGeneration + ).catch((error) => { + toast.error(error); + return null; + }); - if (enableImageGeneration) { - config.set(await getBackendConfig(localStorage.token)); - getModels(); - } - } else { - enableImageGeneration = false; - toast.error('AUTOMATIC1111_BASE_URL not provided'); + if (res) { + imageGenerationEngine = res.engine; + enableImageGeneration = res.enabled; + } + + if (enableImageGeneration) { + config.set(await getBackendConfig(localStorage.token)); + getModels(); } }; onMount(async () => { if ($user.role === 'admin') { - enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token); + const res = await getImageGenerationConfig(localStorage.token).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + imageGenerationEngine = res.engine; + enableImageGeneration = res.enabled; + } AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); - if (enableImageGeneration && AUTOMATIC1111_BASE_URL) { - imageSize = await getImageSize(localStorage.token); - steps = await getImageSteps(localStorage.token); + imageSize = await getImageSize(localStorage.token); + steps = await getImageSteps(localStorage.token); + + if (enableImageGeneration) { getModels(); } } @@ -99,7 +112,7 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={async () => { loading = true; - await updateDefaultDiffusionModel(localStorage.token, selectedModel); + await updateDefaultImageGenerationModel(localStorage.token, selectedModel); await updateImageSize(localStorage.token, imageSize).catch((error) => { toast.error(error); return null; @@ -117,6 +130,23 @@