From d00ec69ac212d06de922e09bff56bcaca508ac76 Mon Sep 17 00:00:00 2001 From: Claude-Assistant Date: Wed, 11 Mar 2026 10:12:54 +0100 Subject: [PATCH] feat: add progress_callback to transcribe, align, and diarize Add optional `progress_callback: Callable[[float], None]` parameter to the three public API functions for real-time progress tracking. Each callback receives 0-100% for its own stage independently. Diarization wraps the callback into pyannote's internal hook protocol, keeping pyannote internals fully encapsulated. Co-Authored-By: Claude Opus 4.6 --- whisperx/alignment.py | 5 +++++ whisperx/asr.py | 5 ++++- whisperx/diarize.py | 26 +++++++++++++++++++++++++- whisperx/schema.py | 4 +++- 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index f3095ba8..0786d0eb 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -19,6 +19,7 @@ SingleAlignedSegment, SingleWordSegment, SegmentData, + ProgressCallback, ) import nltk from nltk.data import load as nltk_load @@ -122,6 +123,7 @@ def align( return_char_alignments: bool = False, print_progress: bool = False, combined_progress: bool = False, + progress_callback: ProgressCallback = None, ) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. @@ -376,6 +378,9 @@ def align( agg_dict["avg_logprob"] = "first" aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) aligned_subsegments = aligned_subsegments.to_dict('records') + if progress_callback is not None: + progress_callback(((sdx + 1) / total_segments) * 100) + aligned_segments += aligned_subsegments # create word_segments list diff --git a/whisperx/asr.py b/whisperx/asr.py index a5c80711..40114864 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -12,7 +12,7 @@ from transformers.pipelines.pt_utils import PipelineIterator from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram -from whisperx.schema import SingleSegment, TranscriptionResult +from whisperx.schema import SingleSegment, TranscriptionResult, ProgressCallback from whisperx.vads import Vad, Silero, Pyannote from whisperx.log_utils import get_logger @@ -205,6 +205,7 @@ def transcribe( print_progress=False, combined_progress=False, verbose=False, + progress_callback: ProgressCallback = None, ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) @@ -268,6 +269,8 @@ def data(audio, segments): base_progress = ((idx + 1) / total_segments) * 100 percent_complete = base_progress / 2 if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") + if progress_callback is not None: + progress_callback(((idx + 1) / total_segments) * 100) text = out['text'] avg_logprob = out['avg_logprob'] if batch_size in [0, 1, None]: diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 041fb129..b767416e 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -5,7 +5,7 @@ import torch from whisperx.audio import load_audio, SAMPLE_RATE -from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult +from whisperx.schema import TranscriptionResult, AlignedTranscriptionResult, ProgressCallback from whisperx.log_utils import get_logger logger = get_logger(__name__) @@ -109,6 +109,7 @@ def __call__( min_speakers: Optional[int] = None, max_speakers: Optional[int] = None, return_embeddings: bool = False, + progress_callback: ProgressCallback = None, ) -> Union[tuple[pd.DataFrame, Optional[dict[str, list[float]]]], pd.DataFrame]: """ Perform speaker diarization on audio. @@ -119,6 +120,7 @@ def __call__( min_speakers: Minimum number of speakers to detect max_speakers: Maximum number of speakers to detect return_embeddings: Whether to return speaker embeddings + progress_callback: Optional callable receiving a float (0-100) with progress percentage Returns: If return_embeddings is True: @@ -133,13 +135,35 @@ def __call__( 'sample_rate': SAMPLE_RATE } + hook = None + if progress_callback is not None: + # pyannote's diarization has two progress-trackable steps, each with + # its own completed/total counter that resets between steps. Map each + # step into a sub-range so progress is monotonic and meaningful. + _STEP_RANGES = { + "segmentation": (0.0, 50.0), + "embeddings": (50.0, 99.0), + } + last_pct = [0.0] + def hook(step_name, step_artifact, file=None, total=None, completed=None): + if total is not None and completed is not None and total > 0: + offset, end = _STEP_RANGES.get(step_name, (0.0, 99.0)) + pct = offset + min(completed / total, 1.0) * (end - offset) + if pct > last_pct[0]: + last_pct[0] = pct + progress_callback(pct) + output = self.model( audio_data, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers, + **({"hook": hook} if hook is not None else {}), ) + if progress_callback is not None: + progress_callback(100.0) + diarization = output.speaker_diarization embeddings = output.speaker_embeddings if return_embeddings else None diff --git a/whisperx/schema.py b/whisperx/schema.py index 83d9147f..e5287794 100644 --- a/whisperx/schema.py +++ b/whisperx/schema.py @@ -1,4 +1,6 @@ -from typing import TypedDict, Optional, List, Tuple +from typing import Callable, TypedDict, Optional, List, Tuple + +ProgressCallback = Optional[Callable[[float], None]] try: from typing import NotRequired