Skip to content

Commit 7122dee

Browse files
authored
Reuse sentence check (#28)
* Reuse sentence check functionality Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Change examples directory to lpz_examples in order to prevent import conflicts Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Add a new example and fix bugs Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Update trtllm examples readme Signed-off-by: aerdem4 <ahmeterd4@gmail.com> --------- Signed-off-by: aerdem4 <ahmeterd4@gmail.com>
1 parent 59b766e commit 7122dee

35 files changed

+312
-146
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ gen_output = model.generate(
5656
```
5757

5858

59-
For the detailed examples in each framework, please have a look at **example_notebook** directory.
59+
For the detailed examples in each framework, please have a look at **lpz_examples** directory.
6060

6161
## Available Logits Processors
6262

logits_processor_zoo/transformers/generation_length.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
import torch
1919
from transformers import PreTrainedTokenizer
20-
from logits_processor_zoo.utils import text_to_token
20+
from logits_processor_zoo.utils import text_to_token, SentenceChecker
2121
from logits_processor_zoo.transformers.base import BaseLogitsProcessor
2222

2323

24-
class GenLengthLogitsProcessor(BaseLogitsProcessor):
24+
class GenLengthLogitsProcessor(BaseLogitsProcessor, SentenceChecker):
2525
"""
2626
A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token
2727
based on the length of the generated sequence, encouraging or discouraging shorter answers.
@@ -39,14 +39,13 @@ class GenLengthLogitsProcessor(BaseLogitsProcessor):
3939
"""
4040
def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float,
4141
p: int = 2, complete_sentences: bool = False, boost_token_str: str = None):
42-
super().__init__()
42+
BaseLogitsProcessor.__init__(self)
43+
SentenceChecker.__init__(self, tokenizer)
4344
self.boost_token = tokenizer.eos_token_id
4445
if boost_token_str is not None:
4546
self.boost_token = text_to_token(tokenizer, boost_token_str, last=False)
4647
self.boost_factor = boost_factor
4748
self.p = p
48-
self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True)
49-
self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True)
5049
self.complete_sentences = complete_sentences
5150

5251
def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
@@ -56,7 +55,7 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
5655

5756
enabled = (input_ids[:, -token_count:] == self.boost_token).sum(dim=1) == 0
5857
if self.complete_sentences:
59-
enabled = enabled & ((input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token))
58+
enabled = enabled & self._check_sentence_end(input_ids)
6059

6160
scores[:, self.boost_token] += enabled * boost_val
6261

logits_processor_zoo/transformers/max_time.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import torch
2020
from transformers import PreTrainedTokenizer
2121
from logits_processor_zoo.transformers.base import BaseLogitsProcessor
22-
from logits_processor_zoo.utils import text_to_token, enforce_tokens
22+
from logits_processor_zoo.utils import text_to_token, enforce_tokens, SentenceChecker
2323

2424

25-
class MaxTimeLogitsProcessor(BaseLogitsProcessor):
25+
class MaxTimeLogitsProcessor(BaseLogitsProcessor, SentenceChecker):
2626
"""
2727
A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes.
2828
Useful for controlling generation time and ensuring responses complete within time constraints.
@@ -44,13 +44,12 @@ def __init__(
4444
complete_sentences: bool = False,
4545
boost_token_str: str = None,
4646
):
47-
super().__init__()
47+
BaseLogitsProcessor.__init__(self)
48+
SentenceChecker.__init__(self, tokenizer)
4849
self.boost_token = tokenizer.eos_token_id
4950
if boost_token_str is not None:
5051
self.boost_token = text_to_token(tokenizer, boost_token_str, last=False)
5152
self.max_time = max_time
52-
self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True)
53-
self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True)
5453
self.complete_sentences = complete_sentences
5554

5655
def _reset(self):
@@ -62,7 +61,7 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
6261

6362
enabled = (input_ids[:, -token_count:] == self.boost_token).sum(dim=1) == 0
6463
if self.complete_sentences:
65-
enabled = enabled & ((input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token))
64+
enabled = enabled & self._check_sentence_end(input_ids)
6665

6766
if elapsed_time > self.max_time:
6867
for i in range(scores.shape[0]):

logits_processor_zoo/trtllm/generation_length.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from transformers import PreTrainedTokenizer
2020
import torch
2121
from tensorrt_llm.sampling_params import LogitsProcessor
22-
from logits_processor_zoo.utils import text_to_token
22+
from logits_processor_zoo.utils import text_to_token, SentenceChecker
2323

2424

25-
class GenLengthLogitsProcessor(LogitsProcessor):
25+
class GenLengthLogitsProcessor(LogitsProcessor, SentenceChecker):
2626
"""
2727
A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token
2828
based on the length of the generated sequence, encouraging or discouraging shorter answers.
@@ -37,18 +37,16 @@ class GenLengthLogitsProcessor(LogitsProcessor):
3737
or a new line. Default is False.
3838
boost_token_str (str, optional): A string to be tokenized and used instead of EOS. Especially useful for </think>.
3939
"""
40-
def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float,
41-
p: int = 2, complete_sentences: bool = False, boost_token_str: str = None):
42-
40+
def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float, p: int = 2,
41+
complete_sentences: bool = False, boost_token_str: str = None):
42+
SentenceChecker.__init__(self, tokenizer)
4343
self.tokenizer = tokenizer
4444
self.boost_token = self.tokenizer.eos_token_id
4545
self.boost_token_str = boost_token_str
4646
if boost_token_str is not None:
4747
self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False)
4848
self.boost_factor = boost_factor
4949
self.p = p
50-
self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True)
51-
self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True)
5250
self.complete_sentences = complete_sentences
5351
self.token_count = 0
5452

@@ -64,7 +62,7 @@ def __call__(self, req_id: int, logits: torch.Tensor,
6462
ids = torch.LongTensor(token_ids).to(logits.device, non_blocking=True)
6563

6664
if self.complete_sentences:
67-
enabled = (ids[:, -1] == self.full_stop_token) | (ids[:, -1] == self.new_line_token)
65+
enabled = self._check_sentence_end(ids)
6866
logits[:, :, self.boost_token] += enabled * boost_val
6967
else:
7068
logits[:, :, self.boost_token] += boost_val

logits_processor_zoo/trtllm/max_time.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from transformers import PreTrainedTokenizer
2121
import torch
2222
from tensorrt_llm.sampling_params import LogitsProcessor
23-
from logits_processor_zoo.utils import text_to_token, enforce_tokens
23+
from logits_processor_zoo.utils import text_to_token, enforce_tokens, SentenceChecker
2424

2525

26-
class MaxTimeLogitsProcessor(LogitsProcessor):
26+
class MaxTimeLogitsProcessor(LogitsProcessor, SentenceChecker):
2727
"""
2828
A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes.
2929
Useful for controlling generation time and ensuring responses complete within time constraints.
@@ -44,13 +44,12 @@ def __init__(
4444
complete_sentences: bool = False,
4545
boost_token_str: str = None,
4646
):
47+
SentenceChecker.__init__(self, tokenizer)
4748
self.tokenizer = tokenizer
4849
self.boost_token = self.tokenizer.eos_token_id
4950
self.boost_token_str = boost_token_str
5051
if boost_token_str is not None:
5152
self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False)
52-
self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True)
53-
self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True)
5453
self.complete_sentences = complete_sentences
5554
self.token_count = 0
5655
self.max_time = max_time
@@ -75,7 +74,7 @@ def __call__(
7574

7675
enabled = True
7776
if self.complete_sentences:
78-
enabled = (ids[:, -1] == self.full_stop_token) | (ids[:, -1] == self.new_line_token)
77+
enabled = self._check_sentence_end(ids)
7978

8079
if time_exceeded and enabled:
8180
# enforce the EOS token

logits_processor_zoo/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
from transformers import PreTrainedTokenizer
19-
from typing import List
19+
from typing import List, Union
2020
import torch
2121

2222

@@ -50,3 +50,15 @@ def enforce_tokens(scores: torch.Tensor, tokens: List[int]):
5050
scores.fill_(scores.min())
5151
scores[tokens] = choice_scores
5252
return scores
53+
54+
55+
class SentenceChecker:
56+
def __init__(self, tokenizer: PreTrainedTokenizer):
57+
self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True)
58+
self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True)
59+
60+
def _check_sentence_end(self, input_ids: Union[List[int], torch.Tensor]):
61+
if isinstance(input_ids, list) or isinstance(input_ids, tuple): # vllm input
62+
return (input_ids[-1] == self.full_stop_token) | (input_ids[-1] == self.new_line_token)
63+
else:
64+
return (input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token)

logits_processor_zoo/vllm/generation_length.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from typing import List, Union
1919
import torch
2020
from transformers import PreTrainedTokenizer, AutoTokenizer
21-
from logits_processor_zoo.utils import text_to_token
21+
from logits_processor_zoo.utils import text_to_token, SentenceChecker
2222

2323

24-
class GenLengthLogitsProcessor:
24+
class GenLengthLogitsProcessor(SentenceChecker):
2525
"""
2626
A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token
2727
based on the length of the generated sequence, encouraging or discouraging shorter answers.
@@ -38,30 +38,31 @@ class GenLengthLogitsProcessor:
3838
"""
3939
def __init__(self, tokenizer: Union[PreTrainedTokenizer, str], boost_factor: float,
4040
p: int = 2, complete_sentences: bool = False, boost_token_str: str = None):
41-
4241
self.tokenizer = tokenizer
4342
if isinstance(self.tokenizer, str):
4443
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer)
44+
SentenceChecker.__init__(self, self.tokenizer)
4545

4646
self.boost_token = self.tokenizer.eos_token_id
4747
self.boost_token_str = boost_token_str
4848
if boost_token_str is not None:
4949
self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False)
5050
self.boost_factor = boost_factor
5151
self.p = p
52-
self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True)
53-
self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True)
5452
self.complete_sentences = complete_sentences
5553

5654
def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
55+
if self.boost_token in past_token_ids: # do not boost repeatedly
56+
return scores
57+
5758
gen_length = len(past_token_ids)
5859

5960
boost_val = 0
6061
if not (self.boost_token in past_token_ids):
6162
boost_val = self.boost_factor * (gen_length ** self.p) / (10 ** self.p)
6263

6364
if self.complete_sentences and gen_length > 0:
64-
enabled = (past_token_ids[-1] == self.full_stop_token) | (past_token_ids[-1] == self.new_line_token)
65+
enabled = self._check_sentence_end(past_token_ids)
6566
scores[self.boost_token] += enabled * boost_val
6667
else:
6768
scores[self.boost_token] += boost_val

logits_processor_zoo/vllm/max_time.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from typing import List
2020
import torch
2121
from transformers import PreTrainedTokenizer, AutoTokenizer
22-
from logits_processor_zoo.utils import text_to_token, enforce_tokens
22+
from logits_processor_zoo.utils import text_to_token, enforce_tokens, SentenceChecker
2323

2424

25-
class MaxTimeLogitsProcessor:
25+
class MaxTimeLogitsProcessor(SentenceChecker):
2626
"""
2727
A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes.
2828
Useful for controlling generation time and ensuring responses complete within time constraints.
@@ -47,13 +47,12 @@ def __init__(
4747
self.tokenizer = tokenizer
4848
if isinstance(self.tokenizer, str):
4949
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer)
50+
SentenceChecker.__init__(self, self.tokenizer)
5051

5152
self.boost_token = self.tokenizer.eos_token_id
5253
self.boost_token_str = boost_token_str
5354
if boost_token_str is not None:
5455
self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False)
55-
self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True)
56-
self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True)
5756
self.complete_sentences = complete_sentences
5857
self.max_time = max_time
5958
self._reset()
@@ -77,14 +76,16 @@ def __call__(
7776
past_token_ids: List[int],
7877
scores: torch.Tensor,
7978
) -> torch.Tensor:
79+
if self.boost_token in past_token_ids: # do not force repeatedly
80+
return scores
8081

8182
elapsed_time = time.time() - self.start_time
8283
time_exceeded = elapsed_time > self.max_time
8384
gen_length = len(past_token_ids)
8485

8586
enabled = True
8687
if self.complete_sentences and gen_length > 0:
87-
enabled = (past_token_ids[-1] == self.full_stop_token) | (past_token_ids[-1] == self.new_line_token)
88+
enabled = self._check_sentence_end(past_token_ids)
8889

8990
if time_exceeded and enabled:
9091
scores = enforce_tokens(scores, [self.boost_token])

examples/transformers/cite_prompt_logits_processor.ipynb renamed to lpz_examples/transformers/cite_prompt_logits_processor.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
}
3434
],
3535
"source": [
36-
"from examples.transformers.utils import LLMRunner\n",
36+
"from lpz_examples.transformers.utils import LLMRunner\n",
3737
"from logits_processor_zoo.transformers import CiteFromPromptLogitsProcessor\n",
3838
"\n",
3939
"\n",

examples/transformers/force_last_phrase_logits_processor.ipynb renamed to lpz_examples/transformers/force_last_phrase_logits_processor.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
}
3838
],
3939
"source": [
40-
"from examples.transformers.utils import LLMRunner\n",
40+
"from lpz_examples.transformers.utils import LLMRunner\n",
4141
"from logits_processor_zoo.transformers import ForceLastPhraseLogitsProcessor\n",
4242
"\n",
4343
"\n",

0 commit comments

Comments
 (0)