Skip to content

Commit 91702e4

Browse files
authored
Fix boost_first_word of MultipleChoice to consider other newline tokens
* Fix boost_first_word of MultipleChoice to consider other new line alternatives * Update version * Cover more newline tokens
1 parent 4b7eb22 commit 91702e4

File tree

8 files changed

+74
-28
lines changed

8 files changed

+74
-28
lines changed

example_notebooks/transformers/multiple_choice_logits_processor.ipynb

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,25 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 1,
66
"id": "28ed6952",
77
"metadata": {},
8-
"outputs": [],
8+
"outputs": [
9+
{
10+
"name": "stdout",
11+
"output_type": "stream",
12+
"text": [
13+
"/home/aerdem/projects/nvidia/logits-processor-zoo\n"
14+
]
15+
}
16+
],
917
"source": [
1018
"%cd ../.."
1119
]
1220
},
1321
{
1422
"cell_type": "code",
15-
"execution_count": null,
23+
"execution_count": 2,
1624
"id": "a85f8503",
1725
"metadata": {},
1826
"outputs": [],
@@ -103,7 +111,7 @@
103111
},
104112
{
105113
"cell_type": "code",
106-
"execution_count": null,
114+
"execution_count": 4,
107115
"id": "7d74eb26",
108116
"metadata": {},
109117
"outputs": [
@@ -160,7 +168,7 @@
160168
},
161169
{
162170
"cell_type": "code",
163-
"execution_count": null,
171+
"execution_count": 5,
164172
"id": "b2297aab",
165173
"metadata": {},
166174
"outputs": [
@@ -177,7 +185,7 @@
177185
"\n",
178186
"\n",
179187
"LLM response:\n",
180-
"1\n",
188+
"3\n",
181189
"-----END-----\n",
182190
"\n",
183191
"\n",
@@ -190,7 +198,7 @@
190198
"\n",
191199
"\n",
192200
"LLM response:\n",
193-
"b\n",
201+
"a\n",
194202
"-----END-----\n",
195203
"\n",
196204
"\n"
@@ -214,7 +222,7 @@
214222
],
215223
"metadata": {
216224
"kernelspec": {
217-
"display_name": ".venv",
225+
"display_name": "Python 3 (ipykernel)",
218226
"language": "python",
219227
"name": "python3"
220228
},
@@ -228,7 +236,7 @@
228236
"name": "python",
229237
"nbconvert_exporter": "python",
230238
"pygments_lexer": "ipython3",
231-
"version": "3.12.7"
239+
"version": "3.10.13"
232240
}
233241
},
234242
"nbformat": 4,

example_notebooks/vllm/multiple_choice_logits_processor.ipynb

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"name": "stdout",
1111
"output_type": "stream",
1212
"text": [
13-
"/home/aerdem/projects/logits-processor-zoo\n"
13+
"/home/aerdem/projects/nvidia/logits-processor-zoo\n"
1414
]
1515
}
1616
],
@@ -25,22 +25,35 @@
2525
"metadata": {},
2626
"outputs": [
2727
{
28-
"name": "stderr",
28+
"name": "stdout",
2929
"output_type": "stream",
3030
"text": [
31-
"/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
32-
" warnings.warn(\n"
31+
"WARNING 12-19 10:37:26 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
32+
"INFO 12-19 10:37:26 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='google/gemma-1.1-2b-it', speculative_config=None, tokenizer='google/gemma-1.1-2b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=google/gemma-1.1-2b-it, use_v2_block_manager=False, enable_prefix_caching=False)\n",
33+
"INFO 12-19 10:37:27 model_runner.py:879] Starting to load model google/gemma-1.1-2b-it...\n",
34+
"INFO 12-19 10:37:28 weight_utils.py:236] Using model weights format ['*.safetensors']\n"
3335
]
3436
},
37+
{
38+
"data": {
39+
"application/vnd.jupyter.widget-view+json": {
40+
"model_id": "243efc7aaada47fd82cc1043c275f03d",
41+
"version_major": 2,
42+
"version_minor": 0
43+
},
44+
"text/plain": [
45+
"Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]\n"
46+
]
47+
},
48+
"metadata": {},
49+
"output_type": "display_data"
50+
},
3551
{
3652
"name": "stdout",
3753
"output_type": "stream",
3854
"text": [
39-
"WARNING 07-23 11:04:22 config.py:1222] Casting torch.bfloat16 to torch.float16.\n",
40-
"INFO 07-23 11:04:22 llm_engine.py:161] Initializing an LLM engine (v0.5.0.post1) with config: model='google/gemma-1.1-2b-it', speculative_config=None, tokenizer='google/gemma-1.1-2b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=google/gemma-1.1-2b-it)\n",
41-
"INFO 07-23 11:04:25 weight_utils.py:218] Using model weights format ['*.safetensors']\n",
42-
"INFO 07-23 11:04:27 model_runner.py:160] Loading model weights took 4.6720 GB\n",
43-
"INFO 07-23 11:04:28 gpu_executor.py:83] # GPU blocks: 52902, # CPU blocks: 14563\n"
55+
"INFO 12-19 10:37:30 model_runner.py:890] Loading model weights took 4.6720 GB\n",
56+
"INFO 12-19 10:37:32 gpu_executor.py:121] # GPU blocks: 49691, # CPU blocks: 14563\n"
4457
]
4558
}
4659
],

logits_processor_zoo/transformers/multiple_choice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformers import PreTrainedTokenizer
1919
from typing import List
2020
import torch
21-
from logits_processor_zoo.utils import text_to_token
21+
from logits_processor_zoo.utils import text_to_token, get_new_line_tokens
2222

