diff --git a/whisperx/asr.py b/whisperx/asr.py index c35900cf..bddc0b78 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -58,11 +58,7 @@ def generate_segment_batched( ) encoder_output = self.encode(features) - - max_initial_timestamp_index = int( - round(options.max_initial_timestamp / self.time_precision) - ) - + result = self.model.generate( encoder_output, [prompt] * batch_size, @@ -72,11 +68,13 @@ def generate_segment_batched( max_length=self.max_length, suppress_blank=options.suppress_blank, suppress_tokens=options.suppress_tokens, + no_repeat_ngram_size=options.no_repeat_ngram_size, + repetition_penalty=options.repetition_penalty, ) tokens_batch = [x.sequences_ids[0] for x in result] - def decode_batch(tokens: List[List[int]]) -> str: + def decode_batch(tokens: List[List[int]]) -> List[str]: res = [] for tk in tokens: res.append([token for token in tk if token < tokenizer.eot])