Skip to content

Commit 5acca23

Browse files
committed
Make CiteFromPrompt robust against tokenizer vocab size discrepancies
Signed-off-by: aerdem4 <ahmeterd4@gmail.com>
1 parent 4acbb29 commit 5acca23

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

logits_processor_zoo/transformers/cite_prompt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float = 1.0, bo
4141
self.boost_eos = boost_eos
4242

4343
def _process(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
44+
voc_size = scores.shape[1]
4445
for i in range(scores.shape[0]):
4546
tokens = set(self.prompt_token_ids[i])
4647
if self.boost_eos:
4748
tokens.add(self.eos_token_id)
4849

49-
tokens = list(tokens)
50+
tokens = [t for t in tokens if t < voc_size]
5051
scores[i, tokens] += self.boost_factor
5152
return scores

logits_processor_zoo/vllm/cite_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scor
4646
if self.boost_eos:
4747
tokens.add(self.eos_token_id)
4848

49-
tokens = list(tokens)
49+
tokens = [t for t in tokens if t < scores.shape[0]]
5050
scores[tokens] += self.boost_factor
5151
return scores

0 commit comments

Comments
 (0)