Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ gen_output = model.generate(
```


For the detailed examples in each framework, please have a look at **example_notebook** directory.
For the detailed examples in each framework, please have a look at **lpz_examples** directory.

## Available Logits Processors

Expand Down
11 changes: 5 additions & 6 deletions logits_processor_zoo/transformers/generation_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

import torch
from transformers import PreTrainedTokenizer
from logits_processor_zoo.utils import text_to_token
from logits_processor_zoo.utils import text_to_token, SentenceChecker
from logits_processor_zoo.transformers.base import BaseLogitsProcessor


class GenLengthLogitsProcessor(BaseLogitsProcessor):
class GenLengthLogitsProcessor(BaseLogitsProcessor, SentenceChecker):
"""
A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token
based on the length of the generated sequence, encouraging or discouraging shorter answers.
Expand All @@ -39,14 +39,13 @@ class GenLengthLogitsProcessor(BaseLogitsProcessor):
"""
def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float,
p: int = 2, complete_sentences: bool = False, boost_token_str: str = None):
super().__init__()
BaseLogitsProcessor.__init__(self)
SentenceChecker.__init__(self, tokenizer)
self.boost_token = tokenizer.eos_token_id
if boost_token_str is not None:
self.boost_token = text_to_token(tokenizer, boost_token_str, last=False)
self.boost_factor = boost_factor
self.p = p
self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True)
self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True)
self.complete_sentences = complete_sentences

def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
Expand All @@ -56,7 +55,7 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

enabled = (input_ids[:, -token_count:] == self.boost_token).sum(dim=1) == 0
if self.complete_sentences:
enabled = enabled & ((input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token))
enabled = enabled & self._check_sentence_end(input_ids)

scores[:, self.boost_token] += enabled * boost_val

Expand Down
11 changes: 5 additions & 6 deletions logits_processor_zoo/transformers/max_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import torch
from transformers import PreTrainedTokenizer
from logits_processor_zoo.transformers.base import BaseLogitsProcessor
from logits_processor_zoo.utils import text_to_token, enforce_tokens
from logits_processor_zoo.utils import text_to_token, enforce_tokens, SentenceChecker


class MaxTimeLogitsProcessor(BaseLogitsProcessor):
class MaxTimeLogitsProcessor(BaseLogitsProcessor, SentenceChecker):
"""
A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes.
Useful for controlling generation time and ensuring responses complete within time constraints.
Expand All @@ -44,13 +44,12 @@ def __init__(
complete_sentences: bool = False,
boost_token_str: str = None,
):
super().__init__()
BaseLogitsProcessor.__init__(self)
SentenceChecker.__init__(self, tokenizer)
self.boost_token = tokenizer.eos_token_id
if boost_token_str is not None:
self.boost_token = text_to_token(tokenizer, boost_token_str, last=False)
self.max_time = max_time
self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True)
self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True)
self.complete_sentences = complete_sentences

def _reset(self):
Expand All @@ -62,7 +61,7 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

enabled = (input_ids[:, -token_count:] == self.boost_token).sum(dim=1) == 0
if self.complete_sentences:
enabled = enabled & ((input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token))
enabled = enabled & self._check_sentence_end(input_ids)

if elapsed_time > self.max_time:
for i in range(scores.shape[0]):
Expand Down
14 changes: 6 additions & 8 deletions logits_processor_zoo/trtllm/generation_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from transformers import PreTrainedTokenizer
import torch
from tensorrt_llm.sampling_params import LogitsProcessor
from logits_processor_zoo.utils import text_to_token
from logits_processor_zoo.utils import text_to_token, SentenceChecker


class GenLengthLogitsProcessor(LogitsProcessor):
class GenLengthLogitsProcessor(LogitsProcessor, SentenceChecker):
"""
A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token
based on the length of the generated sequence, encouraging or discouraging shorter answers.
Expand All @@ -37,18 +37,16 @@ class GenLengthLogitsProcessor(LogitsProcessor):
or a new line. Default is False.
boost_token_str (str, optional): A string to be tokenized and used instead of EOS. Especially useful for </think>.
"""
def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float,
p: int = 2, complete_sentences: bool = False, boost_token_str: str = None):

def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float, p: int = 2,
complete_sentences: bool = False, boost_token_str: str = None):
SentenceChecker.__init__(self, tokenizer)
self.tokenizer = tokenizer
self.boost_token = self.tokenizer.eos_token_id
self.boost_token_str = boost_token_str
if boost_token_str is not None:
self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False)
self.boost_factor = boost_factor
self.p = p
self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True)
self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True)
self.complete_sentences = complete_sentences
self.token_count = 0

Expand All @@ -64,7 +62,7 @@ def __call__(self, req_id: int, logits: torch.Tensor,
ids = torch.LongTensor(token_ids).to(logits.device, non_blocking=True)

if self.complete_sentences:
enabled = (ids[:, -1] == self.full_stop_token) | (ids[:, -1] == self.new_line_token)
enabled = self._check_sentence_end(ids)
logits[:, :, self.boost_token] += enabled * boost_val
else:
logits[:, :, self.boost_token] += boost_val
Expand Down
9 changes: 4 additions & 5 deletions logits_processor_zoo/trtllm/max_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from transformers import PreTrainedTokenizer
import torch
from tensorrt_llm.sampling_params import LogitsProcessor
from logits_processor_zoo.utils import text_to_token, enforce_tokens
from logits_processor_zoo.utils import text_to_token, enforce_tokens, SentenceChecker


class MaxTimeLogitsProcessor(LogitsProcessor):
class MaxTimeLogitsProcessor(LogitsProcessor, SentenceChecker):
"""
A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes.
Useful for controlling generation time and ensuring responses complete within time constraints.
Expand All @@ -44,13 +44,12 @@ def __init__(
complete_sentences: bool = False,
boost_token_str: str = None,
):
SentenceChecker.__init__(self, tokenizer)
self.tokenizer = tokenizer
self.boost_token = self.tokenizer.eos_token_id
self.boost_token_str = boost_token_str
if boost_token_str is not None:
self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False)
self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True)
self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True)
self.complete_sentences = complete_sentences
self.token_count = 0
self.max_time = max_time
Expand All @@ -75,7 +74,7 @@ def __call__(

enabled = True
if self.complete_sentences:
enabled = (ids[:, -1] == self.full_stop_token) | (ids[:, -1] == self.new_line_token)
enabled = self._check_sentence_end(ids)

if time_exceeded and enabled:
# enforce the EOS token
Expand Down
14 changes: 13 additions & 1 deletion logits_processor_zoo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from transformers import PreTrainedTokenizer
from typing import List
from typing import List, Union
import torch


Expand Down Expand Up @@ -50,3 +50,15 @@ def enforce_tokens(scores: torch.Tensor, tokens: List[int]):
scores.fill_(scores.min())
scores[tokens] = choice_scores
return scores


class SentenceChecker:
def __init__(self, tokenizer: PreTrainedTokenizer):
self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True)
self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True)

def _check_sentence_end(self, input_ids: Union[List[int], torch.Tensor]):
if isinstance(input_ids, list) or isinstance(input_ids, tuple): # vllm input
return (input_ids[-1] == self.full_stop_token) | (input_ids[-1] == self.new_line_token)
else:
return (input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token)
13 changes: 7 additions & 6 deletions logits_processor_zoo/vllm/generation_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from typing import List, Union
import torch
from transformers import PreTrainedTokenizer, AutoTokenizer
from logits_processor_zoo.utils import text_to_token
from logits_processor_zoo.utils import text_to_token, SentenceChecker


class GenLengthLogitsProcessor:
class GenLengthLogitsProcessor(SentenceChecker):
"""
A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token
based on the length of the generated sequence, encouraging or discouraging shorter answers.
Expand All @@ -38,30 +38,31 @@ class GenLengthLogitsProcessor:
"""
def __init__(self, tokenizer: Union[PreTrainedTokenizer, str], boost_factor: float,
p: int = 2, complete_sentences: bool = False, boost_token_str: str = None):

self.tokenizer = tokenizer
if isinstance(self.tokenizer, str):
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer)
SentenceChecker.__init__(self, self.tokenizer)

self.boost_token = self.tokenizer.eos_token_id
self.boost_token_str = boost_token_str
if boost_token_str is not None:
self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False)
self.boost_factor = boost_factor
self.p = p
self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True)
self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True)
self.complete_sentences = complete_sentences

def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
if self.boost_token in past_token_ids: # do not boost repeatedly
return scores

gen_length = len(past_token_ids)

boost_val = 0
if not (self.boost_token in past_token_ids):
boost_val = self.boost_factor * (gen_length ** self.p) / (10 ** self.p)

if self.complete_sentences and gen_length > 0:
enabled = (past_token_ids[-1] == self.full_stop_token) | (past_token_ids[-1] == self.new_line_token)
enabled = self._check_sentence_end(past_token_ids)
scores[self.boost_token] += enabled * boost_val
else:
scores[self.boost_token] += boost_val
Expand Down
11 changes: 6 additions & 5 deletions logits_processor_zoo/vllm/max_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from typing import List
import torch
from transformers import PreTrainedTokenizer, AutoTokenizer
from logits_processor_zoo.utils import text_to_token, enforce_tokens
from logits_processor_zoo.utils import text_to_token, enforce_tokens, SentenceChecker


class MaxTimeLogitsProcessor:
class MaxTimeLogitsProcessor(SentenceChecker):
"""
A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes.
Useful for controlling generation time and ensuring responses complete within time constraints.
Expand All @@ -47,13 +47,12 @@ def __init__(
self.tokenizer = tokenizer
if isinstance(self.tokenizer, str):
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer)
SentenceChecker.__init__(self, self.tokenizer)

self.boost_token = self.tokenizer.eos_token_id
self.boost_token_str = boost_token_str
if boost_token_str is not None:
self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False)
self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True)
self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True)
self.complete_sentences = complete_sentences
self.max_time = max_time
self._reset()
Expand All @@ -77,14 +76,16 @@ def __call__(
past_token_ids: List[int],
scores: torch.Tensor,
) -> torch.Tensor:
if self.boost_token in past_token_ids: # do not force repeatedly
return scores

elapsed_time = time.time() - self.start_time
time_exceeded = elapsed_time > self.max_time
gen_length = len(past_token_ids)

enabled = True
if self.complete_sentences and gen_length > 0:
enabled = (past_token_ids[-1] == self.full_stop_token) | (past_token_ids[-1] == self.new_line_token)
enabled = self._check_sentence_end(past_token_ids)

if time_exceeded and enabled:
scores = enforce_tokens(scores, [self.boost_token])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
}
],
"source": [
"from examples.transformers.utils import LLMRunner\n",
"from lpz_examples.transformers.utils import LLMRunner\n",
"from logits_processor_zoo.transformers import CiteFromPromptLogitsProcessor\n",
"\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
}
],
"source": [
"from examples.transformers.utils import LLMRunner\n",
"from lpz_examples.transformers.utils import LLMRunner\n",
"from logits_processor_zoo.transformers import ForceLastPhraseLogitsProcessor\n",
"\n",
"\n",
Expand Down
Loading