diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 02d1f5e8..4f65b5f5 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -28,6 +28,7 @@ from config import ( UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, + WHISPER_MODEL_AUTO_UPDATE, DEVICE_TYPE, ) @@ -69,12 +70,22 @@ def transcribe( f.write(contents) f.close() - model = WhisperModel( - WHISPER_MODEL, - device=whisper_device_type, - compute_type="int8", - download_root=WHISPER_MODEL_DIR, - ) + whisper_kwargs = { + "model_size_or_path": WHISPER_MODEL, + "device": whisper_device_type, + "compute_type": "int8", + "download_root": WHISPER_MODEL_DIR, + "local_files_only": not WHISPER_MODEL_AUTO_UPDATE, + } + + log.debug(f"whisper_kwargs: {whisper_kwargs}") + + try: + model = WhisperModel(**whisper_kwargs) + except: + log.debug("WhisperModel initialization failed, attempting download with local_files_only=False") + whisper_kwargs["local_files_only"] = False + model = WhisperModel(**whisper_kwargs) segments, info = model.transcribe(file_path, beam_size=5) log.info( diff --git a/backend/config.py b/backend/config.py index 6e3cf92a..4436a5a0 100644 --- a/backend/config.py +++ b/backend/config.py @@ -446,6 +446,9 @@ Query: [query]""" WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") +WHISPER_MODEL_AUTO_UPDATE = ( + os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" +) ####################################