feat: openai compatible api support

This commit is contained in:
Timothy J. Baek 2024-01-04 18:38:03 -08:00
parent 5e4dc98f44
commit 17c66fde0f
9 changed files with 260 additions and 85 deletions

View file

@ -1,6 +1,6 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
import requests import requests
import json import json
@ -69,18 +69,18 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)): async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY)
body = await request.body()
headers = dict(request.headers)
if user.role not in ["user", "admin"]: if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
headers.pop("Host", None) body = await request.body()
headers.pop("Authorization", None) # headers = dict(request.headers)
headers.pop("Origin", None) # print(headers)
headers.pop("Referer", None)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
try: try:
@ -94,11 +94,32 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
r.raise_for_status() r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse( return StreamingResponse(
r.iter_content(chunk_size=8192), r.iter_content(chunk_size=8192),
status_code=r.status_code, status_code=r.status_code,
headers=dict(r.headers), headers=dict(r.headers),
) )
else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )
response_data = r.json()
print(type(response_data))
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
)
return response_data
except Exception as e: except Exception as e:
print(e) print(e)
error_detail = "Ollama WebUI: Server Connection Error" error_detail = "Ollama WebUI: Server Connection Error"

View file

@ -33,7 +33,10 @@ if ENV == "prod":
#################################### ####################################
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "https://api.openai.com/v1") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
#################################### ####################################
# WEBUI_VERSION # WEBUI_VERSION

View file

@ -33,4 +33,5 @@ class ERROR_MESSAGES(str, Enum):
) )
NOT_FOUND = "We could not find what you're looking for :/" NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes." MALICIOUS = "Unusual activities detected, please try again in a few minutes."

View file

@ -6,6 +6,8 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
import time import time
@ -46,7 +48,7 @@ async def check_url(request: Request, call_next):
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
# app.mount("/ollama/api", WSGIMiddleware(ollama_app))
app.mount("/ollama/api", ollama_app) app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app)
app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files")

View file

@ -1,12 +1,6 @@
# If you're serving both the frontend and backend (Recommended)
# Set the public API base URL for seamless communication
PUBLIC_API_BASE_URL='/ollama/api'
# If you're serving only the frontend (Not recommended and not fully supported)
# Comment above and Uncomment below
# You can use the default value or specify a custom path, e.g., '/api'
# PUBLIC_API_BASE_URL='http://{location.hostname}:11434/api'
# Ollama URL for the backend to connect # Ollama URL for the backend to connect
# The path '/ollama/api' will be redirected to the specified backend URL # The path '/ollama/api' will be redirected to the specified backend URL
OLLAMA_API_BASE_URL='http://localhost:11434/api' OLLAMA_API_BASE_URL='http://localhost:11434/api'
OPENAI_API_BASE_URL=''
OPENAI_API_KEY=''

View file

