diff --git a/README.md b/README.md index ff86ffb..13a5d4e 100644 --- a/README.md +++ b/README.md @@ -81,11 +81,14 @@ I am getting a lot of calls during the day. What is more important for me to con The goal is to make LLM generate "3" as an answer. ### TriggerPhraseLogitsProcessor -A logits processor which triggers phrases when it encounters a given token. +A logits processor which triggers phrases when it encounters a given token or after a specified time. One common use case is to force writing python code just after thinking: ```python trigger_python = TriggerPhraseLogitsProcessor(phrase="\n```python", trigger_token_phrase="", tokenizer=tokenizer, trigger_count=1, trigger_after=True) ``` ### PreventHallucinationLogitsProcessor -A logits processor that mitigates hallucinated model outputs by enforcing a predefined fallback phrase when token confidence falls below a specified threshold. \ No newline at end of file +A logits processor that mitigates hallucinated model outputs by enforcing a predefined fallback phrase when token confidence falls below a specified threshold. + +### MaxTimeLogitsProcessor +A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes, optionally waiting for a new line or a full stop. Useful for controlling generation time and ensuring responses complete within time constraints. \ No newline at end of file diff --git a/examples/transformers/max_time_logits_processor.ipynb b/examples/transformers/max_time_logits_processor.ipynb new file mode 100644 index 0000000..346f46c --- /dev/null +++ b/examples/transformers/max_time_logits_processor.ipynb @@ -0,0 +1,447 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "28ed6952", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/data/projects/logproc_ws/logits-processor-zoo\n" + ] + } + ], + "source": [ + "%cd ../.." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0ea01217", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/envs/logproc/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from examples.transformers.utils import LLMRunner\n", + "\n", + "runner = LLMRunner(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "859aef8d", + "metadata": {}, + "source": [ + "## Default Response" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "56d5e65f", + "metadata": {}, + "outputs": [], + "source": [ + "example_prompts = [\n", + "\"\"\"\n", + "A farmer has a rectangular field. The length of the field is 20 meters longer than its width. \n", + "If the perimeter of the field is 200 meters, find the dimensions of the field.\n", + "\"\"\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cbf4c2d5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n", + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: \n", + "A farmer has a rectangular field. The length of the field is 20 meters longer than its width. \n", + "If the perimeter of the field is 200 meters, find the dimensions of the field.\n", + "\n", + "\n", + "LLM response:\n", + "First, I'll define the width of the field as \\( w \\) meters. Since the length is 20 meters longer than the width, the length will be \\( w + 20 \\) meters.\n", + "\n", + "Next, I'll use the formula for the perimeter of a rectangle, which is \\( 2 \\times (\\text{length} + \\text{width}) \\). Plugging in the expressions for length and width, the equation becomes:\n", + "\\[\n", + "2(w + (w + 20)) = 200\n", + "\\]\n", + "\n", + "Simplifying the equation:\n", + "\\[\n", + "2(2w + 20) = 200 \\\\\n", + "4w + 40 = 200 \\\\\n", + "4w = 160 \\\\\n", + "w = 40\n", + "\\]\n", + "\n", + "Finally, the width is 40 meters, and the length is \\( 40 + 20 = 60 \\) meters.\n", + "\n", + "\n", + "Let's solve the problem step by step.\n", + "\n", + "**Given:**\n", + "- The field is rectangular.\n", + "- The length is \\( 20 \\) meters longer than the width.\n", + "- The perimeter of the field is \\( 200 \\) meters.\n", + "\n", + "**Let:**\n", + "- \\( w \\) = width of the field (in meters)\n", + "- \\( l \\) = length of the field (in meters)\n", + "\n", + "**Step 1: Express the Length in Terms of the Width**\n", + "\n", + "Since the length is \\( 20 \\) meters longer than the width:\n", + "\\[\n", + "l = w + 20\n", + "\\]\n", + "\n", + "**Step 2: Use the Perimeter Formula**\n", + "\n", + "The perimeter \\( P \\) of a rectangle is given by:\n", + "\\[\n", + "P = 2l + 2w\n", + "\\]\n", + "Given that the perimeter is \\( 200 \\) meters:\n", + "\\[\n", + "2l + 2w = 200\n", + "\\]\n", + "\n", + "**Step 3: Substitute the Expression for \\( l \\) into the Perimeter Equation**\n", + "\n", + "\\[\n", + "2(w + 20) + 2w = 200\n", + "\\]\n", + "\n", + "**Step 4: Simplify and Solve for \\( w \\)**\n", + "\n", + "\\[\n", + "2w + 40 + 2w = 200 \\\\\n", + "4w + 40 = 200 \\\\\n", + "4w = 200 - 40 \\\\\n", + "4w = 160 \\\\\n", + "w = \\frac{160}{4} \\\\\n", + "w = 40 \\text{ meters}\n", + "\\]\n", + "\n", + "**Step 5: Find the Length \\( l \\)**\n", + "\n", + "\\[\n", + "l = w + 20 = 40 + 20 = 60 \\text{ meters}\n", + "\\]\n", + "\n", + "**Final Answer:**\n", + "\\[\n", + "\\boxed{\\text{Width: } 40 \\text{ meters}, \\text{ Length: } 60 \\text{ meters}}\n", + "\\]\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "runner.generate_response(example_prompts, max_tokens=4096) " + ] + }, + { + "cell_type": "markdown", + "id": "88bc2f8a", + "metadata": {}, + "source": [ + "## Interrupt after N seconds" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7d74eb26", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: \n", + "A farmer has a rectangular field. The length of the field is 20 meters longer than its width. \n", + "If the perimeter of the field is 200 meters, find the dimensions of the field.\n", + "\n", + "\n", + "LLM response:\n", + "First, I'll define the width of the field as \\( w \\) meters. Since the length is 20 meters longer than the width, the length will be \\( w + 20 \\) meters.\n", + "\n", + "Next, I'll use the formula for the perimeter of a rectangle, which is \\( 2 \\times (\\text{length} + \\text{width}) \\). Plugging in the expressions for length and width, the equation becomes:\n", + "\\[\n", + "2(w + (w + 20)) = 200\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "from logits_processor_zoo.transformers.max_time import MaxTimeLogitsProcessor\n", + "\n", + "runner.generate_response(example_prompts,\n", + " [MaxTimeLogitsProcessor(runner.tokenizer, \n", + " complete_sentences=True, \n", + " max_time=1)],\n", + " max_tokens=4096)" + ] + }, + { + "cell_type": "markdown", + "id": "f3da686c", + "metadata": {}, + "source": [ + "### Combine with TriggerPhrase: give the final answer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f1e982c3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: \n", + "A farmer has a rectangular field. The length of the field is 20 meters longer than its width. \n", + "If the perimeter of the field is 200 meters, find the dimensions of the field.\n", + "\n", + "\n", + "LLM response:\n", + "First, I'll define the width of the field as \\( w \\) meters. Since the length is 20 meters longer than the width, the length will be \\( w + 20 \\) meters.\n", + "\n", + "Next, I'll use the formula for the perimeter of a rectangle, which is \\( 2 \\times (\\text{length} + \\text{width}) \\). Plugging in the expressions for length and width, the equation becomes:\n", + "\\[\n", + "2(w + (w + 20)) = 200\n", + "\\]\n", + "\n", + "Simplifying the equation:\n", + "\\[\n", + "2(2w + 20) = 200 \\\\\n", + "4w + 40 = The time is over. The final answer is: The width is 45 meters and the length is 65 meters.\n", + "\\]\n", + "\n", + "\n", + "Let's solve the problem step by step.\n", + "\n", + "**Given:**\n", + "- The field is rectangular.\n", + "- The length is 20 meters longer than the width.\n", + "- The perimeter of the field is 200 meters.\n", + "\n", + "**Let:**\n", + "- \\( w \\) = width of the field (in meters)\n", + "- \\( l \\) = length of the field (in meters)\n", + "\n", + "**Step 1: Express the Length in Terms of the Width**\n", + "\n", + "According to the problem:\n", + "\\[\n", + "l = w + 20\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "from logits_processor_zoo.transformers import TriggerPhraseLogitsProcessor\n", + "\n", + "hurry_up = TriggerPhraseLogitsProcessor(runner.tokenizer,\n", + " batch_size=1, \n", + " phrase=\" The time is over. The final answer is:\", \n", + " trigger_time=2, \n", + " trigger_count=1, \n", + " trigger_after=True)\n", + "max_time = MaxTimeLogitsProcessor(runner.tokenizer,\n", + " complete_sentences=True,\n", + " max_time=2.5)\n", + "\n", + "runner.generate_response(example_prompts,\n", + " [hurry_up, max_time],\n", + " max_tokens=4096)" + ] + }, + { + "cell_type": "markdown", + "id": "8cc8cbe4", + "metadata": {}, + "source": [ + "### Combine with TriggerPhrase: Hurry up" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "779cca39", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: Count up to 100.\n", + "\n", + "LLM response:\n", + "I need to count up from 1 to 100. I'll start at 1 and continue adding one each time until I reach 100. I'll make sure to include every number in between without skipping any. This way, I can ensure that I count up correctly and comprehensively.\n", + "\n", + "\n", + "**Counting Up to 100**\n", + "\n", + "To count up from 1 to 100, follow these steps:\n", + "\n", + "1. **Start at 1**: Begin your count with the number 1.\n", + "2. **Add 1 Each Time**: Continue adding 1 to the previous number to get the next number in the sequence.\n", + "3. **Continue Until 100**: Keep doing this until you reach the number 100.\n", + "\n", + "Here is the sequence of numbers from 1 to 100:\n", + "\n", + "\\[\n", + "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, \\ldots, 98, 99, 100\n", + "\\]\n", + "\n", + "By following this method, you will have successfully counted up to 100.\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "\n", + "example_prompts = [\n", + "\"\"\"Count up to 100.\"\"\"\n", + "]\n", + "\n", + "runner.generate_response(example_prompts,\n", + " max_tokens=4096)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9a26d5ee", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: Count up to 100.\n", + "\n", + "LLM response:\n", + "I need to count up from 1 to 100. I'll start at 1 and continue adding one each time until I reach 100. I'll make sure to include every number in between without skipping any. This way, I can ensure that I count up correctly and comprehensively.\n", + "\n", + "\n", + "**Counting Up to Hurry up! The time is almost over. Stop thinking and go straight to the answer. **\n", + "\n", + "**Answer:**\n", + "\n", + "\\[\n", + "\\boxed{1,\\ 2,\\ 3,\\ \\ldots,\\ 100}\n", + "\\]\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "hurry_up = TriggerPhraseLogitsProcessor(runner.tokenizer,\n", + " batch_size=1, \n", + " phrase=\" Hurry up! The time is almost over. Stop thinking and go straight to the answer. \", \n", + " trigger_time=1, \n", + " trigger_count=1, \n", + " trigger_after=True)\n", + "max_time = MaxTimeLogitsProcessor(runner.tokenizer,\n", + " complete_sentences=True,\n", + " max_time=3)\n", + "\n", + "runner.generate_response(example_prompts,\n", + " [hurry_up, max_time],\n", + " max_tokens=4096)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "logproc", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/transformers/trigger_phrase_logits_processor.ipynb b/examples/transformers/trigger_phrase_logits_processor.ipynb index 72daa80..1490697 100644 --- a/examples/transformers/trigger_phrase_logits_processor.ipynb +++ b/examples/transformers/trigger_phrase_logits_processor.ipynb @@ -10,7 +10,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "/home/aerdem/projects/nvidia/logits-processor-zoo\n" + "/data/projects/logproc_ws/logits-processor-zoo\n" ] } ], @@ -28,7 +28,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n" + "/data/envs/logproc/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -291,8 +292,10 @@ } ], "source": [ + "\n", "runner.generate_response(example_prompts,\n", - " [TriggerPhraseLogitsProcessor(\"Wait,\", \"\", runner.tokenizer, batch_size=1, \n", + " [TriggerPhraseLogitsProcessor(runner.tokenizer, batch_size=1, phrase=\"Wait,\", \n", + " trigger_token_phrase=\"\", \n", " trigger_count=2, trigger_after=False)],\n", " max_tokens=4096)" ] @@ -377,8 +380,11 @@ } ], "source": [ + "\n", + "\n", "runner.generate_response(example_prompts,\n", - " [TriggerPhraseLogitsProcessor(\"\\n```python\", \"\", runner.tokenizer, batch_size=1,\n", + " [TriggerPhraseLogitsProcessor(runner.tokenizer, batch_size=1, phrase=\"\\n```python\", \n", + " trigger_token_phrase=\"\", \n", " trigger_count=1, trigger_after=True)],\n", " max_tokens=4096)" ] @@ -440,23 +446,16 @@ " [\n", " GenLengthLogitsProcessor(runner.tokenizer, boost_factor=0.1, complete_sentences=True,\n", " boost_token_str=\"\"),\n", - " TriggerPhraseLogitsProcessor(\"\\n```python\", \"\", runner.tokenizer, batch_size=1,\n", + " TriggerPhraseLogitsProcessor(runner.tokenizer, batch_size=1, phrase=\"\\n```python\", \n", + " trigger_token_phrase=\"\", \n", " trigger_count=1, trigger_after=True)],\n", " max_tokens=4096)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf7a7a91", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "logproc", "language": "python", "name": "python3" }, @@ -470,7 +469,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.17" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/examples/trtllm/max_time_logits_processor.py b/examples/trtllm/max_time_logits_processor.py new file mode 100644 index 0000000..006e6ac --- /dev/null +++ b/examples/trtllm/max_time_logits_processor.py @@ -0,0 +1,16 @@ +from transformers import AutoTokenizer +from logits_processor_zoo.trtllm import MaxTimeLogitsProcessor +from utils import TRTLLMTester, get_parser + + +if __name__ == "__main__": + args = get_parser() + + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + llm_tester = TRTLLMTester(args.model_name) + + lp = MaxTimeLogitsProcessor(tokenizer, max_time=100, complete_sentences=True) + llm_tester.run([args.prompt], logits_processor=lp) + + lp = MaxTimeLogitsProcessor(tokenizer, max_time=1.0, complete_sentences=True) + llm_tester.run([args.prompt], logits_processor=lp) diff --git a/examples/trtllm/trigger_phrase_logits_processor.py b/examples/trtllm/trigger_phrase_logits_processor.py index 2056bd4..9cf50be 100644 --- a/examples/trtllm/trigger_phrase_logits_processor.py +++ b/examples/trtllm/trigger_phrase_logits_processor.py @@ -9,9 +9,15 @@ tokenizer = AutoTokenizer.from_pretrained(args.model_name) llm_tester = TRTLLMTester(args.model_name) - lp = TriggerPhraseLogitsProcessor("...Wait, let me think more.", " function", tokenizer, - trigger_count=2, trigger_after=False) + lp = TriggerPhraseLogitsProcessor( + tokenizer, "...Wait, let me think more.", " function", trigger_count=2, trigger_after=False + ) llm_tester.run([args.prompt], logits_processor=lp) - lp = TriggerPhraseLogitsProcessor("\n```python", " function", tokenizer, trigger_count=1, trigger_after=True) + lp = TriggerPhraseLogitsProcessor(tokenizer, "\n```python", " function", trigger_count=1, trigger_after=True) + llm_tester.run([args.prompt], logits_processor=lp) + + lp = TriggerPhraseLogitsProcessor( + tokenizer, " only a few seconds left...", trigger_time=2, trigger_count=1, trigger_after=True + ) llm_tester.run([args.prompt], logits_processor=lp) diff --git a/examples/vllm/max_time_logits_processor.ipynb b/examples/vllm/max_time_logits_processor.ipynb new file mode 100644 index 0000000..568e7f7 --- /dev/null +++ b/examples/vllm/max_time_logits_processor.ipynb @@ -0,0 +1,482 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "28ed6952", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/data/projects/logproc_ws/logits-processor-zoo\n" + ] + } + ], + "source": [ + "%cd ../.." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5627e226", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/envs/logproc/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 07-05 10:13:49 [__init__.py:244] Automatically detected platform cuda.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-07-05 10:13:51,942\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 07-05 10:14:01 [config.py:823] This model supports multiple tasks: {'generate', 'score', 'classify', 'embed', 'reward'}. Defaulting to 'generate'.\n", + "WARNING 07-05 10:14:01 [config.py:3271] Casting torch.bfloat16 to torch.float16.\n", + "WARNING 07-05 10:14:02 [cuda.py:91] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used\n", + "INFO 07-05 10:14:02 [llm_engine.py:230] Initializing a V0 LLM engine (v0.9.1) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=16384, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=None, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=False, pooler_config=None, compilation_config={\"level\":0,\"debug_dump_path\":\"\",\"cache_dir\":\"\",\"backend\":\"\",\"custom_ops\":[],\"splitting_ops\":[],\"use_inductor\":true,\"compile_sizes\":[],\"inductor_compile_config\":{\"enable_auto_functionalized_v2\":false},\"inductor_passes\":{},\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":0,\"cudagraph_capture_sizes\":[],\"cudagraph_copy_inputs\":false,\"full_cuda_graph\":false,\"max_capture_size\":0,\"local_cache_dir\":null}, use_cached_outputs=False, \n", + "INFO 07-05 10:14:03 [cuda.py:327] Using Flash Attention backend.\n", + "INFO 07-05 10:14:04 [parallel_state.py:1065] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0\n", + "INFO 07-05 10:14:04 [model_runner.py:1171] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...\n", + "INFO 07-05 10:14:04 [weight_utils.py:292] Using model weights format ['*.safetensors']\n", + "INFO 07-05 10:14:04 [weight_utils.py:345] No model.safetensors.index.json found in remote.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00\n", + "\n", + "Let's solve the problem step by step.\n", + "\n", + "**Given:**\n", + "- The field is rectangular.\n", + "- The length is \\( 20 \\) meters longer than the width.\n", + "- The perimeter of the field is \\( 200 \\) meters.\n", + "\n", + "**Let:**\n", + "- \\( w \\) = width of the field (in meters)\n", + "- \\( l \\) = length of the field (in meters)\n", + "\n", + "**Step 1: Express the Length in Terms of the Width**\n", + "\n", + "Since the length is \\( 20 \\) meters longer than the width:\n", + "\\[\n", + "l = w + 20\n", + "\\]\n", + "\n", + "**Step 2: Use the Perimeter Formula**\n", + "\n", + "The perimeter \\( P \\) of a rectangle is given by:\n", + "\\[\n", + "P = 2l + 2w\n", + "\\]\n", + "Given that the perimeter is \\( 200 \\) meters:\n", + "\\[\n", + "2l + 2w = 200\n", + "\\]\n", + "\n", + "**Step 3: Substitute the Expression for \\( l \\) into the Perimeter Equation**\n", + "\n", + "\\[\n", + "2(w + 20) + 2w = 200\n", + "\\]\n", + "\n", + "**Step 4: Simplify and Solve for \\( w \\)**\n", + "\n", + "\\[\n", + "2w + 40 + 2w = 200 \\\\\n", + "4w + 40 = 200 \\\\\n", + "4w = 200 - 40 \\\\\n", + "4w = 160 \\\\\n", + "w = \\frac{160}{4} \\\\\n", + "w = 40 \\text{ meters}\n", + "\\]\n", + "\n", + "**Step 5: Find the Length \\( l \\)**\n", + "\n", + "\\[\n", + "l = w + 20 = 40 + 20 = 60 \\text{ meters}\n", + "\\]\n", + "\n", + "**Final Answer:**\n", + "\\[\n", + "\\boxed{\\text{Width: } 40 \\text{ meters}, \\text{ Length: } 60 \\text{ meters}}\n", + "\\]\n", + "-----END-----\n", + "\n", + "Prompt: Count up to 100.\n", + "I need to count up from 1 to 100. I'll start at 1 and continue adding one each time until I reach 100. I'll make sure to include every number in between without skipping any. This way, I can ensure that I count up correctly and comprehensively.\n", + "\n", + "\n", + "**Counting Up to 100**\n", + "\n", + "To count up from 1 to 100, follow these steps:\n", + "\n", + "1. **Start at 1**: Begin your count with the number 1.\n", + "2. **Add 1 Each Time**: Continue adding 1 to the previous number to get the next number in the sequence.\n", + "3. **Continue Until 100**: Keep doing this until you reach the number 100.\n", + "\n", + "Here is the sequence of numbers from 1 to 100:\n", + "\n", + "\\[\n", + "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, \\ldots, 98, 99, 100\n", + "\\]\n", + "\n", + "By following this method, you will have successfully counted up to 100.\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "runner.generate_response(example_prompts, max_tokens=4096)" + ] + }, + { + "cell_type": "markdown", + "id": "88bc2f8a", + "metadata": {}, + "source": [ + "## Interrupt after N seconds" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7d74eb26", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: \n", + "A farmer has a rectangular field. The length of the field is 20 meters longer than its width. \n", + "If the perimeter of the field is 200 meters, find the dimensions of the field.\n", + "\n", + "First, I'll define the width of the field as \\( w \\) meters. Since the length is 20 meters longer than the width, the length will be \\( w + 20 \\) meters.\n", + "\n", + "Next, I'll use the formula for the perimeter of a rectangle, which is \\( 2 \\times (\\text{length} + \\text{width}) \\). Plugging in the expressions for length and width, the equation becomes:\n", + "\\[\n", + "2(w + (w + 20)) =\n", + "-----END-----\n", + "\n", + "Prompt: Count up to 100.\n", + "I need to count up from 1 to 100. I'll start at 1 and continue adding one each time until I reach 100. I'll make sure to include every number in between without skipping any. This way, I can ensure that I count up correctly and comprehensively.\n", + "\n", + "\n", + "**Counting Up to 100**\n", + "\n", + "To count up from 1 to 100, follow these steps:\n", + "\n", + "1. **Start at 1**: Begin your count with the number\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "runner.generate_response(example_prompts,\n", + " [MaxTimeLogitsProcessor(runner.tokenizer, \n", + " complete_sentences=False, \n", + " max_time=1,\n", + " )],\n", + " max_tokens=2000)" + ] + }, + { + "cell_type": "markdown", + "id": "549f741e", + "metadata": {}, + "source": [ + "## Combine with TriggerPhrase " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "40679857", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: \n", + "A farmer has a rectangular field. The length of the field is 20 meters longer than its width. \n", + "If the perimeter of the field is 200 meters, find the dimensions of the field.\n", + "\n", + "First, I'll define the width of the field as \\( w \\) meters. Since the length is 20 meters longer than the width, the length will be \\( w + 20 \\) meters.\n", + "\n", + "Next, I'll use the formula for the perimeter of a rectangle, which is \\( 2 \\times (\\text{length} + \\text{width}) \\). Plugging in the expressions for length and width, the equation becomes:\n", + "\\[\n", + "2(w + (w + 20)) = 200\n", + "\\]\n", + "\n", + "Simplifying the equation:\n", + "\\[\n", + "2(2w + 20) = 200 \\\\\n", + "4w + 40 = 200 \\\\\n", + "4w = 160 \\\\\n", + "w = 40\n", + "\\]\n", + "\n", + "Finally, the width is 40 meters, and the length is \\( 40 + 20 = 60 \\) meters.\n", + "\n", + "\n", + "Let's solve the problem step by step.\n", + "\n", + "**Given:**\n", + "- The field is rectangular.\n", + "- The length is \\( 2 The time is over. The final answer is: **Width = 40 meters, Length = 60 meters.\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "from logits_processor_zoo.vllm import TriggerPhraseLogitsProcessor\n", + "\n", + "hurry_up = TriggerPhraseLogitsProcessor(runner.tokenizer,\n", + " phrase=\" The time is over. The final answer is:\", \n", + " trigger_time=2, \n", + " trigger_count=1, \n", + " trigger_after=True)\n", + "max_time = MaxTimeLogitsProcessor(runner.tokenizer,\n", + " complete_sentences=True,\n", + " max_time=2.2)\n", + "\n", + "runner.generate_response(example_prompts,\n", + " [hurry_up, max_time],\n", + " max_tokens=4096)" + ] + }, + { + "cell_type": "markdown", + "id": "1fdad362", + "metadata": {}, + "source": [ + "### Combine with TriggerPhrase: Hurry up!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "72b70940", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: Count up to 100.\n", + "I need to count up from 1 to 100. I'll start at 1 and continue adding one each time until I reach 100. I'll make sure to include every number in between without skipping any. This way, I can ensure that I count up correctly and comprehensively.\n", + "\n", + "\n", + "**Counting Up to 100**\n", + "\n", + "To count up from 1 to 100, follow these steps:\n", + "\n", + "1. **Start at 1**: Begin your count with the number 1.\n", + "2. **Add 1 Each Time**: Continue adding 1 to the previous number to get the next number in the sequence.\n", + "3. **Continue Until 100**: Keep doing this until you reach the number 100.\n", + "\n", + "Here is the sequence of numbers from 1 to 100:\n", + "\n", + "\\[\n", + "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, \\ldots, 98, 99, 100\n", + "\\]\n", + "\n", + "By following this method, you will have successfully counted up to 100.\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "example_prompts = [\n", + " \"\"\"Count up to 100.\"\"\"\n", + "]\n", + "\n", + "runner.generate_response(example_prompts,\n", + " max_tokens=4096) " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7dfff28b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: Count up to 100.\n", + "I need to count up from 1 to 100. I'll start at 1 and continue adding one each time until I reach 100. I'll make sure to include every number in between without skipping any. This way, I can ensure that I Hurry up! The time is almost over. Stop thinking and go straight to a short answer. The count from 1 to 100 is straightforward. I'll list each number sequentially, making sure to include every integer from 1 up to 100. This will provide a clear and complete sequence of numbers in the specified range. \n", + "\n", + "\n", + "**Counting from 1 to 100**\n", + "\n", + "Here is the sequence of numbers from 1 to 100:\n", + "\n", + "\\[\n", + "1, 2, 3, \\ldots, 98, 99, 100\n", + "\\]\n", + "\n", + "This is a straightforward count where each number is incremented by 1 from the previous number, starting at 1 and ending at 100.\n", + "-----END-----\n", + "\n" + ] + } + ], + "source": [ + "\n", + "\n", + "example_prompts = [\n", + " \"\"\"Count up to 100.\"\"\"\n", + "]\n", + "\n", + "hurry_up = TriggerPhraseLogitsProcessor(runner.tokenizer,\n", + " phrase=\" Hurry up! The time is almost over. Stop thinking and go straight to a short answer. \", \n", + " trigger_time=0.5, \n", + " trigger_count=1, \n", + " trigger_after=True)\n", + "max_time = MaxTimeLogitsProcessor(runner.tokenizer,\n", + " complete_sentences=True,\n", + " max_time=3)\n", + "\n", + "runner.generate_response(example_prompts,\n", + " [hurry_up, max_time],\n", + " max_tokens=4096)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "logproc", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/vllm/trigger_phrase_logits_processor.ipynb b/examples/vllm/trigger_phrase_logits_processor.ipynb index 2f796d4..e438883 100644 --- a/examples/vllm/trigger_phrase_logits_processor.ipynb +++ b/examples/vllm/trigger_phrase_logits_processor.ipynb @@ -10,7 +10,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "/home/aerdem/projects/nvidia/logits-processor-zoo\n" + "/data/projects/logproc_ws/logits-processor-zoo\n" ] } ], @@ -24,67 +24,77 @@ "id": "b89279fe", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/envs/logproc/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "WARNING 04-29 15:31:18 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead. See https://pypi.org/project/pynvml for more information.\n" + "INFO 07-08 11:01:41 [__init__.py:244] Automatically detected platform cuda.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n" + "2025-07-08 11:01:49,020\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "WARNING 04-29 15:31:21 config.py:1563] Casting torch.bfloat16 to torch.float16.\n", - "WARNING 04-29 15:31:21 arg_utils.py:839] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.\n", - "INFO 04-29 15:31:21 config.py:911] Chunked prefill is enabled with max_num_batched_tokens=512.\n", - "INFO 04-29 15:31:21 llm_engine.py:184] Initializing an LLM engine (v0.5.5) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, use_v2_block_manager=False, enable_prefix_caching=False)\n" + "INFO 07-08 11:02:04 [config.py:823] This model supports multiple tasks: {'generate', 'score', 'classify', 'embed', 'reward'}. Defaulting to 'generate'.\n", + "WARNING 07-08 11:02:04 [config.py:3271] Casting torch.bfloat16 to torch.float16.\n", + "WARNING 07-08 11:02:04 [cuda.py:91] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used\n", + "INFO 07-08 11:02:04 [llm_engine.py:230] Initializing a V0 LLM engine (v0.9.1) with config: model='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=16384, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=None, served_model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=False, pooler_config=None, compilation_config={\"level\":0,\"debug_dump_path\":\"\",\"cache_dir\":\"\",\"backend\":\"\",\"custom_ops\":[],\"splitting_ops\":[],\"use_inductor\":true,\"compile_sizes\":[],\"inductor_compile_config\":{\"enable_auto_functionalized_v2\":false},\"inductor_passes\":{},\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":0,\"cudagraph_capture_sizes\":[],\"cudagraph_copy_inputs\":false,\"full_cuda_graph\":false,\"max_capture_size\":0,\"local_cache_dir\":null}, use_cached_outputs=False, \n", + "INFO 07-08 11:02:06 [cuda.py:327] Using Flash Attention backend.\n", + "INFO 07-08 11:02:07 [parallel_state.py:1065] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0\n", + "INFO 07-08 11:02:07 [model_runner.py:1171] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B...\n", + "INFO 07-08 11:02:07 [weight_utils.py:292] Using model weights format ['*.safetensors']\n", + "INFO 07-08 11:02:07 [weight_utils.py:345] No model.safetensors.index.json found in remote.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + "Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00\n", @@ -241,21 +251,23 @@ "\n", "Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n", "\n", - "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n", + "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n", "\n", - "Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n", + "Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n", "\n", "I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n", "\n", - "Wait, but in the problem statement, it says to make it recursive. So, the function should call itself, which it does. So, the function is recursive as required.\n", + "Wait, but in the problem statement, it says to make it recursive. So, the function as written is recursive, but perhaps it's better to memoize it for better performance. But that's beyond the scope here. So, I'll proceed with the basic recursive approach.\n", + "\n", + "Wait, but in the function I wrote, for n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function seems to handle all cases correctly.\n", "\n", - "Wait, but in the function I wrote, for n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. For n=3, returns 2, etc. So, the function is correct.\n", + "I think that's a solid approach. So, the function is as written above.\n", "\n", "\n", "To solve this problem, we need to generate the nth Fibonacci number using a recursive approach. The Fibonacci sequence is a series of numbers where each number is the sum of the two preceding ones, starting from 0 and 1. \n", "\n", "### Approach\n", - "The approach to solve this problem involves using recursion, which is a method where a function calls itself to solve smaller instances of the same problem. Here's a step-by-step breakdown of the approach:\n", + "The approach to solve this problem involves using recursion, which is a method where the function calls itself with a smaller input until it reaches a base case. Here's a step-by-step breakdown of the approach:\n", "\n", "1. **Base Cases**: \n", " - If `n` is 0, return 0.\n", @@ -275,14 +287,14 @@ " elif n == 1:\n", " return 1\n", " else:\n", - " return fibonacci(n-1) + fibonacci(n-2)\n", + " return fibonacci(n - 1) + fibonacci(n - 2)\n", "```\n", "\n", "### Explanation\n", "- **Base Cases**: The function first checks if `n` is 0 or 1. If `n` is 0, it returns 0. If `n` is 1, it returns 1. These are the simplest cases of the Fibonacci sequence.\n", "- **Recursive Case**: For any `n` greater than 1, the function calls itself with `n-1` and `n-2` and returns the sum of these two recursive calls. This builds up the solution by solving smaller subproblems and combining their results.\n", "\n", - "This approach is straightforward and leverages the divide-and-conquer strategy inherent in recursion, making it easy to understand and implement. However, it's important to note that this approach has a time complexity of O(2^n) due to the exponential number of function calls, which is not efficient for large values of `n`. For larger values, an iterative approach or memoization would be more efficient.\n", + "This approach is straightforward and easy to understand, but it's important to note that for large values of `n`, this method can be inefficient due to repeated calculations. However, for the purpose of this problem, the recursive approach is sufficient.\n", "-----END-----\n", "\n" ] @@ -290,7 +302,7 @@ ], "source": [ "runner.generate_response(example_prompts,\n", - " [TriggerPhraseLogitsProcessor(\"\\nWait,\", \"\", runner.tokenizer, \n", + " [TriggerPhraseLogitsProcessor(runner.tokenizer, \"\\nWait,\", \"\", \n", " trigger_count=2, trigger_after=False)],\n", " max_tokens=4096)" ] @@ -342,9 +354,9 @@ "\n", "Let me test this function with some examples. For n=0, it returns 0. For n=1, returns 1. For n=2, it's F(1)+F(0) = 1+0=1. For n=3, F(2)+F(1)=1+1=2. That looks correct.\n", "\n", - "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which matches the standard definition. So, the function should work regardless of the starting point as long as the base cases are correct.\n", + "Wait, but sometimes people define the Fibonacci sequence starting with F(1)=1, F(2)=1, F(3)=2, etc. So, if the function is called with n=5, it should return 5. Let me see: F(5) is 5, which is correct.\n", "\n", - "Another thing to consider is the base cases. If the function is called with n=0, it returns 0, which is correct. For n=1, returns 1. For n=2, returns 1, which is correct. So, the function should handle all non-negative integers correctly.\n", + "Another test case: n=5. Let's compute it step by step. F(0)=0, F(1)=1, F(2)=1, F(3)=2, F(4)=3, F(5)=5. So the function should return 5 for n=5.\n", "\n", "I think this should work. So, the function is straightforward. It's a simple recursive implementation, but it's not the most efficient for large n. However, for the purpose of this problem, it's acceptable.\n", "\n", @@ -365,7 +377,7 @@ } ], "source": [ - "trigger_python = TriggerPhraseLogitsProcessor(\"\\n```python\", \"\", runner.tokenizer, \n", + "trigger_python = TriggerPhraseLogitsProcessor(runner.tokenizer, \"\\n```python\", \"\", \n", " trigger_count=1, trigger_after=True)\n", "\n", "runner.generate_response(example_prompts,\n", @@ -426,19 +438,11 @@ " ],\n", " max_tokens=4096)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf7a7a91", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "logproc", "language": "python", "name": "python3" }, @@ -452,7 +456,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.17" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/logits_processor_zoo/transformers/__init__.py b/logits_processor_zoo/transformers/__init__.py index 8e412fe..b6f51a3 100644 --- a/logits_processor_zoo/transformers/__init__.py +++ b/logits_processor_zoo/transformers/__init__.py @@ -21,6 +21,8 @@ from .multiple_choice import MultipleChoiceLogitsProcessor from .trigger_phrase import TriggerPhraseLogitsProcessor from .prevent_hallucination import PreventHallucinationLogitsProcessor +from .max_time import MaxTimeLogitsProcessor __all__ = ['GenLengthLogitsProcessor', 'CiteFromPromptLogitsProcessor', 'ForceLastPhraseLogitsProcessor', - 'MultipleChoiceLogitsProcessor', 'TriggerPhraseLogitsProcessor', 'PreventHallucinationLogitsProcessor'] + 'MultipleChoiceLogitsProcessor', 'TriggerPhraseLogitsProcessor', 'PreventHallucinationLogitsProcessor', + 'MaxTimeLogitsProcessor'] diff --git a/logits_processor_zoo/transformers/max_time.py b/logits_processor_zoo/transformers/max_time.py new file mode 100644 index 0000000..8f7d6a1 --- /dev/null +++ b/logits_processor_zoo/transformers/max_time.py @@ -0,0 +1,73 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +import torch +from transformers import PreTrainedTokenizer +from logits_processor_zoo.transformers.base import BaseLogitsProcessor +from logits_processor_zoo.utils import text_to_token, enforce_tokens + + +class MaxTimeLogitsProcessor(BaseLogitsProcessor): + """ + A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes. + Useful for controlling generation time and ensuring responses complete within time constraints. + + Parameters + ---------- + tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. + max_time (float): Maximum time (wall-clock time) in seconds after which the EOS token must be enforced. + complete_sentences (bool, optional): If True, enforces EOS token only when the last token is a full stop + or a new line. Default is False. + boost_token_str (str, optional): A string to be tokenized and used instead of EOS. + + """ + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + max_time: float, + complete_sentences: bool = False, + boost_token_str: str = None, + ): + super().__init__() + self.boost_token = tokenizer.eos_token_id + if boost_token_str is not None: + self.boost_token = text_to_token(tokenizer, boost_token_str, last=False) + self.max_time = max_time + self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True) + self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True) + self.complete_sentences = complete_sentences + + def _reset(self): + self.start_time = time.time() + + def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: + elapsed_time = time.time() - self.start_time + token_count = input_ids.shape[1] - self.prompt_token_ids.shape[1] + + enabled = (input_ids[:, -token_count:] == self.boost_token).sum(dim=1) == 0 + if self.complete_sentences: + enabled = enabled & ((input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token)) + + if elapsed_time > self.max_time: + for i in range(scores.shape[0]): + if enabled[i]: + scores[i] = enforce_tokens(scores[i], [self.boost_token]) + return scores + + return scores diff --git a/logits_processor_zoo/transformers/trigger_phrase.py b/logits_processor_zoo/transformers/trigger_phrase.py index 949de51..8baf965 100644 --- a/logits_processor_zoo/transformers/trigger_phrase.py +++ b/logits_processor_zoo/transformers/trigger_phrase.py @@ -15,6 +15,8 @@ # limitations under the License. # +import time +from typing import Optional from transformers import PreTrainedTokenizer import torch from logits_processor_zoo.utils import text_to_token, enforce_tokens @@ -27,24 +29,39 @@ class TriggerPhraseLogitsProcessor(BaseLogitsProcessor): Parameters ---------- - phrase (str): The phrase to be generated by LLM when it encounters the trigger token. - trigger_token_phrase (str): One token phrase in string to trigger phrases. tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. + batch_size (int): The batch size. + phrase (str): The phrase to be generated by LLM when it encounters the trigger token. + trigger_token_phrase (str): (Optional) One token phrase in string to trigger phrases. + trigger_time (float): (Optional) Time (wall-clock time) in seconds after which the phrase will be triggered. trigger_count (int): How many times the phrase will be triggered. trigger_after (bool): Whether the phrase is written after the trigger token or instead of the trigger token. """ - def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrainedTokenizer, batch_size: int, + + def __init__(self, tokenizer: PreTrainedTokenizer, batch_size: int, phrase: str, + trigger_token_phrase: Optional[str] = None, trigger_time: Optional[float] = None, trigger_count: int = 1, trigger_after: bool = False): + + assert ( + trigger_token_phrase is not None or trigger_time is not None + ), "Either trigger_token_phrase or trigger_time must be provided" + super().__init__() - self.trigger_token = text_to_token(tokenizer, trigger_token_phrase, last=False) + + self.trigger_token = None + if trigger_token_phrase is not None: + self.trigger_token = text_to_token(tokenizer, trigger_token_phrase, last=False) + self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False) self.trigger_after = trigger_after self.batch_size = batch_size self.initial_trigger_count = trigger_count + self.trigger_time = trigger_time or float("inf") def _reset(self): self.iterators = -torch.ones(self.batch_size, dtype=torch.int32) - self.trigger_count = self.initial_trigger_count*torch.ones(self.batch_size, dtype=torch.int32) + self.trigger_count = self.initial_trigger_count * torch.ones(self.batch_size, dtype=torch.int32) + self.start_time = time.time() def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: for i in range(scores.shape[0]): @@ -52,7 +69,9 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to continue it = self.iterators[i].item() - if scores[i, :].argmax() == self.trigger_token and it == -1: + + time_over = time.time() - self.start_time > self.trigger_time + if (scores[i, :].argmax() == self.trigger_token or time_over) and it == -1: self.iterators[i] = 0 if not self.trigger_after: scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[0]]) @@ -64,5 +83,6 @@ def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to if len(self.phrase_tokens) == self.iterators[i].item(): # phrase completed, reset for next trigger self.iterators[i] = -1 self.trigger_count[i] -= 1 + self.start_time = time.time() return scores diff --git a/logits_processor_zoo/trtllm/__init__.py b/logits_processor_zoo/trtllm/__init__.py index 97e5cf2..7ac298d 100644 --- a/logits_processor_zoo/trtllm/__init__.py +++ b/logits_processor_zoo/trtllm/__init__.py @@ -21,6 +21,8 @@ from .multiple_choice import MultipleChoiceLogitsProcessor from .prevent_hallucination import PreventHallucinationLogitsProcessor from .trigger_phrase import TriggerPhraseLogitsProcessor +from .max_time import MaxTimeLogitsProcessor __all__ = ['GenLengthLogitsProcessor', 'ForceLastPhraseLogitsProcessor', 'CiteFromPromptLogitsProcessor', - 'MultipleChoiceLogitsProcessor', 'PreventHallucinationLogitsProcessor', 'TriggerPhraseLogitsProcessor'] + 'MultipleChoiceLogitsProcessor', 'PreventHallucinationLogitsProcessor', 'TriggerPhraseLogitsProcessor', + 'MaxTimeLogitsProcessor'] diff --git a/logits_processor_zoo/trtllm/max_time.py b/logits_processor_zoo/trtllm/max_time.py new file mode 100644 index 0000000..01f2f3e --- /dev/null +++ b/logits_processor_zoo/trtllm/max_time.py @@ -0,0 +1,85 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Optional +import time +from transformers import PreTrainedTokenizer +import torch +from tensorrt_llm.sampling_params import LogitsProcessor +from logits_processor_zoo.utils import text_to_token, enforce_tokens + + +class MaxTimeLogitsProcessor(LogitsProcessor): + """ + A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes. + Useful for controlling generation time and ensuring responses complete within time constraints. + + Parameters + ---------- + tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. + max_time (float): Maximum time (wall-clock time) in seconds after which the EOS token must be enforced. + complete_sentences (bool, optional): If True, enforces EOS token only when the last token is a full stop + or a new line. Default is False. + boost_token_str (str, optional): A string to be tokenized and used instead of EOS. + """ + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + max_time: float, + complete_sentences: bool = False, + boost_token_str: str = None, + ): + self.tokenizer = tokenizer + self.boost_token = self.tokenizer.eos_token_id + self.boost_token_str = boost_token_str + if boost_token_str is not None: + self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False) + self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True) + self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True) + self.complete_sentences = complete_sentences + self.token_count = 0 + self.max_time = max_time + self.start_time = time.time() + + def __call__( + self, + req_id: int, + logits: torch.Tensor, + token_ids: List[List[int]], + stream_ptr: Optional[int], + client_id: Optional[int], + ) -> None: + + elapsed_time = time.time() - self.start_time + time_exceeded = elapsed_time > self.max_time + + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) + + with torch.cuda.stream(stream): + ids = torch.LongTensor(token_ids).to(logits.device, non_blocking=True) + + enabled = True + if self.complete_sentences: + enabled = (ids[:, -1] == self.full_stop_token) | (ids[:, -1] == self.new_line_token) + + if time_exceeded and enabled: + # enforce the EOS token + for i in range(logits.shape[1]): + enforce_tokens(logits[0, i], [self.boost_token]) + + self.token_count += 1 diff --git a/logits_processor_zoo/trtllm/trigger_phrase.py b/logits_processor_zoo/trtllm/trigger_phrase.py index 8afa7b8..0482930 100644 --- a/logits_processor_zoo/trtllm/trigger_phrase.py +++ b/logits_processor_zoo/trtllm/trigger_phrase.py @@ -16,6 +16,7 @@ # from typing import List, Optional +import time from transformers import PreTrainedTokenizer import torch from logits_processor_zoo.utils import enforce_tokens, text_to_token @@ -28,25 +29,35 @@ class TriggerPhraseLogitsProcessor(LogitsProcessor): Parameters ---------- - phrase (str): The phrase to be generated by LLM when it encounters the trigger token. - trigger_token_phrase (str): One token phrase in string to trigger phrases. tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. + phrase (str): The phrase to be generated by LLM when it encounters the trigger token. + trigger_token_phrase (str): (Optional) One token phrase in string to trigger phrases. + trigger_time (float): (Optional) Time (wall-clock time) in seconds after which the phrase will be triggered. trigger_count (int): How many times the phrase will be triggered. trigger_after (bool): Whether the phrase is written after the trigger token or instead of the trigger token. """ - def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrainedTokenizer, - trigger_count: int = 1, trigger_after: bool = False): + + def __init__(self, tokenizer: PreTrainedTokenizer, phrase: str, trigger_token_phrase: Optional[str] = None, + trigger_time: Optional[float] = None, trigger_count: int = 1, trigger_after: bool = False): + assert ( + trigger_token_phrase is not None or trigger_time is not None + ), "Either trigger_token_phrase or trigger_time must be provided" self.tokenizer = tokenizer - self.trigger_token = text_to_token(self.tokenizer, trigger_token_phrase, last=False) + self.trigger_token = None + if trigger_token_phrase is not None: + self.trigger_token = text_to_token(self.tokenizer, trigger_token_phrase, last=False) + + self.trigger_time = trigger_time or float("inf") self.phrase_tokens = self.tokenizer.encode(phrase, add_special_tokens=False) self.initial_trigger_count = trigger_count self.trigger_after = trigger_after self.iterators = None self.trigger_counts = None + self.start_time = time.time() def _init_before_gen(self, beam_width): self.iterators = -torch.ones(beam_width, dtype=torch.int32) - self.trigger_counts = self.initial_trigger_count*torch.ones(beam_width, dtype=torch.int32) + self.trigger_counts = self.initial_trigger_count * torch.ones(beam_width, dtype=torch.int32) def __call__(self, req_id: int, logits: torch.Tensor, token_ids: List[List[int]], stream_ptr: Optional[int], @@ -64,7 +75,8 @@ def __call__(self, req_id: int, logits: torch.Tensor, current_index = self.iterators[i].item() - if logits[0, i].argmax() == self.trigger_token and current_index == -1: + time_over = time.time() - self.start_time > self.trigger_time + if (logits[0, i].argmax() == self.trigger_token or time_over) and current_index == -1: self.iterators[i] = 0 if not self.trigger_after: enforce_tokens(logits[0, i], [self.phrase_tokens[0]]) diff --git a/logits_processor_zoo/vllm/__init__.py b/logits_processor_zoo/vllm/__init__.py index 8e412fe..b6f51a3 100644 --- a/logits_processor_zoo/vllm/__init__.py +++ b/logits_processor_zoo/vllm/__init__.py @@ -21,6 +21,8 @@ from .multiple_choice import MultipleChoiceLogitsProcessor from .trigger_phrase import TriggerPhraseLogitsProcessor from .prevent_hallucination import PreventHallucinationLogitsProcessor +from .max_time import MaxTimeLogitsProcessor __all__ = ['GenLengthLogitsProcessor', 'CiteFromPromptLogitsProcessor', 'ForceLastPhraseLogitsProcessor', - 'MultipleChoiceLogitsProcessor', 'TriggerPhraseLogitsProcessor', 'PreventHallucinationLogitsProcessor'] + 'MultipleChoiceLogitsProcessor', 'TriggerPhraseLogitsProcessor', 'PreventHallucinationLogitsProcessor', + 'MaxTimeLogitsProcessor'] diff --git a/logits_processor_zoo/vllm/max_time.py b/logits_processor_zoo/vllm/max_time.py new file mode 100644 index 0000000..467795e --- /dev/null +++ b/logits_processor_zoo/vllm/max_time.py @@ -0,0 +1,92 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from typing import List +import torch +from transformers import PreTrainedTokenizer, AutoTokenizer +from logits_processor_zoo.utils import text_to_token, enforce_tokens + + +class MaxTimeLogitsProcessor: + """ + A logits processor that enforces the end-of-sentence (EOS) token after a specified maximum time passes. + Useful for controlling generation time and ensuring responses complete within time constraints. + + Parameters + ---------- + tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. + max_time (float): Maximum time (wall-clock time) in seconds after which the EOS token must be enforced. + complete_sentences (bool, optional): If True, enforces EOS token only when the last token is a full stop + or a new line. Default is False. + boost_token_str (str, optional): A string to be tokenized and used instead of EOS. + + """ + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + max_time: float, + complete_sentences: bool = False, + boost_token_str: str = None, + ): + self.tokenizer = tokenizer + if isinstance(self.tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer) + + self.boost_token = self.tokenizer.eos_token_id + self.boost_token_str = boost_token_str + if boost_token_str is not None: + self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False) + self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True) + self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True) + self.complete_sentences = complete_sentences + self.max_time = max_time + self._reset() + + # Mutable logits processor gets cloned for each prompt in a batch in order to prevent updating the same object + # https://github.com/vllm-project/vllm/blob/19dcc02a72e3ed52e3bf95aae44ea1f40ce42ea0/vllm/sampling_params.py#L537-L550 + def clone(self): + return MaxTimeLogitsProcessor( + self.tokenizer, + self.max_time, + self.complete_sentences, + self.boost_token_str, + ) + + def _reset(self): + self.start_time = time.time() + + def __call__( + self, + prompt_tokens_ids: List[int], + past_token_ids: List[int], + scores: torch.Tensor, + ) -> torch.Tensor: + + elapsed_time = time.time() - self.start_time + time_exceeded = elapsed_time > self.max_time + gen_length = len(past_token_ids) + + enabled = True + if self.complete_sentences and gen_length > 0: + enabled = (past_token_ids[-1] == self.full_stop_token) | (past_token_ids[-1] == self.new_line_token) + + if time_exceeded and enabled: + scores = enforce_tokens(scores, [self.boost_token]) + + return scores diff --git a/logits_processor_zoo/vllm/trigger_phrase.py b/logits_processor_zoo/vllm/trigger_phrase.py index 356ae7f..569d61d 100644 --- a/logits_processor_zoo/vllm/trigger_phrase.py +++ b/logits_processor_zoo/vllm/trigger_phrase.py @@ -15,8 +15,9 @@ # limitations under the License. # +import time from transformers import PreTrainedTokenizer, AutoTokenizer -from typing import List, Union +from typing import List, Optional, Union import torch from logits_processor_zoo.utils import text_to_token, enforce_tokens @@ -27,14 +28,22 @@ class TriggerPhraseLogitsProcessor: Parameters ---------- + tokenizer (Union[PreTrainedTokenizer, str]): The tokenizer to use. phrase (str): The phrase to be generated by LLM when it encounters the trigger token. - trigger_token_phrase (str): One token phrase in string to trigger phrases. - tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. + trigger_token_phrase (str): (Optional) One token phrase in string to trigger phrases. + trigger_time (float): (Optional) Time (wall-clock time) in seconds after which the phrase will be triggered. trigger_count (int): How many times the phrase will be triggered. trigger_after (bool): Whether the phrase is written after the trigger token or instead of the trigger token. """ - def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: Union[PreTrainedTokenizer, str], + + def __init__(self, tokenizer: Union[PreTrainedTokenizer, str], phrase: str, + trigger_token_phrase: Optional[str] = None, trigger_time: Optional[float] = None, trigger_count: int = 1, trigger_after: bool = False): + + assert ( + trigger_token_phrase is not None or trigger_time is not None + ), "Either trigger_token_phrase or trigger_time must be provided" + self.tokenizer = tokenizer if isinstance(self.tokenizer, str): self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer) @@ -42,21 +51,26 @@ def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: Union[PreT self.phrase = phrase self.trigger_token_phrase = trigger_token_phrase self.trigger_count = trigger_count - self.trigger_token = text_to_token(self.tokenizer, trigger_token_phrase, last=False) + self.trigger_token = None + if trigger_token_phrase is not None: + self.trigger_token = text_to_token(self.tokenizer, trigger_token_phrase, last=False) + self.phrase_tokens = self.tokenizer.encode(phrase, add_special_tokens=False) self.initial_trigger_count = trigger_count self.trigger_after = trigger_after + self.trigger_time = trigger_time or float("inf") self._reset() # Mutable logits processor gets cloned for each prompt in a batch in order to prevent updating the same object # https://github.com/vllm-project/vllm/blob/19dcc02a72e3ed52e3bf95aae44ea1f40ce42ea0/vllm/sampling_params.py#L537-L550 def clone(self): - return TriggerPhraseLogitsProcessor(self.phrase, self.trigger_token_phrase, self.tokenizer, + return TriggerPhraseLogitsProcessor(self.tokenizer, self.phrase, self.trigger_token_phrase, self.trigger_time, self.initial_trigger_count, self.trigger_after) def _reset(self): self.index = -1 self.trigger_count = self.initial_trigger_count + self.start_time = time.time() def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor: if not past_token_ids: # new generation @@ -65,7 +79,8 @@ def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scor if self.trigger_count <= 0: return scores - if scores.argmax() == self.trigger_token and self.index == -1: + time_over = time.time() - self.start_time > self.trigger_time + if (scores.argmax() == self.trigger_token or time_over) and self.index == -1: self.index = 0 if not self.trigger_after: scores = enforce_tokens(scores, [self.phrase_tokens[self.index]]) @@ -77,5 +92,6 @@ def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scor if len(self.phrase_tokens) == self.index: # phrase completed, reset for next trigger self.index = -1 self.trigger_count -= 1 + self.start_time = time.time() return scores diff --git a/pyproject.toml b/pyproject.toml index 4eb012d..41b9002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "logits-processor-zoo" -version = "0.2.0" +version = "0.2.1" description = "A collection of LogitsProcessors to customize and enhance LLM behavior for specific tasks." authors = ["Ahmet Erdem", "Ivan Sorokin", "Maximilian Jeblick", "Darragh Hanley", "David Austin"] diff --git a/tests/transformers/test_max_time.py b/tests/transformers/test_max_time.py new file mode 100644 index 0000000..27d32d7 --- /dev/null +++ b/tests/transformers/test_max_time.py @@ -0,0 +1,39 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from logits_processor_zoo.transformers import MaxTimeLogitsProcessor + + +def test_max_time_logits_processor(llm_runner): + """Test that the phrase is triggered when the specified token is generated.""" + example_prompts = [ + "Hello, how are you?", + "What is the capital of France?", + "What is the capital of Germany?", + ] + + max_time = 2 + tolerance = 1 + start_time = time.time() + + logits_processors = [MaxTimeLogitsProcessor(llm_runner.tokenizer, max_time=max_time, complete_sentences=False)] + outs = llm_runner.generate_response(example_prompts, logits_processors, max_new_tokens=1000) + end_time = time.time() + elapsed_time = end_time - start_time + print(outs) + assert elapsed_time <= max_time + tolerance diff --git a/tests/transformers/test_prevent_halluciantion.py b/tests/transformers/test_prevent_hallucination.py similarity index 80% rename from tests/transformers/test_prevent_halluciantion.py rename to tests/transformers/test_prevent_hallucination.py index e68b683..792f287 100644 --- a/tests/transformers/test_prevent_halluciantion.py +++ b/tests/transformers/test_prevent_hallucination.py @@ -19,13 +19,9 @@ def test_gen_length_logits_processor(llm_runner): - example_prompts = [ - "Please describe what macaques are.", - "Tell me a story about a kid lost in forest." - ] + example_prompts = ["Please describe what macaques are.", "Tell me a story about a kid lost in forest."] - logits_processors = [PreventHallucinationLogitsProcessor(llm_runner.tokenizer, batch_size=2, - minp=0.99, tolerate=2)] + logits_processors = [PreventHallucinationLogitsProcessor(llm_runner.tokenizer, batch_size=2, minp=0.99, tolerate=2)] processed_gen_output = llm_runner.generate_response(example_prompts, logits_processors) assert all(["I don't know" in out for out in processed_gen_output]) diff --git a/tests/transformers/test_trigger_phrase.py b/tests/transformers/test_trigger_phrase.py new file mode 100644 index 0000000..dd5d709 --- /dev/null +++ b/tests/transformers/test_trigger_phrase.py @@ -0,0 +1,63 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from logits_processor_zoo.transformers import TriggerPhraseLogitsProcessor + + +def test_trigger_phrase_token_based_triggering(llm_runner): + """Test that the phrase is triggered when the specified token is generated.""" + example_prompts = ["Query: "] + + trigger_token = "fig" + phrase = "This is a triggered phrase." + + logits_processors = [ + TriggerPhraseLogitsProcessor( + llm_runner.tokenizer, + batch_size=len(example_prompts), + phrase=phrase, + trigger_token_phrase=trigger_token, + trigger_after=True, + ) + ] + + processed_gen_output = llm_runner.generate_response(example_prompts, logits_processors, max_new_tokens=1000) + assert phrase in processed_gen_output[0] + + +def test_trigger_phrase_token_phrase_based_triggering(llm_runner): + """Test that the phrase is triggered when the specified token is generated.""" + example_prompts = [ + "Generate a python function to calculate fibonacci numbers.", + "Simple python function to calculate fibonacci numbers.", + ] + + trigger_time = 2 + phrase = "This is a triggered phrase." + + logits_processors = [ + TriggerPhraseLogitsProcessor( + llm_runner.tokenizer, + batch_size=len(example_prompts), + phrase=phrase, + trigger_time=trigger_time, + trigger_after=True, + ) + ] + + processed_gen_output = llm_runner.generate_response(example_prompts, logits_processors, max_new_tokens=1000) + assert phrase in processed_gen_output[0]