diff --git a/logits_processor_zoo/utils.py b/logits_processor_zoo/utils.py index 3775113..fd284fe 100644 --- a/logits_processor_zoo/utils.py +++ b/logits_processor_zoo/utils.py @@ -23,8 +23,14 @@ def text_to_token(tokenizer: PreTrainedTokenizer, text: str, last: bool): tokens = tokenizer.encode(text, add_special_tokens=False) - if not last and len(tokens) > 2: - # Usually the first token indicates the beginning, and the second token is our main token + # We allow 2 tokens to account for the BOS or prefix token + max_token_count = 1 + bos_token_added = getattr(tokenizer, 'bos_token', None) and getattr(tokenizer, 'bos_token_id', None) in tokens + prefix_token_added = getattr(tokenizer, 'add_prefix_space', None) is not False + if bos_token_added or prefix_token_added: + max_token_count = 2 + + if not last and len(tokens) > max_token_count: raise Exception(f"Can't convert {text} to token. It has {len(tokens)} tokens.") return tokens[-1] diff --git a/pyproject.toml b/pyproject.toml index 1de2def..30af0a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "logits-processor-zoo" -version = "0.1.11" +version = "0.1.12" description = "A collection of LogitsProcessors to customize and enhance LLM behavior for specific tasks." authors = ["Ahmet Erdem", "Ivan Sorokin", "Maximilian Jeblick", "Darragh Hanley", "David Austin"] readme = "README.md"