Skip to content

Commit ab6ad7e

Browse files
authored
Reset processors after each batch to be able to re-use (#13)
* Reset processors after each batch to be able to re-use Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Re-use processor objects in example notebooks Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Add TriggerPhraseLogitsProcessor to readme Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Fix grammar mistake Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Comment reset_if_new_batch logic Signed-off-by: aerdem4 <ahmeterd4@gmail.com> * Make new generation detection logic more robust Signed-off-by: aerdem4 <ahmeterd4@gmail.com> --------- Signed-off-by: aerdem4 <ahmeterd4@gmail.com>
1 parent ec4e07f commit ab6ad7e

16 files changed

+229
-85
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,12 @@ I am getting a lot of calls during the day. What is more important for me to con
7878
2. Operating System
7979
3. Battery
8080
```
81-
The goal is to make LLM generate "3" as an answer.
81+
The goal is to make LLM generate "3" as an answer.
82+
83+
### TriggerPhraseLogitsProcessor
84+
A logits processor which triggers phrases when it encounters a given token.
85+
One common use case is to force writing python code just after thinking:
86+
```python
87+
trigger_python = TriggerPhraseLogitsProcessor(phrase="\n```python", trigger_token_phrase="</think>",
88+
tokenizer=tokenizer, trigger_count=1, trigger_after=True)
89+
```

example_notebooks/transformers/cite_prompt_logits_processor.ipynb

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
" \n",
137137
"\n",
138138
"LLM response:\n",
139-
"The user seems to have mixed feelings about the price of the product. They find it expensive, but they also appreciate its softness, colorfulness, and style, which suggests that the product is well-made and worth the cost.\n",
139+
"The user seems to have mixed feelings about the price of the product. They find it expensive, but they also appreciate its softness, colorfulness, and style.\n",
140140
"-----END-----\n",
141141
"\n",
142142
"Prompt: \n",
@@ -158,7 +158,7 @@
158158
"source": [
159159
"runner.generate_response(\n",
160160
" example_prompts,\n",
161-
" [CiteFromPromptLogitsProcessor(runner.tokenizer, example_prompts, boost_factor=2.0)]\n",
161+
" [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=2.0, boost_eos=False)]\n",
162162
")"
163163
]
164164
},
@@ -187,9 +187,17 @@
187187
" \n",
188188
"\n",
189189
"LLM response:\n",
190-
"The reviewer seems to have mixed feelings towards the pricing of the product. They describe it as \"expensive\" and \"deserves its price,\" which suggests that they find it worth paying for its quality or unique features. The use of words like \"stylish\" further emphasizes their positive impression.\n",
190+
"The reviewer seems to have mixed feelings towards the pricing of the product:\n",
191191
"\n",
192-
"So in summary, while they appreciate the design and style of the product, they also acknowledge that it might be quite pricey. Therefore, their overall sentiment can be described as **mixed**, with appreciation for both aspects (quality and style).\n",
192+
"- They describe it as \"very soft\" and \"colorful\", suggesting that they appreciate these qualities.\n",
193+
"\n",
194+
"- They also mention that it is \"expensive,\" which might be seen as negative if you're looking for an affordable option or if this was their first time buying something like this.\n",
195+
"\n",
196+
"- However, they state that it \"deserves its price,\" indicating that they believe the high cost reflects on quality or value.\n",
197+
"\n",
198+
"Overall, while they seem satisfied with the overall experience and don't mind paying more for what they perceive as good-quality materials and design, they may feel that the price point could be higher than expected for everyday use or budget-conscious shoppers.\n",
199+
"\n",
200+
"So in summary, they find the item to be well-made and aesthetically pleasing despite feeling that it might not be suitable for everyone due to being too pricey for some people's budgets. The reviewer seems generally positive toward the purchase decision itself rather than just the specific item.\n",
193201
"-----END-----\n",
194202
"\n",
195203
"Prompt: \n",
@@ -215,9 +223,17 @@
215223
"source": [
216224
"runner.generate_response(\n",
217225
" example_prompts,\n",
218-
" [CiteFromPromptLogitsProcessor(runner.tokenizer, example_prompts, boost_factor=-2.0)]\n",
226+
" [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=-2.0, boost_eos=False)]\n",
219227
")"
220228
]
229+
},
230+
{
231+
"cell_type": "code",
232+
"execution_count": null,
233+
"id": "c29fedb3",
234+
"metadata": {},
235+
"outputs": [],
236+
"source": []
221237
}
222238
],
223239
"metadata": {

example_notebooks/vllm/force_last_phrase_logits_processor.ipynb

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@
2828
"name": "stdout",
2929
"output_type": "stream",
3030
"text": [
31-
"WARNING 02-12 13:42:36 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n",
32-
"WARNING 02-12 13:42:39 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
33-
"INFO 02-12 13:42:39 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='google/gemma-1.1-2b-it', speculative_config=None, tokenizer='google/gemma-1.1-2b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=google/gemma-1.1-2b-it, use_v2_block_manager=False, enable_prefix_caching=False)\n",
34-
"INFO 02-12 13:42:40 model_runner.py:879] Starting to load model google/gemma-1.1-2b-it...\n",
35-
"INFO 02-12 13:42:40 weight_utils.py:236] Using model weights format ['*.safetensors']\n"
31+
"WARNING 03-18 13:40:54 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n",
32+
"WARNING 03-18 13:40:58 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
33+
"INFO 03-18 13:40:58 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='google/gemma-1.1-2b-it', speculative_config=None, tokenizer='google/gemma-1.1-2b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=google/gemma-1.1-2b-it, use_v2_block_manager=False, enable_prefix_caching=False)\n",
34+
"INFO 03-18 13:40:59 model_runner.py:879] Starting to load model google/gemma-1.1-2b-it...\n",
35+
"INFO 03-18 13:41:00 weight_utils.py:236] Using model weights format ['*.safetensors']\n"
3636
]
3737
},
3838
{
3939
"data": {
4040
"application/vnd.jupyter.widget-view+json": {
41-
"model_id": "",
41+
"model_id": "ebad9294acfd4e15aa9272a1aac448df",
4242
"version_major": 2,
4343
"version_minor": 0
4444
},
@@ -53,8 +53,8 @@
5353
"name": "stdout",
5454
"output_type": "stream",
5555
"text": [
56-
"INFO 02-12 13:42:42 model_runner.py:890] Loading model weights took 4.6720 GB\n",
57-
"INFO 02-12 13:42:44 gpu_executor.py:121] # GPU blocks: 49686, # CPU blocks: 14563\n"
56+
"INFO 03-18 13:41:02 model_runner.py:890] Loading model weights took 4.6720 GB\n",
57+
"INFO 03-18 13:41:05 gpu_executor.py:121] # GPU blocks: 49742, # CPU blocks: 14563\n"
5858
]
5959
}
6060
],
@@ -156,10 +156,10 @@
156156
}
157157
],
158158
"source": [
159-
"phrase = \"\\n\\nReferences:\"\n",
159+
"reference = ForceLastPhraseLogitsProcessor(\"\\n\\nReferences:\", runner.tokenizer)\n",
160160
"\n",
161161
"runner.generate_response(example_prompts,\n",
162-
" [ForceLastPhraseLogitsProcessor(phrase, runner.tokenizer)])"
162+
" [reference])"
163163
]
164164
},
165165
{
@@ -199,19 +199,58 @@
199199
}
200200
],
201201
"source": [
202-
"phrase = \"\\n\\nThanks for trying our RAG application! If you have more questions about\"\n",
202+
"thank = ForceLastPhraseLogitsProcessor(\"\\n\\nThanks for trying our RAG application! If you have more questions about\",\n",
203+
" runner.tokenizer)\n",
203204
"\n",
204205
"runner.generate_response(example_prompts,\n",
205-
" [ForceLastPhraseLogitsProcessor(phrase, runner.tokenizer)])"
206+
" [thank])"
207+
]
208+
},
209+
{
210+
"cell_type": "markdown",
211+
"id": "34735f41",
212+
"metadata": {},
213+
"source": [
214+
"## Both"
206215
]
207216
},
208217
{
209218
"cell_type": "code",
210-
"execution_count": null,
219+
"execution_count": 6,
211220
"id": "88032bf3",
212221
"metadata": {},
213-
"outputs": [],
214-
"source": []
222+
"outputs": [
223+
{
224+
"name": "stdout",
225+
"output_type": "stream",
226+
"text": [
227+
"Prompt: \n",
228+
" Retrieved information from: https://en.wikipedia.org/wiki/Bulbasaur\n",
229+
" Bulbasaur is a fictional Pokémon species in Nintendo and Game Freak's Pokémon franchise. \n",
230+
" Designed by Atsuko Nishida, Bulbasaur is a Grass and Poison-type, first appearing in Pocket Monsters: Red and Green (Pokémon Red and Blue outside Japan) as a starter Pokémon. \n",
231+
" Since then, it has reappeared in sequels, spin-off games, related merchandise, and animated and printed adaptations of the franchise. \n",
232+
" It is a central character in the Pokémon anime, being one of Ash Ketchum's main Pokémon for the first season, with a different one later being obtained by supporting character May. \n",
233+
" It is featured in various manga and is owned by protagonist Red in Pokémon Adventures.\n",
234+
" \n",
235+
" What is Bulbasaur?\n",
236+
" \n",
237+
"Bulbasaur is a fictional Pokémon species that is a Grass and Poison-type.\n",
238+
"\n",
239+
"Thanks for trying our RAG application! If you have more questions about Bulbasaur or anything else, please let me know.\n",
240+
"\n",
241+
"References:\n",
242+
"\n",
243+
"* Wikipedia: Bulbasaur\n",
244+
"* Pokémon franchise website\n",
245+
"-----END-----\n",
246+
"\n"
247+
]
248+
}
249+
],
250+
"source": [
251+
"runner.generate_response(example_prompts,\n",
252+
" [thank, reference])"
253+
]
215254
}
216255
],
217256
"metadata": {

example_notebooks/vllm/trigger_phrase_logits_processor.ipynb

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@
2828
"name": "stdout",
2929
"output_type": "stream",
3030
"text": [
31-
"WARNING 02-13 10:32:45 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n",
32-
"WARNING 02-13 10:32:49 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
33-
"WARNING 02-13 10:32:49 arg_utils.py:839] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.\n",
34-
"INFO 02-13 10:32:49 config.py:911] Chunked prefill is enabled with max_num_batched_tokens=512.\n",
35-
"INFO 02-13 10:32:49 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, use_v2_block_manager=False, enable_prefix_caching=False)\n",
36-
"INFO 02-13 10:32:50 model_runner.py:879] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...\n",
37-
"INFO 02-13 10:32:51 weight_utils.py:236] Using model weights format ['*.safetensors']\n",
38-
"INFO 02-13 10:32:52 weight_utils.py:280] No model.safetensors.index.json found in remote.\n"
31+
"WARNING 03-18 13:37:20 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n",
32+
"WARNING 03-18 13:37:24 config.py:1563] Casting torch.bfloat16 to torch.float16.\n",
33+
"WARNING 03-18 13:37:24 arg_utils.py:839] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.\n",
34+
"INFO 03-18 13:37:24 config.py:911] Chunked prefill is enabled with max_num_batched_tokens=512.\n",
35+
"INFO 03-18 13:37:24 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, use_v2_block_manager=False, enable_prefix_caching=False)\n",
36+
"INFO 03-18 13:37:25 model_runner.py:879] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...\n",
37+
"INFO 03-18 13:37:26 weight_utils.py:236] Using model weights format ['*.safetensors']\n",
38+
"INFO 03-18 13:37:27 weight_utils.py:280] No model.safetensors.index.json found in remote.\n"
3939
]
4040
},
4141
{
4242
"data": {
4343
"application/vnd.jupyter.widget-view+json": {
44-
"model_id": "35561044f6c848eb9c56be591d6c50c8",
44+
"model_id": "e3ff26a7bf23415c9e29952e2a5f2d1a",
4545
"version_major": 2,
4646
"version_minor": 0
4747
},
@@ -56,8 +56,8 @@
5656
"name": "stdout",
5757
"output_type": "stream",
5858
"text": [
59-
"INFO 02-13 10:32:53 model_runner.py:890] Loading model weights took 3.3460 GB\n",
60-
"INFO 02-13 10:32:53 gpu_executor.py:121] # GPU blocks: 37897, # CPU blocks: 9362\n"
59+
"INFO 03-18 13:37:29 model_runner.py:890] Loading model weights took 3.3460 GB\n",
60+
"INFO 03-18 13:37:29 gpu_executor.py:121] # GPU blocks: 37898, # CPU blocks: 9362\n"
6161
]
6262
}
6363
],
@@ -338,9 +338,11 @@
338338
}
339339
],
340340
"source": [
341+
"trigger_python = TriggerPhraseLogitsProcessor(\"\\n```python\", \"</think>\", runner.tokenizer, \n",
342+
" trigger_count=1, trigger_after=True)\n",
343+
"\n",
341344
"runner.generate_response(example_prompts,\n",
342-
" [TriggerPhraseLogitsProcessor(\"\\n```python\", \"</think>\", runner.tokenizer, \n",
343-
" trigger_count=1, trigger_after=True)],\n",
345+
" [trigger_python],\n",
344346
" max_tokens=4096)"
345347
]
346348
},
@@ -387,12 +389,14 @@
387389
}
388390
],
389391
"source": [
392+
"keep_thinking_short = GenLengthLogitsProcessor(runner.tokenizer, boost_factor=0.1, complete_sentences=True,\n",
393+
" boost_token_str=\"</think>\")\n",
394+
"\n",
390395
"runner.generate_response(example_prompts,\n",
391396
" [\n",
392-
" GenLengthLogitsProcessor(runner.tokenizer, boost_factor=0.1, complete_sentences=True,\n",
393-
" boost_token_str=\"</think>\"),\n",
394-
" TriggerPhraseLogitsProcessor(\"\\n```python\", \"</think>\", runner.tokenizer, \n",
395-
" trigger_count=1, trigger_after=True)],\n",
397+
" keep_thinking_short,\n",
398+
" trigger_python\n",
399+
" ],\n",
396400
" max_tokens=4096)"
397401
]
398402
},
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import torch
19+
20+
21+
class BaseLogitsProcessor:
22+
def __init__(self):
23+
self.prompt_token_ids = None
24+
self.prev_token_ids = None
25+
26+
def _reset(self):
27+
pass
28+
29+
def _check_new_generation(self, input_ids: torch.LongTensor):
30+
first_time = self.prompt_token_ids is None
31+
if first_time:
32+
self._reset()
33+
self.prompt_token_ids = input_ids
34+
else:
35+
same_gen = False
36+
if input_ids.shape[1] > 1:
37+
same_gen = torch.equal(input_ids[:, :-1], self.prev_token_ids)
38+
39+
if not same_gen:
40+
self._reset()
41+
self.prompt_token_ids = input_ids
42+
43+
self.prev_token_ids = input_ids
44+
45+
def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
46+
return scores
47+
48+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
49+
self._check_new_generation(input_ids)
50+
scores = self._process(input_ids, scores)
51+
return scores

0 commit comments

Comments
 (0)