forked from open-webui/open-webui
		
	feat: terminate request on user stop
This commit is contained in:
		
							parent
							
								
									684bdf5151
								
							
						
					
					
						commit
						442e3d978a
					
				
					 4 changed files with 170 additions and 86 deletions
				
			
		|  | @ -5,6 +5,7 @@ from fastapi.concurrency import run_in_threadpool | ||||||
| 
 | 
 | ||||||
| import requests | import requests | ||||||
| import json | import json | ||||||
|  | import uuid | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
| from apps.web.models.users import Users | from apps.web.models.users import Users | ||||||
|  | @ -26,6 +27,9 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL | ||||||
| # TARGET_SERVER_URL = OLLAMA_API_BASE_URL | # TARGET_SERVER_URL = OLLAMA_API_BASE_URL | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | REQUEST_POOL = [] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @app.get("/url") | @app.get("/url") | ||||||
| async def get_ollama_api_url(user=Depends(get_current_user)): | async def get_ollama_api_url(user=Depends(get_current_user)): | ||||||
|     if user and user.role == "admin": |     if user and user.role == "admin": | ||||||
|  | @ -49,6 +53,16 @@ async def update_ollama_api_url( | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @app.get("/cancel/{request_id}") | ||||||
|  | async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)): | ||||||
|  |     if user: | ||||||
|  |         if request_id in REQUEST_POOL: | ||||||
|  |             REQUEST_POOL.remove(request_id) | ||||||
|  |         return True | ||||||
|  |     else: | ||||||
|  |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||||
| async def proxy(path: str, request: Request, user=Depends(get_current_user)): | async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
|     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" |     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" | ||||||
|  | @ -74,7 +88,27 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
|     def get_request(): |     def get_request(): | ||||||
|         nonlocal r |         nonlocal r | ||||||
|  | 
 | ||||||
|  |         request_id = str(uuid.uuid4()) | ||||||
|         try: |         try: | ||||||
|  |             REQUEST_POOL.append(request_id) | ||||||
|  | 
 | ||||||
|  |             def stream_content(): | ||||||
|  |                 try: | ||||||
|  |                     if path in ["chat"]: | ||||||
|  |                         yield json.dumps({"id": request_id, "done": False}) + "\n" | ||||||
|  | 
 | ||||||
|  |                     for chunk in r.iter_content(chunk_size=8192): | ||||||
|  |                         if request_id in REQUEST_POOL: | ||||||
|  |                             yield chunk | ||||||
|  |                         else: | ||||||
|  |                             print("User: canceled request") | ||||||
|  |                             break | ||||||
|  |                 finally: | ||||||
|  |                     if hasattr(r, "close"): | ||||||
|  |                         r.close() | ||||||
|  |                         REQUEST_POOL.remove(request_id) | ||||||
|  | 
 | ||||||
|             r = requests.request( |             r = requests.request( | ||||||
|                 method=request.method, |                 method=request.method, | ||||||
|                 url=target_url, |                 url=target_url, | ||||||
|  | @ -85,8 +119,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
| 
 | 
 | ||||||
|             r.raise_for_status() |             r.raise_for_status() | ||||||
| 
 | 
 | ||||||
|  |             # r.close() | ||||||
|  | 
 | ||||||
|             return StreamingResponse( |             return StreamingResponse( | ||||||
|                 r.iter_content(chunk_size=8192), |                 stream_content(), | ||||||
|                 status_code=r.status_code, |                 status_code=r.status_code, | ||||||
|                 headers=dict(r.headers), |                 headers=dict(r.headers), | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  | @ -206,9 +206,11 @@ export const generatePrompt = async (token: string = '', model: string, conversa | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| export const generateChatCompletion = async (token: string = '', body: object) => { | export const generateChatCompletion = async (token: string = '', body: object) => { | ||||||
|  | 	let controller = new AbortController(); | ||||||
| 	let error = null; | 	let error = null; | ||||||
| 
 | 
 | ||||||
| 	const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { | 	const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { | ||||||
|  | 		signal: controller.signal, | ||||||
| 		method: 'POST', | 		method: 'POST', | ||||||
| 		headers: { | 		headers: { | ||||||
| 			'Content-Type': 'text/event-stream', | 			'Content-Type': 'text/event-stream', | ||||||
|  | @ -224,6 +226,27 @@ export const generateChatCompletion = async (token: string = '', body: object) = | ||||||
| 		throw error; | 		throw error; | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	return [res, controller]; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | export const cancelChatCompletion = async (token: string = '', requestId: string) => { | ||||||
|  | 	let error = null; | ||||||
|  | 
 | ||||||
|  | 	const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, { | ||||||
|  | 		method: 'GET', | ||||||
|  | 		headers: { | ||||||
|  | 			'Content-Type': 'text/event-stream', | ||||||
|  | 			Authorization: `Bearer ${token}` | ||||||
|  | 		} | ||||||
|  | 	}).catch((err) => { | ||||||
|  | 		error = err; | ||||||
|  | 		return null; | ||||||
|  | 	}); | ||||||
|  | 
 | ||||||
|  | 	if (error) { | ||||||
|  | 		throw error; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return res; | 	return res; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -9,7 +9,7 @@ | ||||||
| 	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; | 	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; | ||||||
| 	import { copyToClipboard, splitStream } from '$lib/utils'; | 	import { copyToClipboard, splitStream } from '$lib/utils'; | ||||||
| 
 | 
 | ||||||
| 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; | 	import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; | ||||||
| 	import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats'; | 	import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats'; | ||||||
| 	import { queryVectorDB } from '$lib/apis/rag'; | 	import { queryVectorDB } from '$lib/apis/rag'; | ||||||
| 	import { generateOpenAIChatCompletion } from '$lib/apis/openai'; | 	import { generateOpenAIChatCompletion } from '$lib/apis/openai'; | ||||||
|  | @ -24,6 +24,8 @@ | ||||||
| 	let autoScroll = true; | 	let autoScroll = true; | ||||||
| 	let processing = ''; | 	let processing = ''; | ||||||
| 
 | 
 | ||||||
|  | 	let currentRequestId = null; | ||||||
|  | 
 | ||||||
| 	let selectedModels = ['']; | 	let selectedModels = ['']; | ||||||
| 
 | 
 | ||||||
| 	let selectedModelfile = null; | 	let selectedModelfile = null; | ||||||
|  | @ -279,7 +281,7 @@ | ||||||
| 		// Scroll down | 		// Scroll down | ||||||
| 		window.scrollTo({ top: document.body.scrollHeight }); | 		window.scrollTo({ top: document.body.scrollHeight }); | ||||||
| 
 | 
 | ||||||
| 		const res = await generateChatCompletion(localStorage.token, { | 		const [res, controller] = await generateChatCompletion(localStorage.token, { | ||||||
| 			model: model, | 			model: model, | ||||||
| 			messages: [ | 			messages: [ | ||||||
| 				$settings.system | 				$settings.system | ||||||
|  | @ -307,6 +309,8 @@ | ||||||
| 		}); | 		}); | ||||||
| 
 | 
 | ||||||
| 		if (res && res.ok) { | 		if (res && res.ok) { | ||||||
|  | 			console.log('controller', controller); | ||||||
|  | 
 | ||||||
| 			const reader = res.body | 			const reader = res.body | ||||||
| 				.pipeThrough(new TextDecoderStream()) | 				.pipeThrough(new TextDecoderStream()) | ||||||
| 				.pipeThrough(splitStream('\n')) | 				.pipeThrough(splitStream('\n')) | ||||||
|  | @ -317,6 +321,14 @@ | ||||||
| 				if (done || stopResponseFlag || _chatId !== $chatId) { | 				if (done || stopResponseFlag || _chatId !== $chatId) { | ||||||
| 					responseMessage.done = true; | 					responseMessage.done = true; | ||||||
| 					messages = messages; | 					messages = messages; | ||||||
|  | 
 | ||||||
|  | 					if (stopResponseFlag) { | ||||||
|  | 						controller.abort('User: Stop Response'); | ||||||
|  | 						await cancelChatCompletion(localStorage.token, currentRequestId); | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					currentRequestId = null; | ||||||
|  | 
 | ||||||
| 					break; | 					break; | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
|  | @ -332,52 +344,57 @@ | ||||||
| 								throw data; | 								throw data; | ||||||
| 							} | 							} | ||||||
| 
 | 
 | ||||||
| 							if (data.done == false) { | 							if ('id' in data) { | ||||||
| 								if (responseMessage.content == '' && data.message.content == '\n') { | 								console.log(data); | ||||||
| 									continue; | 								currentRequestId = data.id; | ||||||
| 								} else { |  | ||||||
| 									responseMessage.content += data.message.content; |  | ||||||
| 									messages = messages; |  | ||||||
| 								} |  | ||||||
| 							} else { | 							} else { | ||||||
| 								responseMessage.done = true; | 								if (data.done == false) { | ||||||
|  | 									if (responseMessage.content == '' && data.message.content == '\n') { | ||||||
|  | 										continue; | ||||||
|  | 									} else { | ||||||
|  | 										responseMessage.content += data.message.content; | ||||||
|  | 										messages = messages; | ||||||
|  | 									} | ||||||
|  | 								} else { | ||||||
|  | 									responseMessage.done = true; | ||||||
| 
 | 
 | ||||||
| 								if (responseMessage.content == '') { | 									if (responseMessage.content == '') { | ||||||
| 									responseMessage.error = true; | 										responseMessage.error = true; | ||||||
| 									responseMessage.content = | 										responseMessage.content = | ||||||
| 										'Oops! No text generated from Ollama, Please try again.'; | 											'Oops! No text generated from Ollama, Please try again.'; | ||||||
| 								} | 									} | ||||||
| 
 | 
 | ||||||
| 								responseMessage.context = data.context ?? null; | 									responseMessage.context = data.context ?? null; | ||||||
| 								responseMessage.info = { | 									responseMessage.info = { | ||||||
| 									total_duration: data.total_duration, | 										total_duration: data.total_duration, | ||||||
| 									load_duration: data.load_duration, | 										load_duration: data.load_duration, | ||||||
| 									sample_count: data.sample_count, | 										sample_count: data.sample_count, | ||||||
| 									sample_duration: data.sample_duration, | 										sample_duration: data.sample_duration, | ||||||
| 									prompt_eval_count: data.prompt_eval_count, | 										prompt_eval_count: data.prompt_eval_count, | ||||||
| 									prompt_eval_duration: data.prompt_eval_duration, | 										prompt_eval_duration: data.prompt_eval_duration, | ||||||
| 									eval_count: data.eval_count, | 										eval_count: data.eval_count, | ||||||
| 									eval_duration: data.eval_duration | 										eval_duration: data.eval_duration | ||||||
| 								}; | 									}; | ||||||
| 								messages = messages; | 									messages = messages; | ||||||
| 
 | 
 | ||||||
| 								if ($settings.notificationEnabled && !document.hasFocus()) { | 									if ($settings.notificationEnabled && !document.hasFocus()) { | ||||||
| 									const notification = new Notification( | 										const notification = new Notification( | ||||||
| 										selectedModelfile | 											selectedModelfile | ||||||
| 											? `${ | 												? `${ | ||||||
| 													selectedModelfile.title.charAt(0).toUpperCase() + | 														selectedModelfile.title.charAt(0).toUpperCase() + | ||||||
| 													selectedModelfile.title.slice(1) | 														selectedModelfile.title.slice(1) | ||||||
| 											  }` | 												  }` | ||||||
| 											: `Ollama - ${model}`, | 												: `Ollama - ${model}`, | ||||||
| 										{ | 											{ | ||||||
| 											body: responseMessage.content, | 												body: responseMessage.content, | ||||||
| 											icon: selectedModelfile?.imageUrl ?? '/favicon.png' | 												icon: selectedModelfile?.imageUrl ?? '/favicon.png' | ||||||
| 										} | 											} | ||||||
| 									); | 										); | ||||||
| 								} | 									} | ||||||
| 
 | 
 | ||||||
| 								if ($settings.responseAutoCopy) { | 									if ($settings.responseAutoCopy) { | ||||||
| 									copyToClipboard(responseMessage.content); | 										copyToClipboard(responseMessage.content); | ||||||
|  | 									} | ||||||
| 								} | 								} | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
|  |  | ||||||
|  | @ -297,7 +297,7 @@ | ||||||
| 		// Scroll down | 		// Scroll down | ||||||
| 		window.scrollTo({ top: document.body.scrollHeight }); | 		window.scrollTo({ top: document.body.scrollHeight }); | ||||||
| 
 | 
 | ||||||
| 		const res = await generateChatCompletion(localStorage.token, { | 		const [res, controller] = await generateChatCompletion(localStorage.token, { | ||||||
| 			model: model, | 			model: model, | ||||||
| 			messages: [ | 			messages: [ | ||||||
| 				$settings.system | 				$settings.system | ||||||
|  | @ -335,6 +335,10 @@ | ||||||
| 				if (done || stopResponseFlag || _chatId !== $chatId) { | 				if (done || stopResponseFlag || _chatId !== $chatId) { | ||||||
| 					responseMessage.done = true; | 					responseMessage.done = true; | ||||||
| 					messages = messages; | 					messages = messages; | ||||||
|  | 
 | ||||||
|  | 					if (stopResponseFlag) { | ||||||
|  | 						controller.abort('User: Stop Response'); | ||||||
|  | 					} | ||||||
| 					break; | 					break; | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
|  | @ -350,52 +354,56 @@ | ||||||
| 								throw data; | 								throw data; | ||||||
| 							} | 							} | ||||||
| 
 | 
 | ||||||
| 							if (data.done == false) { | 							if ('id' in data) { | ||||||
| 								if (responseMessage.content == '' && data.message.content == '\n') { | 								console.log(data); | ||||||
| 									continue; |  | ||||||
| 								} else { |  | ||||||
| 									responseMessage.content += data.message.content; |  | ||||||
| 									messages = messages; |  | ||||||
| 								} |  | ||||||
| 							} else { | 							} else { | ||||||
| 								responseMessage.done = true; | 								if (data.done == false) { | ||||||
|  | 									if (responseMessage.content == '' && data.message.content == '\n') { | ||||||
|  | 										continue; | ||||||
|  | 									} else { | ||||||
|  | 										responseMessage.content += data.message.content; | ||||||
|  | 										messages = messages; | ||||||
|  | 									} | ||||||
|  | 								} else { | ||||||
|  | 									responseMessage.done = true; | ||||||
| 
 | 
 | ||||||
| 								if (responseMessage.content == '') { | 									if (responseMessage.content == '') { | ||||||
| 									responseMessage.error = true; | 										responseMessage.error = true; | ||||||
| 									responseMessage.content = | 										responseMessage.content = | ||||||
| 										'Oops! No text generated from Ollama, Please try again.'; | 											'Oops! No text generated from Ollama, Please try again.'; | ||||||
| 								} | 									} | ||||||
| 
 | 
 | ||||||
| 								responseMessage.context = data.context ?? null; | 									responseMessage.context = data.context ?? null; | ||||||
| 								responseMessage.info = { | 									responseMessage.info = { | ||||||
| 									total_duration: data.total_duration, | 										total_duration: data.total_duration, | ||||||
| 									load_duration: data.load_duration, | 										load_duration: data.load_duration, | ||||||
| 									sample_count: data.sample_count, | 										sample_count: data.sample_count, | ||||||
| 									sample_duration: data.sample_duration, | 										sample_duration: data.sample_duration, | ||||||
| 									prompt_eval_count: data.prompt_eval_count, | 										prompt_eval_count: data.prompt_eval_count, | ||||||
| 									prompt_eval_duration: data.prompt_eval_duration, | 										prompt_eval_duration: data.prompt_eval_duration, | ||||||
| 									eval_count: data.eval_count, | 										eval_count: data.eval_count, | ||||||
| 									eval_duration: data.eval_duration | 										eval_duration: data.eval_duration | ||||||
| 								}; | 									}; | ||||||
| 								messages = messages; | 									messages = messages; | ||||||
| 
 | 
 | ||||||
| 								if ($settings.notificationEnabled && !document.hasFocus()) { | 									if ($settings.notificationEnabled && !document.hasFocus()) { | ||||||
| 									const notification = new Notification( | 										const notification = new Notification( | ||||||
| 										selectedModelfile | 											selectedModelfile | ||||||
| 											? `${ | 												? `${ | ||||||
| 													selectedModelfile.title.charAt(0).toUpperCase() + | 														selectedModelfile.title.charAt(0).toUpperCase() + | ||||||
| 													selectedModelfile.title.slice(1) | 														selectedModelfile.title.slice(1) | ||||||
| 											  }` | 												  }` | ||||||
| 											: `Ollama - ${model}`, | 												: `Ollama - ${model}`, | ||||||
| 										{ | 											{ | ||||||
| 											body: responseMessage.content, | 												body: responseMessage.content, | ||||||
| 											icon: selectedModelfile?.imageUrl ?? '/favicon.png' | 												icon: selectedModelfile?.imageUrl ?? '/favicon.png' | ||||||
| 										} | 											} | ||||||
| 									); | 										); | ||||||
| 								} | 									} | ||||||
| 
 | 
 | ||||||
| 								if ($settings.responseAutoCopy) { | 									if ($settings.responseAutoCopy) { | ||||||
| 									copyToClipboard(responseMessage.content); | 										copyToClipboard(responseMessage.content); | ||||||
|  | 									} | ||||||
| 								} | 								} | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek