feat: sd frontend integration

This commit is contained in:
Timothy J. Baek 2024-02-21 18:36:40 -08:00
parent 733e963c44
commit cc50cc10e6
7 changed files with 218 additions and 29 deletions

View file

@ -5,6 +5,8 @@ OLLAMA_API_BASE_URL='http://localhost:11434/api'
OPENAI_API_BASE_URL=''
OPENAI_API_KEY=''
# AUTOMATIC1111_BASE_URL="http://localhost:7860"
# DO NOT TRACK
SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true

View file

@ -283,7 +283,7 @@ git clone https://github.com/open-webui/open-webui.git
cd open-webui/
# Copying required .env file
cp -RPp example.env .env
cp -RPp .env.example .env
# Building Frontend Using Node
npm i

View file

@ -33,7 +33,7 @@ app.add_middleware(
)
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.ENABLED = False
app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
@app.get("/enabled", response_model=bool)
@ -129,20 +129,33 @@ def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, form_data.size.split("x")))
print(form_data)
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json={
try:
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, form_data.size.split("x")))
data = {
"prompt": form_data.prompt,
"negative_prompt": form_data.negative_prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
},
)
}
return r.json()
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
print(data)
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
return r.json()
except Exception as e:
print(e)
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))

View file

@ -1,5 +1,69 @@
import { IMAGES_API_BASE_URL } from '$lib/constants';
export const getImageGenerationEnabledStatus = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const toggleImageGenerationEnabledStatus = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const getAUTOMATIC1111Url = async (token: string = '') => {
let error = null;
@ -165,3 +229,38 @@ export const updateDefaultDiffusionModel = async (token: string = '', model: str
return res.model;
};
export const imageGenerations = async (token: string = '', prompt: string) => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/generations`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
prompt: prompt
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -11,6 +11,7 @@
import ResponseMessage from './Messages/ResponseMessage.svelte';
import Placeholder from './Messages/Placeholder.svelte';
import Spinner from '../common/Spinner.svelte';
import { imageGenerations } from '$lib/apis/images';
export let chatId = '';
export let sendPrompt: Function;

View file

@ -16,6 +16,7 @@
import { synthesizeOpenAISpeech } from '$lib/apis/openai';
import { extractSentences } from '$lib/utils';
import { imageGenerations } from '$lib/apis/images';
export let modelfiles = [];
export let message;
@ -43,6 +44,8 @@
let loadingSpeech = false;
let generatingImage = false;
$: tokens = marked.lexer(message.content);
const renderer = new marked.Renderer();
@ -267,6 +270,21 @@
renderStyling();
};
const generateImage = async (message) => {
generatingImage = true;
const res = await imageGenerations(localStorage.token, message.content);
console.log(res);
if (res) {
message.files = res.images.map((image) => ({
type: 'image',
url: `data:image/png;base64,${image}`
}));
}
generatingImage = false;
};
onMount(async () => {
await tick();
renderStyling();
@ -295,6 +313,18 @@
{#if message.content === ''}
<Skeleton />
{:else}
{#if message.files}
<div class="my-2.5 w-full flex overflow-x-auto gap-2 flex-wrap">
{#each message.files as file}
<div>
{#if file.type === 'image'}
<img src={file.url} alt="input" class=" max-h-96 rounded-lg" draggable="false" />
{/if}
</div>
{/each}
</div>
{/if}
<div
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:m-0 prose-p:-mb-6 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-8 prose-li:-mb-4 whitespace-pre-line"
>
@ -601,23 +631,62 @@
? 'visible'
: 'invisible group-hover:visible'} p-1 rounded dark:hover:text-white hover:text-black transition"
on:click={() => {
// generateImage
if (!generatingImage) {
generateImage(message);
}
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="m2.25 15.75 5.159-5.159a2.25 2.25 0 0 1 3.182 0l5.159 5.159m-1.5-1.5 1.409-1.409a2.25 2.25 0 0 1 3.182 0l2.909 2.909m-18 3.75h16.5a1.5 1.5 0 0 0 1.5-1.5V6a1.5 1.5 0 0 0-1.5-1.5H3.75A1.5 1.5 0 0 0 2.25 6v12a1.5 1.5 0 0 0 1.5 1.5Zm10.5-11.25h.008v.008h-.008V8.25Zm.375 0a.375.375 0 1 1-.75 0 .375.375 0 0 1 .75 0Z"
/>
</svg>
{#if generatingImage}
<svg
class=" w-4 h-4"
fill="currentColor"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_S1WN {
animation: spinner_MGfb 0.8s linear infinite;
animation-delay: -0.8s;
}
.spinner_Km9P {
animation-delay: -0.65s;
}
.spinner_JApP {
animation-delay: -0.5s;
}
@keyframes spinner_MGfb {
93.75%,
100% {
opacity: 0.2;
}
}
</style><circle class="spinner_S1WN" cx="4" cy="12" r="3" /><circle
class="spinner_S1WN spinner_Km9P"
cx="12"
cy="12"
r="3"
/><circle
class="spinner_S1WN spinner_JApP"
cx="20"
cy="12"
r="3"
/></svg
>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="m2.25 15.75 5.159-5.159a2.25 2.25 0 0 1 3.182 0l5.159 5.159m-1.5-1.5 1.409-1.409a2.25 2.25 0 0 1 3.182 0l2.909 2.909m-18 3.75h16.5a1.5 1.5 0 0 0 1.5-1.5V6a1.5 1.5 0 0 0-1.5-1.5H3.75A1.5 1.5 0 0 0 2.25 6v12a1.5 1.5 0 0 0 1.5 1.5Zm10.5-11.25h.008v.008h-.008V8.25Zm.375 0a.375.375 0 1 1-.75 0 .375.375 0 0 1 .75 0Z"
/>
</svg>
{/if}
</button>
{/if}

View file

@ -2,14 +2,17 @@
import toast from 'svelte-french-toast';
import { createEventDispatcher, onMount } from 'svelte';
import { user } from '$lib/stores';
import { config, user } from '$lib/stores';
import {
getAUTOMATIC1111Url,
getDefaultDiffusionModel,
getDiffusionModels,
getImageGenerationEnabledStatus,
toggleImageGenerationEnabledStatus,
updateAUTOMATIC1111Url,
updateDefaultDiffusionModel
} from '$lib/apis/images';
import { getBackendConfig } from '$lib/apis';
const dispatch = createEventDispatcher();
export let saveSettings: Function;
@ -42,11 +45,13 @@
};
const toggleImageGeneration = async () => {
enableImageGeneration = !enableImageGeneration;
enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token);
config.set(await getBackendConfig(localStorage.token));
};
onMount(async () => {
if ($user.role === 'admin') {
enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token);
AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
if (AUTOMATIC1111_BASE_URL) {