forked from open-webui/open-webui
		
	feat: terminate request on user stop
This commit is contained in:
		
							parent
							
								
									684bdf5151
								
							
						
					
					
						commit
						442e3d978a
					
				
					 4 changed files with 170 additions and 86 deletions
				
			
		|  | @ -5,6 +5,7 @@ from fastapi.concurrency import run_in_threadpool | |||
| 
 | ||||
| import requests | ||||
| import json | ||||
| import uuid | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from apps.web.models.users import Users | ||||
|  | @ -26,6 +27,9 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL | |||
| # TARGET_SERVER_URL = OLLAMA_API_BASE_URL | ||||
| 
 | ||||
| 
 | ||||
| REQUEST_POOL = [] | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/url") | ||||
| async def get_ollama_api_url(user=Depends(get_current_user)): | ||||
|     if user and user.role == "admin": | ||||
|  | @ -49,6 +53,16 @@ async def update_ollama_api_url( | |||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/cancel/{request_id}") | ||||
| async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)): | ||||
|     if user: | ||||
|         if request_id in REQUEST_POOL: | ||||
|             REQUEST_POOL.remove(request_id) | ||||
|         return True | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||
| async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||
|     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" | ||||
|  | @ -74,7 +88,27 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | |||
| 
 | ||||
|     def get_request(): | ||||
|         nonlocal r | ||||
| 
 | ||||
|         request_id = str(uuid.uuid4()) | ||||
|         try: | ||||
|             REQUEST_POOL.append(request_id) | ||||
| 
 | ||||
|             def stream_content(): | ||||
|                 try: | ||||
|                     if path in ["chat"]: | ||||
|                         yield json.dumps({"id": request_id, "done": False}) + "\n" | ||||
| 
 | ||||
|                     for chunk in r.iter_content(chunk_size=8192): | ||||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|                         r.close() | ||||
|                         REQUEST_POOL.remove(request_id) | ||||
| 
 | ||||
|             r = requests.request( | ||||
|                 method=request.method, | ||||
|                 url=target_url, | ||||
|  | @ -85,8 +119,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | |||
| 
 | ||||
|             r.raise_for_status() | ||||
| 
 | ||||
|             # r.close() | ||||
| 
 | ||||
|             return StreamingResponse( | ||||
|                 r.iter_content(chunk_size=8192), | ||||
|                 stream_content(), | ||||
|                 status_code=r.status_code, | ||||
|                 headers=dict(r.headers), | ||||
|             ) | ||||
|  |  | |||
|  | @ -206,9 +206,11 @@ export const generatePrompt = async (token: string = '', model: string, conversa | |||
| }; | ||||
| 
 | ||||
| export const generateChatCompletion = async (token: string = '', body: object) => { | ||||
| 	let controller = new AbortController(); | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, { | ||||
| 		signal: controller.signal, | ||||
| 		method: 'POST', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'text/event-stream', | ||||
|  | @ -224,6 +226,27 @@ export const generateChatCompletion = async (token: string = '', body: object) = | |||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return [res, controller]; | ||||
| }; | ||||
| 
 | ||||
| export const cancelChatCompletion = async (token: string = '', requestId: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, { | ||||
| 		method: 'GET', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'text/event-stream', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		} | ||||
| 	}).catch((err) => { | ||||
| 		error = err; | ||||
| 		return null; | ||||
| 	}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
|  |  | |||
|  | @ -9,7 +9,7 @@ | |||
| 	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; | ||||
| 	import { copyToClipboard, splitStream } from '$lib/utils'; | ||||
| 
 | ||||
| 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; | ||||
| 	import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; | ||||
| 	import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats'; | ||||
| 	import { queryVectorDB } from '$lib/apis/rag'; | ||||
| 	import { generateOpenAIChatCompletion } from '$lib/apis/openai'; | ||||
|  | @ -24,6 +24,8 @@ | |||
| 	let autoScroll = true; | ||||
| 	let processing = ''; | ||||
| 
 | ||||
| 	let currentRequestId = null; | ||||
| 
 | ||||
| 	let selectedModels = ['']; | ||||
| 
 | ||||
| 	let selectedModelfile = null; | ||||
|  | @ -279,7 +281,7 @@ | |||
| 		// Scroll down | ||||
| 		window.scrollTo({ top: document.body.scrollHeight }); | ||||
| 
 | ||||
| 		const res = await generateChatCompletion(localStorage.token, { | ||||
| 		const [res, controller] = await generateChatCompletion(localStorage.token, { | ||||
| 			model: model, | ||||
| 			messages: [ | ||||
| 				$settings.system | ||||
|  | @ -307,6 +309,8 @@ | |||
| 		}); | ||||
| 
 | ||||
| 		if (res && res.ok) { | ||||
| 			console.log('controller', controller); | ||||
| 
 | ||||
| 			const reader = res.body | ||||
| 				.pipeThrough(new TextDecoderStream()) | ||||
| 				.pipeThrough(splitStream('\n')) | ||||
|  | @ -317,6 +321,14 @@ | |||
| 				if (done || stopResponseFlag || _chatId !== $chatId) { | ||||
| 					responseMessage.done = true; | ||||
| 					messages = messages; | ||||
| 
 | ||||
| 					if (stopResponseFlag) { | ||||
| 						controller.abort('User: Stop Response'); | ||||
| 						await cancelChatCompletion(localStorage.token, currentRequestId); | ||||
| 					} | ||||
| 
 | ||||
| 					currentRequestId = null; | ||||
| 
 | ||||
| 					break; | ||||
| 				} | ||||
| 
 | ||||
|  | @ -332,6 +344,10 @@ | |||
| 								throw data; | ||||
| 							} | ||||
| 
 | ||||
| 							if ('id' in data) { | ||||
| 								console.log(data); | ||||
| 								currentRequestId = data.id; | ||||
| 							} else { | ||||
| 								if (data.done == false) { | ||||
| 									if (responseMessage.content == '' && data.message.content == '\n') { | ||||
| 										continue; | ||||
|  | @ -382,6 +398,7 @@ | |||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} catch (error) { | ||||
| 					console.log(error); | ||||
| 					if ('detail' in error) { | ||||
|  |  | |||
|  | @ -297,7 +297,7 @@ | |||
| 		// Scroll down | ||||
| 		window.scrollTo({ top: document.body.scrollHeight }); | ||||
| 
 | ||||
| 		const res = await generateChatCompletion(localStorage.token, { | ||||
| 		const [res, controller] = await generateChatCompletion(localStorage.token, { | ||||
| 			model: model, | ||||
| 			messages: [ | ||||
| 				$settings.system | ||||
|  | @ -335,6 +335,10 @@ | |||
| 				if (done || stopResponseFlag || _chatId !== $chatId) { | ||||
| 					responseMessage.done = true; | ||||
| 					messages = messages; | ||||
| 
 | ||||
| 					if (stopResponseFlag) { | ||||
| 						controller.abort('User: Stop Response'); | ||||
| 					} | ||||
| 					break; | ||||
| 				} | ||||
| 
 | ||||
|  | @ -350,6 +354,9 @@ | |||
| 								throw data; | ||||
| 							} | ||||
| 
 | ||||
| 							if ('id' in data) { | ||||
| 								console.log(data); | ||||
| 							} else { | ||||
| 								if (data.done == false) { | ||||
| 									if (responseMessage.content == '' && data.message.content == '\n') { | ||||
| 										continue; | ||||
|  | @ -400,6 +407,7 @@ | |||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} catch (error) { | ||||
| 					console.log(error); | ||||
| 					if ('detail' in error) { | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek