2024-02-09 01:05:01 +01:00
|
|
|
from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
|
2024-01-04 22:06:31 +01:00
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from fastapi.responses import StreamingResponse
|
2024-01-06 02:16:35 +01:00
|
|
|
from fastapi.concurrency import run_in_threadpool
|
2023-11-15 01:28:51 +01:00
|
|
|
|
|
|
|
import requests
|
|
|
|
import json
|
2024-01-18 04:19:44 +01:00
|
|
|
import uuid
|
2024-01-04 22:06:31 +01:00
|
|
|
from pydantic import BaseModel
|
2023-11-15 01:28:51 +01:00
|
|
|
|
2023-11-19 01:47:12 +01:00
|
|
|
from apps.web.models.users import Users
|
|
|
|
from constants import ERROR_MESSAGES
|
2024-02-09 01:05:01 +01:00
|
|
|
from utils.utils import decode_token, get_current_user, get_admin_user
|
2023-11-19 09:41:29 +01:00
|
|
|
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
|
2023-11-15 01:28:51 +01:00
|
|
|
|
2024-01-04 22:06:31 +01:00
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=["*"],
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
2023-11-15 01:28:51 +01:00
|
|
|
|
2024-01-04 22:06:31 +01:00
|
|
|
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
|
2023-11-15 01:28:51 +01:00
|
|
|
|
2024-01-04 22:06:31 +01:00
|
|
|
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
|
2023-11-15 01:28:51 +01:00
|
|
|
|
|
|
|
|
2024-01-18 04:19:44 +01:00
|
|
|
REQUEST_POOL = []
|
|
|
|
|
|
|
|
|
2024-01-04 22:06:31 +01:00
|
|
|
@app.get("/url")
|
2024-02-09 01:05:01 +01:00
|
|
|
async def get_ollama_api_url(user=Depends(get_admin_user)):
|
|
|
|
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
|
2024-01-04 22:06:31 +01:00
|
|
|
|
2023-11-15 01:28:51 +01:00
|
|
|
|
2024-01-04 22:06:31 +01:00
|
|
|
class UrlUpdateForm(BaseModel):
|
|
|
|
url: str
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/url/update")
|
2024-02-17 08:30:38 +01:00
|
|
|
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
|
2024-02-09 01:05:01 +01:00
|
|
|
app.state.OLLAMA_API_BASE_URL = form_data.url
|
|
|
|
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
|
2024-01-05 10:25:34 +01:00
|
|
|
|
|
|
|
|
2024-01-18 04:19:44 +01:00
|
|
|
@app.get("/cancel/{request_id}")
|
|
|
|
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
|
|
|
|
if user:
|
|
|
|
if request_id in REQUEST_POOL:
|
|
|
|
REQUEST_POOL.remove(request_id)
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
|
|
|
|
|
|
|
2024-01-04 22:06:31 +01:00
|
|
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
|
|
|
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
|
|
|
|
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
|
|
|
|
|
|
|
|
body = await request.body()
|
|
|
|
headers = dict(request.headers)
|
|
|
|
|
|
|
|
if user.role in ["user", "admin"]:
|
|
|
|
if path in ["pull", "delete", "push", "copy", "create"]:
|
|
|
|
if user.role != "admin":
|
2024-01-05 10:25:34 +01:00
|
|
|
raise HTTPException(
|
2024-02-17 08:30:38 +01:00
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
2024-01-05 10:25:34 +01:00
|
|
|
)
|
2024-01-04 22:06:31 +01:00
|
|
|
else:
|
2024-02-17 08:30:38 +01:00
|
|
|
raise HTTPException(
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
|
|
)
|
2023-12-15 02:05:46 +01:00
|
|
|
|
Fix bug: Header attributes (Host, Authorization, Origin, Referer) not sanitized
- Resolved an issue where header attributes Host, Authorization, Origin, and Referer were not being sanitized, resulting in two major issues:
1. Ollama requests inadvertently exposed user information, leading to data leakage.
2. When Ollama is deployed on different servers, and the intermediary proxy layer uses the host header to locate downstream services, it fails to find them.
Root Cause:
- In FastAPI, when accessing request.headers, all header names are converted to lowercase. This is because FastAPI, and its underlying framework Starlette, adhere to the HTTP/2 standard, which mandates lowercase header field names for performance and consistency.
- In HTTP/2, enforcing lowercase header field names reduces complexity in header processing as case sensitivity is no longer a concern. Thus, regardless of the case used in client-sent header fields, the server processes them uniformly in lowercase.
- This practice is adopted in FastAPI and other modern HTTP frameworks, even in an HTTP/1.1 context, to maintain consistency with HTTP/2 and improve overall performance. As a result, header field names are always presented in lowercase in FastAPI, even if the original request used capitalization or mixed case.
2024-01-11 07:36:34 +01:00
|
|
|
headers.pop("host", None)
|
|
|
|
headers.pop("authorization", None)
|
|
|
|
headers.pop("origin", None)
|
|
|
|
headers.pop("referer", None)
|
2023-12-26 22:40:03 +01:00
|
|
|
|
2024-01-06 02:16:35 +01:00
|
|
|
r = None
|
|
|
|
|
|
|
|
def get_request():
|
|
|
|
nonlocal r
|
2024-01-18 04:19:44 +01:00
|
|
|
|
|
|
|
request_id = str(uuid.uuid4())
|
2024-01-06 02:16:35 +01:00
|
|
|
try:
|
2024-01-18 04:19:44 +01:00
|
|
|
REQUEST_POOL.append(request_id)
|
|
|
|
|
|
|
|
def stream_content():
|
|
|
|
try:
|
2024-03-02 12:01:44 +01:00
|
|
|
if path == "generate":
|
|
|
|
data = json.loads(body.decode("utf-8"))
|
|
|
|
|
|
|
|
if not ("stream" in data and data["stream"] == False):
|
|
|
|
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
|
|
|
|
|
|
|
elif path == "chat":
|
2024-01-18 04:19:44 +01:00
|
|
|
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
|
|
|
|
|
|
|
for chunk in r.iter_content(chunk_size=8192):
|
|
|
|
if request_id in REQUEST_POOL:
|
|
|
|
yield chunk
|
|
|
|
else:
|
|
|
|
print("User: canceled request")
|
|
|
|
break
|
|
|
|
finally:
|
|
|
|
if hasattr(r, "close"):
|
|
|
|
r.close()
|
2024-03-02 12:01:44 +01:00
|
|
|
if request_id in REQUEST_POOL:
|
|
|
|
REQUEST_POOL.remove(request_id)
|
2024-01-18 04:19:44 +01:00
|
|
|
|
2024-01-06 02:16:35 +01:00
|
|
|
r = requests.request(
|
|
|
|
method=request.method,
|
|
|
|
url=target_url,
|
|
|
|
data=body,
|
|
|
|
headers=headers,
|
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
r.raise_for_status()
|
|
|
|
|
2024-01-18 04:19:44 +01:00
|
|
|
# r.close()
|
|
|
|
|
2024-01-06 02:16:35 +01:00
|
|
|
return StreamingResponse(
|
2024-01-18 04:19:44 +01:00
|
|
|
stream_content(),
|
2024-01-06 02:16:35 +01:00
|
|
|
status_code=r.status_code,
|
|
|
|
headers=dict(r.headers),
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
raise e
|
2023-12-14 02:37:29 +01:00
|
|
|
|
2024-01-06 02:16:35 +01:00
|
|
|
try:
|
|
|
|
return await run_in_threadpool(get_request)
|
2023-12-14 02:37:29 +01:00
|
|
|
except Exception as e:
|
2024-02-17 08:30:38 +01:00
|
|
|
error_detail = "Open WebUI: Server Connection Error"
|
2024-01-06 02:16:35 +01:00
|
|
|
if r is not None:
|
2024-01-04 22:06:31 +01:00
|
|
|
try:
|
2024-01-06 02:16:35 +01:00
|
|
|
res = r.json()
|
2024-01-04 22:06:31 +01:00
|
|
|
if "error" in res:
|
|
|
|
error_detail = f"Ollama: {res['error']}"
|
|
|
|
except:
|
|
|
|
error_detail = f"Ollama: {e}"
|
|
|
|
|
2024-01-05 10:25:34 +01:00
|
|
|
raise HTTPException(
|
2024-01-06 02:16:35 +01:00
|
|
|
status_code=r.status_code if r else 500,
|
2024-01-05 10:25:34 +01:00
|
|
|
detail=error_detail,
|
|
|
|
)
|