@@ -126,42 +126,118 @@ def _pad_to_max_length(
126126 current_segments ,
127127 pad_token_id ,
128128 device ,
129- padding = "right" ,
129+ padding_side = "right" ,
130+ padding = "longest" ,
130131 bos_token_tensor = None ,
131132 cut_off_length = None ,
133+ return_token_timestamps = False ,
134+ force_unique_generate_call = False ,
132135):
133136 max_total_length = 0
134137 sequences = []
135- if padding not in ["right" , "left" ]:
136- raise ValueError (f"`padding` must be either 'right' or 'left', not { padding } " )
138+ token_timestamps_list = []
139+
140+ if padding_side not in ["right" , "left" ]:
141+ raise ValueError (
142+ f"`padding_side` must be either 'right' or 'left', not { padding_side } "
143+ )
144+
145+ if padding not in ["longest" , "max_length" ]:
146+ raise ValueError (
147+ f"`padding` must be either 'longest' or 'max_length', not { padding } "
148+ )
149+ elif padding == "max_length" and cut_off_length is None :
150+ raise ValueError (
151+ "`cut_off_length` must be specified when `padding='max_length'`"
152+ )
153+
154+ if force_unique_generate_call :
155+ sequences_list = []
156+ timestamps_list = []
157+ for segments in current_segments :
158+ result = segments [0 ]["result" ]
159+ sequences_list .append (
160+ result if isinstance (result , torch .Tensor ) else result ["sequences" ]
161+ )
162+ if return_token_timestamps :
163+ timestamps_list .append (result ["token_timestamps" ])
164+
165+ sequences = torch .stack (sequences_list , dim = 0 )
166+ if return_token_timestamps :
167+ token_timestamps = torch .stack (timestamps_list , dim = 0 )
168+ return sequences , token_timestamps
169+ return sequences
137170
138171 for current_segment_list in current_segments :
139172 if (
140173 current_segment_list is not None
141174 and len ([d ["tokens" ] for d in current_segment_list ]) > 0
142175 ):
143176 sequence = torch .cat ([d ["tokens" ] for d in current_segment_list ], dim = - 1 )
177+ if return_token_timestamps :
178+ token_timestamps = torch .cat (
179+ [
180+ d ["result" ]["token_timestamps" ][d ["idxs" ][0 ] : d ["idxs" ][1 ]]
181+ for d in current_segment_list
182+ ],
183+ dim = - 1 ,
184+ )
144185
145186 if cut_off_length is not None :
146187 sequence = sequence [- cut_off_length :]
188+ if return_token_timestamps :
189+ token_timestamps = token_timestamps [- cut_off_length :]
147190
148191 if bos_token_tensor is not None :
149192 sequence = torch .cat ([bos_token_tensor , sequence ])
150-
193+ if return_token_timestamps :
194+ token_timestamps = torch .cat (
195+ [
196+ torch .ones_like (bos_token_tensor , device = device ) * 0.0 ,
197+ token_timestamps ,
198+ ]
199+ )
151200 sequences .append (sequence )
201+ if return_token_timestamps :
202+ token_timestamps_list .append (token_timestamps )
152203 max_total_length = max (max_total_length , len (sequences [- 1 ]))
153204 elif bos_token_tensor is not None :
154205 sequences .append (bos_token_tensor )
206+ if return_token_timestamps :
207+ token_timestamps_list .append (
208+ torch .ones_like (bos_token_tensor , device = device ) * 0.0
209+ )
155210 else :
156211 sequences .append (torch .tensor ([], device = device ))
212+ if return_token_timestamps :
213+ token_timestamps_list .append (torch .tensor ([], device = device ))
157214
215+ max_total_length = (
216+ cut_off_length + 1 if padding == "max_length" else max_total_length
217+ )
158218 for i in range (len (current_segments )):
159219 pad_length = max_total_length - len (sequences [i ])
160- pad = (0 , pad_length ) if padding == "right" else (pad_length , 0 )
220+ pad = (0 , pad_length ) if padding_side == "right" else (pad_length , 0 )
221+
161222 sequences [i ] = F .pad (sequences [i ], pad = pad , value = pad_token_id )
223+ if return_token_timestamps :
224+ token_timestamps_list [i ] = F .pad (
225+ token_timestamps_list [i ],
226+ pad = pad ,
227+ value = (
228+ token_timestamps_list [i ][- 1 ]
229+ if len (token_timestamps_list [i ]) > 0
230+ else 0.0
231+ ),
232+ )
162233
163234 sequences = torch .stack (sequences , dim = 0 )
164- return sequences
235+
236+ if return_token_timestamps :
237+ token_timestamps = torch .stack (token_timestamps_list , dim = 0 )
238+ return sequences , token_timestamps
239+ else :
240+ return sequences
165241
166242
167243def whisper_generate (
@@ -186,9 +262,11 @@ def whisper_generate(
186262 num_segment_frames : Optional [int ] = None ,
187263 attention_mask : Optional [torch .Tensor ] = None ,
188264 time_precision : float = 0.02 ,
265+ time_precision_features : float = 0.01 ,
189266 return_token_timestamps : Optional [bool ] = None ,
190267 return_segments : bool = False ,
191268 return_dict_in_generate : Optional [bool ] = None ,
269+ force_unique_generate_call : Optional [bool ] = None ,
192270 ** kwargs ,
193271):
194272 # 0. deprecate old inputs
@@ -270,11 +348,23 @@ def whisper_generate(
270348 else input_features .device
271349 )
272350 begin_index = init_tokens .shape [1 ]
351+ num_beams = kwargs .get (
352+ "num_beams" ,
353+ (
354+ generation_config .num_beams
355+ if hasattr (generation_config , "num_beams" )
356+ and generation_config .num_beams is not None
357+ else 1
358+ ),
359+ )
360+ if "assistant_model" in kwargs :
361+ # speculative decoding: the model should be able to return eos token
362+ generation_config .begin_suppress_tokens = None
273363 logits_processor = self ._retrieve_logit_processors (
274364 generation_config = generation_config ,
275365 logits_processor = logits_processor ,
276366 begin_index = begin_index , # begin index is index of first generated decoder token
277- num_beams = kwargs . get ( " num_beams" , 1 ) ,
367+ num_beams = num_beams ,
278368 device = device ,
279369 )
280370
@@ -321,7 +411,23 @@ def whisper_generate(
321411 batch_size = cur_bsz ,
322412 generation_config = generation_config ,
323413 )
324-
414+ # 5bis speculative decoding: ensure the assistant model does only one call to generate
415+ # and therefore returns decoder input token ids and eos token id
416+ # we set a flag in the generation config to force the model to make only one call to generate
417+ # and return the decoder input token ids and eos token id
418+ if "assistant_model" in kwargs :
419+ assistant_model = kwargs ["assistant_model" ]
420+ assistant_model .generation_config .force_unique_generate_call = True
421+
422+ if force_unique_generate_call is None :
423+ if hasattr (generation_config , "force_unique_generate_call" ):
424+ force_unique_generate_call = generation_config .force_unique_generate_call
425+ elif hasattr (self .generation_config , "force_unique_generate_call" ):
426+ force_unique_generate_call = (
427+ self .generation_config .force_unique_generate_call
428+ )
429+ else :
430+ force_unique_generate_call = False
325431 # 6 Transcribe audio until we reach the end of all input audios
326432 while (seek < max_frames ).any ():
327433 # 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically
@@ -336,7 +442,11 @@ def whisper_generate(
336442 cur_bsz = cur_bsz ,
337443 batch_idx_map = batch_idx_map ,
338444 )
339- time_offset = seek * time_precision / input_stride
445+ time_offset = (
446+ seek .to (torch .float32 if device .type == "mps" else torch .float64 )
447+ * time_precision
448+ / input_stride
449+ )
340450 seek_num_frames = (max_frames - seek ).clamp (max = num_segment_frames )
341451
342452 # 6.2 cut out next 30s segment from input features
@@ -355,6 +465,9 @@ def whisper_generate(
355465 transformers .generation .logits_process .SuppressTokensLogitsProcessor ,
356466 "suppress_tokens" ,
357467 )
468+ extra_kwargs = {}
469+ if version .parse (transformers .__version__ ) >= version .parse ("4.47.0" ):
470+ extra_kwargs ["timestamp_begin" ] = timestamp_begin
358471
359472 decoder_input_ids , kwargs = self ._prepare_decoder_input_ids (
360473 cur_bsz = cur_bsz ,
@@ -367,6 +480,7 @@ def whisper_generate(
367480 config = self .config ,
368481 device = init_tokens .device ,
369482 suppress_tokens = suppress_tokens ,
483+ ** extra_kwargs ,
370484 kwargs = kwargs ,
371485 )
372486
@@ -419,7 +533,11 @@ def whisper_generate(
419533 if should_skip [i ]:
420534 seek [prev_i ] += seek_num_frames [prev_i ]
421535 continue
422-
536+ extra_kwargs = {}
537+ if version .parse (transformers .__version__ ) >= version .parse ("4.48.0" ):
538+ extra_kwargs ["decoder_input_ids" ] = decoder_input_ids
539+ if version .parse (transformers .__version__ ) >= version .parse ("4.47.0" ):
540+ extra_kwargs ["time_precision_features" ] = time_precision_features
423541 segments , segment_offset = self ._retrieve_segment (
424542 seek_sequence = seek_sequence ,
425543 seek_outputs = seek_outputs ,
@@ -431,14 +549,13 @@ def whisper_generate(
431549 prev_idx = prev_i ,
432550 idx = i ,
433551 return_token_timestamps = return_token_timestamps ,
552+ ** extra_kwargs ,
434553 )
435-
554+ seek [ prev_i ] += segment_offset
436555 current_segments [prev_i ] += segments
437556
438- if is_shortform :
439- seek [prev_i ] += max_frames [i ]
440- else :
441- seek [prev_i ] += segment_offset
557+ if force_unique_generate_call :
558+ break
442559
443560 # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
444561 # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
@@ -451,65 +568,69 @@ def whisper_generate(
451568 else current_segments
452569 )
453570
454- sequences = _pad_to_max_length (
455- final_segments ,
456- generation_config .pad_token_id ,
571+ # if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False,
572+ # meaning we are sure only one call to generate has been made,
573+ # -> we can return a ModelOutput
574+ # otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
575+ if (
576+ return_dict_in_generate
577+ and generation_config .return_dict_in_generate
578+ and (force_unique_generate_call or not return_timestamps )
579+ ):
580+ # only one call to generate_with_fallback, we can return a ModelOutput
581+ outputs = self ._stack_split_outputs (
582+ seek_outputs , model_output_type , self .device , kwargs
583+ )
584+ if num_return_sequences > 1 :
585+ if (
586+ hasattr (outputs , "encoder_attentions" )
587+ and outputs .encoder_attentions is not None
588+ ):
589+ outputs .encoder_attentions = tuple (
590+ outputs .encoder_attentions [i ][::num_return_sequences ]
591+ for i in range (len (outputs .encoder_attentions ))
592+ )
593+ if (
594+ hasattr (outputs , "encoder_hidden_states" )
595+ and outputs .encoder_hidden_states is not None
596+ ):
597+ outputs .encoder_hidden_states = tuple (
598+ outputs .encoder_hidden_states [i ][::num_return_sequences ]
599+ for i in range (len (outputs .encoder_hidden_states ))
600+ )
601+ return outputs
602+
603+ padded_outputs = _pad_to_max_length (
604+ current_segments = final_segments ,
605+ pad_token_id = generation_config .pad_token_id ,
457606 device = self .device ,
458- padding = "right" ,
607+ padding_side = "right" ,
608+ return_token_timestamps = return_token_timestamps ,
609+ force_unique_generate_call = force_unique_generate_call ,
459610 )
460611
461- # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
462- if return_segments :
463- return {"sequences" : sequences , "segments" : final_segments }
464-
465- if is_shortform :
466- # add eos token:
467- if (
468- generation_config .max_new_tokens is None
469- and generation_config .max_length is None
470- ):
471- eos_tokens = torch .full (
472- (sequences .shape [0 ], 1 ), generation_config .eos_token_id
473- )
474- sequences = torch .cat ([sequences , eos_tokens ], dim = - 1 )
475-
476- if return_token_timestamps :
477- outputs = {}
478- outputs ["sequences" ] = sequences
479- outputs ["token_timestamps" ] = torch .stack (
480- [d ["token_timestamps" ] for d in seek_outputs ], dim = 0
481- )
482- elif hasattr (self .config , "token_latency" ) and self .config .token_latency :
483- outputs = (sequences , seek_outputs [0 ])
484- else :
485- outputs = sequences
486-
487- if return_dict_in_generate and generation_config .return_dict_in_generate :
488- dict_outputs = self ._stack_split_outputs (
489- seek_outputs , model_output_type , sequences .device , kwargs
490- )
491-
492- if num_return_sequences > 1 :
493- if (
494- hasattr (dict_outputs , "encoder_attentions" )
495- and dict_outputs .encoder_attentions is not None
496- ):
497- dict_outputs .encoder_attentions = tuple (
498- dict_outputs .encoder_attentions [i ][::num_return_sequences ]
499- for i in range (len (dict_outputs .encoder_attentions ))
500- )
501- if (
502- hasattr (dict_outputs , "encoder_hidden_states" )
503- and dict_outputs .encoder_hidden_states is not None
504- ):
505- dict_outputs .encoder_hidden_states = tuple (
506- dict_outputs .encoder_hidden_states [i ][::num_return_sequences ]
507- for i in range (len (dict_outputs .encoder_hidden_states ))
508- )
509- if return_token_timestamps :
510- dict_outputs ["token_timestamps" ] = outputs ["token_timestamps" ]
511- return dict_outputs
612+ if return_dict_in_generate and generation_config .return_dict_in_generate :
613+ return_segments = True
614+ elif not return_segments and not return_token_timestamps :
615+ if hasattr (self .config , "token_latency" ) and self .config .token_latency :
616+ return (padded_outputs , seek_outputs [0 ])
617+ return padded_outputs
618+
619+ if return_token_timestamps :
620+ sequences , token_timestamps = padded_outputs
621+ outputs = {
622+ "sequences" : sequences ,
623+ "token_timestamps" : token_timestamps ,
624+ }
625+ elif hasattr (self .config , "token_latency" ) and self .config .token_latency :
626+ outputs = (sequences , seek_outputs [0 ])
627+ else :
628+ sequences = padded_outputs
629+ outputs = {
630+ "sequences" : sequences ,
631+ }
512632
513- return outputs
633+ if return_segments :
634+ outputs ["segments" ] = final_segments
514635
515- return sequences
636+ return outputs
0 commit comments