forked from open-webui/open-webui
feat: terminate request on user stop
This commit is contained in:
parent
684bdf5151
commit
442e3d978a
4 changed files with 170 additions and 86 deletions
|
@ -5,6 +5,7 @@ from fastapi.concurrency import run_in_threadpool
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
|
@ -26,6 +27,9 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
|
||||||
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
|
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
|
||||||
|
|
||||||
|
|
||||||
|
REQUEST_POOL = []
|
||||||
|
|
||||||
|
|
||||||
@app.get("/url")
|
@app.get("/url")
|
||||||
async def get_ollama_api_url(user=Depends(get_current_user)):
|
async def get_ollama_api_url(user=Depends(get_current_user)):
|
||||||
if user and user.role == "admin":
|
if user and user.role == "admin":
|
||||||
|
@ -49,6 +53,16 @@ async def update_ollama_api_url(
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/cancel/{request_id}")
|
||||||
|
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
|
||||||
|
if user:
|
||||||
|
if request_id in REQUEST_POOL:
|
||||||
|
REQUEST_POOL.remove(request_id)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||||
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
|
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
|
||||||
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
|
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
|
||||||
|
@ -74,7 +88,27 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
|
||||||
|
|
||||||
def get_request():
|
def get_request():
|
||||||
nonlocal r
|
nonlocal r
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
|
REQUEST_POOL.append(request_id)
|
||||||
|
|
||||||
|
def stream_content():
|
||||||
|
try:
|
||||||
|
if path in ["chat"]:
|
||||||
|
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||||
|
|
||||||
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
if request_id in REQUEST_POOL:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
print("User: canceled request")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
if hasattr(r, "close"):
|
||||||
|
r.close()
|
||||||
|
REQUEST_POOL.remove(request_id)
|
||||||
|
|
||||||
r = requests.request(
|
r = requests.request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=target_url,
|
url=target_url,
|
||||||
|
@ -85,8 +119,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
|
||||||
|
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
|
# r.close()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
r.iter_content(chunk_size=8192),
|
stream_content(),
|
||||||
status_code=r.status_code,
|
status_code=r.status_code,
|
||||||
headers=dict(r.headers),
|
headers=dict(r.headers),
|
||||||
)
|
)
|
||||||
|
|
|
@ -206,9 +206,11 @@ export const generatePrompt = async (token: string = '', model: string, conversa
|
||||||
};
|
};
|
||||||
|
|
||||||
export const generateChatCompletion = async (token: string = '', body: object) => {
|
export const generateChatCompletion = async (token: string = '', body: object) => {
|
||||||
|
let controller = new AbortController();
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, {
|
const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, {
|
||||||
|
signal: controller.signal,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'text/event-stream',
|
'Content-Type': 'text/event-stream',
|
||||||
|
@ -224,6 +226,27 @@ export const generateChatCompletion = async (token: string = '', body: object) =
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return [res, controller];
|
||||||
|
};
|
||||||
|
|
||||||
|
export const cancelChatCompletion = async (token: string = '', requestId: string) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'text/event-stream',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
}
|
||||||
|
}).catch((err) => {
|
||||||
|
error = err;
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
|
import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
|
||||||
import { copyToClipboard, splitStream } from '$lib/utils';
|
import { copyToClipboard, splitStream } from '$lib/utils';
|
||||||
|
|
||||||
import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
|
import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
|
||||||
import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
|
import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
|
||||||
import { queryVectorDB } from '$lib/apis/rag';
|
import { queryVectorDB } from '$lib/apis/rag';
|
||||||
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
|
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
|
||||||
|
@ -24,6 +24,8 @@
|
||||||
let autoScroll = true;
|
let autoScroll = true;
|
||||||
let processing = '';
|
let processing = '';
|
||||||
|
|
||||||
|
let currentRequestId = null;
|
||||||
|
|
||||||
let selectedModels = [''];
|
let selectedModels = [''];
|
||||||
|
|
||||||
let selectedModelfile = null;
|
let selectedModelfile = null;
|
||||||
|
@ -279,7 +281,7 @@
|
||||||
// Scroll down
|
// Scroll down
|
||||||
window.scrollTo({ top: document.body.scrollHeight });
|
window.scrollTo({ top: document.body.scrollHeight });
|
||||||
|
|
||||||
const res = await generateChatCompletion(localStorage.token, {
|
const [res, controller] = await generateChatCompletion(localStorage.token, {
|
||||||
model: model,
|
model: model,
|
||||||
messages: [
|
messages: [
|
||||||
$settings.system
|
$settings.system
|
||||||
|
@ -307,6 +309,8 @@
|
||||||
});
|
});
|
||||||
|
|
||||||
if (res && res.ok) {
|
if (res && res.ok) {
|
||||||
|
console.log('controller', controller);
|
||||||
|
|
||||||
const reader = res.body
|
const reader = res.body
|
||||||
.pipeThrough(new TextDecoderStream())
|
.pipeThrough(new TextDecoderStream())
|
||||||
.pipeThrough(splitStream('\n'))
|
.pipeThrough(splitStream('\n'))
|
||||||
|
@ -317,6 +321,14 @@
|
||||||
if (done || stopResponseFlag || _chatId !== $chatId) {
|
if (done || stopResponseFlag || _chatId !== $chatId) {
|
||||||
responseMessage.done = true;
|
responseMessage.done = true;
|
||||||
messages = messages;
|
messages = messages;
|
||||||
|
|
||||||
|
if (stopResponseFlag) {
|
||||||
|
controller.abort('User: Stop Response');
|
||||||
|
await cancelChatCompletion(localStorage.token, currentRequestId);
|
||||||
|
}
|
||||||
|
|
||||||
|
currentRequestId = null;
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,52 +344,57 @@
|
||||||
throw data;
|
throw data;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.done == false) {
|
if ('id' in data) {
|
||||||
if (responseMessage.content == '' && data.message.content == '\n') {
|
console.log(data);
|
||||||
continue;
|
currentRequestId = data.id;
|
||||||
} else {
|
|
||||||
responseMessage.content += data.message.content;
|
|
||||||
messages = messages;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
responseMessage.done = true;
|
if (data.done == false) {
|
||||||
|
if (responseMessage.content == '' && data.message.content == '\n') {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
responseMessage.content += data.message.content;
|
||||||
|
messages = messages;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
responseMessage.done = true;
|
||||||
|
|
||||||
if (responseMessage.content == '') {
|
if (responseMessage.content == '') {
|
||||||
responseMessage.error = true;
|
responseMessage.error = true;
|
||||||
responseMessage.content =
|
responseMessage.content =
|
||||||
'Oops! No text generated from Ollama, Please try again.';
|
'Oops! No text generated from Ollama, Please try again.';
|
||||||
}
|
}
|
||||||
|
|
||||||
responseMessage.context = data.context ?? null;
|
responseMessage.context = data.context ?? null;
|
||||||
responseMessage.info = {
|
responseMessage.info = {
|
||||||
total_duration: data.total_duration,
|
total_duration: data.total_duration,
|
||||||
load_duration: data.load_duration,
|
load_duration: data.load_duration,
|
||||||
sample_count: data.sample_count,
|
sample_count: data.sample_count,
|
||||||
sample_duration: data.sample_duration,
|
sample_duration: data.sample_duration,
|
||||||
prompt_eval_count: data.prompt_eval_count,
|
prompt_eval_count: data.prompt_eval_count,
|
||||||
prompt_eval_duration: data.prompt_eval_duration,
|
prompt_eval_duration: data.prompt_eval_duration,
|
||||||
eval_count: data.eval_count,
|
eval_count: data.eval_count,
|
||||||
eval_duration: data.eval_duration
|
eval_duration: data.eval_duration
|
||||||
};
|
};
|
||||||
messages = messages;
|
messages = messages;
|
||||||
|
|
||||||
if ($settings.notificationEnabled && !document.hasFocus()) {
|
if ($settings.notificationEnabled && !document.hasFocus()) {
|
||||||
const notification = new Notification(
|
const notification = new Notification(
|
||||||
selectedModelfile
|
selectedModelfile
|
||||||
? `${
|
? `${
|
||||||
selectedModelfile.title.charAt(0).toUpperCase() +
|
selectedModelfile.title.charAt(0).toUpperCase() +
|
||||||
selectedModelfile.title.slice(1)
|
selectedModelfile.title.slice(1)
|
||||||
}`
|
}`
|
||||||
: `Ollama - ${model}`,
|
: `Ollama - ${model}`,
|
||||||
{
|
{
|
||||||
body: responseMessage.content,
|
body: responseMessage.content,
|
||||||
icon: selectedModelfile?.imageUrl ?? '/favicon.png'
|
icon: selectedModelfile?.imageUrl ?? '/favicon.png'
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($settings.responseAutoCopy) {
|
if ($settings.responseAutoCopy) {
|
||||||
copyToClipboard(responseMessage.content);
|
copyToClipboard(responseMessage.content);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -297,7 +297,7 @@
|
||||||
// Scroll down
|
// Scroll down
|
||||||
window.scrollTo({ top: document.body.scrollHeight });
|
window.scrollTo({ top: document.body.scrollHeight });
|
||||||
|
|
||||||
const res = await generateChatCompletion(localStorage.token, {
|
const [res, controller] = await generateChatCompletion(localStorage.token, {
|
||||||
model: model,
|
model: model,
|
||||||
messages: [
|
messages: [
|
||||||
$settings.system
|
$settings.system
|
||||||
|
@ -335,6 +335,10 @@
|
||||||
if (done || stopResponseFlag || _chatId !== $chatId) {
|
if (done || stopResponseFlag || _chatId !== $chatId) {
|
||||||
responseMessage.done = true;
|
responseMessage.done = true;
|
||||||
messages = messages;
|
messages = messages;
|
||||||
|
|
||||||
|
if (stopResponseFlag) {
|
||||||
|
controller.abort('User: Stop Response');
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,52 +354,56 @@
|
||||||
throw data;
|
throw data;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.done == false) {
|
if ('id' in data) {
|
||||||
if (responseMessage.content == '' && data.message.content == '\n') {
|
console.log(data);
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
responseMessage.content += data.message.content;
|
|
||||||
messages = messages;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
responseMessage.done = true;
|
if (data.done == false) {
|
||||||
|
if (responseMessage.content == '' && data.message.content == '\n') {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
responseMessage.content += data.message.content;
|
||||||
|
messages = messages;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
responseMessage.done = true;
|
||||||
|
|
||||||
if (responseMessage.content == '') {
|
if (responseMessage.content == '') {
|
||||||
responseMessage.error = true;
|
responseMessage.error = true;
|
||||||
responseMessage.content =
|
responseMessage.content =
|
||||||
'Oops! No text generated from Ollama, Please try again.';
|
'Oops! No text generated from Ollama, Please try again.';
|
||||||
}
|
}
|
||||||
|
|
||||||
responseMessage.context = data.context ?? null;
|
responseMessage.context = data.context ?? null;
|
||||||
responseMessage.info = {
|
responseMessage.info = {
|
||||||
total_duration: data.total_duration,
|
total_duration: data.total_duration,
|
||||||
load_duration: data.load_duration,
|
load_duration: data.load_duration,
|
||||||
sample_count: data.sample_count,
|
sample_count: data.sample_count,
|
||||||
sample_duration: data.sample_duration,
|
sample_duration: data.sample_duration,
|
||||||
prompt_eval_count: data.prompt_eval_count,
|
prompt_eval_count: data.prompt_eval_count,
|
||||||
prompt_eval_duration: data.prompt_eval_duration,
|
prompt_eval_duration: data.prompt_eval_duration,
|
||||||
eval_count: data.eval_count,
|
eval_count: data.eval_count,
|
||||||
eval_duration: data.eval_duration
|
eval_duration: data.eval_duration
|
||||||
};
|
};
|
||||||
messages = messages;
|
messages = messages;
|
||||||
|
|
||||||
if ($settings.notificationEnabled && !document.hasFocus()) {
|
if ($settings.notificationEnabled && !document.hasFocus()) {
|
||||||
const notification = new Notification(
|
const notification = new Notification(
|
||||||
selectedModelfile
|
selectedModelfile
|
||||||
? `${
|
? `${
|
||||||
selectedModelfile.title.charAt(0).toUpperCase() +
|
selectedModelfile.title.charAt(0).toUpperCase() +
|
||||||
selectedModelfile.title.slice(1)
|
selectedModelfile.title.slice(1)
|
||||||
}`
|
}`
|
||||||
: `Ollama - ${model}`,
|
: `Ollama - ${model}`,
|
||||||
{
|
{
|
||||||
body: responseMessage.content,
|
body: responseMessage.content,
|
||||||
icon: selectedModelfile?.imageUrl ?? '/favicon.png'
|
icon: selectedModelfile?.imageUrl ?? '/favicon.png'
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($settings.responseAutoCopy) {
|
if ($settings.responseAutoCopy) {
|
||||||
copyToClipboard(responseMessage.content);
|
copyToClipboard(responseMessage.content);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue