1919from transformers import PreTrainedTokenizer
2020import torch
2121from tensorrt_llm .sampling_params import LogitsProcessor
22- from logits_processor_zoo .utils import text_to_token
22+ from logits_processor_zoo .utils import text_to_token , SentenceChecker
2323
2424
25- class GenLengthLogitsProcessor (LogitsProcessor ):
25+ class GenLengthLogitsProcessor (LogitsProcessor , SentenceChecker ):
2626 """
2727 A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token
2828 based on the length of the generated sequence, encouraging or discouraging shorter answers.
@@ -37,18 +37,16 @@ class GenLengthLogitsProcessor(LogitsProcessor):
3737 or a new line. Default is False.
3838 boost_token_str (str, optional): A string to be tokenized and used instead of EOS. Especially useful for </think>.
3939 """
40- def __init__ (self , tokenizer : PreTrainedTokenizer , boost_factor : float ,
41- p : int = 2 , complete_sentences : bool = False , boost_token_str : str = None ):
42-
40+ def __init__ (self , tokenizer : PreTrainedTokenizer , boost_factor : float , p : int = 2 ,
41+ complete_sentences : bool = False , boost_token_str : str = None ):
42+ SentenceChecker . __init__ ( self , tokenizer )
4343 self .tokenizer = tokenizer
4444 self .boost_token = self .tokenizer .eos_token_id
4545 self .boost_token_str = boost_token_str
4646 if boost_token_str is not None :
4747 self .boost_token = text_to_token (self .tokenizer , boost_token_str , last = False )
4848 self .boost_factor = boost_factor
4949 self .p = p
50- self .full_stop_token = text_to_token (self .tokenizer , "It is a sentence." , last = True )
51- self .new_line_token = text_to_token (self .tokenizer , "It is a new line\n " , last = True )
5250 self .complete_sentences = complete_sentences
5351 self .token_count = 0
5452
@@ -64,7 +62,7 @@ def __call__(self, req_id: int, logits: torch.Tensor,
6462 ids = torch .LongTensor (token_ids ).to (logits .device , non_blocking = True )
6563
6664 if self .complete_sentences :
67- enabled = ( ids [:, - 1 ] == self .full_stop_token ) | (ids [:, - 1 ] == self . new_line_token )
65+ enabled = self ._check_sentence_end (ids )
6866 logits [:, :, self .boost_token ] += enabled * boost_val
6967 else :
7068 logits [:, :, self .boost_token ] += boost_val
0 commit comments