fix: backend proxy

This commit is contained in:
Timothy J. Baek 2024-01-05 17:16:35 -08:00
parent 439185be80
commit bb2971260d
3 changed files with 116 additions and 184 deletions

View file

@ -12,10 +12,9 @@ RUN npm run build
FROM python:3.11-slim-buster as base FROM python:3.11-slim-buster as base
ARG OLLAMA_API_BASE_URL='/ollama/api'
ENV ENV=prod ENV ENV=prod
ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL
ENV OLLAMA_API_BASE_URL "/ollama/api"
ENV OPENAI_API_BASE_URL "" ENV OPENAI_API_BASE_URL ""
ENV OPENAI_API_KEY "" ENV OPENAI_API_KEY ""

View file

@ -1,6 +1,7 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
import requests import requests
import json import json
@ -11,8 +12,6 @@ from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user from utils.utils import decode_token, get_current_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
import aiohttp
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -50,25 +49,9 @@ async def update_ollama_api_url(
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# async def fetch_sse(method, target_url, body, headers):
# async with aiohttp.ClientSession() as session:
# try:
# async with session.request(
# method, target_url, data=body, headers=headers
# ) as response:
# print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Ollama WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)): async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}" target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
print(target_url)
body = await request.body() body = await request.body()
headers = dict(request.headers) headers = dict(request.headers)
@ -87,41 +70,42 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
headers.pop("Origin", None) headers.pop("Origin", None)
headers.pop("Referer", None) headers.pop("Referer", None)
session = aiohttp.ClientSession() r = None
response = None
def get_request():
nonlocal r
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try: try:
response = await session.request( return await run_in_threadpool(get_request)
request.method, target_url, data=body, headers=headers
)
print(response)
if not response.ok:
data = await response.json()
print(data)
response.raise_for_status()
async def generate():
async for line in response.content:
print(line)
yield line
await session.close()
return StreamingResponse(generate(), response.status)
except Exception as e: except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error" error_detail = "Ollama WebUI: Server Connection Error"
if r is not None:
if response is not None:
try: try:
res = await response.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
await session.close()
raise HTTPException( raise HTTPException(
status_code=response.status if response else 500, status_code=r.status_code if r else 500,
detail=error_detail, detail=error_detail,
) )

View file

@ -1,178 +1,127 @@
from flask import Flask, request, Response, jsonify from fastapi import FastAPI, Request, Response, HTTPException, Depends
from flask_cors import CORS from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import requests import requests
import json import json
from pydantic import BaseModel
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token from utils.utils import decode_token, get_current_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = Flask(__name__) import aiohttp
CORS(
app
) # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains
# Define the target server URL app = FastAPI()
TARGET_SERVER_URL = OLLAMA_API_BASE_URL app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@app.route("/url", methods=["GET"]) @app.get("/url")
def get_ollama_api_url(): async def get_ollama_api_url(user=Depends(get_current_user)):
headers = dict(request.headers) if user and user.role == "admin":
if "Authorization" in headers: return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
_, credentials = headers["Authorization"].split()
token_data = decode_token(credentials)
if token_data is None or "email" not in token_data:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
user = Users.get_user_by_email(token_data["email"])
if user and user.role == "admin":
return (
jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}),
200,
)
else:
return (
jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
401,
)
else: else:
return ( raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}),
401,
)
@app.route("/url/update", methods=["POST"]) class UrlUpdateForm(BaseModel):
def update_ollama_api_url(): url: str
headers = dict(request.headers)
data = request.get_json(force=True)
if "Authorization" in headers:
_, credentials = headers["Authorization"].split()
token_data = decode_token(credentials)
if token_data is None or "email" not in token_data:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
user = Users.get_user_by_email(token_data["email"]) @app.post("/url/update")
if user and user.role == "admin": async def update_ollama_api_url(
TARGET_SERVER_URL = data["url"] form_data: UrlUpdateForm, user=Depends(get_current_user)
return ( ):
jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}), if user and user.role == "admin":
200, app.state.OLLAMA_API_BASE_URL = form_data.url
) return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
return (
jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
401,
)
else: else:
return ( raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}),
401,
)
@app.route("/", # async def fetch_sse(method, target_url, body, headers):
defaults={"path": ""}, # async with aiohttp.ClientSession() as session:
methods=["GET", "POST", "PUT", "DELETE"]) # try:
@app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"]) # async with session.request(
def proxy(path): # method, target_url, data=body, headers=headers
# Combine the base URL of the target server with the requested path # ) as response:
target_url = f"{TARGET_SERVER_URL}/{path}" # print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Ollama WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
print(target_url) print(target_url)
# Get data from the original request body = await request.body()
data = request.get_data()
headers = dict(request.headers) headers = dict(request.headers)
# Basic RBAC support if user.role in ["user", "admin"]:
if WEBUI_AUTH: if path in ["pull", "delete", "push", "copy", "create"]:
if "Authorization" in headers: if user.role != "admin":
_, credentials = headers["Authorization"].split() raise HTTPException(
token_data = decode_token(credentials) status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
if token_data is None or "email" not in token_data: )
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
user = Users.get_user_by_email(token_data["email"])
if user:
# Only user and admin roles can access
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
# Only admin role can perform actions above
if user.role == "admin":
pass
else:
return (
jsonify({
"detail":
ERROR_MESSAGES.ACCESS_PROHIBITED
}),
401,
)
else:
pass
else:
return jsonify(
{"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
else:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
else: else:
pass raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
r = None
headers.pop("Host", None) headers.pop("Host", None)
headers.pop("Authorization", None) headers.pop("Authorization", None)
headers.pop("Origin", None) headers.pop("Origin", None)
headers.pop("Referer", None) headers.pop("Referer", None)
session = aiohttp.ClientSession()
response = None
try: try:
# Make a request to the target server response = await session.request(
r = requests.request( request.method, target_url, data=body, headers=headers
method=request.method,
url=target_url,
data=data,
headers=headers,
stream=True, # Enable streaming for server-sent events
) )
r.raise_for_status() print(response)
if not response.ok:
data = await response.json()
print(data)
response.raise_for_status()
# Proxy the target server's response to the client async def generate():
def generate(): async for line in response.content:
for chunk in r.iter_content(chunk_size=8192): print(line)
yield chunk yield line
await session.close()
response = Response(generate(), status=r.status_code) return StreamingResponse(generate(), response.status)
# Copy headers from the target server's response to the client's response
for key, value in r.headers.items():
response.headers[key] = value
return response
except Exception as e: except Exception as e:
print(e) print(e)
error_detail = "Ollama WebUI: Server Connection Error" error_detail = "Ollama WebUI: Server Connection Error"
if r != None:
print(r.text)
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
print(res)
return ( if response is not None:
jsonify({ try:
"detail": error_detail, res = await response.json()
"message": str(e), if "error" in res:
}), error_detail = f"Ollama: {res['error']}"
400, except:
error_detail = f"Ollama: {e}"
await session.close()
raise HTTPException(
status_code=response.status if response else 500,
detail=error_detail,
) )
if __name__ == "__main__":
app.run(debug=True)