@ -1,4 +1,176 @@
export const getOpenAIModels = async ( import { OPENAI_API_BASE_URL } from '$lib/constants';
export const getOpenAIUrl = async (token: string = '') => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/url`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_BASE_URL;
};
export const updateOpenAIUrl = async (token: string = '', url: string) => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/url/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
url: url
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_BASE_URL;
};
export const getOpenAIKey = async (token: string = '') => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/key`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_KEY;
};
export const updateOpenAIKey = async (token: string = '', key: string) => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/key/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
key: key
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_KEY;
};
export const getOpenAIModels = async (token: string = '') => {
let error = null;
const res = await fetch(`${OPENAI_API_BASE_URL}/models`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`;
return [];
});
if (error) {
throw error;
}
const models = Array.isArray(res) ? res : res?.data ?? null;
return models
? models
.map((model) => ({ name: model.id, external: true }))
.sort((a, b) => {
return a.name.localeCompare(b.name);
})
: models;
};
export const getOpenAIModelsDirect = async (
base_url: string = 'https://api.openai.com/v1', base_url: string = 'https://api.openai.com/v1',
api_key: string = '' api_key: string = ''
) => { ) => {

View file

@ -24,6 +24,13 @@
import { updateUserPassword } from '$lib/apis/auths'; import { updateUserPassword } from '$lib/apis/auths';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import Page from '../../../routes/(app)/+page.svelte'; import Page from '../../../routes/(app)/+page.svelte';
import {
getOpenAIKey,
getOpenAIModels,
getOpenAIUrl,
updateOpenAIKey,
updateOpenAIUrl
} from '$lib/apis/openai';
export let show = false; export let show = false;
@ -153,6 +160,13 @@
} }
}; };
const updateOpenAIHandler = async () => {
OPENAI_API_BASE_URL = await updateOpenAIUrl(localStorage.token, OPENAI_API_BASE_URL);
OPENAI_API_KEY = await updateOpenAIKey(localStorage.token, OPENAI_API_KEY);
await models.set(await getModels());
};
const toggleTheme = async () => { const toggleTheme = async () => {
if (theme === 'dark') { if (theme === 'dark') {
theme = 'light'; theme = 'light';
@ -484,7 +498,7 @@
}; };
const getModels = async (type = 'all') => { const getModels = async (type = 'all') => {
let models = []; const models = [];
models.push( models.push(
...(await getOllamaModels(localStorage.token).catch((error) => { ...(await getOllamaModels(localStorage.token).catch((error) => {
toast.error(error); toast.error(error);
@ -493,43 +507,13 @@
); );
// If OpenAI API Key exists // If OpenAI API Key exists
if (type === 'all' && $settings.OPENAI_API_KEY) { if (type === 'all' && OPENAI_API_KEY) {
const OPENAI_API_BASE_URL = $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1'; const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => {
// Validate OPENAI_API_KEY
const openaiModelRes = await fetch(`${OPENAI_API_BASE_URL}/models`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${$settings.OPENAI_API_KEY}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((error) => {
console.log(error); console.log(error);
toast.error(`OpenAI: ${error?.error?.message ?? 'Network Problem'}`);
return null; return null;
}); });
const openAIModels = Array.isArray(openaiModelRes) models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
? openaiModelRes
: openaiModelRes?.data ?? null;
models.push(
...(openAIModels
? [
{ name: 'hr' },
...openAIModels
.map((model) => ({ name: model.id, external: true }))
.filter((model) =>
OPENAI_API_BASE_URL.includes('openai') ? model.name.includes('gpt') : true
)
]
: [])
);
} }
return models; return models;
@ -564,6 +548,8 @@
console.log('settings', $user.role === 'admin'); console.log('settings', $user.role === 'admin');
if ($user.role === 'admin') { if ($user.role === 'admin') {
API_BASE_URL = await getOllamaAPIUrl(localStorage.token); API_BASE_URL = await getOllamaAPIUrl(localStorage.token);
OPENAI_API_BASE_URL = await getOpenAIUrl(localStorage.token);
OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
} }
let settings = JSON.parse(localStorage.getItem('settings') ?? '{}'); let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
@ -584,9 +570,6 @@
options = { ...options, ...settings.options }; options = { ...options, ...settings.options };
options.stop = (settings?.options?.stop ?? []).join(','); options.stop = (settings?.options?.stop ?? []).join(',');
OPENAI_API_KEY = settings.OPENAI_API_KEY ?? '';
OPENAI_API_BASE_URL = settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1';
titleAutoGenerate = settings.titleAutoGenerate ?? true; titleAutoGenerate = settings.titleAutoGenerate ?? true;
speechAutoSend = settings.speechAutoSend ?? false; speechAutoSend = settings.speechAutoSend ?? false;
responseAutoCopy = settings.responseAutoCopy ?? false; responseAutoCopy = settings.responseAutoCopy ?? false;
@ -1415,10 +1398,12 @@
<form <form
class="flex flex-col h-full justify-between space-y-3 text-sm" class="flex flex-col h-full justify-between space-y-3 text-sm"
on:submit|preventDefault={() => { on:submit|preventDefault={() => {
saveSettings({ updateOpenAIHandler();
OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined // saveSettings({
}); // OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
// OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined
// });
show = false; show = false;
}} }}
> >

View file

@ -1,11 +1,10 @@
import { dev } from '$app/environment'; import { dev } from '$app/environment';
export const OLLAMA_API_BASE_URL = dev
? `http://${location.hostname}:8080/ollama/api`
: '/ollama/api';
export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``; export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``;
export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`; export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
export const WEB_UI_VERSION = 'v1.0.0-alpha-static'; export const WEB_UI_VERSION = 'v1.0.0-alpha-static';

View file

@ -37,19 +37,17 @@
return []; return [];
})) }))
); );
// If OpenAI API Key exists
if ($settings.OPENAI_API_KEY) { // $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1',
const openAIModels = await getOpenAIModels( // $settings.OPENAI_API_KEY
$settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1',
$settings.OPENAI_API_KEY const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => {
).catch((error) => {
console.log(error); console.log(error);
toast.error(error);
return null; return null;
}); });
models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : [])); models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
}
return models; return models;
}; };