diff --git a/Dockerfile b/Dockerfile index 7080d73b..504cfff6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,6 +16,10 @@ ARG OLLAMA_API_BASE_URL='/ollama/api' ENV ENV=prod ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL + +ENV OPENAI_API_BASE_URL "" +ENV OPENAI_API_KEY "" + ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY" WORKDIR /app diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py new file mode 100644 index 00000000..03d4621b --- /dev/null +++ b/backend/apps/openai/main.py @@ -0,0 +1,135 @@ +from fastapi import FastAPI, Request, Response, HTTPException, Depends +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse + +import requests +import json +from pydantic import BaseModel + +from apps.web.models.users import Users +from constants import ERROR_MESSAGES +from utils.utils import decode_token, get_current_user +from config import OPENAI_API_BASE_URL, OPENAI_API_KEY + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL +app.state.OPENAI_API_KEY = OPENAI_API_KEY + + +class UrlUpdateForm(BaseModel): + url: str + + +class KeyUpdateForm(BaseModel): + key: str + + +@app.get("/url") +async def get_openai_url(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} + else: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/url/update") +async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): + if user and user.role == "admin": + app.state.OPENAI_API_BASE_URL = form_data.url + return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} + else: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.get("/key") +async def get_openai_key(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} + else: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/key/update") +async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): + if user and user.role == "admin": + app.state.OPENAI_API_KEY = form_data.key + return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} + else: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_current_user)): + target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" + print(target_url, app.state.OPENAI_API_KEY) + + if user.role not in ["user", "admin"]: + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + if app.state.OPENAI_API_KEY == "": + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + body = await request.body() + # headers = dict(request.headers) + # print(headers) + + headers = {} + headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" + + try: + r = requests.request( + method=request.method, + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + else: + # For non-SSE, read the response and return it + # response_data = ( + # r.json() + # if r.headers.get("Content-Type", "") + # == "application/json" + # else r.text + # ) + + response_data = r.json() + + print(type(response_data)) + + if "openai" in app.state.OPENAI_API_BASE_URL and path == "models": + response_data["data"] = list( + filter(lambda model: "gpt" in model["id"], response_data["data"]) + ) + + return response_data + except Exception as e: + print(e) + error_detail = "Ollama WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" + + raise HTTPException(status_code=r.status_code, detail=error_detail) diff --git a/backend/config.py b/backend/config.py index 8e100fe5..16e7eeb2 100644 --- a/backend/config.py +++ b/backend/config.py @@ -27,11 +27,22 @@ if ENV == "prod": if OLLAMA_API_BASE_URL == "/ollama/api": OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api" + +#################################### +# OPENAI_API +#################################### + +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") +OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") + +if OPENAI_API_BASE_URL == "": + OPENAI_API_BASE_URL = "https://api.openai.com/v1" + #################################### # WEBUI_VERSION #################################### -WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.42") +WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50") #################################### # WEBUI_AUTH (Required for security) diff --git a/backend/constants.py b/backend/constants.py index 0817445b..e51ecdda 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -33,4 +33,5 @@ class ERROR_MESSAGES(str, Enum): ) NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" + API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." MALICIOUS = "Unusual activities detected, please try again in a few minutes." diff --git a/backend/main.py b/backend/main.py index 5e3b7e83..a97e6e85 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,6 +6,8 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException from apps.ollama.main import app as ollama_app +from apps.openai.main import app as openai_app + from apps.web.main import app as webui_app import time @@ -46,7 +48,7 @@ async def check_url(request: Request, call_next): app.mount("/api/v1", webui_app) -# app.mount("/ollama/api", WSGIMiddleware(ollama_app)) app.mount("/ollama/api", ollama_app) +app.mount("/openai/api", openai_app) app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") diff --git a/example.env b/example.env index 9c628b42..74d52223 100644 --- a/example.env +++ b/example.env @@ -1,12 +1,6 @@ -# If you're serving both the frontend and backend (Recommended) -# Set the public API base URL for seamless communication -PUBLIC_API_BASE_URL='/ollama/api' - -# If you're serving only the frontend (Not recommended and not fully supported) -# Comment above and Uncomment below -# You can use the default value or specify a custom path, e.g., '/api' -# PUBLIC_API_BASE_URL='http://{location.hostname}:11434/api' - # Ollama URL for the backend to connect # The path '/ollama/api' will be redirected to the specified backend URL OLLAMA_API_BASE_URL='http://localhost:11434/api' + +OPENAI_API_BASE_URL='' +OPENAI_API_KEY='' \ No newline at end of file diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index c144ae89..dcd92710 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -1,4 +1,176 @@ -export const getOpenAIModels = async ( +import { OPENAI_API_BASE_URL } from '$lib/constants'; + +export const getOpenAIUrl = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/url`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_BASE_URL; +}; + +export const updateOpenAIUrl = async (token: string = '', url: string) => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/url/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + url: url + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_BASE_URL; +}; + +export const getOpenAIKey = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/key`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; +}; + +export const updateOpenAIKey = async (token: string = '', key: string) => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/key/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + key: key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.OPENAI_API_KEY; +}; + +export const getOpenAIModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + const models = Array.isArray(res) ? res : res?.data ?? null; + + return models + ? models + .map((model) => ({ name: model.id, external: true })) + .sort((a, b) => { + return a.name.localeCompare(b.name); + }) + : models; +}; + +export const getOpenAIModelsDirect = async ( base_url: string = 'https://api.openai.com/v1', api_key: string = '' ) => { @@ -34,3 +206,26 @@ export const getOpenAIModels = async ( return a.name.localeCompare(b.name); }); }; + +export const generateOpenAIChatCompletion = async (token: string = '', body: object) => { + let error = null; + + const res = await fetch(`${OPENAI_API_BASE_URL}/chat/completions`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(body) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/Messages/Placeholder.svelte b/src/lib/components/chat/Messages/Placeholder.svelte index 1d709b5f..59b47ba4 100644 --- a/src/lib/components/chat/Messages/Placeholder.svelte +++ b/src/lib/components/chat/Messages/Placeholder.svelte @@ -27,7 +27,7 @@ > {#if model in modelfiles} modelfileModels - {/if} - + + {/if}