From 8d34324d1211715d94338ade5d7694662ade67be Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 5 Mar 2024 01:41:22 -0800 Subject: [PATCH] feat: cancel request Resolves #1006 --- backend/apps/ollama/main.py | 102 ++++++++++++++++++++++- src/routes/(app)/playground/+page.svelte | 87 ++++++------------- 2 files changed, 126 insertions(+), 63 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index fbaf622b..988d4bf4 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict import random import requests @@ -684,7 +684,7 @@ class GenerateChatCompletionForm(BaseModel): @app.post("/api/chat") @app.post("/api/chat/{url_idx}") -async def generate_completion( +async def generate_chat_completion( form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, user=Depends(get_current_user), @@ -765,6 +765,104 @@ async def generate_completion( ) +# TODO: we should update this part once Ollama supports other types +class OpenAIChatMessage(BaseModel): + role: str + content: str + + model_config = ConfigDict(extra="allow") + + +class OpenAIChatCompletionForm(BaseModel): + model: str + messages: List[OpenAIChatMessage] + + model_config = ConfigDict(extra="allow") + + +@app.post("/v1/chat/completions") +@app.post("/v1/chat/completions/{url_idx}") +async def generate_openai_chat_completion( + form_data: OpenAIChatCompletionForm, + url_idx: Optional[int] = None, + user=Depends(get_current_user), +): + + if url_idx == None: + if form_data.model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[form_data.model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + + r = None + + def get_request(): + nonlocal form_data + nonlocal r + + request_id = str(uuid.uuid4()) + try: + REQUEST_POOL.append(request_id) + + def stream_content(): + try: + if form_data.stream: + yield json.dumps( + {"request_id": request_id, "done": False} + ) + "\n" + + for chunk in r.iter_content(chunk_size=8192): + if request_id in REQUEST_POOL: + yield chunk + else: + print("User: canceled request") + break + finally: + if hasattr(r, "close"): + r.close() + if request_id in REQUEST_POOL: + REQUEST_POOL.remove(request_id) + + r = requests.request( + method="POST", + url=f"{url}/v1/chat/completions", + data=form_data.model_dump_json(exclude_none=True), + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + stream_content(), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + + try: + return await run_in_threadpool(get_request) + except Exception as e: + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise HTTPException( + status_code=r.status_code if r else 500, + detail=error_detail, + ) + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): url = app.state.OLLAMA_BASE_URLS[0] diff --git a/src/routes/(app)/playground/+page.svelte b/src/routes/(app)/playground/+page.svelte index 81c3032b..36f4a2bc 100644 --- a/src/routes/(app)/playground/+page.svelte +++ b/src/routes/(app)/playground/+page.svelte @@ -26,7 +26,7 @@ let selectedModelId = ''; let loading = false; - let currentRequestId; + let currentRequestId = null; let stopResponseFlag = false; let messagesContainerElement: HTMLDivElement; @@ -92,6 +92,10 @@ while (true) { const { value, done } = await reader.read(); if (done || stopResponseFlag) { + if (stopResponseFlag) { + await cancelChatCompletion(localStorage.token, currentRequestId); + } + currentRequestId = null; break; } @@ -108,7 +112,11 @@ let data = JSON.parse(line.replace(/^data: /, '')); console.log(data); - text += data.choices[0].delta.content ?? ''; + if ('request_id' in data) { + currentRequestId = data.request_id; + } else { + text += data.choices[0].delta.content ?? ''; + } } } } @@ -146,16 +154,6 @@ : `${OLLAMA_API_BASE_URL}/v1` ); - // const [res, controller] = await generateChatCompletion(localStorage.token, { - // model: selectedModelId, - // messages: [ - // { - // role: 'assistant', - // content: text - // } - // ] - // }); - let responseMessage; if (messages.at(-1)?.role === 'assistant') { responseMessage = messages.at(-1); @@ -180,6 +178,11 @@ while (true) { const { value, done } = await reader.read(); if (done || stopResponseFlag) { + if (stopResponseFlag) { + await cancelChatCompletion(localStorage.token, currentRequestId); + } + + currentRequestId = null; break; } @@ -196,17 +199,21 @@ let data = JSON.parse(line.replace(/^data: /, '')); console.log(data); - if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { - continue; + if ('request_id' in data) { + currentRequestId = data.request_id; } else { - textareaElement.style.height = textareaElement.scrollHeight + 'px'; + if (responseMessage.content == '' && data.choices[0].delta.content == '\n') { + continue; + } else { + textareaElement.style.height = textareaElement.scrollHeight + 'px'; - responseMessage.content += data.choices[0].delta.content ?? ''; - messages = messages; + responseMessage.content += data.choices[0].delta.content ?? ''; + messages = messages; - textareaElement.style.height = textareaElement.scrollHeight + 'px'; + textareaElement.style.height = textareaElement.scrollHeight + 'px'; - await tick(); + await tick(); + } } } } @@ -217,48 +224,6 @@ scrollToBottom(); } - - // while (true) { - // const { value, done } = await reader.read(); - // if (done || stopResponseFlag) { - // if (stopResponseFlag) { - // await cancelChatCompletion(localStorage.token, currentRequestId); - // } - - // currentRequestId = null; - // break; - // } - - // try { - // let lines = value.split('\n'); - - // for (const line of lines) { - // if (line !== '') { - // console.log(line); - // let data = JSON.parse(line); - - // if ('detail' in data) { - // throw data; - // } - - // if ('id' in data) { - // console.log(data); - // currentRequestId = data.id; - // } else { - // if (data.done == false) { - // text += data.message.content; - // } else { - // console.log('done'); - // } - // } - // } - // } - // } catch (error) { - // console.log(error); - // } - - // scrollToBottom(); - // } } };