From f04d60b6d954ee4e68805240f19ea044028e4dba Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 5 Mar 2024 00:59:35 -0800 Subject: [PATCH 1/5] feat: multiple ollama support --- backend/apps/ollama/main.py | 728 +++++++++++++++++- backend/main.py | 8 + src/lib/apis/ollama/index.ts | 42 +- .../chat/Settings/Connections.svelte | 115 ++- 4 files changed, 831 insertions(+), 62 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 7965ff32..5b3589e0 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -3,16 +3,23 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool +from pydantic import BaseModel + +import random import requests import json import uuid -from pydantic import BaseModel +import aiohttp +import asyncio from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user, get_admin_user from config import OLLAMA_BASE_URL, WEBUI_AUTH +from typing import Optional, List, Union + + app = FastAPI() app.add_middleware( CORSMiddleware, @@ -23,26 +30,39 @@ app.add_middleware( ) app.state.OLLAMA_BASE_URL = OLLAMA_BASE_URL - -# TARGET_SERVER_URL = OLLAMA_API_BASE_URL +app.state.OLLAMA_BASE_URLS = [OLLAMA_BASE_URL] +app.state.MODELS = {} REQUEST_POOL = [] -@app.get("/url") -async def get_ollama_api_url(user=Depends(get_admin_user)): - return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL} +@app.middleware("http") +async def check_url(request: Request, call_next): + if len(app.state.MODELS) == 0: + await get_all_models() + else: + pass + + response = await call_next(request) + return response + + +@app.get("/urls") +async def get_ollama_api_urls(user=Depends(get_admin_user)): + return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} class UrlUpdateForm(BaseModel): - url: str + urls: List[str] -@app.post("/url/update") +@app.post("/urls/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.OLLAMA_BASE_URL = form_data.url - return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL} + app.state.OLLAMA_BASE_URLS = form_data.urls + + print(app.state.OLLAMA_BASE_URLS) + return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} @app.get("/cancel/{request_id}") @@ -55,9 +75,695 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +async def fetch_url(url): + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + return await response.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + return None + + +def merge_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + for model in model_list: + digest = model["digest"] + if digest not in merged_models: + model["urls"] = [idx] + merged_models[digest] = model + else: + merged_models[digest]["urls"].append(idx) + + return list(merged_models.values()) + + +# user=Depends(get_current_user) + + +async def get_all_models(): + print("get_all_models") + tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] + responses = await asyncio.gather(*tasks) + responses = list(filter(lambda x: x is not None, responses)) + + models = { + "models": merge_models_lists( + map(lambda response: response["models"], responses) + ) + } + app.state.MODELS = {model["model"]: model for model in models["models"]} + + return models + + +@app.get("/api/tags") +@app.get("/api/tags/{url_idx}") +async def get_ollama_tags( + url_idx: Optional[int] = None, user=Depends(get_current_user) +): + + if url_idx == None: + return await get_all_models() + else: + url = app.state.OLLAMA_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/api/tags") + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.get("/api/version") +@app.get("/api/version/{url_idx}") +async def get_ollama_versions(url_idx: Optional[int] = None): + + if url_idx == None: + + # returns lowest version + tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] + responses = await asyncio.gather(*tasks) + responses = list(filter(lambda x: x is not None, responses)) + + lowest_version = min( + responses, key=lambda x: tuple(map(int, x["version"].split("."))) + ) + + return {"version": lowest_version["version"]} + else: + url = app.state.OLLAMA_BASE_URLS[url_idx] + try: + r = requests.request(method="GET", url=f"{url}/api/version") + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class ModelNameForm(BaseModel): + name: str + + +@app.post("/api/pull") +@app.post("/api/pull/{url_idx}") +async def pull_model( + form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) +): + url = app.state.OLLAMA_BASE_URLS[url_idx] + r = None + + def get_request(url): + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/pull", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request(url)) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class PushModelForm(BaseModel): + name: str + insecure: Optional[bool] = None + stream: Optional[bool] = None + + +@app.delete("/api/push") +@app.delete("/api/push/{url_idx}") +async def push_model( + form_data: PushModelForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.name in app.state.MODELS: + url_idx = app.state.MODELS[form_data.name]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + + r = None + + def get_request(): + nonlocal url + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/push", + data=form_data.model_dump_json(exclude_none=True), + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class CreateModelForm(BaseModel): + name: str + modelfile: Optional[str] = None + stream: Optional[bool] = None + path: Optional[str] = None + + +@app.post("/api/create") +@app.post("/api/create/{url_idx}") +async def create_model( + form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) +): + print(form_data) + url = app.state.OLLAMA_BASE_URLS[url_idx] + r = None + + def get_request(): + nonlocal url + nonlocal r + try: + + def stream_content(): + for chunk in r.iter_content(chunk_size=8192): + yield chunk + + r = requests.request( + method="POST", + url=f"{url}/api/create", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + print(r) + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class CopyModelForm(BaseModel): + source: str + destination: str + + +@app.post("/api/copy") +@app.post("/api/copy/{url_idx}") +async def copy_model( + form_data: CopyModelForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.source in app.state.MODELS: + url_idx = app.state.MODELS[form_data.source]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + + try: + r = requests.request( + method="POST", + url=f"{url}/api/copy", + data=form_data.model_dump_json(exclude_none=True), + ) + r.raise_for_status() + + print(r.text) + + return True + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.delete("/api/delete") +@app.delete("/api/delete/{url_idx}") +async def delete_model( + form_data: ModelNameForm, + url_idx: Optional[int] = None, + user=Depends(get_admin_user), +): + if url_idx == None: + if form_data.name in app.state.MODELS: + url_idx = app.state.MODELS[form_data.name]["urls"][0] + else: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + + try: + r = requests.request( + method="DELETE", + url=f"{url}/api/delete", + data=form_data.model_dump_json(exclude_none=True), + ) + r.raise_for_status() + + print(r.text) + + return True + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +@app.post("/api/show") +async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)): + if form_data.name not in app.state.MODELS: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) + url = app.state.OLLAMA_BASE_URLS[url_idx] + + try: + r = requests.request( + method="POST", + url=f"{url}/api/show", + data=form_data.model_dump_json(exclude_none=True), + ) + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class GenerateEmbeddingsForm(BaseModel): + model: str + prompt: str + options: Optional[dict] = None + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/embeddings") +@app.post("/api/embeddings/{url_idx}") +async def generate_embeddings( + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True), + ) + r.raise_for_status() + + return r.json() + except Exception as e: + print(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class GenerateCompletionForm(BaseModel): + model: str + prompt: str + images: Optional[List[str]] = None + format: Optional[str] = None + options: Optional[dict] = None + system: Optional[str] = None + template: Optional[str] = None + context: Optional[str] = None + stream: Optional[bool] = True + raw: Optional[bool] = None + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/generate") +@app.post("/api/generate/{url_idx}") +async def generate_completion( + form_data: GenerateCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + + r = None + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/api/generate", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + +class ChatMessage(BaseModel): + role: str + content: str + images: Optional[List[str]] = None + + +class GenerateChatCompletionForm(BaseModel): + model: str + messages: List[ChatMessage] + format: Optional[str] = None + options: Optional[dict] = None + template: Optional[str] = None + stream: Optional[bool] = True + keep_alive: Optional[Union[int, str]] = None + + +@app.post("/api/chat") +@app.post("/api/chat/{url_idx}") +async def generate_completion( + form_data: GenerateChatCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail="error_detail", + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + + r = None + + print(form_data.model_dump_json(exclude_none=True)) + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps({"id": request_id, "done": False}) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/api/chat", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_current_user)): - target_url = f"{app.state.OLLAMA_BASE_URL}/{path}" + url = app.state.OLLAMA_BASE_URLS[0] + target_url = f"{url}/{path}" body = await request.body() headers = dict(request.headers) diff --git a/backend/main.py b/backend/main.py index 9c71091e..5f6b4441 100644 --- a/backend/main.py +++ b/backend/main.py @@ -125,6 +125,14 @@ async def get_app_config(): } +@app.get("/api/version") +async def get_app_config(): + + return { + "version": VERSION, + } + + @app.get("/api/changelog") async def get_app_changelog(): return CHANGELOG diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 0c96b2ab..5887b11f 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -1,9 +1,9 @@ import { OLLAMA_API_BASE_URL } from '$lib/constants'; -export const getOllamaAPIUrl = async (token: string = '') => { +export const getOllamaUrls = async (token: string = '') => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/url`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/urls`, { method: 'GET', headers: { Accept: 'application/json', @@ -29,13 +29,13 @@ export const getOllamaAPIUrl = async (token: string = '') => { throw error; } - return res.OLLAMA_BASE_URL; + return res.OLLAMA_BASE_URLS; }; -export const updateOllamaAPIUrl = async (token: string = '', url: string) => { +export const updateOllamaUrls = async (token: string = '', urls: string[]) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/url/update`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/urls/update`, { method: 'POST', headers: { Accept: 'application/json', @@ -43,7 +43,7 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => { ...(token && { authorization: `Bearer ${token}` }) }, body: JSON.stringify({ - url: url + urls: urls }) }) .then(async (res) => { @@ -64,7 +64,7 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => { throw error; } - return res.OLLAMA_BASE_URL; + return res.OLLAMA_BASE_URLS; }; export const getOllamaVersion = async (token: string = '') => { @@ -151,7 +151,8 @@ export const generateTitle = async ( const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -189,7 +190,8 @@ export const generatePrompt = async (token: string = '', model: string, conversa const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -223,7 +225,8 @@ export const generateTextCompletion = async (token: string = '', model: string, const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -251,7 +254,8 @@ export const generateChatCompletion = async (token: string = '', body: object) = signal: controller.signal, method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify(body) @@ -294,7 +298,8 @@ export const createModel = async (token: string, tagName: string, content: strin const res = await fetch(`${OLLAMA_API_BASE_URL}/api/create`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -319,7 +324,8 @@ export const deleteModel = async (token: string, tagName: string) => { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/delete`, { method: 'DELETE', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ @@ -336,7 +342,12 @@ export const deleteModel = async (token: string, tagName: string) => { }) .catch((err) => { console.log(err); - error = err.error; + error = err; + + if ('detail' in err) { + error = err.detail; + } + return null; }); @@ -353,7 +364,8 @@ export const pullModel = async (token: string, tagName: string) => { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream', + Accept: 'application/json', + 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ diff --git a/src/lib/components/chat/Settings/Connections.svelte b/src/lib/components/chat/Settings/Connections.svelte index 55c38358..ed9ddf4d 100644 --- a/src/lib/components/chat/Settings/Connections.svelte +++ b/src/lib/components/chat/Settings/Connections.svelte @@ -3,14 +3,15 @@ import { createEventDispatcher, onMount } from 'svelte'; const dispatch = createEventDispatcher(); - import { getOllamaAPIUrl, getOllamaVersion, updateOllamaAPIUrl } from '$lib/apis/ollama'; + import { getOllamaUrls, getOllamaVersion, updateOllamaUrls } from '$lib/apis/ollama'; import { getOpenAIKey, getOpenAIUrl, updateOpenAIKey, updateOpenAIUrl } from '$lib/apis/openai'; import { toast } from 'svelte-sonner'; export let getModels: Function; // External - let API_BASE_URL = ''; + let OLLAMA_BASE_URL = ''; + let OLLAMA_BASE_URLS = ['']; let OPENAI_API_KEY = ''; let OPENAI_API_BASE_URL = ''; @@ -25,8 +26,8 @@ await models.set(await getModels()); }; - const updateOllamaAPIUrlHandler = async () => { - API_BASE_URL = await updateOllamaAPIUrl(localStorage.token, API_BASE_URL); + const updateOllamaUrlsHandler = async () => { + OLLAMA_BASE_URLS = await updateOllamaUrls(localStorage.token, OLLAMA_BASE_URLS); const ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => { toast.error(error); @@ -41,7 +42,7 @@ onMount(async () => { if ($user.role === 'admin') { - API_BASE_URL = await getOllamaAPIUrl(localStorage.token); + OLLAMA_BASE_URLS = await getOllamaUrls(localStorage.token); OPENAI_API_BASE_URL = await getOpenAIUrl(localStorage.token); OPENAI_API_KEY = await getOpenAIKey(localStorage.token); } @@ -53,11 +54,6 @@ on:submit|preventDefault={() => { updateOpenAIHandler(); dispatch('save'); - - // saveSettings({ - // OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined, - // OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined - // }); }} >
@@ -115,34 +111,81 @@
Ollama Base URL
-
-
- +
+
+ {#each OLLAMA_BASE_URLS as url, idx} +
+ + +
+ {#if idx === 0} + + {:else} + + {/if} +
+
+ {/each}
- + + + + +
From 8fd7bc342d142ff673e96bdf6bc87e43f591c0b8 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 5 Mar 2024 01:07:59 -0800 Subject: [PATCH 2/5] refac: error message --- backend/apps/ollama/main.py | 12 ++++++------ backend/constants.py | 2 ++ src/routes/(app)/modelfiles/+page.svelte | 5 ++++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 5b3589e0..5bd48462 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -267,7 +267,7 @@ async def push_model( else: raise HTTPException( status_code=400, - detail="error_detail", + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) url = app.state.OLLAMA_BASE_URLS[url_idx] @@ -399,7 +399,7 @@ async def copy_model( else: raise HTTPException( status_code=400, - detail="error_detail", + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) url = app.state.OLLAMA_BASE_URLS[url_idx] @@ -445,7 +445,7 @@ async def delete_model( else: raise HTTPException( status_code=400, - detail="error_detail", + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) url = app.state.OLLAMA_BASE_URLS[url_idx] @@ -483,7 +483,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use if form_data.name not in app.state.MODELS: raise HTTPException( status_code=400, - detail="error_detail", + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) @@ -535,7 +535,7 @@ async def generate_embeddings( else: raise HTTPException( status_code=400, - detail="error_detail", + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) url = app.state.OLLAMA_BASE_URLS[url_idx] @@ -691,7 +691,7 @@ async def generate_completion( else: raise HTTPException( status_code=400, - detail="error_detail", + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) url = app.state.OLLAMA_BASE_URLS[url_idx] diff --git a/backend/constants.py b/backend/constants.py index 006fa7bb..b2bbe9aa 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -48,3 +48,5 @@ class ERROR_MESSAGES(str, Enum): lambda err="": f"Invalid format. Please use the correct format{err if err else ''}" ) RATE_LIMIT_EXCEEDED = "API rate limit exceeded" + + MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" diff --git a/src/routes/(app)/modelfiles/+page.svelte b/src/routes/(app)/modelfiles/+page.svelte index e3ee91a9..a80b6014 100644 --- a/src/routes/(app)/modelfiles/+page.svelte +++ b/src/routes/(app)/modelfiles/+page.svelte @@ -20,7 +20,10 @@ const deleteModelHandler = async (tagName) => { let success = null; - success = await deleteModel(localStorage.token, tagName); + success = await deleteModel(localStorage.token, tagName).catch((err) => { + toast.error(err); + return null; + }); if (success) { toast.success(`Deleted ${tagName}`); From b2dd2f191dbecd6b77db2348f6d13f36e48b771e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 5 Mar 2024 01:21:50 -0800 Subject: [PATCH 3/5] refac: comment --- backend/apps/ollama/main.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 5bd48462..fbaf622b 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -37,6 +37,11 @@ app.state.MODELS = {} REQUEST_POOL = [] +# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. +# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, +# least connections, or least response time for better resource utilization and performance optimization. + + @app.middleware("http") async def check_url(request: Request, call_next): if len(app.state.MODELS) == 0: @@ -761,7 +766,7 @@ async def generate_completion( @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_current_user)): +async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): url = app.state.OLLAMA_BASE_URLS[0] target_url = f"{url}/{path}" From 8d34324d1211715d94338ade5d7694662ade67be Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 5 Mar 2024 01:41:22 -0800 Subject: [PATCH 4/5] feat: cancel request Resolves #1006 --- backend/apps/ollama/main.py | 102 ++++++++++++++++++++++- src/routes/(app)/playground/+page.svelte | 87 ++++++------------- 2 files changed, 126 insertions(+), 63 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index fbaf622b..988d4bf4 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict import random import requests @@ -684,7 +684,7 @@ class GenerateChatCompletionForm(BaseModel): @app.post("/api/chat") @app.post("/api/chat/{url_idx}") -async def generate_completion( +async def generate_chat_completion( form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, user=Depends(get_current_user), @@ -765,6 +765,104 @@ async def generate_completion( ) +# TODO: we should update this part once Ollama supports other types +class OpenAIChatMessage(BaseModel): + role: str + content: str + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatCompletionForm(BaseModel): + model: str + messages: List[OpenAIChatMessage] + + model_config = ConfigDict(extra="allow") + + +@app.post("/v1/chat/completions") +@app.post("/v1/chat/completions/{url_idx}") +async def generate_openai_chat_completion( + form_data: OpenAIChatCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + + r = None + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps( + {"request_id": request_id, "done": False} + ) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/v1/chat/completions", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): url = app.state.OLLAMA_BASE_URLS[0] diff --git a/src/routes/(app)/playground/+page.svelte b/src/routes/(app)/playground/+page.svelte index 81c3032b..36f4a2bc 100644 --- a/src/routes/(app)/playground/+page.svelte +++ b/src/routes/(app)/playground/+page.svelte @@ -26,7 +26,7 @@ let selectedModelId = ''; let loading = false; - let currentRequestId; + let currentRequestId = null; let stopResponseFlag = false; let messagesContainerElement: HTMLDivElement; @@ -92,6 +92,10 @@ while (true) { const { value, done } = await reader.read(); if (done || stopResponseFlag) { + if (stopResponseFlag) { + await cancelChatCompletion(localStorage.token, currentRequestId); + } + currentRequestId = null; break; } @@ -108,7 +112,11 @@ let data = JSON.parse(line.replace(/^data: /, '')); console.log(data); - text += data.choices[0].delta.content ?? ''; + if ('request_id' in data) { + currentRequestId = data.request_id; + } else { + text += data.choices[0].delta.content ?? ''; + } } } } @@ -146,16 +154,6 @@ : `${OLLAMA_API_BASE_URL}/v1` ); - // const [res, controller] = await generateChatCompletion(localStorage.token, { - // model: selectedModelId, - // messages: [ - // { - // role: 'assistant', - // content: text - // } - // ] - // }); - let responseMessage; if (messages.at(-1)?.role === 'assistant') { responseMessage = messages.at(-1); @@ -180,6 +178,11 @@ while (true) { const { value, done } = await reader.read(); if (done || stopResponseFlag) { + if (stopResponseFlag) { + await cancelChatCompletion(localStorage.token, currentRequestId); + } + + currentRequestId = null; break; } @@ -196,17 +199,21 @@ let data = JSON.parse(line.replace(/^data: /, '')); console.log(data); - if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { - continue; + if ('request_id' in data) { + currentRequestId = data.request_id; } else { - textareaElement.style.height = textareaElement.scrollHeight + 'px'; + if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { + continue; + } else { + textareaElement.style.height = textareaElement.scrollHeight + 'px'; - responseMessage.content += data.choices[0].delta.content ?? ''; - messages = messages; + responseMessage.content += data.choices[0].delta.content ?? ''; + messages = messages; - textareaElement.style.height = textareaElement.scrollHeight + 'px'; + textareaElement.style.height = textareaElement.scrollHeight + 'px'; - await tick(); + await tick(); + } } } } @@ -217,48 +224,6 @@ scrollToBottom(); } - - // while (true) { - // const { value, done } = await reader.read(); - // if (done || stopResponseFlag) { - // if (stopResponseFlag) { - // await cancelChatCompletion(localStorage.token, currentRequestId); - // } - - // currentRequestId = null; - // break; - // } - - // try { - // let lines = value.split('\n'); - - // for (const line of lines) { - // if (line !== '') { - // console.log(line); - // let data = JSON.parse(line); - - // if ('detail' in data) { - // throw data; - // } - - // if ('id' in data) { - // console.log(data); - // currentRequestId = data.id; - // } else { - // if (data.done == false) { - // text += data.message.content; - // } else { - // console.log('done'); - // } - // } - // } - // } - // } catch (error) { - // console.log(error); - // } - - // scrollToBottom(); - // } } }; From 8626ae60dd224d86bc776ec625eca46db99ed716 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 5 Mar 2024 01:43:33 -0800 Subject: [PATCH 5/5] chore: old_main removed --- backend/apps/ollama/old_main.py | 127 -------------------------------- 1 file changed, 127 deletions(-) delete mode 100644 backend/apps/ollama/old_main.py diff --git a/backend/apps/ollama/old_main.py b/backend/apps/ollama/old_main.py deleted file mode 100644 index 5e5b8811..00000000 --- a/backend/apps/ollama/old_main.py +++ /dev/null @@ -1,127 +0,0 @@ -from fastapi import FastAPI, Request, Response, HTTPException, Depends -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse - -import requests -import json -from pydantic import BaseModel - -from apps.web.models.users import Users -from constants import ERROR_MESSAGES -from utils.utils import decode_token, get_current_user -from config import OLLAMA_API_BASE_URL, WEBUI_AUTH - -import aiohttp - -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL - -# TARGET_SERVER_URL = OLLAMA_API_BASE_URL - - -@app.get("/url") -async def get_ollama_api_url(user=Depends(get_current_user)): - if user and user.role == "admin": - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - -class UrlUpdateForm(BaseModel): - url: str - - -@app.post("/url/update") -async def update_ollama_api_url( - form_data: UrlUpdateForm, user=Depends(get_current_user) -): - if user and user.role == "admin": - app.state.OLLAMA_API_BASE_URL = form_data.url - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - -# async def fetch_sse(method, target_url, body, headers): -# async with aiohttp.ClientSession() as session: -# try: -# async with session.request( -# method, target_url, data=body, headers=headers -# ) as response: -# print(response.status) -# async for line in response.content: -# yield line -# except Exception as e: -# print(e) -# error_detail = "Open WebUI: Server Connection Error" -# yield json.dumps({"error": error_detail, "message": str(e)}).encode() - - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_current_user)): - target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" - print(target_url) - - body = await request.body() - headers = dict(request.headers) - - if user.role in ["user", "admin"]: - if path in ["pull", "delete", "push", "copy", "create"]: - if user.role != "admin": - raise HTTPException( - status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) - - headers.pop("Host", None) - headers.pop("Authorization", None) - headers.pop("Origin", None) - headers.pop("Referer", None) - - session = aiohttp.ClientSession() - response = None - try: - response = await session.request( - request.method, target_url, data=body, headers=headers - ) - - print(response) - if not response.ok: - data = await response.json() - print(data) - response.raise_for_status() - - async def generate(): - async for line in response.content: - print(line) - yield line - await session.close() - - return StreamingResponse(generate(), response.status) - - except Exception as e: - print(e) - error_detail = "Open WebUI: Server Connection Error" - - if response is not None: - try: - res = await response.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - except: - error_detail = f"Ollama: {e}" - - await session.close() - raise HTTPException( - status_code=response.status if response else 500, - detail=error_detail, - )