2626 from transformers .models .auto .modeling_auto import _BaseAutoModelClass , MODEL_MAPPING_NAMES , \
2727 MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES , MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES , \
2828 MODEL_FOR_CAUSAL_LM_MAPPING_NAMES , MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
29- from transformers .generation .utils import GreedySearchDecoderOnlyOutput , SampleDecoderOnlyOutput
29+ from transformers .generation .utils import GreedySearchDecoderOnlyOutput , SampleDecoderOnlyOutput , StoppingCriteria
3030 from transformers import (
3131 LogitsProcessorList ,
3232 MinLengthLogitsProcessor , TemperatureLogitsWarper ,
@@ -507,6 +507,51 @@ class HFGenerativeLMTokenizerConfig(HFTokenizerConfig):
507507 truncation_side : Literal ['left' , 'right' ] = 'left' ## Keeps tokens at the end of the string, useful for LLMs
508508
509509
510+ class HFSubstringMatchStoppingCriteria (StoppingCriteria ):
511+ def __init__ (
512+ self ,
513+ * ,
514+ stop_sequences : List [str ],
515+ tokenizer : Any ,
516+ tokenizer_decode_dict : Dict ,
517+ prompt_input_ids : Tensor ,
518+ ):
519+ self .tokenizer : PreTrainedTokenizerBase = tokenizer
520+ self .tokenizer_decode_dict : Dict = tokenizer_decode_dict
521+ self .stop_sequences : List [str ] = as_list (stop_sequences )
522+ self .prompt_input_ids : Tensor = prompt_input_ids
523+
524+ def __call__ (self , input_ids , scores , ** kwargs ):
525+ # Get the generated text as a string
526+ generated_texts : List [str ] = self .tokenizer .batch_decode (
527+ input_ids [:, self .prompt_input_ids .shape [1 ]:],
528+ ** self .tokenizer_decode_dict ,
529+ )
530+ # Check if the target sequence appears in ALL generated texts
531+ should_stop_generating : List [bool ] = []
532+ for generated_text in generated_texts :
533+ should_stop_generating .append (False )
534+ for stop_seq in self .stop_sequences :
535+ if stop_seq in generated_text :
536+ should_stop_generating [- 1 ] = True
537+ break
538+ if bool (all (should_stop_generating )):
539+ # print('=' * 40)
540+ # print(f'Stopped at this point:')
541+ # print('=' * 40)
542+ # for generated_text in generated_texts:
543+ # print(generated_text, end='\n\n')
544+ # print('=' * 40)
545+ return True ## Stop generation
546+ return False ## Continue generation
547+
548+ def __len__ (self ):
549+ return len (self .stop_sequences )
550+
551+ def __iter__ (self ):
552+ yield self
553+
554+
510555 class HFPyTorchGenerativeLMMixin (GenerativeLM , HFPyTorchTextModel , ABC ):
511556 class Hyperparameters (HFPyTorchTextModel .Hyperparameters ):
512557 prompt_prefix : str = ''
@@ -529,6 +574,14 @@ def set_generative_lm_params(cls, params: Dict) -> Dict:
529574 def max_num_generated_tokens (self ) -> int :
530575 return self .hyperparams .generation_params .max_new_tokens
531576
577+ @property
578+ def tokenizer_decode_dict (self ) -> Dict :
579+ return self .hyperparams .tokenizer_decode .dict ()
580+
581+ @property
582+ def stop_sequences (self ) -> Optional [List [str ]]:
583+ return self .hyperparams .generation_params .stop_sequences
584+
532585 def _task_preprocess (self , batch : Prompts , ** kwargs ) -> Prompts :
533586 batch : Prompts = super (HFPyTorchGenerativeLMMixin , self )._task_preprocess (
534587 batch ,
@@ -539,12 +592,20 @@ def _task_preprocess(self, batch: Prompts, **kwargs) -> Prompts:
539592 def forward (self , input : Dict , ** kwargs ) -> Dict :
540593 ## Feed the input_ids and masks to the model:
541594 input .pop ('token_type_ids' , None )
595+ input_ids : Tensor = input ['input_ids' ]
542596 with disable_hf_logging ():
543597 gen_kwargs : Dict = {
544598 ** input ,
545599 ** self .hyperparams .generation_params .hf_dict (),
546600 ** dict (return_dict_in_generate = True ), ## Always return a *DecoderOnlyOutput
547601 }
602+ if self .stop_sequences is not None :
603+ gen_kwargs ['stopping_criteria' ] = HFSubstringMatchStoppingCriteria (
604+ stop_sequences = self .stop_sequences ,
605+ tokenizer = self .tokenizer ,
606+ tokenizer_decode_dict = self .tokenizer_decode_dict ,
607+ prompt_input_ids = input_ids ,
608+ )
548609 out : Union [GreedySearchDecoderOnlyOutput , SampleDecoderOnlyOutput ] = self .model .generate (** gen_kwargs )
549610 return dict (out )
550611
@@ -566,8 +627,22 @@ def prepare_predictions(self, output: Dict, input: Dict, **kwargs) -> Any:
566627 num_generated_tokens : int = generated_sequences .shape [1 ]
567628 generated_texts : List [str ] = self .tokenizer .batch_decode (
568629 generated_sequences ,
569- ** self .hyperparams . tokenizer_decode . dict () ,
630+ ** self .tokenizer_decode_dict ,
570631 )
632+ ## Post process stop-sequences:
633+ if self .stop_sequences is not None :
634+ for gen_text_i , generated_text in enumerate (generated_texts ):
635+ earliest_stop_idx : Optional [int ] = None
636+ for stop_seq in self .stop_sequences :
637+ stop_idx : int = generated_text .find (stop_seq )
638+ if stop_idx != - 1 :
639+ if earliest_stop_idx is None :
640+ earliest_stop_idx : int = stop_idx
641+ else :
642+ earliest_stop_idx : int = min (earliest_stop_idx , stop_idx )
643+ if earliest_stop_idx is not None :
644+ generated_texts [gen_text_i ]: str = generated_text [:earliest_stop_idx ]
645+
571646 predictions : Dict = {
572647 GENERATED_TEXTS_COL : generated_texts
573648 }
0 commit comments