feat: chat playground backend integration

This commit is contained in:
Timothy J. Baek 2024-03-02 18:16:02 -08:00
parent 656f8dab05
commit 901e7a33fa
5 changed files with 259 additions and 81 deletions

View file

@ -11,7 +11,7 @@ from pydantic import BaseModel
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_BASE_URL, WEBUI_AUTH
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
@ -22,7 +22,7 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL app.state.OLLAMA_BASE_URL = OLLAMA_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@ -32,7 +32,7 @@ REQUEST_POOL = []
@app.get("/url") @app.get("/url")
async def get_ollama_api_url(user=Depends(get_admin_user)): async def get_ollama_api_url(user=Depends(get_admin_user)):
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL}
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
@ -41,8 +41,8 @@ class UrlUpdateForm(BaseModel):
@app.post("/url/update") @app.post("/url/update")
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OLLAMA_API_BASE_URL = form_data.url app.state.OLLAMA_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL}
@app.get("/cancel/{request_id}") @app.get("/cancel/{request_id}")
@ -57,7 +57,7 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user))
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)): async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" target_url = f"{app.state.OLLAMA_BASE_URL}/{path}"
body = await request.body() body = await request.body()
headers = dict(request.headers) headers = dict(request.headers)

View file

@ -211,6 +211,17 @@ if ENV == "prod":
if OLLAMA_API_BASE_URL == "/ollama/api": if OLLAMA_API_BASE_URL == "/ollama/api":
OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api" OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api"
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
if OLLAMA_BASE_URL == "":
OLLAMA_BASE_URL = (
OLLAMA_API_BASE_URL[:-4]
if OLLAMA_API_BASE_URL.endswith("/api")
else OLLAMA_API_BASE_URL
)
#################################### ####################################
# OPENAI_API # OPENAI_API
#################################### ####################################

View file

@ -29,7 +29,7 @@ export const getOllamaAPIUrl = async (token: string = '') => {
throw error; throw error;
} }
return res.OLLAMA_API_BASE_URL; return res.OLLAMA_BASE_URL;
}; };
export const updateOllamaAPIUrl = async (token: string = '', url: string) => { export const updateOllamaAPIUrl = async (token: string = '', url: string) => {
@ -64,13 +64,13 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => {
throw error; throw error;
} }
return res.OLLAMA_API_BASE_URL; return res.OLLAMA_BASE_URL;
}; };
export const getOllamaVersion = async (token: string = '') => { export const getOllamaVersion = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/version`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/version`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -102,7 +102,7 @@ export const getOllamaVersion = async (token: string = '') => {
export const getOllamaModels = async (token: string = '') => { export const getOllamaModels = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/tags`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -148,7 +148,7 @@ export const generateTitle = async (
console.log(template); console.log(template);
const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -186,7 +186,7 @@ export const generatePrompt = async (token: string = '', model: string, conversa
conversation = '[no existing conversation]'; conversation = '[no existing conversation]';
} }
const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -220,7 +220,7 @@ export const generatePrompt = async (token: string = '', model: string, conversa
export const generateTextCompletion = async (token: string = '', model: string, text: string) => { export const generateTextCompletion = async (token: string = '', model: string, text: string) => {
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -247,7 +247,7 @@ export const generateChatCompletion = async (token: string = '', body: object) =
let controller = new AbortController(); let controller = new AbortController();
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/chat`, {
signal: controller.signal, signal: controller.signal,
method: 'POST', method: 'POST',
headers: { headers: {
@ -291,7 +291,7 @@ export const cancelChatCompletion = async (token: string = '', requestId: string
export const createModel = async (token: string, tagName: string, content: string) => { export const createModel = async (token: string, tagName: string, content: string) => {
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/create`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/create`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -316,7 +316,7 @@ export const createModel = async (token: string, tagName: string, content: strin
export const deleteModel = async (token: string, tagName: string) => { export const deleteModel = async (token: string, tagName: string) => {
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/delete`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/delete`, {
method: 'DELETE', method: 'DELETE',
headers: { headers: {
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -350,7 +350,7 @@ export const deleteModel = async (token: string, tagName: string) => {
export const pullModel = async (token: string, tagName: string) => { export const pullModel = async (token: string, tagName: string) => {
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/pull`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',

View file

@ -114,12 +114,12 @@
<hr class=" dark:border-gray-700" /> <hr class=" dark:border-gray-700" />
<div> <div>
<div class=" mb-2.5 text-sm font-medium">Ollama API URL</div> <div class=" mb-2.5 text-sm font-medium">Ollama Base URL</div>
<div class="flex w-full"> <div class="flex w-full">
<div class="flex-1 mr-2"> <div class="flex-1 mr-2">
<input <input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none" class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
placeholder="Enter URL (e.g. http://localhost:11434/api)" placeholder="Enter URL (e.g. http://localhost:11434)"
bind:value={API_BASE_URL} bind:value={API_BASE_URL}
/> />
</div> </div>

View file

@ -1,14 +1,21 @@
<script> <script>
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { onMount } from 'svelte'; import { onMount, tick } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { WEBUI_API_BASE_URL } from '$lib/constants'; import {
LITELLM_API_BASE_URL,
OLLAMA_API_BASE_URL,
OPENAI_API_BASE_URL,
WEBUI_API_BASE_URL
} from '$lib/constants';
import { WEBUI_NAME, config, user, models, settings } from '$lib/stores'; import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
import { cancelChatCompletion, generateChatCompletion } from '$lib/apis/ollama'; import { cancelChatCompletion, generateChatCompletion } from '$lib/apis/ollama';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
let mode = 'chat'; let mode = 'chat';
@ -16,18 +23,28 @@
let text = ''; let text = '';
let selectedModel = ''; let selectedModelId = '';
let loading = false; let loading = false;
let currentRequestId; let currentRequestId;
let stopResponseFlag = false; let stopResponseFlag = false;
let system = ''; let system = '';
let messages = []; let messages = [
{
role: 'user',
content: ''
}
];
const scrollToBottom = () => { const scrollToBottom = () => {
const element = document.getElementById('text-completion-textarea'); // const element = document.getElementById('text-completion-textarea');
element.scrollTop = element.scrollHeight;
const element = document.getElementById('messages-container');
if (element) {
element.scrollTop = element?.scrollHeight;
}
}; };
// const cancelHandler = async () => { // const cancelHandler = async () => {
@ -43,67 +60,216 @@
console.log('stopResponse'); console.log('stopResponse');
}; };
const submitHandler = async () => { const textCompletionHandler = async () => {
if (selectedModel) { const [res, controller] = await generateChatCompletion(localStorage.token, {
loading = true; model: selectedModelId,
messages: [
{
role: 'assistant',
content: text
}
]
});
const [res, controller] = await generateChatCompletion(localStorage.token, { if (res && res.ok) {
model: selectedModel, const reader = res.body
messages: [ .pipeThrough(new TextDecoderStream())
{ .pipeThrough(splitStream('\n'))
role: 'assistant', .getReader();
content: text
}
]
});
if (res && res.ok) { while (true) {
const reader = res.body const { value, done } = await reader.read();
.pipeThrough(new TextDecoderStream()) if (done || stopResponseFlag) {
.pipeThrough(splitStream('\n')) if (stopResponseFlag) {
.getReader(); await cancelChatCompletion(localStorage.token, currentRequestId);
while (true) {
const { value, done } = await reader.read();
if (done || stopResponseFlag) {
if (stopResponseFlag) {
await cancelChatCompletion(localStorage.token, currentRequestId);
}
currentRequestId = null;
break;
} }
try { currentRequestId = null;
let lines = value.split('\n'); break;
}
for (const line of lines) { try {
if (line !== '') { let lines = value.split('\n');
console.log(line);
let data = JSON.parse(line);
if ('detail' in data) { for (const line of lines) {
throw data; if (line !== '') {
} console.log(line);
let data = JSON.parse(line);
if ('id' in data) { if ('detail' in data) {
console.log(data); throw data;
currentRequestId = data.id; }
if ('id' in data) {
console.log(data);
currentRequestId = data.id;
} else {
if (data.done == false) {
text += data.message.content;
} else { } else {
if (data.done == false) { console.log('done');
text += data.message.content;
} else {
console.log('done');
}
} }
} }
} }
} catch (error) {
console.log(error);
} }
} catch (error) {
scrollToBottom(); console.log(error);
} }
scrollToBottom();
}
}
};
const chatCompletionHandler = async () => {
const model = $models.find((model) => model.id === selectedModelId);
const res = await generateOpenAIChatCompletion(
localStorage.token,
{
model: model.id,
stream: true,
messages: [
system
? {
role: 'system',
content: system
}
: undefined,
...messages
].filter((message) => message)
},
model.external
? model.source === 'litellm'
? `${LITELLM_API_BASE_URL}/v1`
: `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1`
);
// const [res, controller] = await generateChatCompletion(localStorage.token, {
// model: selectedModelId,
// messages: [
// {
// role: 'assistant',
// content: text
// }
// ]
// });
let responseMessage;
if (messages.at(-1)?.role === 'assistant') {
responseMessage = messages.at(-1);
} else {
responseMessage = {
role: 'assistant',
content: ''
};
messages.push(responseMessage);
messages = messages;
}
await tick();
const textareaElement = document.getElementById(`assistant-${messages.length - 1}-textarea`);
if (res && res.ok) {
const reader = res.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(splitStream('\n'))
.getReader();
while (true) {
const { value, done } = await reader.read();
if (done || stopResponseFlag) {
break;
}
try {
let lines = value.split('\n');
for (const line of lines) {
if (line !== '') {
console.log(line);
if (line === 'data: [DONE]') {
// responseMessage.done = true;
messages = messages;
} else {
let data = JSON.parse(line.replace(/^data: /, ''));
console.log(data);
if (responseMessage.content == '' && data.choices[0].delta.content == '\n') {
continue;
} else {
textareaElement.style.height = textareaElement.scrollHeight + 'px';
responseMessage.content += data.choices[0].delta.content ?? '';
messages = messages;
textareaElement.style.height = textareaElement.scrollHeight + 'px';
await tick();
}
}
}
}
} catch (error) {
console.log(error);
}
scrollToBottom();
}
// while (true) {
// const { value, done } = await reader.read();
// if (done || stopResponseFlag) {
// if (stopResponseFlag) {
// await cancelChatCompletion(localStorage.token, currentRequestId);
// }
// currentRequestId = null;
// break;
// }
// try {
// let lines = value.split('\n');
// for (const line of lines) {
// if (line !== '') {
// console.log(line);
// let data = JSON.parse(line);
// if ('detail' in data) {
// throw data;
// }
// if ('id' in data) {
// console.log(data);
// currentRequestId = data.id;
// } else {
// if (data.done == false) {
// text += data.message.content;
// } else {
// console.log('done');
// }
// }
// }
// }
// } catch (error) {
// console.log(error);
// }
// scrollToBottom();
// }
}
};
const submitHandler = async () => {
if (selectedModelId) {
loading = true;
if (mode === 'complete') {
await textCompletionHandler();
} else if (mode === 'chat') {
await chatCompletionHandler();
} }
loading = false; loading = false;
@ -118,11 +284,11 @@
} }
if ($settings?.models) { if ($settings?.models) {
selectedModel = $settings?.models[0]; selectedModelId = $settings?.models[0];
} else if ($config?.default_models) { } else if ($config?.default_models) {
selectedModel = $config?.default_models.split(',')[0]; selectedModelId = $config?.default_models.split(',')[0];
} else { } else {
selectedModel = ''; selectedModelId = '';
} }
loaded = true; loaded = true;
}); });
@ -185,7 +351,7 @@
<select <select
id="models" id="models"
class="outline-none bg-transparent text-sm font-medium rounded-lg w-full placeholder-gray-400" class="outline-none bg-transparent text-sm font-medium rounded-lg w-full placeholder-gray-400"
bind:value={selectedModel} bind:value={selectedModelId}
> >
<option class=" text-gray-800" value="" selected disabled>Select a model</option> <option class=" text-gray-800" value="" selected disabled>Select a model</option>
@ -234,10 +400,11 @@
<div class="p-3 outline outline-1 outline-gray-200 dark:outline-gray-800 rounded-lg"> <div class="p-3 outline outline-1 outline-gray-200 dark:outline-gray-800 rounded-lg">
<div class=" text-sm font-medium">System</div> <div class=" text-sm font-medium">System</div>
<textarea <textarea
id="text-completion-textarea" id="system-textarea"
class="w-full h-full bg-transparent resize-none outline-none text-sm" class="w-full h-full bg-transparent resize-none outline-none text-sm"
bind:value={system} bind:value={system}
placeholder="You're a helpful assistant." placeholder="You're a helpful assistant."
rows="4"
/> />
</div> </div>
</div> </div>
@ -271,8 +438,8 @@
<div class="flex-1"> <div class="flex-1">
<textarea <textarea
id="text-completion-textarea" id="{message.role}-{idx}-textarea"
class="w-full bg-transparent outline-none rounded-lg p-2 text-sm resize-none" class="w-full bg-transparent outline-none rounded-lg p-2 text-sm resize-none overflow-hidden"
placeholder="Enter {message.role === 'user' placeholder="Enter {message.role === 'user'
? 'a user' ? 'a user'
: 'an assistant'} message here" : 'an assistant'} message here"
@ -320,7 +487,7 @@
{/each} {/each}
<button <button
class="flex items-center gap-2" class="flex items-center gap-2 px-2 py-1"
on:click={() => { on:click={() => {
console.log(messages.at(-1)); console.log(messages.at(-1));
messages.push({ messages.push({