forked from open-webui/open-webui
feat: whisper support
This commit is contained in:
parent
182ab8b8a2
commit
a5b9bbf10b
7 changed files with 127 additions and 0 deletions
80
backend/apps/audio/main.py
Normal file
80
backend/apps/audio/main.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
from fastapi import (
|
||||||
|
FastAPI,
|
||||||
|
Request,
|
||||||
|
Depends,
|
||||||
|
HTTPException,
|
||||||
|
status,
|
||||||
|
UploadFile,
|
||||||
|
File,
|
||||||
|
Form,
|
||||||
|
)
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
from constants import ERROR_MESSAGES
|
||||||
|
from utils.utils import (
|
||||||
|
decode_token,
|
||||||
|
get_current_user,
|
||||||
|
get_verified_user,
|
||||||
|
get_admin_user,
|
||||||
|
)
|
||||||
|
from utils.misc import calculate_sha256
|
||||||
|
|
||||||
|
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL_NAME
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/transcribe")
|
||||||
|
def transcribe(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
print(file.content_type)
|
||||||
|
|
||||||
|
if file.content_type not in ["audio/mpeg", "audio/wav"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
filename = file.filename
|
||||||
|
file_path = f"{UPLOAD_DIR}/{filename}"
|
||||||
|
contents = file.file.read()
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(contents)
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
model_name = WHISPER_MODEL_NAME
|
||||||
|
model = WhisperModel(
|
||||||
|
model_name,
|
||||||
|
device="cpu",
|
||||||
|
compute_type="int8",
|
||||||
|
download_root=f"{CACHE_DIR}/whisper/models",
|
||||||
|
)
|
||||||
|
|
||||||
|
segments, info = model.transcribe(file_path, beam_size=5)
|
||||||
|
print(
|
||||||
|
"Detected language '%s' with probability %f"
|
||||||
|
% (info.language, info.language_probability)
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript = "".join([segment.text for segment in list(segments)])
|
||||||
|
|
||||||
|
return {"text": transcript}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
)
|
|
@ -132,3 +132,8 @@ CHROMA_CLIENT = chromadb.PersistentClient(
|
||||||
)
|
)
|
||||||
CHUNK_SIZE = 1500
|
CHUNK_SIZE = 1500
|
||||||
CHUNK_OVERLAP = 100
|
CHUNK_OVERLAP = 100
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# Transcribe
|
||||||
|
####################################
|
||||||
|
WHISPER_MODEL_NAME = "tiny"
|
||||||
|
|
|
@ -10,6 +10,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
|
|
||||||
from apps.ollama.main import app as ollama_app
|
from apps.ollama.main import app as ollama_app
|
||||||
from apps.openai.main import app as openai_app
|
from apps.openai.main import app as openai_app
|
||||||
|
from apps.audio.main import app as audio_app
|
||||||
|
|
||||||
|
|
||||||
from apps.web.main import app as webui_app
|
from apps.web.main import app as webui_app
|
||||||
from apps.rag.main import app as rag_app
|
from apps.rag.main import app as rag_app
|
||||||
|
@ -55,6 +57,8 @@ app.mount("/api/v1", webui_app)
|
||||||
|
|
||||||
app.mount("/ollama/api", ollama_app)
|
app.mount("/ollama/api", ollama_app)
|
||||||
app.mount("/openai/api", openai_app)
|
app.mount("/openai/api", openai_app)
|
||||||
|
|
||||||
|
app.mount("/audio/api/v1", audio_app)
|
||||||
app.mount("/rag/api/v1", rag_app)
|
app.mount("/rag/api/v1", rag_app)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,8 @@ openpyxl
|
||||||
pyxlsb
|
pyxlsb
|
||||||
xlrd
|
xlrd
|
||||||
|
|
||||||
|
faster-whisper
|
||||||
|
|
||||||
PyJWT
|
PyJWT
|
||||||
pyjwt[crypto]
|
pyjwt[crypto]
|
||||||
|
|
||||||
|
|
31
src/lib/apis/audio/index.ts
Normal file
31
src/lib/apis/audio/index.ts
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import { AUDIO_API_BASE_URL } from '$lib/constants';
|
||||||
|
|
||||||
|
export const transcribeAudio = async (token: string, file: File) => {
|
||||||
|
const data = new FormData();
|
||||||
|
data.append('file', file);
|
||||||
|
|
||||||
|
let error = null;
|
||||||
|
const res = await fetch(`${AUDIO_API_BASE_URL}/transcribe`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
authorization: `Bearer ${token}`
|
||||||
|
},
|
||||||
|
body: data
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
error = err.detail;
|
||||||
|
console.log(err);
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
|
@ -11,6 +11,7 @@
|
||||||
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
|
import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
|
||||||
import Documents from './MessageInput/Documents.svelte';
|
import Documents from './MessageInput/Documents.svelte';
|
||||||
import Models from './MessageInput/Models.svelte';
|
import Models from './MessageInput/Models.svelte';
|
||||||
|
import { transcribeAudio } from '$lib/apis/audio';
|
||||||
|
|
||||||
export let submitPrompt: Function;
|
export let submitPrompt: Function;
|
||||||
export let stopResponse: Function;
|
export let stopResponse: Function;
|
||||||
|
@ -201,6 +202,9 @@
|
||||||
console.log(file, file.name.split('.').at(-1));
|
console.log(file, file.name.split('.').at(-1));
|
||||||
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
|
if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
|
||||||
reader.readAsDataURL(file);
|
reader.readAsDataURL(file);
|
||||||
|
} else if (['audio/mpeg', 'audio/wav'].includes(file['type'])) {
|
||||||
|
const res = await transcribeAudio(localStorage.token, file);
|
||||||
|
console.log(res);
|
||||||
} else if (
|
} else if (
|
||||||
SUPPORTED_FILE_TYPE.includes(file['type']) ||
|
SUPPORTED_FILE_TYPE.includes(file['type']) ||
|
||||||
SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1))
|
SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1))
|
||||||
|
|
|
@ -7,6 +7,7 @@ export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
|
||||||
export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
|
export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
|
||||||
export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
|
export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
|
||||||
export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
|
export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
|
||||||
|
export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`;
|
||||||
|
|
||||||
export const WEB_UI_VERSION = 'v1.0.0-alpha-static';
|
export const WEB_UI_VERSION = 'v1.0.0-alpha-static';
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue