From bb2971260d7df8406e8e87369c4e435872c83939 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 5 Jan 2024 17:16:35 -0800 Subject: [PATCH] fix: backend proxy --- Dockerfile | 5 +- backend/apps/ollama/main.py | 72 ++++------- backend/apps/ollama/old_main.py | 223 ++++++++++++-------------------- 3 files changed, 116 insertions(+), 184 deletions(-) diff --git a/Dockerfile b/Dockerfile index 504cfff6..2e57ee66 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,10 +12,9 @@ RUN npm run build FROM python:3.11-slim-buster as base -ARG OLLAMA_API_BASE_URL='/ollama/api' - ENV ENV=prod -ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL + +ENV OLLAMA_API_BASE_URL "/ollama/api" ENV OPENAI_API_BASE_URL "" ENV OPENAI_API_KEY "" diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 961d1935..dc0c9d3f 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,6 +1,7 @@ from fastapi import FastAPI, Request, Response, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse +from fastapi.concurrency import run_in_threadpool import requests import json @@ -11,8 +12,6 @@ from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user from config import OLLAMA_API_BASE_URL, WEBUI_AUTH -import aiohttp - app = FastAPI() app.add_middleware( CORSMiddleware, @@ -50,25 +49,9 @@ async def update_ollama_api_url( raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) -# async def fetch_sse(method, target_url, body, headers): -# async with aiohttp.ClientSession() as session: -# try: -# async with session.request( -# method, target_url, data=body, headers=headers -# ) as response: -# print(response.status) -# async for line in response.content: -# yield line -# except Exception as e: -# print(e) -# error_detail = "Ollama WebUI: Server Connection Error" -# yield json.dumps({"error": error_detail, "message": str(e)}).encode() - - @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_current_user)): target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" - print(target_url) body = await request.body() headers = dict(request.headers) @@ -87,41 +70,42 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): headers.pop("Origin", None) headers.pop("Referer", None) - session = aiohttp.ClientSession() - response = None + r = None + + def get_request(): + nonlocal r + try: + r = requests.request( + method=request.method, + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + except Exception as e: + raise e + try: - response = await session.request( - request.method, target_url, data=body, headers=headers - ) - - print(response) - if not response.ok: - data = await response.json() - print(data) - response.raise_for_status() - - async def generate(): - async for line in response.content: - print(line) - yield line - await session.close() - - return StreamingResponse(generate(), response.status) - + return await run_in_threadpool(get_request) except Exception as e: - print(e) error_detail = "Ollama WebUI: Server Connection Error" - - if response is not None: + if r is not None: try: - res = await response.json() + res = r.json() if "error" in res: error_detail = f"Ollama: {res['error']}" except: error_detail = f"Ollama: {e}" - await session.close() raise HTTPException( - status_code=response.status if response else 500, + status_code=r.status_code if r else 500, detail=error_detail, ) diff --git a/backend/apps/ollama/old_main.py b/backend/apps/ollama/old_main.py index f809d442..961d1935 100644 --- a/backend/apps/ollama/old_main.py +++ b/backend/apps/ollama/old_main.py @@ -1,178 +1,127 @@ -from flask import Flask, request, Response, jsonify -from flask_cors import CORS +from fastapi import FastAPI, Request, Response, HTTPException, Depends +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse import requests import json +from pydantic import BaseModel from apps.web.models.users import Users from constants import ERROR_MESSAGES -from utils.utils import decode_token +from utils.utils import decode_token, get_current_user from config import OLLAMA_API_BASE_URL, WEBUI_AUTH -app = Flask(__name__) -CORS( - app -) # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains +import aiohttp -# Define the target server URL -TARGET_SERVER_URL = OLLAMA_API_BASE_URL +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL + +# TARGET_SERVER_URL = OLLAMA_API_BASE_URL -@app.route("/url", methods=["GET"]) -def get_ollama_api_url(): - headers = dict(request.headers) - if "Authorization" in headers: - _, credentials = headers["Authorization"].split() - token_data = decode_token(credentials) - if token_data is None or "email" not in token_data: - return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 - - user = Users.get_user_by_email(token_data["email"]) - if user and user.role == "admin": - return ( - jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}), - 200, - ) - else: - return ( - jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), - 401, - ) +@app.get("/url") +async def get_ollama_api_url(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} else: - return ( - jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), - 401, - ) + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) -@app.route("/url/update", methods=["POST"]) -def update_ollama_api_url(): - headers = dict(request.headers) - data = request.get_json(force=True) +class UrlUpdateForm(BaseModel): + url: str - if "Authorization" in headers: - _, credentials = headers["Authorization"].split() - token_data = decode_token(credentials) - if token_data is None or "email" not in token_data: - return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 - user = Users.get_user_by_email(token_data["email"]) - if user and user.role == "admin": - TARGET_SERVER_URL = data["url"] - return ( - jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}), - 200, - ) - else: - return ( - jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), - 401, - ) +@app.post("/url/update") +async def update_ollama_api_url( + form_data: UrlUpdateForm, user=Depends(get_current_user) +): + if user and user.role == "admin": + app.state.OLLAMA_API_BASE_URL = form_data.url + return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} else: - return ( - jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), - 401, - ) + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) -@app.route("/", - defaults={"path": ""}, - methods=["GET", "POST", "PUT", "DELETE"]) -@app.route("/", methods=["GET", "POST", "PUT", "DELETE"]) -def proxy(path): - # Combine the base URL of the target server with the requested path - target_url = f"{TARGET_SERVER_URL}/{path}" +# async def fetch_sse(method, target_url, body, headers): +# async with aiohttp.ClientSession() as session: +# try: +# async with session.request( +# method, target_url, data=body, headers=headers +# ) as response: +# print(response.status) +# async for line in response.content: +# yield line +# except Exception as e: +# print(e) +# error_detail = "Ollama WebUI: Server Connection Error" +# yield json.dumps({"error": error_detail, "message": str(e)}).encode() + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_current_user)): + target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" print(target_url) - # Get data from the original request - data = request.get_data() + body = await request.body() headers = dict(request.headers) - # Basic RBAC support - if WEBUI_AUTH: - if "Authorization" in headers: - _, credentials = headers["Authorization"].split() - token_data = decode_token(credentials) - if token_data is None or "email" not in token_data: - return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 - - user = Users.get_user_by_email(token_data["email"]) - if user: - # Only user and admin roles can access - if user.role in ["user", "admin"]: - if path in ["pull", "delete", "push", "copy", "create"]: - # Only admin role can perform actions above - if user.role == "admin": - pass - else: - return ( - jsonify({ - "detail": - ERROR_MESSAGES.ACCESS_PROHIBITED - }), - 401, - ) - else: - pass - else: - return jsonify( - {"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401 - else: - return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 - else: - return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 + if user.role in ["user", "admin"]: + if path in ["pull", "delete", "push", "copy", "create"]: + if user.role != "admin": + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED + ) else: - pass - - r = None + raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) headers.pop("Host", None) headers.pop("Authorization", None) headers.pop("Origin", None) headers.pop("Referer", None) + session = aiohttp.ClientSession() + response = None try: - # Make a request to the target server - r = requests.request( - method=request.method, - url=target_url, - data=data, - headers=headers, - stream=True, # Enable streaming for server-sent events + response = await session.request( + request.method, target_url, data=body, headers=headers ) - r.raise_for_status() + print(response) + if not response.ok: + data = await response.json() + print(data) + response.raise_for_status() - # Proxy the target server's response to the client - def generate(): - for chunk in r.iter_content(chunk_size=8192): - yield chunk + async def generate(): + async for line in response.content: + print(line) + yield line + await session.close() - response = Response(generate(), status=r.status_code) + return StreamingResponse(generate(), response.status) - # Copy headers from the target server's response to the client's response - for key, value in r.headers.items(): - response.headers[key] = value - - return response except Exception as e: print(e) error_detail = "Ollama WebUI: Server Connection Error" - if r != None: - print(r.text) - res = r.json() - if "error" in res: - error_detail = f"Ollama: {res['error']}" - print(res) - return ( - jsonify({ - "detail": error_detail, - "message": str(e), - }), - 400, + if response is not None: + try: + res = await response.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + await session.close() + raise HTTPException( + status_code=response.status if response else 500, + detail=error_detail, ) - - -if __name__ == "__main__": - app.run(debug=True)