diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index bc797f08..7965ff32 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -11,7 +11,7 @@ from pydantic import BaseModel from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user, get_admin_user -from config import OLLAMA_API_BASE_URL, WEBUI_AUTH +from config import OLLAMA_BASE_URL, WEBUI_AUTH app = FastAPI() app.add_middleware( @@ -22,7 +22,7 @@ app.add_middleware( allow_headers=["*"], ) -app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL +app.state.OLLAMA_BASE_URL = OLLAMA_BASE_URL # TARGET_SERVER_URL = OLLAMA_API_BASE_URL @@ -32,7 +32,7 @@ REQUEST_POOL = [] @app.get("/url") async def get_ollama_api_url(user=Depends(get_admin_user)): - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} + return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL} class UrlUpdateForm(BaseModel): @@ -41,8 +41,8 @@ class UrlUpdateForm(BaseModel): @app.post("/url/update") async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): - app.state.OLLAMA_API_BASE_URL = form_data.url - return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} + app.state.OLLAMA_BASE_URL = form_data.url + return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL} @app.get("/cancel/{request_id}") @@ -57,7 +57,7 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_current_user)): - target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" + target_url = f"{app.state.OLLAMA_BASE_URL}/{path}" body = await request.body() headers = dict(request.headers) @@ -91,7 +91,13 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): def stream_content(): try: - if path in ["chat"]: + if path == "generate": + data = json.loads(body.decode("utf-8")) + + if not ("stream" in data and data["stream"] == False): + yield json.dumps({"id": request_id, "done": False}) + "\n" + + elif path == "chat": yield json.dumps({"id": request_id, "done": False}) + "\n" for chunk in r.iter_content(chunk_size=8192): @@ -103,7 +109,8 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): finally: if hasattr(r, "close"): r.close() - REQUEST_POOL.remove(request_id) + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) r = requests.request( method=request.method, diff --git a/backend/config.py b/backend/config.py index df24b97e..cd1a2702 100644 --- a/backend/config.py +++ b/backend/config.py @@ -211,6 +211,17 @@ if ENV == "prod": if OLLAMA_API_BASE_URL == "/ollama/api": OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api" + +OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") + +if OLLAMA_BASE_URL == "": + OLLAMA_BASE_URL = ( + OLLAMA_API_BASE_URL[:-4] + if OLLAMA_API_BASE_URL.endswith("/api") + else OLLAMA_API_BASE_URL + ) + + #################################### # OPENAI_API #################################### diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 5fc8a5fe..0c96b2ab 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -29,7 +29,7 @@ export const getOllamaAPIUrl = async (token: string = '') => { throw error; } - return res.OLLAMA_API_BASE_URL; + return res.OLLAMA_BASE_URL; }; export const updateOllamaAPIUrl = async (token: string = '', url: string) => { @@ -64,13 +64,13 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => { throw error; } - return res.OLLAMA_API_BASE_URL; + return res.OLLAMA_BASE_URL; }; export const getOllamaVersion = async (token: string = '') => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/version`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/version`, { method: 'GET', headers: { Accept: 'application/json', @@ -102,7 +102,7 @@ export const getOllamaVersion = async (token: string = '') => { export const getOllamaModels = async (token: string = '') => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/tags`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags`, { method: 'GET', headers: { Accept: 'application/json', @@ -148,7 +148,7 @@ export const generateTitle = async ( console.log(template); - const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { 'Content-Type': 'text/event-stream', @@ -186,7 +186,7 @@ export const generatePrompt = async (token: string = '', model: string, conversa conversation = '[no existing conversation]'; } - const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { method: 'POST', headers: { 'Content-Type': 'text/event-stream', @@ -217,11 +217,37 @@ export const generatePrompt = async (token: string = '', model: string, conversa return res; }; +export const generateTextCompletion = async (token: string = '', model: string, text: string) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, { + method: 'POST', + headers: { + 'Content-Type': 'text/event-stream', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: text, + stream: true + }) + }).catch((err) => { + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const generateChatCompletion = async (token: string = '', body: object) => { let controller = new AbortController(); let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/chat`, { signal: controller.signal, method: 'POST', headers: { @@ -265,7 +291,7 @@ export const cancelChatCompletion = async (token: string = '', requestId: string export const createModel = async (token: string, tagName: string, content: string) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/create`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/create`, { method: 'POST', headers: { 'Content-Type': 'text/event-stream', @@ -290,7 +316,7 @@ export const createModel = async (token: string, tagName: string, content: strin export const deleteModel = async (token: string, tagName: string) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/delete`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/delete`, { method: 'DELETE', headers: { 'Content-Type': 'text/event-stream', @@ -324,7 +350,7 @@ export const deleteModel = async (token: string, tagName: string) => { export const pullModel = async (token: string, tagName: string) => { let error = null; - const res = await fetch(`${OLLAMA_API_BASE_URL}/pull`, { + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull`, { method: 'POST', headers: { 'Content-Type': 'text/event-stream', diff --git a/src/lib/components/chat/MessageInput/Models.svelte b/src/lib/components/chat/MessageInput/Models.svelte index cedd1d5b..2c364810 100644 --- a/src/lib/components/chat/MessageInput/Models.svelte +++ b/src/lib/components/chat/MessageInput/Models.svelte @@ -79,14 +79,18 @@ throw data; } - if (data.done == false) { - if (prompt == '' && data.response == '\n') { - continue; - } else { - prompt += data.response; - console.log(data.response); - chatInputElement.scrollTop = chatInputElement.scrollHeight; - await tick(); + if ('id' in data) { + console.log(data); + } else { + if (data.done == false) { + if (prompt == '' && data.response == '\n') { + continue; + } else { + prompt += data.response; + console.log(data.response); + chatInputElement.scrollTop = chatInputElement.scrollHeight; + await tick(); + } } } } diff --git a/src/lib/components/chat/Settings/Connections.svelte b/src/lib/components/chat/Settings/Connections.svelte index fc0de9a1..55c38358 100644 --- a/src/lib/components/chat/Settings/Connections.svelte +++ b/src/lib/components/chat/Settings/Connections.svelte @@ -114,12 +114,12 @@
-
Ollama API URL
+
Ollama Base URL
diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 6201e077..1061a1ab 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -594,6 +594,32 @@
Admin Panel
+ + {/if} +
+ +
+