2323

2424
class MultipleChoiceLogitsProcessor:
@@ -46,7 +46,7 @@ def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None,
4646
if choices is None:
4747
choices = ["1", "2", "3", "4"]
4848

49-
self.new_line_token = text_to_token(tokenizer, "\n", last=False)
49+
self.new_line_tokens = get_new_line_tokens(tokenizer)
5050
self.delimiter_token = text_to_token(tokenizer, delimiter, last=False)
5151
self.choice_tokens = [text_to_token(tokenizer, choice, last=False) for choice in choices]
5252
self.boost_first_words = boost_first_words
@@ -61,7 +61,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
6161
for i in range(len(input_ids[row_ind]) - 3):
6262
# A choice is like "\nA) hair dryer", where first token is "hair"
6363
choice_starts = (
64-
(input_ids[row_ind, i] == self.new_line_token) and
64+
(input_ids[row_ind, i].item() in self.new_line_tokens) and
6565
(input_ids[row_ind, i + 1] == self.choice_tokens[choice]) and
6666
(input_ids[row_ind, i + 2] == self.delimiter_token)
6767
)

logits_processor_zoo/trtllm/multiple_choice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformers import PreTrainedTokenizer
1919
from typing import List, Optional
2020
import torch
21-
from logits_processor_zoo.utils import text_to_token
21+
from logits_processor_zoo.utils import text_to_token, get_new_line_tokens
2222

2323

2424
class MultipleChoiceLogitsProcessor:
@@ -46,7 +46,7 @@ def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None,
4646
if choices is None:
4747
choices = ["1", "2", "3", "4"]
4848

49-
self.new_line_token = text_to_token(tokenizer, "\n", last=False)
49+
self.new_line_tokens = get_new_line_tokens(tokenizer)
5050
self.delimiter_token = text_to_token(tokenizer, delimiter, last=False)
5151
self.choice_tokens = [text_to_token(tokenizer, choice, last=False) for choice in choices]
5252
self.boost_first_words = boost_first_words
@@ -68,7 +68,7 @@ def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor],
6868
for i in range(len(ids_batch[row_ind]) - 3):
6969
# A choice is like "\nA) hair dryer", where first token is "hair"
7070
choice_starts = (
71-
(ids_batch[row_ind, i] == self.new_line_token) and
71+
(ids_batch[row_ind, i].item() in self.new_line_tokens) and
7272
(ids_batch[row_ind, i + 1] == self.choice_tokens[choice]) and
7373
(ids_batch[row_ind, i + 2] == self.delimiter_token)
7474
)

logits_processor_zoo/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,10 @@ def text_to_token(tokenizer: PreTrainedTokenizer, text: str, last: bool):
2626
raise Exception(f"Can't convert {text} to token. It has {len(tokens)} tokens.")
2727

2828
return tokens[-1]
29+
30+
31+
def get_new_line_tokens(tokenizer):
32+
new_line_tokens = [token for token in tokenizer.get_vocab().values()
33+
if tokenizer.decode(token).endswith("\n")]
34+
35+
return set(new_line_tokens)

logits_processor_zoo/vllm/multiple_choice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from transformers import PreTrainedTokenizer
1919
from typing import List
2020
import torch
21-
from logits_processor_zoo.utils import text_to_token
21+
from logits_processor_zoo.utils import text_to_token, get_new_line_tokens
2222

2323

2424
class MultipleChoiceLogitsProcessor:
@@ -46,7 +46,7 @@ def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None,
4646
if choices is None:
4747
choices = ["1", "2", "3", "4"]
4848

49-
self.new_line_token = text_to_token(tokenizer, "\n", last=False)
49+
self.new_line_token = get_new_line_tokens(tokenizer)
5050
self.delimiter_token = text_to_token(tokenizer, delimiter, last=False)
5151
self.choice_tokens = [text_to_token(tokenizer, choice, last=False) for choice in choices]
5252
self.boost_first_words = boost_first_words
@@ -61,7 +61,7 @@ def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scor
6161
for i in range(len(prompt_tokens_ids) - 3):
6262
# A choice is like "\nA) hair dryer", where first token is "hair"
6363
choice_starts = (
64-
(prompt_tokens_ids[i] == self.new_line_token) and
64+
(prompt_tokens_ids[i] in self.new_line_token) and
6565
(prompt_tokens_ids[i + 1] == self.choice_tokens[choice]) and
6666
(prompt_tokens_ids[i + 2] == self.delimiter_token)
6767
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "logits-processor-zoo"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = "A collection of LogitsProcessors to customize and enhance LLM behavior for specific tasks."
55
authors = ["Ahmet Erdem", "Ivan Sorokin", "Maximilian Jeblick", "Darragh Hanley", "David Austin"]
66
readme = "README.md"

tests/test_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from logits_processor_zoo.utils import text_to_token, get_new_line_tokens
2+
3+
4+
def test_text_to_token(llm_runner):
5+
assert text_to_token(llm_runner.tokenizer, ",", last=False) == 1919
6+
assert text_to_token(llm_runner.tokenizer, "apple, orange,", last=True) == 29892
7+
assert text_to_token(llm_runner.tokenizer, "apple, orange\n", last=True) == 13
8+
9+
try:
10+
token = text_to_token(llm_runner.tokenizer, "apple, orange,", last=False)
11+
except Exception:
12+
token = -1
13+
14+
assert token == -1
15+
16+
17+
def test_get_new_line_tokens(llm_runner):
18+
assert get_new_line_tokens(llm_runner.tokenizer) == {13}

0 commit comments

Comments
 (0)