feat: openai tts support

This commit is contained in:
Timothy J. Baek 2024-02-05 22:51:08 -08:00
parent ce31113abd
commit 0b8df52c97
5 changed files with 216 additions and 23 deletions

View file

@ -1,15 +1,19 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
import hashlib
from pathlib import Path
app = FastAPI()
app.add_middleware(
@ -66,6 +70,73 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body()
filename = hashlib.sha256(body).hexdigest() + ".mp3"
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(filename)
print(file_path)
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
print("openai")
r = requests.post(
url=target_url,
data=body,
headers=headers,
stream=True,
)
print(r)
r.raise_for_status()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
# Return the saved file
return FileResponse(file_path)
# return StreamingResponse(
# r.iter_content(chunk_size=8192),
# status_code=r.status_code,
# headers=dict(r.headers),
# )
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
@ -129,8 +200,6 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
response_data = r.json()
print(type(response_data))
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])

View file

@ -35,6 +35,14 @@ FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
UPLOAD_DIR = f"{DATA_DIR}/uploads"
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Cache DIR
####################################
CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# OLLAMA_API_BASE_URL
####################################