99from typing import Any , List
1010
1111import numpy as np
12+ import pandas as pd
1213import torch
1314import triton_python_backend_utils as pb_utils
1415from torch import from_numpy
@@ -239,7 +240,12 @@ def get_output_config_from_request(request, batch_size=1, batch_index=0):
239240 kwargs ["return_generation_logits" ] = get_input_scalar_by_name (
240241 request , 'return_generation_logits' , batch_size , batch_index )
241242 kwargs ["return_perf_metrics" ] = get_input_scalar_by_name (
242- request , 'return_kv_cache_reuse_stats' , batch_size , batch_index )
243+ request , 'return_perf_metrics' , batch_size , batch_index )
244+ if get_input_scalar_by_name (request , 'return_kv_cache_reuse_stats' ,
245+ batch_size , batch_index ):
246+ pb_utils .Logger .log_warn (
247+ "return_kv_cache_reuse_stats is deprecated, please use return_perf_metrics instead."
248+ )
243249 kwargs = {k : v for k , v in kwargs .items () if v is not None }
244250 return trtllm .OutputConfig (** kwargs )
245251
@@ -427,6 +433,39 @@ def get_tensor_and_check_length(name: str, expected_length: int):
427433 return None
428434
429435
436+ def get_lookahead_decoding_config_from_request (request ,
437+ executor_lookahead_config ,
438+ batch_size = 1 ,
439+ batch_index = 0 ):
440+ lookahead_window_size = get_input_tensor_by_name (request ,
441+ "lookahead_window_size" ,
442+ batch_size , batch_index )
443+
444+ lookahead_ngram_size = get_input_tensor_by_name (request ,
445+ "lookahead_ngram_size" ,
446+ batch_size , batch_index )
447+
448+ lookahead_verification_set_size = get_input_tensor_by_name (
449+ request , "lookahead_verification_set_size" , batch_size , batch_index )
450+
451+ # None lookahead config for requests.
452+ if all (x is None for x in [
453+ lookahead_window_size , lookahead_ngram_size ,
454+ lookahead_verification_set_size
455+ ]):
456+ return None
457+
458+ # Have request lookahead config but no executor config.
459+ if executor_lookahead_config is None :
460+ raise RuntimeError (
461+ "The request lookahead decoding input tensors (window_size, ngram_size and verification_set_size) can only be set if the model instance lookahead parameters are also specified"
462+ )
463+
464+ return trtllm .LookaheadDecodingConfig (lookahead_window_size ,
465+ lookahead_ngram_size ,
466+ lookahead_verification_set_size )
467+
468+
430469def build_1_2_5_buckets (max_value : int ) -> List [int ]:
431470 """
432471 Builds a list of buckets with increasing powers of 10 multiplied by
@@ -450,7 +489,10 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
450489 exponent += 1
451490
452491
453- def convert_request (request , exclude_input_from_output , decoupled ):
492+ def convert_request (request ,
493+ exclude_input_from_output ,
494+ decoupled ,
495+ executor_lookahead_config = None ):
454496 inputs = {}
455497 input_token_ids = get_input_tensor_by_name (request , 'input_ids' )
456498 if input_token_ids is None :
@@ -526,6 +568,8 @@ def convert_request(request, exclude_input_from_output, decoupled):
526568 batch_index )
527569 kv_cache_retention_config = get_kv_cache_retention_config_from_request (
528570 request , batch_size , batch_index )
571+ request_lookahead_config = get_lookahead_decoding_config_from_request (
572+ request , executor_lookahead_config , batch_size , batch_index )
529573
530574 # Inputs for mllama support
531575 encoder_input_features = get_input_tensor_by_name (
@@ -579,6 +623,7 @@ def convert_request(request, exclude_input_from_output, decoupled):
579623 prompt_tuning_config = prompt_tuning_config ,
580624 lora_config = lora_config ,
581625 guided_decoding_params = guided_decoding_params ,
626+ lookahead_config = request_lookahead_config ,
582627 kv_cache_retention_config = kv_cache_retention_config ))
583628 return requests
584629
@@ -674,6 +719,55 @@ def convert_response(response,
674719 np .array ([kv_cache_metrics .num_total_allocated_blocks ],
675720 np .int32 ), 0 )))
676721
722+ timing_metrics = result .request_perf_metrics .timing_metrics
723+ output_tensors .append (
724+ pb_utils .Tensor (
725+ "arrival_time_ns" ,
726+ np .expand_dims (
727+ np .array ([pd .Timedelta (timing_metrics .arrival_time ).value ],
728+ np .int64 ), 0 )))
729+ output_tensors .append (
730+ pb_utils .Tensor (
731+ "first_scheduled_time_ns" ,
732+ np .expand_dims (
733+ np .array ([
734+ pd .Timedelta (timing_metrics .first_scheduled_time ).value
735+ ], np .int64 ), 0 )))
736+ output_tensors .append (
737+ pb_utils .Tensor (
738+ "first_token_time_ns" ,
739+ np .expand_dims (
740+ np .array (
741+ [pd .Timedelta (timing_metrics .first_token_time ).value ],
742+ np .int64 ), 0 )))
743+ output_tensors .append (
744+ pb_utils .Tensor (
745+ "last_token_time_ns" ,
746+ np .expand_dims (
747+ np .array (
748+ [pd .Timedelta (timing_metrics .last_token_time ).value ],
749+ np .int64 ), 0 )))
750+
751+ spec_dec_metrics = result .request_perf_metrics .speculative_decoding
752+ output_tensors .append (
753+ pb_utils .Tensor (
754+ "acceptance_rate" ,
755+ np .expand_dims (
756+ np .array ([spec_dec_metrics .acceptance_rate ], np .float32 ),
757+ 0 )))
758+ output_tensors .append (
759+ pb_utils .Tensor (
760+ "total_accepted_draft_tokens" ,
761+ np .expand_dims (
762+ np .array ([spec_dec_metrics .total_accepted_draft_tokens ],
763+ np .int32 ), 0 )))
764+ output_tensors .append (
765+ pb_utils .Tensor (
766+ "total_draft_tokens" ,
767+ np .expand_dims (
768+ np .array ([spec_dec_metrics .total_draft_tokens ], np .int32 ),
769+ 0 )))
770+
677771 return pb_utils .InferenceResponse (
678772 output_tensors ), result .is_final , output_lengths
679773
@@ -830,11 +924,48 @@ def get_peft_cache_config(self, model_config):
830924 float ),
831925 "host_cache_size" :
832926 get_parameter (model_config , "lora_cache_host_memory_bytes" , int ),
927+ "lora_prefetch_dir" :
928+ get_parameter (model_config , "lora_prefetch_dir" , int ),
833929 }
834930 kwargs = {k : v for k , v in kwargs .items () if v is not None }
835931 return trtllm .PeftCacheConfig (** kwargs )
836932
933+ def get_executor_lookahead_config (self , model_config ):
934+ lookahead_window_size = get_parameter (model_config ,
935+ "lookahead_window_size" , int )
936+ lookahead_ngram_size = get_parameter (model_config ,
937+ "lookahead_ngram_size" , int )
938+ lookahead_verification_set_size = get_parameter (
939+ model_config , "lookahead_verification_set_size" , int )
940+ # executor_lookahead_config is not set
941+ if all (item is None for item in [
942+ lookahead_window_size , lookahead_ngram_size ,
943+ lookahead_verification_set_size
944+ ]):
945+ return None
946+
947+ incomplete_config = None in [
948+ lookahead_window_size , lookahead_ngram_size ,
949+ lookahead_verification_set_size
950+ ]
951+
952+ assert (
953+ not incomplete_config
954+ ), "Please set executor_lookahead_window_size, executor_lookahead_ngram_size and executor_lookahead_verification_set_size together."
955+
956+ return trtllm .LookaheadDecodingConfig (lookahead_window_size ,
957+ lookahead_ngram_size ,
958+ lookahead_verification_set_size )
959+
837960 def get_decoding_config (self , model_config ):
961+
962+ decoding_mode = convert_decoding_mode (
963+ get_parameter (model_config , "decoding_mode" ))
964+ self .executor_lookahead_config = None
965+ if decoding_mode == trtllm .DecodingMode .Lookahead ():
966+ # Add LAD config
967+ self .executor_lookahead_config = self .get_executor_lookahead_config (
968+ model_config )
838969 eagle_choices = parse_eagle_choices (
839970 get_parameter (model_config , "eagle_choices" ))
840971 kwargs = {
@@ -844,9 +975,10 @@ def get_decoding_config(self, model_config):
844975 "eagle_config" :
845976 None
846977 if eagle_choices is None else trtllm .EagleConfig (eagle_choices ),
978+ "lookahead_decoding_config" :
979+ self .executor_lookahead_config ,
847980 "decoding_mode" :
848- convert_decoding_mode (get_parameter (model_config ,
849- "decoding_mode" )),
981+ decoding_mode ,
850982 }
851983 print (kwargs )
852984 kwargs = {k : v for k , v in kwargs .items () if v is not None }
@@ -1232,7 +1364,7 @@ def execute(self, requests):
12321364 try :
12331365 converted_reqs = convert_request (
12341366 request , self .exclude_input_from_output ,
1235- self .decoupled )
1367+ self .decoupled , self . executor_lookahead_config )
12361368 except Exception as e :
12371369 response_sender .send (
12381370 pb_utils .InferenceResponse (error = pb_utils .TritonError (
0 commit comments