@@ -79,19 +79,22 @@ def __init__(self, params: GptParams) -> None:
7979 self .lparams .use_mlock = self .params .use_mlock
8080 self .lparams .use_mmap = self .params .use_mmap
8181
82- self .model = llama_cpp .llama_load_model_from_file (
82+ self .model = llama_cpp .llama_model_load_from_file (
8383 self .params .model .encode ("utf8" ), self .lparams
8484 )
85+ self .vocab = llama_cpp .llama_model_get_vocab (self .model )
8586
8687 # Context Params.
8788 self .cparams = llama_cpp .llama_context_default_params ()
8889
89- self .ctx = llama_cpp .llama_new_context_with_model (self .model , self .cparams )
90+ self .ctx = llama_cpp .llama_init_from_model (self .model , self .cparams )
9091 if not self .ctx :
9192 raise RuntimeError (f"error: failed to load model '{ self .params .model } '" )
9293
9394 if self .params .ignore_eos :
94- self .params .logit_bias [llama_cpp .llama_token_eos ()] = - float ("inf" )
95+ self .params .logit_bias [llama_cpp .llama_vocab_eos (self .vocab )] = - float (
96+ "inf"
97+ )
9598
9699 if len (self .params .lora_adapter ) > 0 :
97100 if (
@@ -153,7 +156,7 @@ def __init__(self, params: GptParams) -> None:
153156 _session_tokens = (llama_cpp .llama_token * (self .params .n_ctx ))()
154157 _n_token_count_out = llama_cpp .c_size_t ()
155158 if (
156- llama_cpp .llama_load_session_file (
159+ llama_cpp .llama_state_load_file (
157160 self .ctx ,
158161 self .params .path_session .encode ("utf8" ),
159162 _session_tokens ,
@@ -314,7 +317,7 @@ def __init__(self, params: GptParams) -> None:
314317 def _tokenize (self , prompt , bos = True ):
315318 _arr = (llama_cpp .llama_token * ((len (prompt ) + 1 ) * 4 ))()
316319 _n = llama_cpp .llama_tokenize (
317- self .model ,
320+ self .vocab ,
318321 prompt .encode ("utf8" , errors = "ignore" ),
319322 len (prompt ),
320323 _arr ,
@@ -406,7 +409,7 @@ def generate(self):
406409 if len (self .embd_inp ) <= self .input_consumed : # && !is_interacting
407410 # out of user input, sample next token
408411 top_k = (
409- llama_cpp .llama_n_vocab (self .ctx )
412+ llama_cpp .llama_vocab_n_tokens (self .vocab )
410413 if self .params .top_k <= 0
411414 else self .params .top_k
412415 )
@@ -419,7 +422,7 @@ def generate(self):
419422 # optionally save the session on first sample (for faster prompt loading next time)
420423 if len (self .params .path_session ) > 0 and self .need_to_save_session :
421424 self .need_to_save_session = False
422- llama_cpp .llama_save_session_file (
425+ llama_cpp .llama_state_save_file (
423426 self .ctx ,
424427 self .params .path_session .encode ("utf8" ),
425428 (llama_cpp .llama_token * len (self .session_tokens ))(
@@ -431,7 +434,7 @@ def generate(self):
431434 id = 0
432435
433436 logits = llama_cpp .llama_get_logits (self .ctx )
434- n_vocab = llama_cpp .llama_n_vocab (self .model )
437+ n_vocab = llama_cpp .llama_vocab_n_tokens (self .vocab )
435438
436439 # Apply params.logit_bias map
437440 for key , value in self .params .logit_bias .items ():
@@ -448,7 +451,7 @@ def generate(self):
448451 )
449452
450453 # Apply penalties
451- nl_logit = logits [llama_cpp .llama_token_nl (self .ctx )]
454+ nl_logit = logits [llama_cpp .llama_vocab_nl (self .vocab )]
452455 last_n_repeat = min (len (self .last_n_tokens ), repeat_last_n , self .n_ctx )
453456
454457 _arr = (llama_cpp .llama_token * last_n_repeat )(
@@ -470,7 +473,7 @@ def generate(self):
470473 # last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
471474
472475 if not self .params .penalize_nl :
473- logits [llama_cpp .llama_token_nl ( )] = nl_logit
476+ logits [llama_cpp .llama_vocab_nl ( self . vocab )] = nl_logit
474477
475478 if self .params .temp <= 0 :
476479 # Greedy sampling
@@ -539,7 +542,7 @@ def generate(self):
539542
540543 # replace end of text token with newline token when in interactive mode
541544 if (
542- id == llama_cpp .llama_token_eos (self .ctx )
545+ id == llama_cpp .llama_vocab_eos (self .vocab )
543546 and self .params .interactive
544547 and not self .params .instruct
545548 ):
@@ -599,8 +602,8 @@ def generate(self):
599602 break
600603
601604 # end of text token
602- if len (self .embd ) > 0 and self .embd [- 1 ] == llama_cpp .llama_token_eos (
603- self .ctx
605+ if len (self .embd ) > 0 and self .embd [- 1 ] == llama_cpp .llama_vocab_eos (
606+ self .vocab
604607 ):
605608 if not self .params .instruct :
606609 for i in self .llama_token_eot :
@@ -636,7 +639,7 @@ def token_to_str(self, token_id: int) -> bytes:
636639 size = 32
637640 buffer = (ctypes .c_char * size )()
638641 n = llama_cpp .llama_token_to_piece (
639- self .model , llama_cpp .llama_token (token_id ), buffer , size
642+ self .vocab , llama_cpp .llama_token (token_id ), buffer , size , 0 , False
640643 )
641644 assert n <= size
642645 return bytes (buffer [:n ])
0 commit comments