Skip to content

Commit ec4e07f

Browse files
authored
Reset vllm TriggerPhraseLogitsProcessor for each new sample (#11)
1 parent 3d80cb3 commit ec4e07f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

logits_processor_zoo/vllm/trigger_phrase.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,18 @@ def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrained
3737
trigger_after: bool = False):
3838
self.trigger_token = text_to_token(tokenizer, trigger_token_phrase, last=False)
3939
self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False)
40+
self.initial_trigger_count = trigger_count
41+
self.trigger_after = trigger_after
42+
self.very_large_number = 999
43+
# Initialize state for a new sequence
4044
self.index = -1
4145
self.trigger_count = trigger_count
42-
self.very_large_number = 999
43-
self.trigger_after = trigger_after
4446

4547
def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
48+
if not past_token_ids:
49+
self.index = -1
50+
self.trigger_count = self.initial_trigger_count
51+
4652
if self.trigger_count <= 0:
4753
return scores
4854

0 commit comments

Comments
 (0)