diff --git a/pyproject.toml b/pyproject.toml index 61dc5df4..15d392da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ urls = { repository = "https://github.com/m-bain/whisperx" } authors = [{ name = "Max Bain" }] name = "whisperx" -version = "3.8.0" +version = "3.8.1" description = "Time-Accurate Automatic Speech Recognition using Whisper." readme = "README.md" requires-python = ">=3.10, <3.14" diff --git a/uv.lock b/uv.lock index 66992af5..421c5044 100644 --- a/uv.lock +++ b/uv.lock @@ -3026,7 +3026,7 @@ wheels = [ [[package]] name = "whisperx" -version = "3.8.0" +version = "3.8.1" source = { editable = "." } dependencies = [ { name = "ctranslate2" }, diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 81c47566..ce92d7a4 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -77,7 +77,7 @@ } -def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None): +def load_align_model(language_code: str, device: str, model_name: Optional[str] = None, model_dir=None, model_cache_only: bool = False): if model_name is None: # use default model if language_code in DEFAULT_ALIGN_MODELS_TORCH: @@ -98,8 +98,8 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str] align_dictionary = {c.lower(): i for i, c in enumerate(labels)} else: try: - processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir) - align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir) + processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=model_dir, local_files_only=model_cache_only) + align_model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=model_dir, local_files_only=model_cache_only) except Exception as e: print(e) print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models") diff --git a/whisperx/asr.py b/whisperx/asr.py index f9456be8..7540770f 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -314,6 +314,7 @@ def load_model( download_root: Optional[str] = None, local_files_only=False, threads=4, + use_auth_token: Optional[Union[str, bool]] = None, ) -> FasterWhisperPipeline: """Load a Whisper model for inference. Args: @@ -341,7 +342,8 @@ def load_model( compute_type=compute_type, download_root=download_root, local_files_only=local_files_only, - cpu_threads=threads) + cpu_threads=threads, + use_auth_token=use_auth_token) if language is not None: tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) else: diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 59b0f2f1..041fb129 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -94,12 +94,13 @@ def __init__( model_name=None, token=None, device: Optional[Union[str, torch.device]] = "cpu", + cache_dir=None, ): if isinstance(device, str): device = torch.device(device) model_config = model_name or "pyannote/speaker-diarization-community-1" logger.info(f"Loading diarization model: {model_config}") - self.model = Pipeline.from_pretrained(model_config, token=token).to(device) + self.model = Pipeline.from_pretrained(model_config, token=token, cache_dir=cache_dir).to(device) def __call__( self, diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 0aae410c..7c8be679 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -141,6 +141,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): task=task, local_files_only=model_cache_only, threads=faster_whisper_threads, + use_auth_token=hf_token, ) for audio_path in args.pop("audio"): @@ -166,7 +167,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): tmp_results = results results = [] align_model, align_metadata = load_align_model( - align_language, device, model_name=align_model + align_language, device, model_name=align_model, model_dir=model_dir, model_cache_only=model_cache_only ) for result, audio_path in tmp_results: # >> Align @@ -183,7 +184,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..." ) align_model, align_metadata = load_align_model( - result["language"], device + result["language"], device, model_dir=model_dir, model_cache_only=model_cache_only ) logger.info("Performing alignment...") result: AlignedTranscriptionResult = align( @@ -214,7 +215,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): logger.info("Performing diarization...") logger.info(f"Using model: {diarize_model_name}") results = [] - diarize_model = DiarizationPipeline(model_name=diarize_model_name, token=hf_token, device=device) + diarize_model = DiarizationPipeline(model_name=diarize_model_name, token=hf_token, device=device, cache_dir=model_dir) for result, input_audio_path in tmp_results: diarize_result = diarize_model( input_audio_path,