Merge pull request #1070 from open-webui/multi-openai

feat: multiple openai apis
This commit is contained in:
Timothy Jaeryang Baek 2024-03-06 19:14:03 -05:00 committed by GitHub
commit e17e2e7c17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 289 additions and 142 deletions

View file

@ -3,7 +3,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
import requests import requests
import aiohttp
import asyncio
import json import json
from pydantic import BaseModel from pydantic import BaseModel
@ -15,7 +18,9 @@ from utils.utils import (
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
from typing import List, Optional
import hashlib import hashlib
from pathlib import Path from pathlib import Path
@ -29,49 +34,59 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEY = OPENAI_API_KEY app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {}
class UrlUpdateForm(BaseModel): @app.middleware("http")
url: str async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
response = await call_next(request)
return response
class KeyUpdateForm(BaseModel): class UrlsUpdateForm(BaseModel):
key: str urls: List[str]
@app.get("/url") class KeysUpdateForm(BaseModel):
async def get_openai_url(user=Depends(get_admin_user)): keys: List[str]
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
@app.post("/url/update") @app.get("/urls")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): async def get_openai_urls(user=Depends(get_admin_user)):
app.state.OPENAI_API_BASE_URL = form_data.url return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
@app.get("/key") @app.post("/urls/update")
async def get_openai_key(user=Depends(get_admin_user)): async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} app.state.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
@app.post("/key/update") @app.get("/keys")
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)): async def get_openai_keys(user=Depends(get_admin_user)):
app.state.OPENAI_API_KEY = form_data.key return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
@app.post("/audio/speech") @app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" idx = None
try:
if app.state.OPENAI_API_KEY == "": idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body() body = await request.body()
name = hashlib.sha256(body).hexdigest() name = hashlib.sha256(body).hexdigest()
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
@ -84,13 +99,12 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
try: try:
print("openai")
r = requests.post( r = requests.post(
url=target_url, url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
@ -122,23 +136,106 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=r.status_code, detail=error_detail) raise HTTPException(status_code=r.status_code, detail=error_detail)
except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def fetch_url(url, key):
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
return None
def merge_models_lists(model_lists):
merged_list = []
for idx, models in enumerate(model_lists):
merged_list.extend(
[
{**model, "urlIdx": idx}
for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"]
]
)
return merged_list
async def get_all_models():
print("get_all_models")
tasks = [
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
models = {
"data": merge_models_lists(
list(map(lambda response: response["data"], responses))
)
}
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
# , user=Depends(get_current_user)
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None):
if url_idx == None:
return await get_all_models()
else:
url = app.state.OPENAI_API_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
response_data = r.json()
if "api.openai.com" in url:
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
)
return response_data
except Exception as e:
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)): async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" idx = 0
print(target_url, app.state.OPENAI_API_KEY)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body() body = await request.body()
# TODO: Remove below after gpt-4-vision fix from Open AI # TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
try: try:
body = body.decode("utf-8") body = body.decode("utf-8")
body = json.loads(body) body = json.loads(body)
idx = app.state.MODELS[body.get("model")]["urlIdx"]
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model # This is a workaround until OpenAI fixes the issue with this model
if body.get("model") == "gpt-4-vision-preview": if body.get("model") == "gpt-4-vision-preview":
@ -158,8 +255,16 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print("Error loading request body into a dictionary:", e) print("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
if key == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
try: try:
@ -181,21 +286,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers=dict(r.headers), headers=dict(r.headers),
) )
else: else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )
response_data = r.json() response_data = r.json()
if "api.openai.com" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
)
return response_data return response_data
except Exception as e: except Exception as e:
print(e) print(e)

View file

@ -237,6 +237,19 @@ OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "": if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1" OPENAI_API_BASE_URL = "https://api.openai.com/v1"
OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
OPENAI_API_BASE_URLS = (
OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
)
OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URL.split(";")]
#################################### ####################################
# WEBUI # WEBUI

View file

