forked from open-webui/open-webui
		
	feat: terminate request on user stop
This commit is contained in:
		
							parent
							
								
									684bdf5151
								
							
						
					
					
						commit
						442e3d978a
					
				
					 4 changed files with 170 additions and 86 deletions
				
			
		|  | @ -5,6 +5,7 @@ from fastapi.concurrency import run_in_threadpool | |||
| 
 | ||||
| import requests | ||||
| import json | ||||
| import uuid | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from apps.web.models.users import Users | ||||
|  | @ -26,6 +27,9 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL | |||
| # TARGET_SERVER_URL = OLLAMA_API_BASE_URL | ||||
| 
 | ||||
| 
 | ||||
| REQUEST_POOL = [] | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/url") | ||||
| async def get_ollama_api_url(user=Depends(get_current_user)): | ||||
|     if user and user.role == "admin": | ||||
|  | @ -49,6 +53,16 @@ async def update_ollama_api_url( | |||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/cancel/{request_id}") | ||||
| async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)): | ||||
|     if user: | ||||
|         if request_id in REQUEST_POOL: | ||||
|             REQUEST_POOL.remove(request_id) | ||||
|         return True | ||||
|     else: | ||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
| 
 | ||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||
| async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||
|     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" | ||||
|  | @ -74,7 +88,27 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | |||
| 
 | ||||
|     def get_request(): | ||||
|         nonlocal r | ||||
| 
 | ||||
|         request_id = str(uuid.uuid4()) | ||||
|         try: | ||||
|             REQUEST_POOL.append(request_id) | ||||
| 
 | ||||
|             def stream_content(): | ||||
|                 try: | ||||
|                     if path in ["chat"]: | ||||
|                         yield json.dumps({"id": request_id, "done": False}) + "\n" | ||||
| 
 | ||||
|                     for chunk in r.iter_content(chunk_size=8192): | ||||
|                         if request_id in REQUEST_POOL: | ||||
|                             yield chunk | ||||
|                         else: | ||||
|                             print("User: canceled request") | ||||
|                             break | ||||
|                 finally: | ||||
|                     if hasattr(r, "close"): | ||||
|                         r.close() | ||||
|                         REQUEST_POOL.remove(request_id) | ||||
| 
 | ||||
|             r = requests.request( | ||||
|                 method=request.method, | ||||
|                 url=target_url, | ||||
|  | @ -85,8 +119,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | |||
| 
 | ||||
|             r.raise_for_status() | ||||
| 
 | ||||
|             # r.close() | ||||
| 
 | ||||
|             return StreamingResponse( | ||||
|                 r.iter_content(chunk_size=8192), | ||||
|                 stream_content(), | ||||
|                 status_code=r.status_code, | ||||
|                 headers=dict(r.headers), | ||||
|             ) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek