Skip to content

Commit 08c87be

Browse files
committed
update templates
1 parent 16340fb commit 08c87be

File tree

3 files changed

+261
-8
lines changed

3 files changed

+261
-8
lines changed

src/triton_cli/templates/trt_llm/tensorrt_llm/1/model.py

Lines changed: 137 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, List
1010

1111
import numpy as np
12+
import pandas as pd
1213
import torch
1314
import triton_python_backend_utils as pb_utils
1415
from torch import from_numpy
@@ -239,7 +240,12 @@ def get_output_config_from_request(request, batch_size=1, batch_index=0):
239240
kwargs["return_generation_logits"] = get_input_scalar_by_name(
240241
request, 'return_generation_logits', batch_size, batch_index)
241242
kwargs["return_perf_metrics"] = get_input_scalar_by_name(
242-
request, 'return_kv_cache_reuse_stats', batch_size, batch_index)
243+
request, 'return_perf_metrics', batch_size, batch_index)
244+
if get_input_scalar_by_name(request, 'return_kv_cache_reuse_stats',
245+
batch_size, batch_index):
246+
pb_utils.Logger.log_warn(
247+
"return_kv_cache_reuse_stats is deprecated, please use return_perf_metrics instead."
248+
)
243249
kwargs = {k: v for k, v in kwargs.items() if v is not None}
244250
return trtllm.OutputConfig(**kwargs)
245251

@@ -427,6 +433,39 @@ def get_tensor_and_check_length(name: str, expected_length: int):
427433
return None
428434

429435

436+
def get_lookahead_decoding_config_from_request(request,
437+
executor_lookahead_config,
438+
batch_size=1,
439+
batch_index=0):
440+
lookahead_window_size = get_input_tensor_by_name(request,
441+
"lookahead_window_size",
442+
batch_size, batch_index)
443+
444+
lookahead_ngram_size = get_input_tensor_by_name(request,
445+
"lookahead_ngram_size",
446+
batch_size, batch_index)
447+
448+
lookahead_verification_set_size = get_input_tensor_by_name(
449+
request, "lookahead_verification_set_size", batch_size, batch_index)
450+
451+
# None lookahead config for requests.
452+
if all(x is None for x in [
453+
lookahead_window_size, lookahead_ngram_size,
454+
lookahead_verification_set_size
455+
]):
456+
return None
457+
458+
# Have request lookahead config but no executor config.
459+
if executor_lookahead_config is None:
460+
raise RuntimeError(
461+
"The request lookahead decoding input tensors (window_size, ngram_size and verification_set_size) can only be set if the model instance lookahead parameters are also specified"
462+
)
463+
464+
return trtllm.LookaheadDecodingConfig(lookahead_window_size,
465+
lookahead_ngram_size,
466+
lookahead_verification_set_size)
467+
468+
430469
def build_1_2_5_buckets(max_value: int) -> List[int]:
431470
"""
432471
Builds a list of buckets with increasing powers of 10 multiplied by
@@ -450,7 +489,10 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
450489
exponent += 1
451490

452491

453-
def convert_request(request, exclude_input_from_output, decoupled):
492+
def convert_request(request,
493+
exclude_input_from_output,
494+
decoupled,
495+
executor_lookahead_config=None):
454496
inputs = {}
455497
input_token_ids = get_input_tensor_by_name(request, 'input_ids')
456498
if input_token_ids is None:
@@ -526,6 +568,8 @@ def convert_request(request, exclude_input_from_output, decoupled):
526568
batch_index)
527569
kv_cache_retention_config = get_kv_cache_retention_config_from_request(
528570
request, batch_size, batch_index)
571+
request_lookahead_config = get_lookahead_decoding_config_from_request(
572+
request, executor_lookahead_config, batch_size, batch_index)
529573

530574
# Inputs for mllama support
531575
encoder_input_features = get_input_tensor_by_name(
@@ -579,6 +623,7 @@ def convert_request(request, exclude_input_from_output, decoupled):
579623
prompt_tuning_config=prompt_tuning_config,
580624
lora_config=lora_config,
581625
guided_decoding_params=guided_decoding_params,
626+
lookahead_config=request_lookahead_config,
582627
kv_cache_retention_config=kv_cache_retention_config))
583628
return requests
584629

@@ -674,6 +719,55 @@ def convert_response(response,
674719
np.array([kv_cache_metrics.num_total_allocated_blocks],
675720
np.int32), 0)))
676721

722+
timing_metrics = result.request_perf_metrics.timing_metrics
723+
output_tensors.append(
724+
pb_utils.Tensor(
725+
"arrival_time_ns",
726+
np.expand_dims(
727+
np.array([pd.Timedelta(timing_metrics.arrival_time).value],
728+
np.int64), 0)))
729+
output_tensors.append(
730+
pb_utils.Tensor(
731+
"first_scheduled_time_ns",
732+
np.expand_dims(
733+
np.array([
734+
pd.Timedelta(timing_metrics.first_scheduled_time).value
735+
], np.int64), 0)))
736+
output_tensors.append(
737+
pb_utils.Tensor(
738+
"first_token_time_ns",
739+
np.expand_dims(
740+
np.array(
741+
[pd.Timedelta(timing_metrics.first_token_time).value],
742+
np.int64), 0)))
743+
output_tensors.append(
744+
pb_utils.Tensor(
745+
"last_token_time_ns",
746+
np.expand_dims(
747+
np.array(
748+
[pd.Timedelta(timing_metrics.last_token_time).value],
749+
np.int64), 0)))
750+
751+
spec_dec_metrics = result.request_perf_metrics.speculative_decoding
752+
output_tensors.append(
753+
pb_utils.Tensor(
754+
"acceptance_rate",
755+
np.expand_dims(
756+
np.array([spec_dec_metrics.acceptance_rate], np.float32),
757+
0)))
758+
output_tensors.append(
759+
pb_utils.Tensor(
760+
"total_accepted_draft_tokens",
761+
np.expand_dims(
762+
np.array([spec_dec_metrics.total_accepted_draft_tokens],
763+
np.int32), 0)))
764+
output_tensors.append(
765+
pb_utils.Tensor(
766+
"total_draft_tokens",
767+
np.expand_dims(
768+
np.array([spec_dec_metrics.total_draft_tokens], np.int32),
769+
0)))
770+
677771
return pb_utils.InferenceResponse(
678772
output_tensors), result.is_final, output_lengths
679773

@@ -830,11 +924,48 @@ def get_peft_cache_config(self, model_config):
830924
float),
831925
"host_cache_size":
832926
get_parameter(model_config, "lora_cache_host_memory_bytes", int),
927+
"lora_prefetch_dir":
928+
get_parameter(model_config, "lora_prefetch_dir", int),
833929
}
834930
kwargs = {k: v for k, v in kwargs.items() if v is not None}
835931
return trtllm.PeftCacheConfig(**kwargs)
836932

933+
def get_executor_lookahead_config(self, model_config):
934+
lookahead_window_size = get_parameter(model_config,
935+
"lookahead_window_size", int)
936+
lookahead_ngram_size = get_parameter(model_config,
937+
"lookahead_ngram_size", int)
938+
lookahead_verification_set_size = get_parameter(
939+
model_config, "lookahead_verification_set_size", int)
940+
# executor_lookahead_config is not set
941+
if all(item is None for item in [
942+
lookahead_window_size, lookahead_ngram_size,
943+
lookahead_verification_set_size
944+
]):
945+
return None
946+
947+
incomplete_config = None in [
948+
lookahead_window_size, lookahead_ngram_size,
949+
lookahead_verification_set_size
950+
]
951+
952+
assert (
953+
not incomplete_config
954+
), "Please set executor_lookahead_window_size, executor_lookahead_ngram_size and executor_lookahead_verification_set_size together."
955+
956+
return trtllm.LookaheadDecodingConfig(lookahead_window_size,
957+
lookahead_ngram_size,
958+
lookahead_verification_set_size)
959+
837960
def get_decoding_config(self, model_config):
961+
962+
decoding_mode = convert_decoding_mode(
963+
get_parameter(model_config, "decoding_mode"))
964+
self.executor_lookahead_config = None
965+
if decoding_mode == trtllm.DecodingMode.Lookahead():
966+
# Add LAD config
967+
self.executor_lookahead_config = self.get_executor_lookahead_config(
968+
model_config)
838969
eagle_choices = parse_eagle_choices(
839970
get_parameter(model_config, "eagle_choices"))
840971
kwargs = {
@@ -844,9 +975,10 @@ def get_decoding_config(self, model_config):
844975
"eagle_config":
845976
None
846977
if eagle_choices is None else trtllm.EagleConfig(eagle_choices),
978+
"lookahead_decoding_config":
979+
self.executor_lookahead_config,
847980
"decoding_mode":
848-
convert_decoding_mode(get_parameter(model_config,
849-
"decoding_mode")),
981+
decoding_mode,
850982
}
851983
print(kwargs)
852984
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@@ -1232,7 +1364,7 @@ def execute(self, requests):
12321364
try:
12331365
converted_reqs = convert_request(
12341366
request, self.exclude_input_from_output,
1235-
self.decoupled)
1367+
self.decoupled, self.executor_lookahead_config)
12361368
except Exception as e:
12371369
response_sender.send(
12381370
pb_utils.InferenceResponse(error=pb_utils.TritonError(

src/triton_cli/templates/trt_llm/tensorrt_llm/config.pbtxt

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ input [
276276
optional: true
277277
},
278278
{
279-
name: "return_kv_cache_reuse_stats"
279+
name: "return_perf_metrics"
280280
data_type: TYPE_BOOL
281281
dims: [ 1 ]
282282
reshape: { shape: [ ] }
@@ -444,6 +444,27 @@ input [
444444
dims: [ 1 ]
445445
optional: true
446446
allow_ragged_batch: true
447+
},
448+
{
449+
name: "lookahead_window_size"
450+
data_type: TYPE_INT32
451+
dims: [ 1 ]
452+
optional: true
453+
allow_ragged_batch: true
454+
},
455+
{
456+
name: "lookahead_ngram_size"
457+
data_type: TYPE_INT32
458+
dims: [ 1 ]
459+
optional: true
460+
allow_ragged_batch: true
461+
},
462+
{
463+
name: "lookahead_verification_set_size"
464+
data_type: TYPE_INT32
465+
dims: [ 1 ]
466+
optional: true
467+
allow_ragged_batch: true
447468
}
448469
]
449470
output [
@@ -506,6 +527,41 @@ output [
506527
name: "kv_cache_alloc_total_blocks"
507528
data_type: TYPE_INT32
508529
dims: [ 1 ]
530+
},
531+
{
532+
name: "arrival_time_ns"
533+
data_type: TYPE_INT64
534+
dims: [ 1 ]
535+
},
536+
{
537+
name: "first_scheduled_time_ns"
538+
data_type: TYPE_INT64
539+
dims: [ 1 ]
540+
},
541+
{
542+
name: "first_token_time_ns"
543+
data_type: TYPE_INT64
544+
dims: [ 1 ]
545+
},
546+
{
547+
name: "last_token_time_ns"
548+
data_type: TYPE_INT64
549+
dims: [ 1 ]
550+
},
551+
{
552+
name: "acceptance_rate"
553+
data_type: TYPE_FP32
554+
dims: [ 1 ]
555+
},
556+
{
557+
name: "total_accepted_draft_tokens"
558+
data_type: TYPE_INT32
559+
dims: [ 1 ]
560+
},
561+
{
562+
name: "total_draft_tokens"
563+
data_type: TYPE_INT32
564+
dims: [ 1 ]
509565
}
510566
]
511567
instance_group [
@@ -684,6 +740,12 @@ parameters: {
684740
string_value: "${lora_cache_host_memory_bytes}"
685741
}
686742
}
743+
parameters: {
744+
key: "lora_prefetch_dir"
745+
value: {
746+
string_value: "${lora_prefetch_dir}"
747+
}
748+
}
687749
parameters: {
688750
key: "decoding_mode"
689751
value: {
@@ -696,6 +758,24 @@ parameters: {
696758
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
697759
}
698760
}
761+
parameters: {
762+
key: "lookahead_window_size"
763+
value: {
764+
string_value: "${lookahead_window_size}"
765+
}
766+
}
767+
parameters: {
768+
key: "lookahead_ngram_size"
769+
value: {
770+
string_value: "${lookahead_ngram_size}"
771+
}
772+
}
773+
parameters: {
774+
key: "lookahead_verification_set_size"
775+
value: {
776+
string_value: "${lookahead_verification_set_size}"
777+
}
778+
}
699779
parameters: {
700780
key: "medusa_choices"
701781
value: {
@@ -756,3 +836,9 @@ parameters: {
756836
string_value: "${guided_decoding_backend}"
757837
}
758838
}
839+
parameters: {
840+
key: "xgrammar_tokenizer_info_path"
841+
value: {
842+
string_value: "${xgrammar_tokenizer_info_path}"
843+
}
844+
}

0 commit comments

Comments
 (0)