@@ -43,7 +43,6 @@ def __init__(
4343 self ._exit_stack = ExitStack ()
4444
4545 model = None
46- vocab = None
4746
4847 if not os .path .exists (path_model ):
4948 raise ValueError (f"Model path does not exist: { path_model } " )
@@ -58,24 +57,12 @@ def __init__(
5857
5958 self .model = model
6059
61- vocab = llama_cpp .llama_model_get_vocab (self .model )
62-
63- if vocab is None :
64- raise ValueError (f"Failed to load vocab from file: { path_model } " )
65-
66- self .vocab = vocab
67-
6860 def free_model ():
6961 if self .model is None :
7062 return
7163 llama_cpp .llama_model_free (self .model )
7264 self .model = None
7365
74- if self .vocab is None :
75- return
76- llama_cpp .llama_model_free (self .vocab )
77- self .vocab = None
78-
7966 self ._exit_stack .callback (free_model )
8067
8168 def close (self ):
@@ -84,11 +71,11 @@ def close(self):
8471 def __del__ (self ):
8572 self .close ()
8673
87- def vocab_type (self ) -> int :
88- return llama_cpp .llama_vocab_type (self . vocab )
74+ def vocab_type (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
75+ return llama_cpp .llama_vocab_type (_vocab )
8976
90- def n_vocab (self ) -> int :
91- return llama_cpp .llama_vocab_n_tokens (self . vocab )
77+ def n_vocab (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
78+ return llama_cpp .llama_vocab_n_tokens (_vocab )
9279
9380 def n_ctx_train (self ) -> int :
9481 return llama_cpp .llama_model_n_ctx_train (self .model )
@@ -112,66 +99,66 @@ def n_params(self) -> int:
11299
113100 # Vocab
114101
115- def token_get_text (self , token : int ) -> str :
116- return llama_cpp .llama_vocab_get_text (self . vocab , token ).decode ("utf-8" )
102+ def token_get_text (self , _vocab : llama_cpp . llama_vocab_p , token : int ) -> str :
103+ return llama_cpp .llama_vocab_get_text (_vocab , token ).decode ("utf-8" )
117104
118- def token_get_score (self , token : int ) -> float :
119- return llama_cpp .llama_vocab_get_score (self . vocab , token )
105+ def token_get_score (self , _vocab : llama_cpp . llama_vocab_p , token : int ) -> float :
106+ return llama_cpp .llama_vocab_get_score (_vocab , token )
120107
121- def token_get_attr (self , token : int ) -> int :
122- return llama_cpp .llama_vocab_get_attr (self . vocab , token )
108+ def token_get_attr (self , _vocab : llama_cpp . llama_vocab_p , token : int ) -> int :
109+ return llama_cpp .llama_vocab_get_attr (_vocab , token )
123110
124111 # Special tokens
125112
126- def token_bos (self ) -> int :
127- return llama_cpp .llama_vocab_bos (self . vocab )
113+ def token_bos (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
114+ return llama_cpp .llama_vocab_bos (_vocab )
128115
129- def token_eos (self ) -> int :
130- return llama_cpp .llama_vocab_eos (self . vocab )
116+ def token_eos (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
117+ return llama_cpp .llama_vocab_eos (_vocab )
131118
132- def token_eot (self ) -> int :
133- return llama_cpp .llama_vocab_eot (self . vocab )
119+ def token_eot (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
120+ return llama_cpp .llama_vocab_eot (_vocab )
134121
135- def token_cls (self ) -> int :
136- return llama_cpp .llama_vocab_cls (self . vocab )
122+ def token_cls (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
123+ return llama_cpp .llama_vocab_cls (_vocab )
137124
138- def token_sep (self ) -> int :
139- return llama_cpp .llama_vocab_sep (self . vocab )
125+ def token_sep (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
126+ return llama_cpp .llama_vocab_sep (_vocab )
140127
141- def token_nl (self ) -> int :
142- return llama_cpp .llama_vocab_nl (self . vocab )
128+ def token_nl (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
129+ return llama_cpp .llama_vocab_nl (_vocab )
143130
144- def token_pad (self ) -> int :
145- return llama_cpp .llama_vocab_pad (self . vocab )
131+ def token_pad (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
132+ return llama_cpp .llama_vocab_pad (_vocab )
146133
147- def token_prefix (self ) -> int :
148- return llama_cpp .llama_vocab_fim_pre (self . vocab )
134+ def token_prefix (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
135+ return llama_cpp .llama_vocab_fim_pre (_vocab )
149136
150- def token_middle (self ) -> int :
151- return llama_cpp .llama_vocab_fim_mid (self . vocab )
137+ def token_middle (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
138+ return llama_cpp .llama_vocab_fim_mid (_vocab )
152139
153- def token_suffix (self ) -> int :
154- return llama_cpp .llama_vocab_fim_suf (self . vocab )
140+ def token_suffix (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
141+ return llama_cpp .llama_vocab_fim_suf (_vocab )
155142
156- def add_bos_token (self ) -> bool :
157- return llama_cpp .llama_vocab_get_add_bos (self . vocab )
143+ def add_bos_token (self , _vocab : llama_cpp . llama_vocab_p ) -> bool :
144+ return llama_cpp .llama_vocab_get_add_bos (_vocab )
158145
159- def add_eos_token (self ) -> bool :
160- return llama_cpp .llama_vocab_get_add_eos (self . vocab )
146+ def add_eos_token (self , _vocab : llama_cpp . llama_vocab_p ) -> bool :
147+ return llama_cpp .llama_vocab_get_add_eos (_vocab )
161148
162149 # Tokenization
163150
164- def tokenize (self , text : bytes , add_bos : bool , special : bool ):
151+ def tokenize (self , _vocab : llama_cpp . llama_vocab_p , text : bytes , add_bos : bool , special : bool ):
165152 n_ctx = self .n_ctx_train ()
166153 tokens = (llama_cpp .llama_token * n_ctx )()
167154 n_tokens = llama_cpp .llama_tokenize (
168- self . vocab , text , len (text ), tokens , n_ctx , add_bos , special
155+ _vocab , text , len (text ), tokens , n_ctx , add_bos , special
169156 )
170157 if n_tokens < 0 :
171158 n_tokens = abs (n_tokens )
172159 tokens = (llama_cpp .llama_token * n_tokens )()
173160 n_tokens = llama_cpp .llama_tokenize (
174- self . vocab , text , len (text ), tokens , n_tokens , add_bos , special
161+ _vocab , text , len (text ), tokens , n_tokens , add_bos , special
175162 )
176163 if n_tokens < 0 :
177164 raise RuntimeError (
@@ -618,10 +605,11 @@ def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
618605 def sample (
619606 self ,
620607 ctx_main : LlamaContext ,
608+ _vocab :llama_cpp .llama_vocab_p ,
621609 idx : int = 0 ,
622610 logits_array : Optional [npt .NDArray [np .single ]] = None ,
623611 ):
624- n_vocab = ctx_main .model .n_vocab ()
612+ n_vocab = ctx_main .model .n_vocab (_vocab )
625613 id : int = 0
626614
627615 if logits_array is None :
0 commit comments