From a5b9bbf10b89a53cf8f4ed13c14eeae674a75237 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 11 Feb 2024 00:17:50 -0800 Subject: [PATCH] feat: whisper support --- backend/apps/audio/main.py | 80 +++++++++++++++++++++ backend/config.py | 5 ++ backend/main.py | 4 ++ backend/requirements.txt | 2 + src/lib/apis/audio/index.ts | 31 ++++++++ src/lib/components/chat/MessageInput.svelte | 4 ++ src/lib/constants.ts | 1 + 7 files changed, 127 insertions(+) create mode 100644 backend/apps/audio/main.py create mode 100644 src/lib/apis/audio/index.ts diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py new file mode 100644 index 00000000..f80e3ac8 --- /dev/null +++ b/backend/apps/audio/main.py @@ -0,0 +1,80 @@ +from fastapi import ( + FastAPI, + Request, + Depends, + HTTPException, + status, + UploadFile, + File, + Form, +) +from fastapi.middleware.cors import CORSMiddleware +from faster_whisper import WhisperModel + +from constants import ERROR_MESSAGES +from utils.utils import ( + decode_token, + get_current_user, + get_verified_user, + get_admin_user, +) +from utils.misc import calculate_sha256 + +from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL_NAME + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.post("/transcribe") +def transcribe( + file: UploadFile = File(...), + user=Depends(get_current_user), +): + print(file.content_type) + + if file.content_type not in ["audio/mpeg", "audio/wav"]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, + ) + + try: + filename = file.filename + file_path = f"{UPLOAD_DIR}/{filename}" + contents = file.file.read() + with open(file_path, "wb") as f: + f.write(contents) + f.close() + + model_name = WHISPER_MODEL_NAME + model = WhisperModel( + model_name, + device="cpu", + compute_type="int8", + download_root=f"{CACHE_DIR}/whisper/models", + ) + + segments, info = model.transcribe(file_path, beam_size=5) + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + + transcript = "".join([segment.text for segment in list(segments)]) + + return {"text": transcript} + + except Exception as e: + print(e) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) diff --git a/backend/config.py b/backend/config.py index 65ee2298..cf6e8139 100644 --- a/backend/config.py +++ b/backend/config.py @@ -132,3 +132,8 @@ CHROMA_CLIENT = chromadb.PersistentClient( ) CHUNK_SIZE = 1500 CHUNK_OVERLAP = 100 + +#################################### +# Transcribe +#################################### +WHISPER_MODEL_NAME = "tiny" diff --git a/backend/main.py b/backend/main.py index f7a82b66..3a28670e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,6 +10,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app +from apps.audio.main import app as audio_app + from apps.web.main import app as webui_app from apps.rag.main import app as rag_app @@ -55,6 +57,8 @@ app.mount("/api/v1", webui_app) app.mount("/ollama/api", ollama_app) app.mount("/openai/api", openai_app) + +app.mount("/audio/api/v1", audio_app) app.mount("/rag/api/v1", rag_app) diff --git a/backend/requirements.txt b/backend/requirements.txt index 68cba254..56e1d36e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -30,6 +30,8 @@ openpyxl pyxlsb xlrd +faster-whisper + PyJWT pyjwt[crypto] diff --git a/src/lib/apis/audio/index.ts b/src/lib/apis/audio/index.ts new file mode 100644 index 00000000..d2848339 --- /dev/null +++ b/src/lib/apis/audio/index.ts @@ -0,0 +1,31 @@ +import { AUDIO_API_BASE_URL } from '$lib/constants'; + +export const transcribeAudio = async (token: string, file: File) => { + const data = new FormData(); + data.append('file', file); + + let error = null; + const res = await fetch(`${AUDIO_API_BASE_URL}/transcribe`, { + method: 'POST', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + }, + body: data + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 0844c489..5ad78119 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -11,6 +11,7 @@ import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants'; import Documents from './MessageInput/Documents.svelte'; import Models from './MessageInput/Models.svelte'; + import { transcribeAudio } from '$lib/apis/audio'; export let submitPrompt: Function; export let stopResponse: Function; @@ -201,6 +202,9 @@ console.log(file, file.name.split('.').at(-1)); if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) { reader.readAsDataURL(file); + } else if (['audio/mpeg', 'audio/wav'].includes(file['type'])) { + const res = await transcribeAudio(localStorage.token, file); + console.log(res); } else if ( SUPPORTED_FILE_TYPE.includes(file['type']) || SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) diff --git a/src/lib/constants.ts b/src/lib/constants.ts index b373eb11..ce25a314 100644 --- a/src/lib/constants.ts +++ b/src/lib/constants.ts @@ -7,6 +7,7 @@ export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`; export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`; export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`; export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`; +export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`; export const WEB_UI_VERSION = 'v1.0.0-alpha-static';