feat: sd backend integration

This commit is contained in:
Timothy J. Baek 2024-02-21 18:12:01 -08:00
parent 7a730c3f0f
commit 733e963c44
11 changed files with 611 additions and 33 deletions

148
backend/apps/images/main.py Normal file
View file

@ -0,0 +1,148 @@
import os
import requests
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
get_current_user,
get_admin_user,
)
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
from config import AUTOMATIC1111_BASE_URL
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.ENABLED = False
@app.get("/enabled", response_model=bool)
async def get_enable_status(request: Request, user=Depends(get_admin_user)):
return app.state.ENABLED
@app.get("/enabled/toggle", response_model=bool)
async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
app.state.ENABLED = not app.state.ENABLED
return app.state.ENABLED
class UrlUpdateForm(BaseModel):
url: str
@app.get("/url")
async def get_openai_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
try:
r = requests.head(form_data.url)
if r.ok:
app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"status": True,
}
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
models = r.json()
return models
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
class UpdateModelForm(BaseModel):
model: str
def set_model_handler(model: str):
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
)
return options
@app.post("/models/default/update")
def update_default_model(
form_data: UpdateModelForm,
user=Depends(get_current_user),
):
return set_model_handler(form_data.model)
class GenerateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
n: int = 1
size: str = "512x512"
negative_prompt: Optional[str] = None
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, form_data.size.split("x")))
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json={
"prompt": form_data.prompt,
"negative_prompt": form_data.negative_prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
},
)
return r.json()

View file

