open-webui/backend/apps/audio/main.py

96 lines
2.2 KiB
Python
Raw Normal View History

import os
import logging
2024-02-11 09:17:50 +01:00
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from utils.misc import calculate_sha256
2024-03-31 10:13:39 +02:00
from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
UPLOAD_DIR,
WHISPER_MODEL,
WHISPER_MODEL_DIR,
2024-04-02 14:47:52 +02:00
DEVICE_TYPE,
2024-03-31 10:13:39 +02:00
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
2024-02-11 09:17:50 +01:00
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2024-04-02 14:47:52 +02:00
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")
2024-02-11 09:17:50 +01:00
@app.post("/transcribe")
def transcribe(
file: UploadFile = File(...),
user=Depends(get_current_user),
):
log.info(f"file.content_type: {file.content_type}")
2024-02-11 09:17:50 +01:00
if file.content_type not in ["audio/mpeg", "audio/wav"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
)
try:
filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}"
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
f.close()
model = WhisperModel(
2024-02-15 08:32:54 +01:00
WHISPER_MODEL,
device=whisper_device_type,
2024-02-11 09:17:50 +01:00
compute_type="int8",
2024-02-15 08:32:54 +01:00
download_root=WHISPER_MODEL_DIR,
2024-02-11 09:17:50 +01:00
)
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
2024-02-11 09:17:50 +01:00
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
2024-02-11 11:17:24 +01:00
return {"text": transcript.strip()}
2024-02-11 09:17:50 +01:00
except Exception as e:
log.exception(e)
2024-02-11 09:17:50 +01:00
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)