feat: whisper support

This commit is contained in:
Timothy J. Baek 2024-02-11 00:17:50 -08:00
parent 182ab8b8a2
commit a5b9bbf10b
7 changed files with 127 additions and 0 deletions

View file

@ -0,0 +1,80 @@
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from utils.misc import calculate_sha256
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL_NAME
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/transcribe")
def transcribe(
file: UploadFile = File(...),
user=Depends(get_current_user),
):
print(file.content_type)
if file.content_type not in ["audio/mpeg", "audio/wav"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
)
try:
filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}"
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
f.close()
model_name = WHISPER_MODEL_NAME
model = WhisperModel(
model_name,
device="cpu",
compute_type="int8",
download_root=f"{CACHE_DIR}/whisper/models",
)
segments, info = model.transcribe(file_path, beam_size=5)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
return {"text": transcript}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)

View file

@ -132,3 +132,8 @@ CHROMA_CLIENT = chromadb.PersistentClient(
) )
CHUNK_SIZE = 1500 CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100 CHUNK_OVERLAP = 100
####################################
# Transcribe
####################################
WHISPER_MODEL_NAME = "tiny"

View file

@ -10,6 +10,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
from apps.audio.main import app as audio_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
@ -55,6 +57,8 @@ app.mount("/api/v1", webui_app)
app.mount("/ollama/api", ollama_app) app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app) app.mount("/openai/api", openai_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app) app.mount("/rag/api/v1", rag_app)

View file

@ -30,6 +30,8 @@ openpyxl
pyxlsb pyxlsb
xlrd xlrd
faster-whisper
PyJWT PyJWT
pyjwt[crypto] pyjwt[crypto]

View file

@ -0,0 +1,31 @@
import { AUDIO_API_BASE_URL } from '$lib/constants';
export const transcribeAudio = async (token: string, file: File) => {
const data = new FormData();
data.append('file', file);
let error = null;
const res = await fetch(`${AUDIO_API_BASE_URL}/transcribe`, {
method: 'POST',
headers: {
Accept: 'application/json',
authorization: `Bearer ${token}`
},
body: data
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -11,6 +11,7 @@
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants'; import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
import Documents from './MessageInput/Documents.svelte'; import Documents from './MessageInput/Documents.svelte';
import Models from './MessageInput/Models.svelte'; import Models from './MessageInput/Models.svelte';
import { transcribeAudio } from '$lib/apis/audio';
export let submitPrompt: Function; export let submitPrompt: Function;
export let stopResponse: Function; export let stopResponse: Function;
@ -201,6 +202,9 @@
console.log(file, file.name.split('.').at(-1)); console.log(file, file.name.split('.').at(-1));
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) { if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
reader.readAsDataURL(file); reader.readAsDataURL(file);
} else if (['audio/mpeg', 'audio/wav'].includes(file['type'])) {
const res = await transcribeAudio(localStorage.token, file);
console.log(res);
} else if ( } else if (
SUPPORTED_FILE_TYPE.includes(file['type']) || SUPPORTED_FILE_TYPE.includes(file['type']) ||
SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1))

View file

@ -7,6 +7,7 @@ export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`; export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`; export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`; export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`;
export const WEB_UI_VERSION = 'v1.0.0-alpha-static'; export const WEB_UI_VERSION = 'v1.0.0-alpha-static';