Skip to content

Commit 4bcd8bb

Browse files
Added stopping on substring for HF Transformers.
1 parent e17c57c commit 4bcd8bb

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

src/synthesizrr/base/algorithm/huggingface/transformers.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from transformers.models.auto.modeling_auto import _BaseAutoModelClass, MODEL_MAPPING_NAMES, \
2727
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, \
2828
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
29-
from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
29+
from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, StoppingCriteria
3030
from transformers import (
3131
LogitsProcessorList,
3232
MinLengthLogitsProcessor, TemperatureLogitsWarper,
@@ -507,6 +507,51 @@ class HFGenerativeLMTokenizerConfig(HFTokenizerConfig):
507507
truncation_side: Literal['left', 'right'] = 'left' ## Keeps tokens at the end of the string, useful for LLMs
508508

509509

510+
class HFSubstringMatchStoppingCriteria(StoppingCriteria):
511+
def __init__(
512+
self,
513+
*,
514+
stop_sequences: List[str],
515+
tokenizer: Any,
516+
tokenizer_decode_dict: Dict,
517+
prompt_input_ids: Tensor,
518+
):
519+
self.tokenizer: PreTrainedTokenizerBase = tokenizer
520+
self.tokenizer_decode_dict: Dict = tokenizer_decode_dict
521+
self.stop_sequences: List[str] = as_list(stop_sequences)
522+
self.prompt_input_ids: Tensor = prompt_input_ids
523+
524+
def __call__(self, input_ids, scores, **kwargs):
525+
# Get the generated text as a string
526+
generated_texts: List[str] = self.tokenizer.batch_decode(
527+
input_ids[:, self.prompt_input_ids.shape[1]:],
528+
**self.tokenizer_decode_dict,
529+
)
530+
# Check if the target sequence appears in ALL generated texts
531+
should_stop_generating: List[bool] = []
532+
for generated_text in generated_texts:
533+
should_stop_generating.append(False)
534+
for stop_seq in self.stop_sequences:
535+
if stop_seq in generated_text:
536+
should_stop_generating[-1] = True
537+
break
538+
if bool(all(should_stop_generating)):
539+
# print('=' * 40)
540+
# print(f'Stopped at this point:')
541+
# print('=' * 40)
542+
# for generated_text in generated_texts:
543+
# print(generated_text, end='\n\n')
544+
# print('=' * 40)
545+
return True ## Stop generation
546+
return False ## Continue generation
547+
548+
def __len__(self):
549+
return len(self.stop_sequences)
550+
551+
def __iter__(self):
552+
yield self
553+
554+
510555
class HFPyTorchGenerativeLMMixin(GenerativeLM, HFPyTorchTextModel, ABC):
511556
class Hyperparameters(HFPyTorchTextModel.Hyperparameters):
512557
prompt_prefix: str = ''
@@ -529,6 +574,14 @@ def set_generative_lm_params(cls, params: Dict) -> Dict:
529574
def max_num_generated_tokens(self) -> int:
530575
return self.hyperparams.generation_params.max_new_tokens
531576

577+
@property
578+
def tokenizer_decode_dict(self) -> Dict:
579+
return self.hyperparams.tokenizer_decode.dict()
580+
581+
@property
582+
def stop_sequences(self) -> Optional[List[str]]:
583+
return self.hyperparams.generation_params.stop_sequences
584+
532585
def _task_preprocess(self, batch: Prompts, **kwargs) -> Prompts:
533586
batch: Prompts = super(HFPyTorchGenerativeLMMixin, self)._task_preprocess(
534587
batch,
@@ -539,12 +592,20 @@ def _task_preprocess(self, batch: Prompts, **kwargs) -> Prompts:
539592
def forward(self, input: Dict, **kwargs) -> Dict:
540593
## Feed the input_ids and masks to the model:
541594
input.pop('token_type_ids', None)
595+
input_ids: Tensor = input['input_ids']
542596
with disable_hf_logging():
543597
gen_kwargs: Dict = {
544598
**input,
545599
**self.hyperparams.generation_params.hf_dict(),
546600
**dict(return_dict_in_generate=True), ## Always return a *DecoderOnlyOutput
547601
}
602+
if self.stop_sequences is not None:
603+
gen_kwargs['stopping_criteria'] = HFSubstringMatchStoppingCriteria(
604+
stop_sequences=self.stop_sequences,
605+
tokenizer=self.tokenizer,
606+
tokenizer_decode_dict=self.tokenizer_decode_dict,
607+
prompt_input_ids=input_ids,
608+
)
548609
out: Union[GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput] = self.model.generate(**gen_kwargs)
549610
return dict(out)
550611

@@ -566,8 +627,22 @@ def prepare_predictions(self, output: Dict, input: Dict, **kwargs) -> Any:
566627
num_generated_tokens: int = generated_sequences.shape[1]
567628
generated_texts: List[str] = self.tokenizer.batch_decode(
568629
generated_sequences,
569-
**self.hyperparams.tokenizer_decode.dict(),
630+
**self.tokenizer_decode_dict,
570631
)
632+
## Post process stop-sequences:
633+
if self.stop_sequences is not None:
634+
for gen_text_i, generated_text in enumerate(generated_texts):
635+
earliest_stop_idx: Optional[int] = None
636+
for stop_seq in self.stop_sequences:
637+
stop_idx: int = generated_text.find(stop_seq)
638+
if stop_idx != -1:
639+
if earliest_stop_idx is None:
640+
earliest_stop_idx: int = stop_idx
641+
else:
642+
earliest_stop_idx: int = min(earliest_stop_idx, stop_idx)
643+
if earliest_stop_idx is not None:
644+
generated_texts[gen_text_i]: str = generated_text[:earliest_stop_idx]
645+
571646
predictions: Dict = {
572647
GENERATED_TEXTS_COL: generated_texts
573648
}

src/synthesizrr/base/framework/evaluator/LocalEvaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class LocalEvaluator(Evaluator):
1717
aliases = ['local', 'SimpleEvaluator', 'simple']
1818

1919
## Cache model locally for 15 mins:
20-
cache_timeout: Optional[Union[Timeout, confloat(gt=0)]] = Timeout24Hr(timeout=60 * 15)
20+
cache_timeout: Optional[Union[Timeout, confloat(gt=0)]] = Timeout24Hr(timeout=3 * 60 * 60)
2121

2222
def _load_model(
2323
self,

src/synthesizrr/base/framework/task/text_generation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,9 @@ def set_gen_params(cls, params: Dict) -> Dict:
553553
params['output_scores_tolerance']: Optional[float] = None ## Do not filter out any tokens.
554554
else:
555555
raise NotImplementedError(f'Unsupported `output_scores_format`: "{params["output_scores_format"]}"')
556+
557+
if params.get('stop_sequences') is not None:
558+
params['stop_sequences']: List[str] = as_list(params['stop_sequences'])
556559
return params
557560

558561
def hf_dict(self) -> Dict:

0 commit comments

Comments
 (0)