feat: terminate request on user stop

This commit is contained in:
Timothy J. Baek 2024-01-17 19:19:44 -08:00
parent 684bdf5151
commit 442e3d978a
4 changed files with 170 additions and 86 deletions

View file

@ -5,6 +5,7 @@ from fastapi.concurrency import run_in_threadpool
import requests import requests
import json import json
import uuid
from pydantic import BaseModel from pydantic import BaseModel
from apps.web.models.users import Users from apps.web.models.users import Users
@ -26,6 +27,9 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
REQUEST_POOL = []
@app.get("/url") @app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)): async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin": if user and user.role == "admin":
@ -49,6 +53,16 @@ async def update_ollama_api_url(
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)): async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
@ -74,7 +88,27 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
def get_request(): def get_request():
nonlocal r nonlocal r
request_id = str(uuid.uuid4())
try: try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if path in ["chat"]:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
print("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
REQUEST_POOL.remove(request_id)
r = requests.request( r = requests.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
@ -85,8 +119,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
r.raise_for_status() r.raise_for_status()
# r.close()
return StreamingResponse( return StreamingResponse(
r.iter_content(chunk_size=8192), stream_content(),
status_code=r.status_code, status_code=r.status_code,
headers=dict(r.headers), headers=dict(r.headers),
) )

View file

@ -206,9 +206,11 @@ export const generatePrompt = async (token: string = '', model: string, conversa
}; };
export const generateChatCompletion = async (token: string = '', body: object) => { export const generateChatCompletion = async (token: string = '', body: object) => {
let controller = new AbortController();
let error = null; let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, {
signal: controller.signal,
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -224,6 +226,27 @@ export const generateChatCompletion = async (token: string = '', body: object) =
throw error; throw error;
} }
return [res, controller];
};
export const cancelChatCompletion = async (token: string = '', requestId: string) => {
let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
method: 'GET',
headers: {
'Content-Type': 'text/event-stream',
Authorization: `Bearer ${token}`
}
}).catch((err) => {
error = err;
return null;
});
if (error) {
throw error;
}
return res; return res;
}; };

View file

@ -9,7 +9,7 @@
import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
import { copyToClipboard, splitStream } from '$lib/utils'; import { copyToClipboard, splitStream } from '$lib/utils';
import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats'; import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
import { queryVectorDB } from '$lib/apis/rag'; import { queryVectorDB } from '$lib/apis/rag';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion } from '$lib/apis/openai';
@ -24,6 +24,8 @@
let autoScroll = true; let autoScroll = true;
let processing = ''; let processing = '';
let currentRequestId = null;
let selectedModels = ['']; let selectedModels = [''];
let selectedModelfile = null; let selectedModelfile = null;
@ -279,7 +281,7 @@
// Scroll down // Scroll down
window.scrollTo({ top: document.body.scrollHeight }); window.scrollTo({ top: document.body.scrollHeight });
const res = await generateChatCompletion(localStorage.token, { const [res, controller] = await generateChatCompletion(localStorage.token, {
model: model, model: model,
messages: [ messages: [
$settings.system $settings.system
@ -307,6 +309,8 @@
}); });
if (res && res.ok) { if (res && res.ok) {
console.log('controller', controller);
const reader = res.body const reader = res.body
.pipeThrough(new TextDecoderStream()) .pipeThrough(new TextDecoderStream())
.pipeThrough(splitStream('\n')) .pipeThrough(splitStream('\n'))
@ -317,6 +321,14 @@
if (done || stopResponseFlag || _chatId !== $chatId) { if (done || stopResponseFlag || _chatId !== $chatId) {
responseMessage.done = true; responseMessage.done = true;
messages = messages; messages = messages;
if (stopResponseFlag) {
controller.abort('User: Stop Response');
await cancelChatCompletion(localStorage.token, currentRequestId);
}
currentRequestId = null;
break; break;
} }
@ -332,52 +344,57 @@
throw data; throw data;
} }
if (data.done == false) { if ('id' in data) {
if (responseMessage.content == '' && data.message.content == '\n') { console.log(data);
continue; currentRequestId = data.id;
} else {
responseMessage.content += data.message.content;
messages = messages;
}
} else { } else {
responseMessage.done = true; if (data.done == false) {
if (responseMessage.content == '' && data.message.content == '\n') {
continue;
} else {
responseMessage.content += data.message.content;
messages = messages;
}
} else {
responseMessage.done = true;
if (responseMessage.content == '') { if (responseMessage.content == '') {
responseMessage.error = true; responseMessage.error = true;
responseMessage.content = responseMessage.content =
'Oops! No text generated from Ollama, Please try again.'; 'Oops! No text generated from Ollama, Please try again.';
} }
responseMessage.context = data.context ?? null; responseMessage.context = data.context ?? null;
responseMessage.info = { responseMessage.info = {
total_duration: data.total_duration, total_duration: data.total_duration,
load_duration: data.load_duration, load_duration: data.load_duration,
sample_count: data.sample_count, sample_count: data.sample_count,
sample_duration: data.sample_duration, sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count, prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration, prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count, eval_count: data.eval_count,
eval_duration: data.eval_duration eval_duration: data.eval_duration
}; };
messages = messages; messages = messages;
if ($settings.notificationEnabled && !document.hasFocus()) { if ($settings.notificationEnabled && !document.hasFocus()) {
const notification = new Notification( const notification = new Notification(
selectedModelfile selectedModelfile
? `${ ? `${
selectedModelfile.title.charAt(0).toUpperCase() + selectedModelfile.title.charAt(0).toUpperCase() +
selectedModelfile.title.slice(1) selectedModelfile.title.slice(1)
}` }`
: `Ollama - ${model}`, : `Ollama - ${model}`,
{ {
body: responseMessage.content, body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? '/favicon.png' icon: selectedModelfile?.imageUrl ?? '/favicon.png'
} }
); );
} }
if ($settings.responseAutoCopy) { if ($settings.responseAutoCopy) {
copyToClipboard(responseMessage.content); copyToClipboard(responseMessage.content);
}
} }
} }
} }

View file

@ -297,7 +297,7 @@
// Scroll down // Scroll down
window.scrollTo({ top: document.body.scrollHeight }); window.scrollTo({ top: document.body.scrollHeight });
const res = await generateChatCompletion(localStorage.token, { const [res, controller] = await generateChatCompletion(localStorage.token, {
model: model, model: model,
messages: [ messages: [
$settings.system $settings.system
@ -335,6 +335,10 @@
if (done || stopResponseFlag || _chatId !== $chatId) { if (done || stopResponseFlag || _chatId !== $chatId) {
responseMessage.done = true; responseMessage.done = true;
messages = messages; messages = messages;
if (stopResponseFlag) {
controller.abort('User: Stop Response');
}
break; break;
} }
@ -350,52 +354,56 @@
throw data; throw data;
} }
if (data.done == false) { if ('id' in data) {
if (responseMessage.content == '' && data.message.content == '\n') { console.log(data);
continue;
} else {
responseMessage.content += data.message.content;
messages = messages;
}
} else { } else {
responseMessage.done = true; if (data.done == false) {
if (responseMessage.content == '' && data.message.content == '\n') {
continue;
} else {
responseMessage.content += data.message.content;
messages = messages;
}
} else {
responseMessage.done = true;
if (responseMessage.content == '') { if (responseMessage.content == '') {
responseMessage.error = true; responseMessage.error = true;
responseMessage.content = responseMessage.content =
'Oops! No text generated from Ollama, Please try again.'; 'Oops! No text generated from Ollama, Please try again.';
} }
responseMessage.context = data.context ?? null; responseMessage.context = data.context ?? null;
responseMessage.info = { responseMessage.info = {
total_duration: data.total_duration, total_duration: data.total_duration,
load_duration: data.load_duration, load_duration: data.load_duration,
sample_count: data.sample_count, sample_count: data.sample_count,
sample_duration: data.sample_duration, sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count, prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration, prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count, eval_count: data.eval_count,
eval_duration: data.eval_duration eval_duration: data.eval_duration
}; };
messages = messages; messages = messages;
if ($settings.notificationEnabled && !document.hasFocus()) { if ($settings.notificationEnabled && !document.hasFocus()) {
const notification = new Notification( const notification = new Notification(
selectedModelfile selectedModelfile
? `${ ? `${
selectedModelfile.title.charAt(0).toUpperCase() + selectedModelfile.title.charAt(0).toUpperCase() +
selectedModelfile.title.slice(1) selectedModelfile.title.slice(1)
}` }`
: `Ollama - ${model}`, : `Ollama - ${model}`,
{ {
body: responseMessage.content, body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? '/favicon.png' icon: selectedModelfile?.imageUrl ?? '/favicon.png'
} }
); );
} }
if ($settings.responseAutoCopy) { if ($settings.responseAutoCopy) {
copyToClipboard(responseMessage.content); copyToClipboard(responseMessage.content);
}
} }
} }
} }