forked from open-webui/open-webui
		
	fix: backend proxy
This commit is contained in:
		
							parent
							
								
									439185be80
								
							
						
					
					
						commit
						bb2971260d
					
				
					 3 changed files with 116 additions and 184 deletions
				
			
		|  | @ -12,10 +12,9 @@ RUN npm run build | ||||||
| 
 | 
 | ||||||
| FROM python:3.11-slim-buster as base | FROM python:3.11-slim-buster as base | ||||||
| 
 | 
 | ||||||
| ARG OLLAMA_API_BASE_URL='/ollama/api' |  | ||||||
| 
 |  | ||||||
| ENV ENV=prod | ENV ENV=prod | ||||||
| ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL | 
 | ||||||
|  | ENV OLLAMA_API_BASE_URL "/ollama/api" | ||||||
| 
 | 
 | ||||||
| ENV OPENAI_API_BASE_URL "" | ENV OPENAI_API_BASE_URL "" | ||||||
| ENV OPENAI_API_KEY "" | ENV OPENAI_API_KEY "" | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| from fastapi import FastAPI, Request, Response, HTTPException, Depends | from fastapi import FastAPI, Request, Response, HTTPException, Depends | ||||||
| from fastapi.middleware.cors import CORSMiddleware | from fastapi.middleware.cors import CORSMiddleware | ||||||
| from fastapi.responses import StreamingResponse | from fastapi.responses import StreamingResponse | ||||||
|  | from fastapi.concurrency import run_in_threadpool | ||||||
| 
 | 
 | ||||||
| import requests | import requests | ||||||
| import json | import json | ||||||
|  | @ -11,8 +12,6 @@ from constants import ERROR_MESSAGES | ||||||
| from utils.utils import decode_token, get_current_user | from utils.utils import decode_token, get_current_user | ||||||
| from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | ||||||
| 
 | 
 | ||||||
| import aiohttp |  | ||||||
| 
 |  | ||||||
| app = FastAPI() | app = FastAPI() | ||||||
| app.add_middleware( | app.add_middleware( | ||||||
|     CORSMiddleware, |     CORSMiddleware, | ||||||
|  | @ -50,25 +49,9 @@ async def update_ollama_api_url( | ||||||
|         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # async def fetch_sse(method, target_url, body, headers): |  | ||||||
| #     async with aiohttp.ClientSession() as session: |  | ||||||
| #         try: |  | ||||||
| #             async with session.request( |  | ||||||
| #                 method, target_url, data=body, headers=headers |  | ||||||
| #             ) as response: |  | ||||||
| #                 print(response.status) |  | ||||||
| #                 async for line in response.content: |  | ||||||
| #                     yield line |  | ||||||
| #         except Exception as e: |  | ||||||
| #             print(e) |  | ||||||
| #             error_detail = "Ollama WebUI: Server Connection Error" |  | ||||||
| #             yield json.dumps({"error": error_detail, "message": str(e)}).encode() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||||
| async def proxy(path: str, request: Request, user=Depends(get_current_user)): | async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
|     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" |     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" | ||||||
|     print(target_url) |  | ||||||
| 
 | 
 | ||||||
|     body = await request.body() |     body = await request.body() | ||||||
|     headers = dict(request.headers) |     headers = dict(request.headers) | ||||||
|  | @ -87,41 +70,42 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
|     headers.pop("Origin", None) |     headers.pop("Origin", None) | ||||||
|     headers.pop("Referer", None) |     headers.pop("Referer", None) | ||||||
| 
 | 
 | ||||||
|     session = aiohttp.ClientSession() |     r = None | ||||||
|     response = None | 
 | ||||||
|  |     def get_request(): | ||||||
|  |         nonlocal r | ||||||
|  |         try: | ||||||
|  |             r = requests.request( | ||||||
|  |                 method=request.method, | ||||||
|  |                 url=target_url, | ||||||
|  |                 data=body, | ||||||
|  |                 headers=headers, | ||||||
|  |                 stream=True, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             r.raise_for_status() | ||||||
|  | 
 | ||||||
|  |             return StreamingResponse( | ||||||
|  |                 r.iter_content(chunk_size=8192), | ||||||
|  |                 status_code=r.status_code, | ||||||
|  |                 headers=dict(r.headers), | ||||||
|  |             ) | ||||||
|  |         except Exception as e: | ||||||
|  |             raise e | ||||||
|  | 
 | ||||||
|     try: |     try: | ||||||
|         response = await session.request( |         return await run_in_threadpool(get_request) | ||||||
|             request.method, target_url, data=body, headers=headers |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         print(response) |  | ||||||
|         if not response.ok: |  | ||||||
|             data = await response.json() |  | ||||||
|             print(data) |  | ||||||
|             response.raise_for_status() |  | ||||||
| 
 |  | ||||||
|         async def generate(): |  | ||||||
|             async for line in response.content: |  | ||||||
|                 print(line) |  | ||||||
|                 yield line |  | ||||||
|             await session.close() |  | ||||||
| 
 |  | ||||||
|         return StreamingResponse(generate(), response.status) |  | ||||||
| 
 |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(e) |  | ||||||
|         error_detail = "Ollama WebUI: Server Connection Error" |         error_detail = "Ollama WebUI: Server Connection Error" | ||||||
| 
 |         if r is not None: | ||||||
|         if response is not None: |  | ||||||
|             try: |             try: | ||||||
|                 res = await response.json() |                 res = r.json() | ||||||
|                 if "error" in res: |                 if "error" in res: | ||||||
|                     error_detail = f"Ollama: {res['error']}" |                     error_detail = f"Ollama: {res['error']}" | ||||||
|             except: |             except: | ||||||
|                 error_detail = f"Ollama: {e}" |                 error_detail = f"Ollama: {e}" | ||||||
| 
 | 
 | ||||||
|         await session.close() |  | ||||||
|         raise HTTPException( |         raise HTTPException( | ||||||
|             status_code=response.status if response else 500, |             status_code=r.status_code if r else 500, | ||||||
|             detail=error_detail, |             detail=error_detail, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -1,178 +1,127 @@ | ||||||
| from flask import Flask, request, Response, jsonify | from fastapi import FastAPI, Request, Response, HTTPException, Depends | ||||||
| from flask_cors import CORS | from fastapi.middleware.cors import CORSMiddleware | ||||||
|  | from fastapi.responses import StreamingResponse | ||||||
| 
 | 
 | ||||||
| import requests | import requests | ||||||
| import json | import json | ||||||
|  | from pydantic import BaseModel | ||||||
| 
 | 
 | ||||||
| from apps.web.models.users import Users | from apps.web.models.users import Users | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| from utils.utils import decode_token | from utils.utils import decode_token, get_current_user | ||||||
| from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | from config import OLLAMA_API_BASE_URL, WEBUI_AUTH | ||||||
| 
 | 
 | ||||||
| app = Flask(__name__) | import aiohttp | ||||||
| CORS( |  | ||||||
|     app |  | ||||||
| )  # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains |  | ||||||
| 
 | 
 | ||||||
| # Define the target server URL | app = FastAPI() | ||||||
| TARGET_SERVER_URL = OLLAMA_API_BASE_URL | app.add_middleware( | ||||||
|  |     CORSMiddleware, | ||||||
|  |     allow_origins=["*"], | ||||||
|  |     allow_credentials=True, | ||||||
|  |     allow_methods=["*"], | ||||||
|  |     allow_headers=["*"], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL | ||||||
|  | 
 | ||||||
|  | # TARGET_SERVER_URL = OLLAMA_API_BASE_URL | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.route("/url", methods=["GET"]) | @app.get("/url") | ||||||
| def get_ollama_api_url(): | async def get_ollama_api_url(user=Depends(get_current_user)): | ||||||
|     headers = dict(request.headers) |     if user and user.role == "admin": | ||||||
|     if "Authorization" in headers: |         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} | ||||||
|         _, credentials = headers["Authorization"].split() |  | ||||||
|         token_data = decode_token(credentials) |  | ||||||
|         if token_data is None or "email" not in token_data: |  | ||||||
|             return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 |  | ||||||
| 
 |  | ||||||
|         user = Users.get_user_by_email(token_data["email"]) |  | ||||||
|         if user and user.role == "admin": |  | ||||||
|             return ( |  | ||||||
|                 jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}), |  | ||||||
|                 200, |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             return ( |  | ||||||
|                 jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), |  | ||||||
|                 401, |  | ||||||
|             ) |  | ||||||
|     else: |     else: | ||||||
|         return ( |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
|             jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), |  | ||||||
|             401, |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.route("/url/update", methods=["POST"]) | class UrlUpdateForm(BaseModel): | ||||||
| def update_ollama_api_url(): |     url: str | ||||||
|     headers = dict(request.headers) |  | ||||||
|     data = request.get_json(force=True) |  | ||||||
| 
 | 
 | ||||||
|     if "Authorization" in headers: |  | ||||||
|         _, credentials = headers["Authorization"].split() |  | ||||||
|         token_data = decode_token(credentials) |  | ||||||
|         if token_data is None or "email" not in token_data: |  | ||||||
|             return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 |  | ||||||
| 
 | 
 | ||||||
|         user = Users.get_user_by_email(token_data["email"]) | @app.post("/url/update") | ||||||
|         if user and user.role == "admin": | async def update_ollama_api_url( | ||||||
|             TARGET_SERVER_URL = data["url"] |     form_data: UrlUpdateForm, user=Depends(get_current_user) | ||||||
|             return ( | ): | ||||||
|                 jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}), |     if user and user.role == "admin": | ||||||
|                 200, |         app.state.OLLAMA_API_BASE_URL = form_data.url | ||||||
|             ) |         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} | ||||||
|         else: |  | ||||||
|             return ( |  | ||||||
|                 jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), |  | ||||||
|                 401, |  | ||||||
|             ) |  | ||||||
|     else: |     else: | ||||||
|         return ( |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
|             jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), |  | ||||||
|             401, |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.route("/", | # async def fetch_sse(method, target_url, body, headers): | ||||||
|            defaults={"path": ""}, | #     async with aiohttp.ClientSession() as session: | ||||||
|            methods=["GET", "POST", "PUT", "DELETE"]) | #         try: | ||||||
| @app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"]) | #             async with session.request( | ||||||
| def proxy(path): | #                 method, target_url, data=body, headers=headers | ||||||
|     # Combine the base URL of the target server with the requested path | #             ) as response: | ||||||
|     target_url = f"{TARGET_SERVER_URL}/{path}" | #                 print(response.status) | ||||||
|  | #                 async for line in response.content: | ||||||
|  | #                     yield line | ||||||
|  | #         except Exception as e: | ||||||
|  | #             print(e) | ||||||
|  | #             error_detail = "Ollama WebUI: Server Connection Error" | ||||||
|  | #             yield json.dumps({"error": error_detail, "message": str(e)}).encode() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | ||||||
|  | async def proxy(path: str, request: Request, user=Depends(get_current_user)): | ||||||
|  |     target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" | ||||||
|     print(target_url) |     print(target_url) | ||||||
| 
 | 
 | ||||||
|     # Get data from the original request |     body = await request.body() | ||||||
|     data = request.get_data() |  | ||||||
|     headers = dict(request.headers) |     headers = dict(request.headers) | ||||||
| 
 | 
 | ||||||
|     # Basic RBAC support |     if user.role in ["user", "admin"]: | ||||||
|     if WEBUI_AUTH: |         if path in ["pull", "delete", "push", "copy", "create"]: | ||||||
|         if "Authorization" in headers: |             if user.role != "admin": | ||||||
|             _, credentials = headers["Authorization"].split() |                 raise HTTPException( | ||||||
|             token_data = decode_token(credentials) |                     status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED | ||||||
|             if token_data is None or "email" not in token_data: |                 ) | ||||||
|                 return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 |  | ||||||
| 
 |  | ||||||
|             user = Users.get_user_by_email(token_data["email"]) |  | ||||||
|             if user: |  | ||||||
|                 # Only user and admin roles can access |  | ||||||
|                 if user.role in ["user", "admin"]: |  | ||||||
|                     if path in ["pull", "delete", "push", "copy", "create"]: |  | ||||||
|                         # Only admin role can perform actions above |  | ||||||
|                         if user.role == "admin": |  | ||||||
|                             pass |  | ||||||
|                         else: |  | ||||||
|                             return ( |  | ||||||
|                                 jsonify({ |  | ||||||
|                                     "detail": |  | ||||||
|                                     ERROR_MESSAGES.ACCESS_PROHIBITED |  | ||||||
|                                 }), |  | ||||||
|                                 401, |  | ||||||
|                             ) |  | ||||||
|                     else: |  | ||||||
|                         pass |  | ||||||
|                 else: |  | ||||||
|                     return jsonify( |  | ||||||
|                         {"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401 |  | ||||||
|             else: |  | ||||||
|                 return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 |  | ||||||
|         else: |  | ||||||
|             return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 |  | ||||||
|     else: |     else: | ||||||
|         pass |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||||
| 
 |  | ||||||
|     r = None |  | ||||||
| 
 | 
 | ||||||
|     headers.pop("Host", None) |     headers.pop("Host", None) | ||||||
|     headers.pop("Authorization", None) |     headers.pop("Authorization", None) | ||||||
|     headers.pop("Origin", None) |     headers.pop("Origin", None) | ||||||
|     headers.pop("Referer", None) |     headers.pop("Referer", None) | ||||||
| 
 | 
 | ||||||
|  |     session = aiohttp.ClientSession() | ||||||
|  |     response = None | ||||||
|     try: |     try: | ||||||
|         # Make a request to the target server |         response = await session.request( | ||||||
|         r = requests.request( |             request.method, target_url, data=body, headers=headers | ||||||
|             method=request.method, |  | ||||||
|             url=target_url, |  | ||||||
|             data=data, |  | ||||||
|             headers=headers, |  | ||||||
|             stream=True,  # Enable streaming for server-sent events |  | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         r.raise_for_status() |         print(response) | ||||||
|  |         if not response.ok: | ||||||
|  |             data = await response.json() | ||||||
|  |             print(data) | ||||||
|  |             response.raise_for_status() | ||||||
| 
 | 
 | ||||||
|         # Proxy the target server's response to the client |         async def generate(): | ||||||
|         def generate(): |             async for line in response.content: | ||||||
|             for chunk in r.iter_content(chunk_size=8192): |                 print(line) | ||||||
|                 yield chunk |                 yield line | ||||||
|  |             await session.close() | ||||||
| 
 | 
 | ||||||
|         response = Response(generate(), status=r.status_code) |         return StreamingResponse(generate(), response.status) | ||||||
| 
 | 
 | ||||||
|         # Copy headers from the target server's response to the client's response |  | ||||||
|         for key, value in r.headers.items(): |  | ||||||
|             response.headers[key] = value |  | ||||||
| 
 |  | ||||||
|         return response |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(e) |         print(e) | ||||||
|         error_detail = "Ollama WebUI: Server Connection Error" |         error_detail = "Ollama WebUI: Server Connection Error" | ||||||
|         if r != None: |  | ||||||
|             print(r.text) |  | ||||||
|             res = r.json() |  | ||||||
|             if "error" in res: |  | ||||||
|                 error_detail = f"Ollama: {res['error']}" |  | ||||||
|             print(res) |  | ||||||
| 
 | 
 | ||||||
|         return ( |         if response is not None: | ||||||
|             jsonify({ |             try: | ||||||
|                 "detail": error_detail, |                 res = await response.json() | ||||||
|                 "message": str(e), |                 if "error" in res: | ||||||
|             }), |                     error_detail = f"Ollama: {res['error']}" | ||||||
|             400, |             except: | ||||||
|  |                 error_detail = f"Ollama: {e}" | ||||||
|  | 
 | ||||||
|  |         await session.close() | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=response.status if response else 500, | ||||||
|  |             detail=error_detail, | ||||||
|         ) |         ) | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     app.run(debug=True) |  | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek