forked from open-webui/open-webui
		
	feat: async reverse proxy
This commit is contained in:
		
							parent
							
								
									76139fc8df
								
							
						
					
					
						commit
						47dc3b5fb2
					
				
					 1 changed files with 52 additions and 25 deletions
				
			
		|  | @ -11,6 +11,8 @@ from constants import ERROR_MESSAGES | ||||||
| from utils.utils import decode_token, get_current_user | from utils.utils import decode_token, get_current_user | ||||||
| from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | ||||||
| 
 | 
 | ||||||
|  | import aiohttp | ||||||
|  | 
 | ||||||
| app = FastAPI() | app = FastAPI() | ||||||
| app.add_middleware( | app.add_middleware( | ||||||
|     CORSMiddleware, |     CORSMiddleware, | ||||||
|  | @ -30,8 +32,7 @@ async def get_ollama_api_url(user=Depends(get_current_user)): | ||||||
|     if user and user.role == "admin": |     if user and user.role == "admin": | ||||||
|         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} |         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} | ||||||
|     else: |     else: | ||||||
|         raise HTTPException(status_code=401, |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
|                             detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class UrlUpdateForm(BaseModel): | class UrlUpdateForm(BaseModel): | ||||||
|  | @ -39,14 +40,29 @@ class UrlUpdateForm(BaseModel): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/url/update") | @app.post("/url/update") | ||||||
| async def update_ollama_api_url(form_data: UrlUpdateForm, | async def update_ollama_api_url( | ||||||
|                                 user=Depends(get_current_user)): |     form_data: UrlUpdateForm, user=Depends(get_current_user) | ||||||
|  | ): | ||||||
|     if user and user.role == "admin": |     if user and user.role == "admin": | ||||||
|         app.state.OLLAMA_API_BASE_URL = form_data.url |         app.state.OLLAMA_API_BASE_URL = form_data.url | ||||||
|         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} |         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} | ||||||
|     else: |     else: | ||||||
|         raise HTTPException(status_code=401, |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
|                             detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | 
 | ||||||
|  | 
 | ||||||
|  | # async def fetch_sse(method, target_url, body, headers): | ||||||
|  | #     async with aiohttp.ClientSession() as session: | ||||||
|  | #         try: | ||||||
|  | #             async with session.request( | ||||||
|  | #                 method, target_url, data=body, headers=headers | ||||||
|  | #             ) as response: | ||||||
|  | #                 print(response.status) | ||||||
|  | #                 async for line in response.content: | ||||||
|  | #                     yield line | ||||||
|  | #         except Exception as e: | ||||||
|  | #             print(e) | ||||||
|  | #             error_detail = "Ollama WebUI: Server Connection Error" | ||||||
|  | #             yield json.dumps({"error": error_detail, "message": str(e)}).encode() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||||
|  | @ -59,42 +75,53 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
|     if user.role in ["user", "admin"]: |     if user.role in ["user", "admin"]: | ||||||
|         if path in ["pull", "delete", "push", "copy", "create"]: |         if path in ["pull", "delete", "push", "copy", "create"]: | ||||||
|             if user.role != "admin": |             if user.role != "admin": | ||||||
|                 raise HTTPException(status_code=401, |                 raise HTTPException( | ||||||
|                                     detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |                     status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED | ||||||
|  |                 ) | ||||||
|     else: |     else: | ||||||
|         raise HTTPException(status_code=401, |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
|                             detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |  | ||||||
| 
 | 
 | ||||||
|     headers.pop("Host", None) |     headers.pop("Host", None) | ||||||
|     headers.pop("Authorization", None) |     headers.pop("Authorization", None) | ||||||
|     headers.pop("Origin", None) |     headers.pop("Origin", None) | ||||||
|     headers.pop("Referer", None) |     headers.pop("Referer", None) | ||||||
| 
 | 
 | ||||||
|  |     session = aiohttp.ClientSession() | ||||||
|  |     response = None | ||||||
|     try: |     try: | ||||||
|         r = requests.request( |         response = await session.request( | ||||||
|             method=request.method, |             request.method, target_url, data=body, headers=headers | ||||||
|             url=target_url, |  | ||||||
|             data=body, |  | ||||||
|             headers=headers, |  | ||||||
|             stream=True, |  | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         r.raise_for_status() |         if not response.ok: | ||||||
|  |             data = await response.json() | ||||||
|  |             print(data) | ||||||
|  |             response.raise_for_status() | ||||||
|  | 
 | ||||||
|  |         async def gen(): | ||||||
|  |             async for line in response.content: | ||||||
|  |                 yield line | ||||||
|  |             await session.close() | ||||||
|  | 
 | ||||||
|  |         return StreamingResponse(gen(), response.status) | ||||||
| 
 | 
 | ||||||
|         return StreamingResponse( |  | ||||||
|             r.iter_content(chunk_size=8192), |  | ||||||
|             status_code=r.status_code, |  | ||||||
|             headers=dict(r.headers), |  | ||||||
|         ) |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(e) |         print(e) | ||||||
|         error_detail = "Ollama WebUI: Server Connection Error" |         error_detail = "Ollama WebUI: Server Connection Error" | ||||||
|         if r is not None: |         if response is not None: | ||||||
|             try: |             try: | ||||||
|                 res = r.json() |                 res = await response.json() | ||||||
|                 if "error" in res: |                 if "error" in res: | ||||||
|                     error_detail = f"Ollama: {res['error']}" |                     error_detail = f"Ollama: {res['error']}" | ||||||
|             except: |             except: | ||||||
|                 error_detail = f"Ollama: {e}" |                 error_detail = f"Ollama: {e}" | ||||||
| 
 | 
 | ||||||
|         raise HTTPException(status_code=r.status_code, detail=error_detail) |         await session.close() | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=response.status if response else 500, | ||||||
|  |             detail=error_detail, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # print(e) | ||||||
|  |         # error_detail = "Ollama WebUI: Server Connection Error" | ||||||
|  |         # return {"error": error_detail, "message": str(e)} | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek