feat: terminate request on user stop

This commit is contained in:
Timothy J. Baek 2024-01-17 19:19:44 -08:00
parent 684bdf5151
commit 442e3d978a
4 changed files with 170 additions and 86 deletions

View file

@ -5,6 +5,7 @@ from fastapi.concurrency import run_in_threadpool
import requests
import json
import uuid
from pydantic import BaseModel
from apps.web.models.users import Users
@ -26,6 +27,9 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
REQUEST_POOL = []
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin":
@ -49,6 +53,16 @@ async def update_ollama_api_url(
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
@ -74,7 +88,27 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
def get_request():
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if path in ["chat"]:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
print("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
REQUEST_POOL.remove(request_id)
r = requests.request(
method=request.method,
url=target_url,
@ -85,8 +119,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
r.raise_for_status()
# r.close()
return StreamingResponse(
r.iter_content(chunk_size=8192),
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)