Skip to content

Commit 5d65e36

Browse files
Add possibility to do batch inference in vllm (#15)
1 parent ab6ad7e commit 5d65e36

File tree

6 files changed

+31
-1
lines changed

6 files changed

+31
-1
lines changed

logits_processor_zoo/vllm/cite_prompt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@ class CiteFromPromptLogitsProcessor:
3333
boost_eos (bool, optional): If True, boosts EOS token too.
3434
"""
3535
def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float = 1.0, boost_eos: bool = True):
36+
self.tokenizer = tokenizer
3637
self.boost_factor = boost_factor
3738
self.eos_token_id = tokenizer.eos_token_id
3839
self.boost_eos = boost_eos
3940

41+
def clone(self):
42+
return CiteFromPromptLogitsProcessor(self.tokenizer, self.boost_factor, self.boost_eos)
43+
4044
def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
4145
tokens = set(prompt_tokens_ids)
4246
if self.boost_eos:

logits_processor_zoo/vllm/generation_length.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float,
4646
self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True)
4747
self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True)
4848
self.complete_sentences = complete_sentences
49+
self.tokenizer = tokenizer
50+
51+
def clone(self):
52+
return GenLengthLogitsProcessor(self.tokenizer, self.boost_factor, self.p,
53+
self.complete_sentences, self.boost_token_str)
4954

5055
def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
5156
gen_length = len(past_token_ids)

logits_processor_zoo/vllm/last_phrase.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ def __init__(self, phrase: str, tokenizer: PreTrainedTokenizer):
3434
self.eos_token_id = tokenizer.eos_token_id
3535
self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False)
3636
self._reset()
37+
self.phrase = phrase
38+
self.tokenizer = tokenizer
39+
40+
# LogitsProcessor can contain a clone attribute to deep copy it
41+
# https://github.com/vllm-project/vllm/blob/19dcc02a72e3ed52e3bf95aae44ea1f40ce42ea0/vllm/sampling_params.py#L537-L550
42+
def clone(self):
43+
return ForceLastPhraseLogitsProcessor(self.phrase, self.tokenizer)
3744

3845
def _reset(self):
3946
self.index = 0

logits_processor_zoo/vllm/multiple_choice.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class MultipleChoiceLogitsProcessor:
4343
"""
4444
def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None,
4545
delimiter: str = ".", boost_first_words: float = 0.0):
46+
self.tokenizer = tokenizer
47+
self.choices = choices
48+
self.delimiter = delimiter
4649
if choices is None:
4750
choices = ["1", "2", "3", "4"]
4851

@@ -52,6 +55,9 @@ def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None,
5255
self.boost_first_words = boost_first_words
5356
self.very_large_number = 999
5457

58+
def clone(self):
59+
return MultipleChoiceLogitsProcessor(self.tokenizer, self.choices, self.delimiter, self.boost_first_words)
60+
5561
def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
5662

5763
if self.boost_first_words:

logits_processor_zoo/vllm/trigger_phrase.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,21 @@ class TriggerPhraseLogitsProcessor:
3535
"""
3636
def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrainedTokenizer, trigger_count: int = 1,
3737
trigger_after: bool = False):
38+
self.phrase = phrase
39+
self.trigger_token_phrase = trigger_token_phrase
40+
self.tokenizer = tokenizer
41+
self.trigger_count = trigger_count
3842
self.trigger_token = text_to_token(tokenizer, trigger_token_phrase, last=False)
3943
self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False)
4044
self.initial_trigger_count = trigger_count
4145
self.trigger_after = trigger_after
4246
self.very_large_number = 999
4347
self._reset()
4448

49+
def clone(self):
50+
return TriggerPhraseLogitsProcessor(self.phrase, self.trigger_token_phrase, self.tokenizer,
51+
self.initial_trigger_count, self.trigger_after)
52+
4553
def _reset(self):
4654
self.index = -1
4755
self.trigger_count = self.initial_trigger_count

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "logits-processor-zoo"
3-
version = "0.1.4"
3+
version = "0.1.5"
44
description = "A collection of LogitsProcessors to customize and enhance LLM behavior for specific tasks."
55
authors = ["Ahmet Erdem", "Ivan Sorokin", "Maximilian Jeblick", "Darragh Hanley", "David Austin"]
66
readme = "README.md"

0 commit comments

Comments
 (0)