Skip to content

Commit

Permalink
Add TRT-LLM params like max_num_tokens and opt_num_tokens (NVIDIA#9210)
Browse files Browse the repository at this point in the history
* Add params like max_num_tokens and opt_num_tokens

Signed-off-by: Onur Yilmaz <[email protected]>

* remove padding param added

* update params like max_num_token

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* remove context context_fmha param for now

Signed-off-by: Onur Yilmaz <[email protected]>

* add params like max num token to the script

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

---------

Signed-off-by: Onur Yilmaz <[email protected]>
Signed-off-by: oyilmaz-nvidia <[email protected]>
Co-authored-by: oyilmaz-nvidia <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
  • Loading branch information
3 people authored May 21, 2024
1 parent a69ace4 commit 2e1814c
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 23 deletions.
21 changes: 16 additions & 5 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,16 @@ def export(
max_batch_size: int = 8,
max_prompt_embedding_table_size=None,
use_parallel_embedding: bool = False,
use_inflight_batching: bool = False,
enable_context_fmha: bool = True,
paged_kv_cache: bool = False,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
dtype: str = "bfloat16",
load_model: bool = True,
enable_multi_block_mode: bool = False,
use_lora_plugin: str = None,
lora_target_modules: List[str] = None,
max_lora_rank: int = 64,
max_num_tokens: int = None,
opt_num_tokens: int = None,
save_nemo_model_config: bool = False,
):
"""
Expand All @@ -142,12 +143,18 @@ def export(
max_output_token (int): max output length.
max_batch_size (int): max batch size.
max_prompt_embedding_table_size (int): max prompt embedding size.
use_inflight_batching (bool): if True, enables inflight batching for TensorRT-LLM Triton backend.
enable_context_fmha (bool): if True, use fused Context MultiHeadedAttention.
use_parallel_embedding (bool): whether to use parallel embedding feature of TRT-LLM or not
paged_kv_cache (bool): if True, uses kv cache feature of the TensorRT-LLM.
remove_input_padding (bool): enables removing input padding or not.
dtype (str): Floating point type for model weights (Supports BFloat16/Float16).
load_model (bool): load TensorRT-LLM model after the export.
enable_multi_block_mode (bool): enable faster decoding in multihead attention. Required for long context.
use_lora_plugin (str): use dynamic lora or not.
lora_target_modules (List[str]): list of the target lora modules.
max_lora_rank (int): maximum lora rank.
max_num_tokens (int):
opt_num_tokens (int):
save_nemo_model_config (bool):
"""

if model_type not in self.get_supported_models_list:
Expand Down Expand Up @@ -238,6 +245,10 @@ def export(
lora_target_modules=lora_target_modules,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
enable_multi_block_mode=enable_multi_block_mode,
paged_kv_cache=paged_kv_cache,
remove_input_padding=remove_input_padding,
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
)

tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
Expand Down
28 changes: 25 additions & 3 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tensorrt_llm
import torch
from tensorrt_llm import str_dtype_to_trt
from tensorrt_llm._common import check_max_num_tokens
from tensorrt_llm._utils import np_dtype_to_trt
from tensorrt_llm.builder import BuildConfig, Builder
from tensorrt_llm.commands.build import build as build_trtllm
Expand Down Expand Up @@ -371,6 +372,12 @@ def build_and_save_engine(
lora_target_modules=None,
max_prompt_embedding_table_size=0,
enable_multi_block_mode: bool = False,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
max_num_tokens: int = None,
opt_num_tokens: int = None,
max_beam_width: int = 1,
tokens_per_block: int = 128,
):
try:
model_cls = getattr(tensorrt_llm.models, model_config.architecture)
Expand All @@ -383,15 +390,30 @@ def build_and_save_engine(
plugin_config.set_gpt_attention_plugin(dtype=str_dtype)
plugin_config.set_gemm_plugin(dtype=str_dtype)
plugin_config.set_plugin("multi_block_mode", enable_multi_block_mode)
max_num_tokens = max_batch_size * max_input_len
if paged_kv_cache:
plugin_config.enable_paged_kv_cache(tokens_per_block=tokens_per_block)
else:
plugin_config.paged_kv_cache = False
plugin_config.remove_input_padding = remove_input_padding

max_num_tokens, opt_num_tokens = check_max_num_tokens(
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_beam_width=max_beam_width,
remove_input_padding=remove_input_padding,
enable_context_fmha=plugin_config.context_fmha,
tokens_per_block=tokens_per_block,
)

build_dict = {
'max_input_len': max_input_len,
'max_output_len': max_output_len,
'max_batch_size': max_batch_size,
'max_beam_width': 1,
'max_beam_width': max_beam_width,
'max_num_tokens': max_num_tokens,
'opt_num_tokens': None,
'opt_num_tokens': opt_num_tokens,
'max_prompt_embedding_table_size': max_prompt_embedding_table_size,
'gather_context_logits': False,
'gather_generation_logits': False,
Expand Down
53 changes: 45 additions & 8 deletions scripts/deploy/nlp/deploy_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

def get_args(argv):
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=f"Deploy nemo models to Triton",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=f"Deploy nemo models to Triton",
)
parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source .nemo file")
parser.add_argument(
Expand Down Expand Up @@ -73,18 +74,20 @@ def get_args(argv):
parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model")
parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model")
parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model")
parser.add_argument("-mnt", "--max_num_tokens", default=None, type=int, help="Max number of tokens")
parser.add_argument("-ont", "--opt_num_tokens", default=None, type=int, help="Optimum number of tokens")
parser.add_argument(
"-mpet", "--max_prompt_embedding_table_size", default=None, type=int, help="Max prompt embedding table size"
)
parser.add_argument(
"-upkc", "--use_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache."
)
parser.add_argument(
"-dcf",
"--disable_context_fmha",
"-drip",
"--disable_remove_input_padding",
default=False,
action='store_true',
help="Disable fused Context MultiHeadedAttention (required for V100 support).",
help="Disables the remove input padding option.",
)
parser.add_argument(
"-mbm",
Expand All @@ -101,15 +104,23 @@ def get_args(argv):
'--use_lora_plugin',
nargs='?',
const=None,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help="Activates the lora plugin which enables embedding sharing.",
)
parser.add_argument(
'--lora_target_modules',
nargs='+',
default=None,
choices=["attn_qkv", "attn_q", "attn_k", "attn_v", "attn_dense", "mlp_h_to_4h", "mlp_gate", "mlp_4h_to_h",],
choices=[
"attn_qkv",
"attn_q",
"attn_k",
"attn_v",
"attn_dense",
"mlp_h_to_4h",
"mlp_gate",
"mlp_4h_to_h",
],
help="Add lora in which modules. Only be activated when use_lora_plugin is enabled.",
)
parser.add_argument(
Expand Down Expand Up @@ -198,6 +209,29 @@ def nemo_deploy(argv):
trt_llm_exporter = TensorRTLLM(model_dir=trt_llm_path, lora_ckpt_list=args.lora_ckpt)

if args.nemo_checkpoint is not None:

trt_llm_exporter.export(
nemo_checkpoint_path=args.nemo_checkpoint,
model_type=args.model_type,
n_gpus=args.num_gpus,
tensor_parallel_size=args.num_gpus,
pipeline_parallel_size=1,
max_input_token=args.max_input_len,
max_output_token=args.max_output_len,
max_batch_size=args.max_batch_size,
max_num_tokens=args.max_num_tokens,
opt_num_tokens=args.opt_num_tokens,
max_prompt_embedding_table_size=args.max_prompt_embedding_table_size,
paged_kv_cache=args.use_paged_kv_cache,
remove_input_padding=(not args.disable_remove_input_padding),
dtype=args.dtype,
enable_multi_block_mode=args.multi_block_mode,
use_lora_plugin=args.use_lora_plugin,
lora_target_modules=args.lora_target_modules,
max_lora_rank=args.max_lora_rank,
save_nemo_model_config=True,
)

try:
LOGGER.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.")
trt_llm_exporter.export(
Expand All @@ -209,9 +243,11 @@ def nemo_deploy(argv):
max_input_token=args.max_input_len,
max_output_token=args.max_output_len,
max_batch_size=args.max_batch_size,
max_num_tokens=args.max_num_tokens,
opt_num_tokens=args.opt_num_tokens,
max_prompt_embedding_table_size=args.max_prompt_embedding_table_size,
paged_kv_cache=args.use_paged_kv_cache,
enable_context_fmha=not args.disable_context_fmha,
remove_input_padding=(not args.disable_remove_input_padding),
dtype=args.dtype,
enable_multi_block_mode=args.multi_block_mode,
use_lora_plugin=args.use_lora_plugin,
Expand All @@ -236,7 +272,8 @@ def nemo_deploy(argv):
)
)
trt_llm_exporter.add_prompt_table(
task_name=str(task_id), prompt_embeddings_checkpoint_path=prompt_embeddings_checkpoint_path,
task_name=str(task_id),
prompt_embeddings_checkpoint_path=prompt_embeddings_checkpoint_path,
)
except Exception as error:
LOGGER.error("An error has occurred during adding the prompt embedding table(s). Error message: " + str(error))
Expand Down
19 changes: 12 additions & 7 deletions scripts/export/export_to_trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,20 @@ def get_args(argv):
parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model")
parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model")
parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model")
parser.add_argument("-mnt", "--max_num_tokens", default=None, type=int, help="Max number of tokens")
parser.add_argument("-ont", "--opt_num_tokens", default=None, type=int, help="Optimum number of tokens")
parser.add_argument(
"-mpet", "--max_prompt_embedding_table_size", default=None, type=int, help="Max prompt embedding table size"
)
parser.add_argument(
"-uib",
"--use_inflight_batching",
default=False,
action='store_true',
help="Enable inflight batching for TensorRT-LLM Triton backend.",
"-upkc", "--use_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache."
)
parser.add_argument(
"-upkc", "--use_paged_kv_cache", default=False, action='store_true', help="Enable paged kv cache."
"-drip",
"--disable_remove_input_padding",
default=False,
action='store_true',
help="Disables the remove input padding option.",
)
parser.add_argument(
"-mbm",
Expand Down Expand Up @@ -141,9 +143,12 @@ def nemo_export_trt_llm(argv):
max_input_token=args.max_input_len,
max_output_token=args.max_output_len,
max_batch_size=args.max_batch_size,
max_num_tokens=args.max_num_tokens,
opt_num_tokens=args.opt_num_tokens,
max_prompt_embedding_table_size=args.max_prompt_embedding_table_size,
use_inflight_batching=args.use_inflight_batching,
paged_kv_cache=args.use_paged_kv_cache,
remove_input_padding=(not args.disable_remove_input_padding),
dtype=args.dtype,
enable_multi_block_mode=args.multi_block_mode,
use_lora_plugin=args.use_lora_plugin,
lora_target_modules=args.lora_target_modules,
Expand Down
2 changes: 2 additions & 0 deletions tests/export/test_nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def run_trt_llm_inference(
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_num_tokens=int(max_input_token * max_batch_size * 0.2),
opt_num_tokens=60,
save_nemo_model_config=True,
)

Expand Down

0 comments on commit 2e1814c

Please sign in to comment.