From 17c66fde0f84b1859f27eee93f08ce6d48b7ff2e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jan 2024 18:38:03 -0800 Subject: [PATCH] feat: openai compatible api support --- backend/apps/openai/main.py | 47 +++-- backend/config.py | 5 +- backend/constants.py | 1 + backend/main.py | 4 +- example.env | 12 +- src/lib/apis/openai/index.ts | 174 ++++++++++++++++++- src/lib/components/chat/SettingsModal.svelte | 73 ++++---- src/lib/constants.ts | 7 +- src/routes/(app)/+layout.svelte | 22 ++- 9 files changed, 260 insertions(+), 85 deletions(-) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 0a12137a..c6d06a9e 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, Request, Response, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse import requests import json @@ -69,18 +69,18 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_current_user)): target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" - - body = await request.body() - headers = dict(request.headers) + print(target_url, app.state.OPENAI_API_KEY) if user.role not in ["user", "admin"]: raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + if app.state.OPENAI_API_KEY == "": + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - headers.pop("Host", None) - headers.pop("Authorization", None) - headers.pop("Origin", None) - headers.pop("Referer", None) + body = await request.body() + # headers = dict(request.headers) + # print(headers) + headers = {} headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" try: @@ -94,11 +94,32 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): r.raise_for_status() - return StreamingResponse( - r.iter_content(chunk_size=8192), - status_code=r.status_code, - headers=dict(r.headers), - ) + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + else: + # For non-SSE, read the response and return it + # response_data = ( + # r.json() + # if r.headers.get("Content-Type", "") + # == "application/json" + # else r.text + # ) + + response_data = r.json() + + print(type(response_data)) + + if "openai" in app.state.OPENAI_API_BASE_URL and path == "models": + response_data["data"] = list( + filter(lambda model: "gpt" in model["id"], response_data["data"]) + ) + + return response_data except Exception as e: print(e) error_detail = "Ollama WebUI: Server Connection Error" diff --git a/backend/config.py b/backend/config.py index 90900b9d..4e0f9e97 100644 --- a/backend/config.py +++ b/backend/config.py @@ -33,7 +33,10 @@ if ENV == "prod": #################################### OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") -OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "https://api.openai.com/v1") +OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") + +if OPENAI_API_BASE_URL == "": + OPENAI_API_BASE_URL = "https://api.openai.com/v1" #################################### # WEBUI_VERSION diff --git a/backend/constants.py b/backend/constants.py index 0817445b..e51ecdda 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -33,4 +33,5 @@ class ERROR_MESSAGES(str, Enum): ) NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" + API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." MALICIOUS = "Unusual activities detected, please try again in a few minutes." diff --git a/backend/main.py b/backend/main.py index 5e3b7e83..a97e6e85 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,6 +6,8 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException from apps.ollama.main import app as ollama_app +from apps.openai.main import app as openai_app + from apps.web.main import app as webui_app import time @@ -46,7 +48,7 @@ async def check_url(request: Request, call_next): app.mount("/api/v1", webui_app) -# app.mount("/ollama/api", WSGIMiddleware(ollama_app)) app.mount("/ollama/api", ollama_app) +app.mount("/openai/api", openai_app) app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") diff --git a/example.env b/example.env index 9c628b42..74d52223 100644 --- a/example.env +++ b/example.env @@ -1,12 +1,6 @@ -# If you're serving both the frontend and backend (Recommended) -# Set the public API base URL for seamless communication -PUBLIC_API_BASE_URL='/ollama/api' - -# If you're serving only the frontend (Not recommended and not fully supported) -# Comment above and Uncomment below -# You can use the default value or specify a custom path, e.g., '/api' -# PUBLIC_API_BASE_URL='http://{location.hostname}:11434/api' - # Ollama URL for the backend to connect # The path '/ollama/api' will be redirected to the specified backend URL OLLAMA_API_BASE_URL='http://localhost:11434/api' + +OPENAI_API_BASE_URL='' +OPENAI_API_KEY='' \ No newline at end of file diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index c144ae89..c1135fee 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -1,4 +1,176 @@ -export const getOpenAIModels = async ( +import { OPENAI_API_BASE_URL } from '$lib/constants'; + +export const getOpenAIUrl = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/url`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_BASE_URL; +}; + +export const updateOpenAIUrl = async (token: string = '', url: string) => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/url/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + url: url + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_BASE_URL; +}; + +export const getOpenAIKey = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/key`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; +}; + +export const updateOpenAIKey = async (token: string = '', key: string) => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/key/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + key: key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; +}; + +export const getOpenAIModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + const models = Array.isArray(res) ? res : res?.data ?? null; + + return models + ? models + .map((model) => ({ name: model.id, external: true })) + .sort((a, b) => { + return a.name.localeCompare(b.name); + }) + : models; +}; + +export const getOpenAIModelsDirect = async ( base_url: string = 'https://api.openai.com/v1', api_key: string = '' ) => { diff --git a/src/lib/components/chat/SettingsModal.svelte b/src/lib/components/chat/SettingsModal.svelte index 174ffc17..edb3e840 100644 --- a/src/lib/components/chat/SettingsModal.svelte +++ b/src/lib/components/chat/SettingsModal.svelte @@ -24,6 +24,13 @@ import { updateUserPassword } from '$lib/apis/auths'; import { goto } from '$app/navigation'; import Page from '../../../routes/(app)/+page.svelte'; + import { + getOpenAIKey, + getOpenAIModels, + getOpenAIUrl, + updateOpenAIKey, + updateOpenAIUrl + } from '$lib/apis/openai'; export let show = false; @@ -153,6 +160,13 @@ } }; + const updateOpenAIHandler = async () => { + OPENAI_API_BASE_URL = await updateOpenAIUrl(localStorage.token, OPENAI_API_BASE_URL); + OPENAI_API_KEY = await updateOpenAIKey(localStorage.token, OPENAI_API_KEY); + + await models.set(await getModels()); + }; + const toggleTheme = async () => { if (theme === 'dark') { theme = 'light'; @@ -484,7 +498,7 @@ }; const getModels = async (type = 'all') => { - let models = []; + const models = []; models.push( ...(await getOllamaModels(localStorage.token).catch((error) => { toast.error(error); @@ -493,43 +507,13 @@ ); // If OpenAI API Key exists - if (type === 'all' && $settings.OPENAI_API_KEY) { - const OPENAI_API_BASE_URL = $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1'; + if (type === 'all' && OPENAI_API_KEY) { + const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => { + console.log(error); + return null; + }); - // Validate OPENAI_API_KEY - const openaiModelRes = await fetch(`${OPENAI_API_BASE_URL}/models`, { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${$settings.OPENAI_API_KEY}` - } - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((error) => { - console.log(error); - toast.error(`OpenAI: ${error?.error?.message ?? 'Network Problem'}`); - return null; - }); - - const openAIModels = Array.isArray(openaiModelRes) - ? openaiModelRes - : openaiModelRes?.data ?? null; - - models.push( - ...(openAIModels - ? [ - { name: 'hr' }, - ...openAIModels - .map((model) => ({ name: model.id, external: true })) - .filter((model) => - OPENAI_API_BASE_URL.includes('openai') ? model.name.includes('gpt') : true - ) - ] - : []) - ); + models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : [])); } return models; @@ -564,6 +548,8 @@ console.log('settings', $user.role === 'admin'); if ($user.role === 'admin') { API_BASE_URL = await getOllamaAPIUrl(localStorage.token); + OPENAI_API_BASE_URL = await getOpenAIUrl(localStorage.token); + OPENAI_API_KEY = await getOpenAIKey(localStorage.token); } let settings = JSON.parse(localStorage.getItem('settings') ?? '{}'); @@ -584,9 +570,6 @@ options = { ...options, ...settings.options }; options.stop = (settings?.options?.stop ?? []).join(','); - OPENAI_API_KEY = settings.OPENAI_API_KEY ?? ''; - OPENAI_API_BASE_URL = settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1'; - titleAutoGenerate = settings.titleAutoGenerate ?? true; speechAutoSend = settings.speechAutoSend ?? false; responseAutoCopy = settings.responseAutoCopy ?? false; @@ -1415,10 +1398,12 @@
{ - saveSettings({ - OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined, - OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined - }); + updateOpenAIHandler(); + + // saveSettings({ + // OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined, + // OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined + // }); show = false; }} > diff --git a/src/lib/constants.ts b/src/lib/constants.ts index c22ae207..27744197 100644 --- a/src/lib/constants.ts +++ b/src/lib/constants.ts @@ -1,11 +1,10 @@ import { dev } from '$app/environment'; -export const OLLAMA_API_BASE_URL = dev - ? `http://${location.hostname}:8080/ollama/api` - : '/ollama/api'; - export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``; + export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`; +export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`; +export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`; export const WEB_UI_VERSION = 'v1.0.0-alpha-static'; diff --git a/src/routes/(app)/+layout.svelte b/src/routes/(app)/+layout.svelte index 013638cb..c264592e 100644 --- a/src/routes/(app)/+layout.svelte +++ b/src/routes/(app)/+layout.svelte @@ -37,19 +37,17 @@ return []; })) ); - // If OpenAI API Key exists - if ($settings.OPENAI_API_KEY) { - const openAIModels = await getOpenAIModels( - $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1', - $settings.OPENAI_API_KEY - ).catch((error) => { - console.log(error); - toast.error(error); - return null; - }); - models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : [])); - } + // $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1', + // $settings.OPENAI_API_KEY + + const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => { + console.log(error); + return null; + }); + + models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : [])); + return models; };