diff --git a/pyproject.toml b/pyproject.toml index 15d392da..94626d60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ urls = { repository = "https://github.com/m-bain/whisperx" } authors = [{ name = "Max Bain" }] name = "whisperx" -version = "3.8.1" +version = "3.8.2" description = "Time-Accurate Automatic Speech Recognition using Whisper." readme = "README.md" requires-python = ">=3.10, <3.14" diff --git a/uv.lock b/uv.lock index 421c5044..22d61222 100644 --- a/uv.lock +++ b/uv.lock @@ -3026,7 +3026,7 @@ wheels = [ [[package]] name = "whisperx" -version = "3.8.1" +version = "3.8.2" source = { editable = "." } dependencies = [ { name = "ctranslate2" }, diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 9e5b63a2..f3095ba8 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -2,8 +2,6 @@ Forced Alignment with Whisper C. Max Bain """ -import math - from dataclasses import dataclass from typing import Iterable, Optional, Union, List @@ -86,8 +84,8 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str] model_name = DEFAULT_ALIGN_MODELS_HF[language_code] else: logger.error(f"No default alignment model for language: {language_code}. " - f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, " - f"then pass the model name via --align_model [MODEL_NAME]") + f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, " + f"then pass the model name via --align_model [MODEL_NAME]") raise ValueError(f"No default align-model for language: {language_code}") if model_name in torchaudio.pipelines.__all__: @@ -178,19 +176,11 @@ def align( elif char_ in model_dictionary.keys(): clean_char.append(char_) clean_cdx.append(cdx) - else: - # add placeholder - clean_char.append('*') - clean_cdx.append(cdx) clean_wdx = [] for wdx, wrd in enumerate(per_word): if any([c in model_dictionary.keys() for c in wrd.lower()]): clean_wdx.append(wdx) - else: - # index for placeholder - clean_wdx.append(wdx) - # Use language-specific Punkt model if available otherwise we fallback to English. punkt_lang = PUNKT_LANGUAGES.get(model_lang, 'english') @@ -244,7 +234,7 @@ def align( continue text_clean = "".join(segment_data[sdx]["clean_char"]) - tokens = [model_dictionary.get(c, -1) for c in text_clean] + tokens = [model_dictionary[c] for c in text_clean] f1 = int(t1 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE) @@ -277,8 +267,7 @@ def align( blank_id = code trellis = get_trellis(emission, tokens, blank_id) - # path = backtrack(trellis, emission, tokens, blank_id) - path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) + path = backtrack(trellis, emission, tokens, blank_id) if path is None: logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original') @@ -405,55 +394,25 @@ def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) - trellis = torch.zeros((num_frame, num_tokens)) - trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) - trellis[0, 1:] = -float("inf") - trellis[-num_tokens + 1:, 0] = float("inf") + # Trellis has extra dimensions for both time axis and tokens. + # The extra dim for tokens represents (start-of-sentence) + # The extra dim for time axis is for simplification of the code. + trellis = torch.empty((num_frame + 1, num_tokens + 1)) + trellis[0, 0] = 0 + trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0) + trellis[0, -num_tokens:] = -float("inf") + trellis[-num_tokens:, 0] = float("inf") - for t in range(num_frame - 1): + for t in range(num_frame): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token - # trellis[t, :-1] + emission[t, tokens[1:]], - trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id), + trellis[t, :-1] + emission[t, tokens], ) return trellis -def get_wildcard_emission(frame_emission, tokens, blank_id): - """Processing token emission scores containing wildcards (vectorized version) - - Args: - frame_emission: Emission probability vector for the current frame - tokens: List of token indices - blank_id: ID of the blank token - - Returns: - tensor: Maximum probability score for each token position - """ - assert 0 <= blank_id < len(frame_emission) - - # Convert tokens to a tensor if they are not already - tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens - - # Create a mask to identify wildcard positions - wildcard_mask = (tokens == -1) - - # Get scores for non-wildcard positions - regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index - - # Create a mask and compute the maximum value without modifying frame_emission - max_valid_score = frame_emission.clone() # Create a copy - max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token - max_valid_score = max_valid_score.max() - - # Use where operation to combine results - result = torch.where(wildcard_mask, max_valid_score, regular_scores) - - return result - - @dataclass class Point: token_index: int @@ -462,138 +421,41 @@ class Point: def backtrack(trellis, emission, tokens, blank_id=0): - t, j = trellis.size(0) - 1, trellis.size(1) - 1 - - path = [Point(j, t, emission[t, blank_id].exp().item())] - while j > 0: - # Should not happen but just in case - assert t > 0 - + # Note: + # j and t are indices for trellis, which has extra dimensions + # for time and tokens at the beginning. + # When referring to time frame index `T` in trellis, + # the corresponding index in emission is `T-1`. + # Similarly, when referring to token index `J` in trellis, + # the corresponding index in transcript is `J-1`. + j = trellis.size(1) - 1 + t_start = torch.argmax(trellis[:, j]).item() + + path = [] + for t in range(t_start, 0, -1): # 1. Figure out if the current position was stay or change - # Frame-wise score of stay vs change - p_stay = emission[t - 1, blank_id] - # p_change = emission[t - 1, tokens[j]] - p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] - - # Context-aware score for stay vs change - stayed = trellis[t - 1, j] + p_stay - changed = trellis[t - 1, j - 1] + p_change - - # Update position - t -= 1 + # Note (again): + # `emission[J-1]` is the emission at time frame `J` of trellis dimension. + # Score for token staying the same from time frame J-1 to T. + stayed = trellis[t - 1, j] + emission[t - 1, blank_id] + # Score for token changing from C-1 at T-1 to J at T. + changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] + + # 2. Store the path with frame-wise probability. + prob = emission[t - 1, tokens[j - 1] if changed > stayed else blank_id].exp().item() + # Return token index and time index in non-trellis coordinate. + path.append(Point(j - 1, t - 1, prob)) + + # 3. Update the token if changed > stayed: j -= 1 - - # Store the path with frame-wise probability. - prob = (p_change if changed > stayed else p_stay).exp().item() - path.append(Point(j, t, prob)) - - # Now j == 0, which means, it reached the SoS. - # Fill up the rest for the sake of visualization - while t > 0: - prob = emission[t - 1, blank_id].exp().item() - path.append(Point(j, t - 1, prob)) - t -= 1 - - return path[::-1] - - - -@dataclass -class Path: - points: List[Point] - score: float - - -@dataclass -class BeamState: - """State in beam search.""" - token_index: int # Current token position - time_index: int # Current time step - score: float # Cumulative score - path: List[Point] # Path history - - -def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): - """Standard CTC beam search backtracking implementation. - - Args: - trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps - and N is the number of tokens (including the blank token). - emission (torch.Tensor): The emission probabilities of shape (T, N). - tokens (List[int]): List of token indices (excluding the blank token). - blank_id (int, optional): The ID of the blank token. Defaults to 0. - beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5. - - Returns: - List[Point]: the best path - """ - T, J = trellis.size(0) - 1, trellis.size(1) - 1 - - init_state = BeamState( - token_index=J, - time_index=T, - score=trellis[T, J], - path=[Point(J, T, emission[T, blank_id].exp().item())] - ) - - beams = [init_state] - - while beams and beams[0].token_index > 0: - next_beams = [] - - for beam in beams: - t, j = beam.time_index, beam.token_index - - if t <= 0: - continue - - p_stay = emission[t - 1, blank_id] - p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] - - stay_score = trellis[t - 1, j] - change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') - - # Stay - if not math.isinf(stay_score): - new_path = beam.path.copy() - new_path.append(Point(j, t - 1, p_stay.exp().item())) - next_beams.append(BeamState( - token_index=j, - time_index=t - 1, - score=stay_score, - path=new_path - )) - - # Change - if j > 0 and not math.isinf(change_score): - new_path = beam.path.copy() - new_path.append(Point(j - 1, t - 1, p_change.exp().item())) - next_beams.append(BeamState( - token_index=j - 1, - time_index=t - 1, - score=change_score, - path=new_path - )) - - # sort by score - beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] - - if not beams: - break - - if not beams: + if j == 0: + break + else: + # failed return None - best_beam = beams[0] - t = best_beam.time_index - j = best_beam.token_index - while t > 0: - prob = emission[t - 1, blank_id].exp().item() - best_beam.path.append(Point(j, t - 1, prob)) - t -= 1 - - return best_beam.path[::-1] + return path[::-1] # Merge the labels @@ -643,4 +505,4 @@ def merge_words(segments, separator="|"): i2 = i1 else: i2 += 1 - return words + return words \ No newline at end of file