diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py new file mode 100644 index 00000000..c80903c1 --- /dev/null +++ b/backend/apps/images/main.py @@ -0,0 +1,148 @@ +import os +import requests +from fastapi import ( + FastAPI, + Request, + Depends, + HTTPException, + status, + UploadFile, + File, + Form, +) +from fastapi.middleware.cors import CORSMiddleware +from faster_whisper import WhisperModel + +from constants import ERROR_MESSAGES +from utils.utils import ( + get_current_user, + get_admin_user, +) +from utils.misc import calculate_sha256 +from typing import Optional +from pydantic import BaseModel +from config import AUTOMATIC1111_BASE_URL + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL +app.state.ENABLED = False + + +@app.get("/enabled", response_model=bool) +async def get_enable_status(request: Request, user=Depends(get_admin_user)): + return app.state.ENABLED + + +@app.get("/enabled/toggle", response_model=bool) +async def toggle_enabled(request: Request, user=Depends(get_admin_user)): + app.state.ENABLED = not app.state.ENABLED + return app.state.ENABLED + + +class UrlUpdateForm(BaseModel): + url: str + + +@app.get("/url") +async def get_openai_url(user=Depends(get_admin_user)): + return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} + + +@app.post("/url/update") +async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): + try: + r = requests.head(form_data.url) + if r.ok: + app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/") + return { + "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, + "status": True, + } + except Exception as e: + raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e)) + + +@app.get("/models") +def get_models(user=Depends(get_current_user)): + try: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models") + models = r.json() + return models + except Exception as e: + raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e)) + + +@app.get("/models/default") +async def get_default_model(user=Depends(get_admin_user)): + try: + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() + + return {"model": options["sd_model_checkpoint"]} + except Exception as e: + raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e)) + + +class UpdateModelForm(BaseModel): + model: str + + +def set_model_handler(model: str): + r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") + options = r.json() + + if model != options["sd_model_checkpoint"]: + options["sd_model_checkpoint"] = model + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options + ) + + return options + + +@app.post("/models/default/update") +def update_default_model( + form_data: UpdateModelForm, + user=Depends(get_current_user), +): + return set_model_handler(form_data.model) + + +class GenerateImageForm(BaseModel): + model: Optional[str] = None + prompt: str + n: int = 1 + size: str = "512x512" + negative_prompt: Optional[str] = None + + +@app.post("/generations") +def generate_image( + form_data: GenerateImageForm, + user=Depends(get_current_user), +): + if form_data.model: + set_model_handler(form_data.model) + + width, height = tuple(map(int, form_data.size.split("x"))) + + r = requests.get( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + json={ + "prompt": form_data.prompt, + "negative_prompt": form_data.negative_prompt, + "batch_size": form_data.n, + "width": width, + "height": height, + }, + ) + + return r.json() diff --git a/backend/config.py b/backend/config.py index 8167d4f1..caf2cc45 100644 --- a/backend/config.py +++ b/backend/config.py @@ -185,3 +185,10 @@ Query: [query]""" WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") + + +#################################### +# Images +#################################### + +AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") diff --git a/backend/main.py b/backend/main.py index 3a28670e..d1fb0c20 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,10 +11,10 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app from apps.audio.main import app as audio_app - +from apps.images.main import app as images_app +from apps.rag.main import app as rag_app from apps.web.main import app as webui_app -from apps.rag.main import app as rag_app from config import ENV, FRONTEND_BUILD_DIR @@ -58,10 +58,21 @@ app.mount("/api/v1", webui_app) app.mount("/ollama/api", ollama_app) app.mount("/openai/api", openai_app) +app.mount("/images/api/v1", images_app) app.mount("/audio/api/v1", audio_app) app.mount("/rag/api/v1", rag_app) +@app.get("/api/config") +async def get_app_config(): + return { + "status": True, + "images": images_app.state.ENABLED, + "default_models": webui_app.state.DEFAULT_MODELS, + "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, + } + + app.mount( "/", SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts new file mode 100644 index 00000000..63bc04a9 --- /dev/null +++ b/src/lib/apis/images/index.ts @@ -0,0 +1,167 @@ +import { IMAGES_API_BASE_URL } from '$lib/constants'; + +export const getAUTOMATIC1111Url = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/url`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AUTOMATIC1111_BASE_URL; +}; + +export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + url: url + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AUTOMATIC1111_BASE_URL; +}; + +export const getDiffusionModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getDefaultDiffusionModel = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.model; +}; + +export const updateDefaultDiffusionModel = async (token: string = '', model: string) => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + model: model + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.model; +}; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 91512166..c20107ce 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,9 +1,9 @@ -import { WEBUI_API_BASE_URL } from '$lib/constants'; +import { WEBUI_BASE_URL } from '$lib/constants'; export const getBackendConfig = async () => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/`, { + const res = await fetch(`${WEBUI_BASE_URL}/api/config`, { method: 'GET', headers: { 'Content-Type': 'application/json' diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index ca64575d..182a66b3 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -2,7 +2,7 @@ import toast from 'svelte-french-toast'; import dayjs from 'dayjs'; import { marked } from 'marked'; - import { settings } from '$lib/stores'; + import { config, settings } from '$lib/stores'; import tippy from 'tippy.js'; import auto_render from 'katex/dist/contrib/auto-render.mjs'; import 'katex/dist/katex.min.css'; @@ -595,6 +595,32 @@ {/if} + {#if $config.images} + + {/if} + {#if message.info} + + + +
+ +
AUTOMATIC1111 Base URL
+
+
+ +
+ +
+ +
+ Include `--api` flag when running stable-diffusion-webui + + (e.g. `sh webui.sh --api`) + +
+ + {#if enableImageGeneration} +
+ +
+
Set default model
+
+
+ +
+
+
+ {/if} + + +
+ +
+ diff --git a/src/lib/components/chat/SettingsModal.svelte b/src/lib/components/chat/SettingsModal.svelte index 66ea4784..9d631f16 100644 --- a/src/lib/components/chat/SettingsModal.svelte +++ b/src/lib/components/chat/SettingsModal.svelte @@ -14,6 +14,7 @@ import Audio from './Settings/Audio.svelte'; import Chats from './Settings/Chats.svelte'; import Connections from './Settings/Connections.svelte'; + import Images from './Settings/Images.svelte'; export let show = false; @@ -206,31 +207,33 @@
Audio
- + {#if $user.role === 'admin'} + + {/if}