@ -41,6 +41,7 @@ class ERROR_MESSAGES(str, Enum):
NOT_FOUND = "We could not find what you're looking for :/" NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes." MALICIOUS = "Unusual activities detected, please try again in a few minutes."
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
@ -50,3 +51,4 @@ class ERROR_MESSAGES(str, Enum):
RATE_LIMIT_EXCEEDED = "API rate limit exceeded" RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found"

View file

@ -1,9 +1,9 @@
import { OPENAI_API_BASE_URL } from '$lib/constants'; import { OPENAI_API_BASE_URL } from '$lib/constants';
export const getOpenAIUrl = async (token: string = '') => { export const getOpenAIUrls = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/url`, { const res = await fetch(`${OPENAI_API_BASE_URL}/urls`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -29,13 +29,13 @@ export const getOpenAIUrl = async (token: string = '') => {
throw error; throw error;
} }
return res.OPENAI_API_BASE_URL; return res.OPENAI_API_BASE_URLS;
}; };
export const updateOpenAIUrl = async (token: string = '', url: string) => { export const updateOpenAIUrls = async (token: string = '', urls: string[]) => {
let error = null; let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/url/update`, { const res = await fetch(`${OPENAI_API_BASE_URL}/urls/update`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -43,7 +43,7 @@ export const updateOpenAIUrl = async (token: string = '', url: string) => {
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
}, },
body: JSON.stringify({ body: JSON.stringify({
url: url urls: urls
}) })
}) })
.then(async (res) => { .then(async (res) => {
@ -64,13 +64,13 @@ export const updateOpenAIUrl = async (token: string = '', url: string) => {
throw error; throw error;
} }
return res.OPENAI_API_BASE_URL; return res.OPENAI_API_BASE_URLS;
}; };
export const getOpenAIKey = async (token: string = '') => { export const getOpenAIKeys = async (token: string = '') => {
let error = null; let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/key`, { const res = await fetch(`${OPENAI_API_BASE_URL}/keys`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -96,13 +96,13 @@ export const getOpenAIKey = async (token: string = '') => {
throw error; throw error;
} }
return res.OPENAI_API_KEY; return res.OPENAI_API_KEYS;
}; };
export const updateOpenAIKey = async (token: string = '', key: string) => { export const updateOpenAIKeys = async (token: string = '', keys: string[]) => {
let error = null; let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/key/update`, { const res = await fetch(`${OPENAI_API_BASE_URL}/keys/update`, {
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -110,7 +110,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
}, },
body: JSON.stringify({ body: JSON.stringify({
key: key keys: keys
}) })
}) })
.then(async (res) => { .then(async (res) => {
@ -131,7 +131,7 @@ export const updateOpenAIKey = async (token: string = '', key: string) => {
throw error; throw error;
} }
return res.OPENAI_API_KEY; return res.OPENAI_API_KEYS;
}; };
export const getOpenAIModels = async (token: string = '') => { export const getOpenAIModels = async (token: string = '') => {

View file

@ -4,7 +4,12 @@
const dispatch = createEventDispatcher(); const dispatch = createEventDispatcher();
import { getOllamaUrls, getOllamaVersion, updateOllamaUrls } from '$lib/apis/ollama'; import { getOllamaUrls, getOllamaVersion, updateOllamaUrls } from '$lib/apis/ollama';
import { getOpenAIKey, getOpenAIUrl, updateOpenAIKey, updateOpenAIUrl } from '$lib/apis/openai'; import {
getOpenAIKeys,
getOpenAIUrls,
updateOpenAIKeys,
updateOpenAIUrls
} from '$lib/apis/openai';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
export let getModels: Function; export let getModels: Function;
@ -16,12 +21,14 @@
let OPENAI_API_KEY = ''; let OPENAI_API_KEY = '';
let OPENAI_API_BASE_URL = ''; let OPENAI_API_BASE_URL = '';
let OPENAI_API_KEYS = [''];
let OPENAI_API_BASE_URLS = [''];
let showOpenAI = false; let showOpenAI = false;
let showLiteLLM = false;
const updateOpenAIHandler = async () => { const updateOpenAIHandler = async () => {
OPENAI_API_BASE_URL = await updateOpenAIUrl(localStorage.token, OPENAI_API_BASE_URL); OPENAI_API_BASE_URLS = await updateOpenAIUrls(localStorage.token, OPENAI_API_BASE_URLS);
OPENAI_API_KEY = await updateOpenAIKey(localStorage.token, OPENAI_API_KEY); OPENAI_API_KEYS = await updateOpenAIKeys(localStorage.token, OPENAI_API_KEYS);
await models.set(await getModels()); await models.set(await getModels());
}; };
@ -43,8 +50,8 @@
onMount(async () => { onMount(async () => {
if ($user.role === 'admin') { if ($user.role === 'admin') {
OLLAMA_BASE_URLS = await getOllamaUrls(localStorage.token); OLLAMA_BASE_URLS = await getOllamaUrls(localStorage.token);
OPENAI_API_BASE_URL = await getOpenAIUrl(localStorage.token); OPENAI_API_BASE_URLS = await getOpenAIUrls(localStorage.token);
OPENAI_API_KEY = await getOpenAIKey(localStorage.token); OPENAI_API_KEYS = await getOpenAIKeys(localStorage.token);
} }
}); });
</script> </script>
@ -71,38 +78,75 @@
</div> </div>
{#if showOpenAI} {#if showOpenAI}
<div> <div class="flex flex-col gap-1">
<div class=" mb-2.5 text-sm font-medium">API Key</div> {#each OPENAI_API_BASE_URLS as url, idx}
<div class="flex w-full"> <div class="flex w-full gap-2">
<div class="flex-1"> <div class="flex-1">
<input <input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none" class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder="Enter OpenAI API Key" placeholder="API Base URL"
bind:value={OPENAI_API_KEY} bind:value={url}
autocomplete="off" autocomplete="off"
/> />
</div> </div>
</div>
</div>
<div>
<div class=" mb-2.5 text-sm font-medium">API Base URL</div>
<div class="flex w-full">
<div class="flex-1"> <div class="flex-1">
<input <input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none" class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder="Enter OpenAI API Base URL" placeholder="API Key"
bind:value={OPENAI_API_BASE_URL} bind:value={OPENAI_API_KEYS[idx]}
autocomplete="off" autocomplete="off"
/> />
</div> </div>
</div> <div class="self-center flex items-center">
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500"> {#if idx === 0}
WebUI will make requests to <span class=" text-gray-200" <button
>'{OPENAI_API_BASE_URL}/chat'</span class="px-1"
on:click={() => {
OPENAI_API_BASE_URLS = [...OPENAI_API_BASE_URLS, ''];
OPENAI_API_KEYS = [...OPENAI_API_KEYS, ''];
}}
type="button"
> >
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
/>
</svg>
</button>
{:else}
<button
class="px-1"
on:click={() => {
OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS.filter(
(url, urlIdx) => idx !== urlIdx
);
OPENAI_API_KEYS = OPENAI_API_KEYS.filter((key, keyIdx) => idx !== keyIdx);
}}
type="button"
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path d="M3.75 7.25a.75.75 0 0 0 0 1.5h8.5a.75.75 0 0 0 0-1.5h-8.5Z" />
</svg>
</button>
{/if}
</div> </div>
</div> </div>
<div class=" mb-1 text-xs text-gray-400 dark:text-gray-500">
WebUI will make requests to <span class=" text-gray-200">'{url}/models'</span>
</div>
{/each}
</div>
{/if} {/if}
</div> </div>
</div> </div>

View file

@ -97,14 +97,11 @@
if (localDBChats.length === 0) { if (localDBChats.length === 0) {
await deleteDB('Chats'); await deleteDB('Chats');
} }
console.log('localdb', localDBChats);
} }
console.log(DB); console.log(DB);
} catch (error) { } catch (error) {
// IndexedDB Not Found // IndexedDB Not Found
console.log('IDB Not Found');
} }
console.log(); console.log();