@@ -571,7 +571,7 @@ def eval_logits(self) -> Deque[List[float]]:
571571 )
572572
573573 def tokenize (
574- self , text : bytes , add_bos : bool = True , special : bool = False
574+ self , vocab : llama_cpp . llama_vocab_p , text : bytes , add_bos : bool = True , special : bool = False
575575 ) -> List [int ]:
576576 """Tokenize a string.
577577
@@ -586,10 +586,11 @@ def tokenize(
586586 Returns:
587587 A list of tokens.
588588 """
589- return self .tokenizer_ .tokenize (text , add_bos , special )
589+ return self .tokenizer_ .tokenize (vocab , text , add_bos , special )
590590
591591 def detokenize (
592592 self ,
593+ vocab :llama_cpp .llama_vocab_p ,
593594 tokens : List [int ],
594595 prev_tokens : Optional [List [int ]] = None ,
595596 special : bool = False ,
@@ -605,7 +606,7 @@ def detokenize(
605606 The detokenized string.
606607 """
607608 return self .tokenizer_ .detokenize (
608- tokens , prev_tokens = prev_tokens , special = special
609+ vocab , tokens , prev_tokens = prev_tokens , special = special
609610 )
610611
611612 def set_cache (self , cache : Optional [BaseLlamaCache ]):
@@ -1073,7 +1074,7 @@ def decode_batch(seq_sizes: List[int]):
10731074
10741075 # accumulate batches and encode
10751076 for text in inputs :
1076- tokens = self .tokenize (text .encode ("utf-8" ))
1077+ tokens = self .tokenize (self . _vocab , text .encode ("utf-8" ))
10771078 if truncate :
10781079 tokens = tokens [:n_batch ]
10791080
@@ -1152,11 +1153,11 @@ def _create_completion(
11521153 completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
11531154 created : int = int (time .time ())
11541155 bos_token_id : int = self .token_bos ()
1155- cls_token_id : int = self ._model .token_cls ()
1156- sep_token_id : int = self ._model .token_sep ()
1157- prefix_token_id : int = self ._model .token_prefix ()
1158- middle_token_id : int = self ._model .token_middle ()
1159- suffix_token_id : int = self ._model .token_suffix ()
1156+ cls_token_id : int = self ._model .token_cls (self . _vocab )
1157+ sep_token_id : int = self ._model .token_sep (self . _vocab )
1158+ prefix_token_id : int = self ._model .token_prefix (self . _vocab )
1159+ middle_token_id : int = self ._model .token_middle (self . _vocab )
1160+ suffix_token_id : int = self ._model .token_suffix (self . _vocab )
11601161 add_space_prefix : bool = (
11611162 self .metadata .get ("tokenizer.ggml.add_space_prefix" , "true" ) == "true"
11621163 )
@@ -1167,13 +1168,13 @@ def _create_completion(
11671168
11681169 if (
11691170 (isinstance (prompt , list ) and suffix is None )
1170- or not self ._model .add_bos_token ()
1171+ or not self ._model .add_bos_token (self . _vocab )
11711172 or bos_tokens [:1 ] == [- 1 ]
11721173 ):
11731174 bos_tokens = []
11741175
11751176 if (isinstance (prompt , list ) and suffix is None ) or (
1176- not self ._model .add_eos_token () and sep_token_id == - 1
1177+ not self ._model .add_eos_token (self . _vocab ) and sep_token_id == - 1
11771178 ):
11781179 eos_tokens = []
11791180
@@ -1192,6 +1193,7 @@ def _create_completion(
11921193 ) + (
11931194 (
11941195 self .tokenize (
1196+ self ._vocab ,
11951197 prompt .encode ("utf-8" ),
11961198 add_bos = False ,
11971199 special = (prefix_token_id < 0 or suffix is None ),
@@ -1206,7 +1208,7 @@ def _create_completion(
12061208 (
12071209 [suffix_token_id ]
12081210 + (
1209- self .tokenize (suffix .encode ("utf-8" ), add_bos = False , special = False )[
1211+ self .tokenize (self . _vocab , suffix .encode ("utf-8" ), add_bos = False , special = False )[
12101212 suffix_space_prefix :
12111213 ]
12121214 if suffix
@@ -1334,14 +1336,14 @@ def logit_bias_processor(
13341336 logits_processor = logits_processor ,
13351337 grammar = grammar ,
13361338 ):
1337- if llama_cpp .llama_vocab_is_eog (self ._model . model , token ):
1338- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1339+ if llama_cpp .llama_vocab_is_eog (self ._vocab , token ):
1340+ text = self .detokenize (self . _vocab , completion_tokens , prev_tokens = prompt_tokens )
13391341 finish_reason = "stop"
13401342 break
13411343
13421344 completion_tokens .append (token )
13431345
1344- all_text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1346+ all_text = self .detokenize (self . _vocab , completion_tokens , prev_tokens = prompt_tokens )
13451347
13461348 # Contains multi-byte UTF8
13471349 for k , char in enumerate (all_text [- 3 :]):
@@ -1366,6 +1368,7 @@ def logit_bias_processor(
13661368 if stream :
13671369 remaining_tokens = completion_tokens [returned_tokens :]
13681370 remaining_text = self .detokenize (
1371+ self ._vocab ,
13691372 remaining_tokens ,
13701373 prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
13711374 )
@@ -1392,6 +1395,7 @@ def logit_bias_processor(
13921395 continue
13931396 token_end_position += len (
13941397 self .detokenize (
1398+ self ._vocab ,
13951399 [token ],
13961400 prev_tokens = prompt_tokens
13971401 + completion_tokens [:returned_tokens ],
@@ -1403,12 +1407,14 @@ def logit_bias_processor(
14031407 ):
14041408 break
14051409 token_str = self .detokenize (
1410+ self ._vocab ,
14061411 [token ],
14071412 prev_tokens = prompt_tokens
14081413 + completion_tokens [:returned_tokens ],
14091414 ).decode ("utf-8" , errors = "ignore" )
14101415 text_offset = len (prompt ) + len (
14111416 self .detokenize (
1417+ self ._vocab ,
14121418 completion_tokens [:returned_tokens ],
14131419 prev_tokens = prompt_tokens
14141420 + completion_tokens [:returned_tokens ],
@@ -1433,6 +1439,7 @@ def logit_bias_processor(
14331439 logprobs_or_none = {
14341440 "tokens" : [
14351441 self .detokenize (
1442+ self ._vocab ,
14361443 [token ],
14371444 prev_tokens = prompt_tokens
14381445 + completion_tokens [:returned_tokens ],
@@ -1451,6 +1458,7 @@ def logit_bias_processor(
14511458 "choices" : [
14521459 {
14531460 "text" : self .detokenize (
1461+ self ._vocab ,
14541462 [token ],
14551463 prev_tokens = prompt_tokens
14561464 + completion_tokens [:returned_tokens ],
@@ -1467,6 +1475,7 @@ def logit_bias_processor(
14671475 for i in range (1 , len (remaining_tokens ) + 1 ):
14681476 try :
14691477 bs = self .detokenize (
1478+ self ._vocab ,
14701479 remaining_tokens [:i ],
14711480 prev_tokens = prompt_tokens
14721481 + completion_tokens [:returned_tokens ],
@@ -1505,14 +1514,14 @@ def logit_bias_processor(
15051514 }
15061515
15071516 if len (completion_tokens ) >= max_tokens :
1508- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1517+ text = self .detokenize (self . _vocab , completion_tokens , prev_tokens = prompt_tokens )
15091518 finish_reason = "length"
15101519 break
15111520
15121521 if stopping_criteria is not None and stopping_criteria (
15131522 self ._input_ids , self ._scores [- 1 , :]
15141523 ):
1515- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1524+ text = self .detokenize (self . _vocab , completion_tokens , prev_tokens = prompt_tokens )
15161525 finish_reason = "stop"
15171526
15181527 if self .verbose :
@@ -1521,6 +1530,7 @@ def logit_bias_processor(
15211530 if stream :
15221531 remaining_tokens = completion_tokens [returned_tokens :]
15231532 remaining_text = self .detokenize (
1533+ self ._vocab ,
15241534 remaining_tokens ,
15251535 prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
15261536 )
@@ -1534,6 +1544,7 @@ def logit_bias_processor(
15341544 for token in remaining_tokens :
15351545 token_end_position += len (
15361546 self .detokenize (
1547+ self ._vocab ,
15371548 [token ],
15381549 prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
15391550 )
@@ -1543,7 +1554,7 @@ def logit_bias_processor(
15431554 if logprobs is not None :
15441555 if token == bos_token_id :
15451556 continue
1546- token_str = self .detokenize ([token ]).decode (
1557+ token_str = self .detokenize (self . _vocab , [token ]).decode (
15471558 "utf-8" , errors = "ignore"
15481559 )
15491560 text_offset = len (prompt ) + len (
@@ -1569,15 +1580,15 @@ def logit_bias_processor(
15691580 top_logprob .update ({token_str : current_logprobs [int (token )]})
15701581 logprobs_or_none = {
15711582 "tokens" : [
1572- self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1583+ self .detokenize (self . _vocab , [token ]).decode ("utf-8" , errors = "ignore" )
15731584 ],
15741585 "text_offset" : [text_offset ],
15751586 "token_logprobs" : [current_logprobs [int (token )]],
15761587 "top_logprobs" : [top_logprob ],
15771588 }
15781589
15791590 if token_end_position >= end :
1580- last_text = self .detokenize ([token ])
1591+ last_text = self .detokenize (self . _vocab , [token ])
15811592 if token_end_position == end - 1 :
15821593 break
15831594 returned_tokens += 1
0 commit comments