@ -185,3 +185,10 @@ Query: [query]"""
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
####################################
# Images
####################################
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")

View file

@ -11,10 +11,10 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
from apps.rag.main import app as rag_app
from config import ENV, FRONTEND_BUILD_DIR from config import ENV, FRONTEND_BUILD_DIR
@ -58,10 +58,21 @@ app.mount("/api/v1", webui_app)
app.mount("/ollama/api", ollama_app) app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app) app.mount("/openai/api", openai_app)
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app) app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app) app.mount("/rag/api/v1", rag_app)
@app.get("/api/config")
async def get_app_config():
return {
"status": True,
"images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
}
app.mount( app.mount(
"/", "/",
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),

View file

@ -0,0 +1,167 @@
import { IMAGES_API_BASE_URL } from '$lib/constants';
export const getAUTOMATIC1111Url = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/url`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.AUTOMATIC1111_BASE_URL;
};
export const updateAUTOMATIC1111Url = async (token: string = '', url: string) => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/url/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
url: url
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.AUTOMATIC1111_BASE_URL;
};
export const getDiffusionModels = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/models`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const getDefaultDiffusionModel = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.model;
};
export const updateDefaultDiffusionModel = async (token: string = '', model: string) => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
model: model
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.model;
};

View file

@ -1,9 +1,9 @@
import { WEBUI_API_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
export const getBackendConfig = async () => { export const getBackendConfig = async () => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/`, { const res = await fetch(`${WEBUI_BASE_URL}/api/config`, {
method: 'GET', method: 'GET',
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'

View file

@ -2,7 +2,7 @@
import toast from 'svelte-french-toast'; import toast from 'svelte-french-toast';
import dayjs from 'dayjs'; import dayjs from 'dayjs';
import { marked } from 'marked'; import { marked } from 'marked';
import { settings } from '$lib/stores'; import { config, settings } from '$lib/stores';
import tippy from 'tippy.js'; import tippy from 'tippy.js';
import auto_render from 'katex/dist/contrib/auto-render.mjs'; import auto_render from 'katex/dist/contrib/auto-render.mjs';
import 'katex/dist/katex.min.css'; import 'katex/dist/katex.min.css';
@ -595,6 +595,32 @@
{/if} {/if}
</button> </button>
{#if $config.images}
<button
class="{isLastMessage
? 'visible'
: 'invisible group-hover:visible'} p-1 rounded dark:hover:text-white hover:text-black transition"
on:click={() => {
// generateImage
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="m2.25 15.75 5.159-5.159a2.25 2.25 0 0 1 3.182 0l5.159 5.159m-1.5-1.5 1.409-1.409a2.25 2.25 0 0 1 3.182 0l2.909 2.909m-18 3.75h16.5a1.5 1.5 0 0 0 1.5-1.5V6a1.5 1.5 0 0 0-1.5-1.5H3.75A1.5 1.5 0 0 0 2.25 6v12a1.5 1.5 0 0 0 1.5 1.5Zm10.5-11.25h.008v.008h-.008V8.25Zm.375 0a.375.375 0 1 1-.75 0 .375.375 0 0 1 .75 0Z"
/>
</svg>
</button>
{/if}
{#if message.info} {#if message.info}
<button <button
class=" {isLastMessage class=" {isLastMessage

View file

@ -18,7 +18,7 @@
<div class=" mb-2.5 text-sm font-medium">{WEBUI_NAME} Version</div> <div class=" mb-2.5 text-sm font-medium">{WEBUI_NAME} Version</div>
<div class="flex w-full"> <div class="flex w-full">
<div class="flex-1 text-xs text-gray-700 dark:text-gray-200"> <div class="flex-1 text-xs text-gray-700 dark:text-gray-200">
{$config && $config.version ? $config.version : WEB_UI_VERSION} {WEB_UI_VERSION}
</div> </div>
</div> </div>
</div> </div>

View file

@ -68,6 +68,7 @@
on:click={() => { on:click={() => {
updateOllamaAPIUrlHandler(); updateOllamaAPIUrlHandler();
}} }}
type="button"
> >
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"

View file

@ -0,0 +1,206 @@
<script lang="ts">
import toast from 'svelte-french-toast';
import { createEventDispatcher, onMount } from 'svelte';
import { user } from '$lib/stores';
import {
getAUTOMATIC1111Url,
getDefaultDiffusionModel,
getDiffusionModels,
updateAUTOMATIC1111Url,
updateDefaultDiffusionModel
} from '$lib/apis/images';
const dispatch = createEventDispatcher();
export let saveSettings: Function;
let loading = false;
let enableImageGeneration = true;
let AUTOMATIC1111_BASE_URL = '';
let selectedModel = '';
let models = [];
const updateAUTOMATIC1111UrlHandler = async () => {
const res = await updateAUTOMATIC1111Url(localStorage.token, AUTOMATIC1111_BASE_URL).catch(
(error) => {
toast.error(error);
return null;
}
);
if (res) {
toast.success('Server connection verified');
AUTOMATIC1111_BASE_URL = res;
models = await getDiffusionModels(localStorage.token);
selectedModel = await getDefaultDiffusionModel(localStorage.token);
} else {
AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
}
};
const toggleImageGeneration = async () => {
enableImageGeneration = !enableImageGeneration;
};
onMount(async () => {
if ($user.role === 'admin') {
AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
if (AUTOMATIC1111_BASE_URL) {
models = await getDiffusionModels(localStorage.token);
selectedModel = await getDefaultDiffusionModel(localStorage.token);
}
}
});
</script>
<form
class="flex flex-col h-full justify-between space-y-3 text-sm"
on:submit|preventDefault={async () => {
loading = true;
const res = await updateDefaultDiffusionModel(localStorage.token, selectedModel);
dispatch('save');
loading = false;
}}
>
<div class=" space-y-3 pr-1.5 overflow-y-scroll max-h-80">
<div>
<div class=" mb-1 text-sm font-medium">Image Settings</div>
<div>
<div class=" py-0.5 flex w-full justify-between">
<div class=" self-center text-xs font-medium">Image Generation (Experimental)</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
on:click={() => {
toggleImageGeneration();
}}
type="button"
>
{#if enableImageGeneration === true}
<span class="ml-2 self-center">On</span>
{:else}
<span class="ml-2 self-center">Off</span>
{/if}
</button>
</div>
</div>
</div>
<hr class=" dark:border-gray-700" />
<div class=" mb-2.5 text-sm font-medium">AUTOMATIC1111 Base URL</div>
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
bind:value={AUTOMATIC1111_BASE_URL}
/>
</div>
<button
class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded transition"
type="button"
on:click={() => {
// updateOllamaAPIUrlHandler();
updateAUTOMATIC1111UrlHandler();
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
clip-rule="evenodd"
/>
</svg>
</button>
</div>
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
Include `--api` flag when running stable-diffusion-webui
<a
class=" text-gray-300 font-medium"
href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/3734"
target="_blank"
>
(e.g. `sh webui.sh --api`)
</a>
</div>
{#if enableImageGeneration}
<hr class=" dark:border-gray-700" />
<div>
<div class=" mb-2.5 text-sm font-medium">Set default model</div>
<div class="flex w-full">
<div class="flex-1 mr-2">
<select
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
bind:value={selectedModel}
placeholder="Select a model"
>
{#if !selectedModel}
<option value="" disabled selected>Select a model</option>
{/if}
{#each models as model}
<option value={model.title} class="bg-gray-100 dark:bg-gray-700"
>{model.model_name}</option
>
{/each}
</select>
</div>
</div>
</div>
{/if}
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded flex flex-row space-x-1 items-center {loading
? ' cursor-not-allowed'
: ''}"
type="submit"
disabled={loading}
>
Save
{#if loading}
<div class="ml-2 self-center">
<svg
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div>
{/if}
</button>
</div>
</form>

View file

@ -14,6 +14,7 @@
import Audio from './Settings/Audio.svelte'; import Audio from './Settings/Audio.svelte';
import Chats from './Settings/Chats.svelte'; import Chats from './Settings/Chats.svelte';
import Connections from './Settings/Connections.svelte'; import Connections from './Settings/Connections.svelte';
import Images from './Settings/Images.svelte';
export let show = false; export let show = false;
@ -206,13 +207,14 @@
<div class=" self-center">Audio</div> <div class=" self-center">Audio</div>
</button> </button>
{#if $user.role === 'admin'}
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'image' 'images'
? 'bg-gray-200 dark:bg-gray-700' ? 'bg-gray-200 dark:bg-gray-700'
: ' hover:bg-gray-300 dark:hover:bg-gray-800'}" : ' hover:bg-gray-300 dark:hover:bg-gray-800'}"
on:click={() => { on:click={() => {
selectedTab = 'image'; selectedTab = 'images';
}} }}
> >
<div class=" self-center mr-2"> <div class=" self-center mr-2">
@ -229,8 +231,9 @@
/> />
</svg> </svg>
</div> </div>
<div class=" self-center">Image</div> <div class=" self-center">Images</div>
</button> </button>
{/if}
<button <button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab === class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
@ -342,6 +345,13 @@
show = false; show = false;
}} }}
/> />
{:else if selectedTab === 'images'}
<Images
{saveSettings}
on:save={() => {
show = false;
}}
/>
{:else if selectedTab === 'chats'} {:else if selectedTab === 'chats'}
<Chats {saveSettings} /> <Chats {saveSettings} />
{:else if selectedTab === 'account'} {:else if selectedTab === 'account'}

View file

@ -1,4 +1,5 @@
import { dev } from '$app/environment'; import { dev } from '$app/environment';
// import { version } from '../../package.json';
export const WEBUI_NAME = 'Open WebUI'; export const WEBUI_NAME = 'Open WebUI';
export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``; export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``;
@ -6,10 +7,11 @@ export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``;
export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`; export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`; export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`; export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`; export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`;
export const IMAGES_API_BASE_URL = `${WEBUI_BASE_URL}/images/api/v1`;
export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
export const WEB_UI_VERSION = 'v1.0.0-alpha-static'; export const WEB_UI_VERSION = APP_VERSION;
export const REQUIRED_OLLAMA_VERSION = '0.1.16'; export const REQUIRED_OLLAMA_VERSION = '0.1.16';