Skip to content

chore: [DEMONSTRATION ONLY] 1st Mass integration of release/0.19 #3850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3471d6c
chore: bump version to 0.19.0 (#3598)
ZhanruiSunCh Apr 16, 2025
715428c
test: add test cases for 0.19 release (#3608)
crazydemo Apr 16, 2025
e36092b
squash (#3642)
syuoni Apr 17, 2025
5cc1d38
fix: nvbugs/5187237: fix deterministic mode crash (#3448)
VALLIS-NERIA Apr 17, 2025
458203d
update fp8 doc (#3647)
litaotju Apr 17, 2025
b1a65c0
tests: change qa perf test to trtllm-bench (#3619)
ruodil Apr 17, 2025
c8cea30
fix: FP8 quantized lm_head (NvBug 5214229) (#3567)
syuoni Apr 17, 2025
56c9dd4
infra: Add PR approval protection for the release branch (#3634)
chzblych Apr 18, 2025
5bf8fdc
fix: nvbugs/5231298: pytorch allreduce issue (#3673)
VALLIS-NERIA Apr 18, 2025
1c6e85b
Fix: nvbugs/5222698 variable not defined (#3630)
zongfeijing Apr 18, 2025
07688cd
test:sync waives.txt from main branch by disabling test_perf/gpt_350m…
nv-guomingz Apr 18, 2025
c70b24c
test:restore fp8 kv cache testing for L0 (#3671)
nv-guomingz Apr 19, 2025
422c1b3
doc: Update DeepSeek perf docs (#3693)
kaiyux Apr 19, 2025
fb8ddfa
tests: waive test_llm_multi_node (#3664)
QiJune Apr 17, 2025
a04b585
fix: update test_user_buffers_mm_add_prologue atol (#3711)
liji-nv Apr 21, 2025
b3ce638
Fix: cherry-pick hmac encryption from main branch (#3635)
yibinl-nvidia Apr 21, 2025
8f17f3f
Un-waive DS-V3-Lite tests. (#3621)
Tracin Apr 21, 2025
8a8a55a
fix: FP8 kv accuracy (#3675)
DylanChen-NV Apr 21, 2025
81d1f4f
Fix script options for engines. (#3622)
Tracin Apr 21, 2025
e69d7bb
unwaive multi-node test (#3721)
Superjomn Apr 21, 2025
e19309c
chore : Split more tests out of gpt tests (#3524) (#3674)
peaceh-nv Apr 22, 2025
ba15155
doc:add torch examples link into torch backend documentation (#3749)
nv-guomingz Apr 22, 2025
611ef8e
test: Get Eagle tests working (#3593) (#3722)
yweng0828 Apr 22, 2025
52e6702
Waive L0 test (#3756)
yiqingy0 Apr 22, 2025
793d010
waive failed case in perf test, change default max_batch_size to 512 …
ruodil Apr 22, 2025
792b71f
Update ds v3 parameters in stress test. (#3676)
dominicshanshan Apr 22, 2025
b11cb2f
waive gemma on L20 (#3766)
crazydemo Apr 22, 2025
a824946
https://nvbugs/5141291: Fix convert.py script for Qwen model. (#3758)
hyukn Apr 22, 2025
851e2f5
fix: PP4 fixes and cleanup (#3688)
amukkara Apr 23, 2025
3e56e40
remove benchmark test list (#3643)
crazydemo Apr 23, 2025
f08e599
skip disagg deepseek test if sm!=90 (#3720)
chuangz0 Apr 23, 2025
b0ac7c9
test: skip failed cases on B200 (#3710)
xinhe-nv Apr 23, 2025
f056d44
test: [nvbug: 5234494] skip_pre_ada for fp8 cases (#3718)
crazydemo Apr 23, 2025
2f02263
add know issue to deepseek doc. (#3800)
lfr-0531 Apr 23, 2025
41b0371
Fix ModelOpt Mixtral AWQ OOM (#3714) (#3761)
Barry-Delaney Apr 23, 2025
e0691e6
Waive L0 tests (#3826)
yiqingy0 Apr 24, 2025
ab2f663
fix: Reduce memory usage in fused moe op associated with AutoTuning a…
hyukn Apr 24, 2025
33c4d49
[doc] Better document for Draft-Target-Model (DTM) speculative decodi…
wili-65535 Apr 24, 2025
19bd6f8
chore: [DEMONSTRATION ONLY] 1st Mass integration of release/0.19
tongyuantongyu Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# This file defines code ownership rules for the repository.
# The rule below requires that any PR to release/**/* branches must be approved by at least one member
# of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR.
# Without approval from a member of this team, PRs cannot be merged to release branches.
* @NVIDIA/trt-llm-release-branch-approval
13 changes: 9 additions & 4 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,8 @@ int AttentionOp::mlaGeneration(
params.quant_scale_kv = generation_params.kv_scale_orig_quant;
params.dequant_scale_q = generation_params.kv_scale_quant_orig;
params.dequant_scale_kv = generation_params.kv_scale_quant_orig;
params.host_bmm1_scale = 1 / (sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
params.host_bmm1_scale
= 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));

invokeMLARopeGeneration<T>(params, kv_cache_buffer, stream);
sync_check_cuda_error(stream);
Expand Down Expand Up @@ -1001,9 +1002,13 @@ int AttentionOp::mlaGeneration(
tllmRunnerParams.mSfStartTokenIdx = generation_params.start_token_idx_sf;

// Scales for quantization
static constexpr int bmm1_scale_offset = 1;
tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(params.bmm2_scale);
tllmRunnerParams.scaleSoftmaxLog2Ptr = reinterpret_cast<float const*>(params.bmm1_scale) + bmm1_scale_offset;
if (mFP8GenerationMLA)
{
static constexpr int bmm1_scale_offset = 1;
tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(params.bmm2_scale);
tllmRunnerParams.scaleSoftmaxLog2Ptr
= reinterpret_cast<float const*>(params.bmm1_scale) + bmm1_scale_offset;
}

TLLM_CHECK_WITH_INFO(mTllmGenFMHARunner.get(), "mTllmGenFMHARunner not initialized.");
mTllmGenFMHARunner->run(tllmRunnerParams);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ __global__ void lamport_initialize_kernel(float* ptr, int size)

void lamport_initialize(void* ptr, int bytes, cudaStream_t stream)
{
lamport_initialize_kernel<<<bytes / 128, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
int grid_size = (bytes + 127) / 128;
lamport_initialize_kernel<<<grid_size, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
}

Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,10 @@ void residualRmsNorm(
void lamportInitialize(void* buffer, size_t size, nvinfer1::DataType dataType, cudaStream_t stream)
{
sync_check_cuda_error(stream);
if (size == 0)
{
return;
}
switch (dataType)
{
case nvinfer1::DataType::kFLOAT:
Expand Down
31 changes: 17 additions & 14 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,12 @@ class FusedMoeRunner : public torch::CustomClassHolder
torch::optional<c10::ArrayRef<int64_t>> profile_ids)
{
// Free the profile workspace to save memory
if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(
cu_free_status == cudaSuccess, "Can't free profile workspace for MoE GEMM profile before runMoe.");
mProfileWorkspace = nullptr;
}
freeProfileWorkspace();

std::lock_guard<std::mutex> lock(mMutex);

TORCH_CHECK(cluster_size == 1 && cluster_rank == 0, "smart_router is supported in min_latency mode");

CHECK_INPUT(input, mActivationDtype)
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
if (token_final_scales)
Expand Down Expand Up @@ -251,6 +246,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
{
std::lock_guard<std::mutex> lock(mMutex);

// Free the profile workspace to save memory
freeProfileWorkspace();

CHECK_INPUT(input, mActivationDtype)
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
if (token_final_scales)
Expand Down Expand Up @@ -381,13 +379,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
hidden_size, inter_size, GROUP_SIZE, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA,
min_latency_mode, parallelism_config);

if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(cu_free_status == cudaSuccess,
"Can't free profile workspace for MoE GEMM profile during memory reallocation.");
mProfileWorkspace = nullptr;
}
freeProfileWorkspace();
size_t profile_workspace_size = mProfiler->getWorkspaceSize(num_rows);
auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size);
TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile.");
Expand Down Expand Up @@ -422,6 +414,17 @@ class FusedMoeRunner : public torch::CustomClassHolder
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
std::vector<Profile> mAllProfiles;

void freeProfileWorkspace()
{
if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(cu_free_status == cudaSuccess,
"Can't free profile workspace for MoE GEMM profile during memory reallocation.");
mProfileWorkspace = nullptr;
}
}

void setRunnerProfiles(torch::optional<c10::ArrayRef<int64_t>> profile_ids)
{
if (mUseFp8BlockScaling)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,33 @@ NVIDIA has announced world-record DeepSeek-R1 inference performance at NVIDIA GT

In this blog, we share the configurations and procedures about how to reproduce the number on both B200 and H200 with PyTorch workflow.

## Table of Contents

- [How to get best performance on DeepSeek-R1 in TensorRT-LLM](#how-to-get-best-performance-on-deepseek-r1-in-tensorrt-llm)
- [Table of Contents](#table-of-contents)
- [Prerequisites: Install TensorRT-LLM and download models](#prerequisites-install-tensorrt-llm-and-download-models)
- [1. Download TensorRT-LLM](#1-download-tensorrt-llm)
- [2. Download the DeepSeek R1 models](#2-download-the-deepseek-r1-models)
- [3. Build and run TensorRT-LLM container](#3-build-and-run-tensorrt-llm-container)
- [4. Compile and Install TensorRT-LLM](#4-compile-and-install-tensorrt-llm)
- [5. Optional: Tune GPU clocks](#5-optional-tune-gpu-clocks)
- [6. Dataset preparation](#6-dataset-preparation)
- [Reproducing steps](#reproducing-steps)
- [B200 min-latency](#b200-min-latency)
- [Expected Results](#expected-results)
- [B200 max-throughput](#b200-max-throughput)
- [Benchmark](#benchmark)
- [Expected Result Format](#expected-result-format)
- [H200 min-latency](#h200-min-latency)
- [Expected Result Format](#expected-result-format-1)
- [H200 max-throughput](#h200-max-throughput)
- [Expected Result Format](#expected-result-format-2)
- [Exploring more ISL/OSL combinations](#exploring-more-islosl-combinations)
- [WIP: Enable more features by default](#wip-enable-more-features-by-default)
- [WIP: Chunked context support on DeepSeek models](#wip-chunked-context-support-on-deepseek-models)
- [Out of memory issues](#out-of-memory-issues)


## Prerequisites: Install TensorRT-LLM and download models

This section can be skipped if you already have TensorRT-LLM installed and have already downloaded the DeepSeek R1 model checkpoint.
Expand Down Expand Up @@ -324,3 +351,25 @@ Total Token Throughput (tokens/sec): 15707.0888
Total Latency (ms): 993548.8470
Average request latency (ms): 197768.0434
```

## Exploring more ISL/OSL combinations

To benchmark TensorRT-LLM on DeepSeek models with more ISL/OSL combinations, you can use `prepare_dataset.py` to generate the dataset and use similar commands mentioned in the previous section. TensorRT-LLM is working on enhancements that can make the benchmark process smoother.
### WIP: Enable more features by default

Currently, there are some features that need to be enabled through a user-defined file `extra-llm-api-config.yml`, such as CUDA graph, overlap scheduler and attention dp. We're working on to enable those features by default, so that users can get good out-of-the-box performance on DeepSeek models.

Note that, `max_batch_size` and `max_num_tokens` can easily affect the performance. The default values for them are already carefully designed and should deliver good performance on overall cases, however, you may still need to tune it for peak performance.

Generally, you should make sure that `max_batch_size` is not too low to bottleneck the throughput, and `max_num_tokens` needs to be large enough so that it covers the max input sequence length of the samples in dataset, as mentioned in below section "WIP: Chunked context support on DeepSeek models".

For more details on `max_batch_size` and `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md).

### WIP: Chunked context support on DeepSeek models

TensorRT-LLM team is actively working on chunked context support for DeepSeek models. Because of that missing feature, there is currently a limitation that `max_num_tokens` has to be at least larger than the max input sequence length of the samples in dataset.
For more details on `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md).

### Out of memory issues

It's possible seeing OOM issues on some cases. Considering reducing `kv_cache_free_gpu_mem_fraction` to a smaller value as a workaround. We're working on the investigation and addressing the problem.
3 changes: 2 additions & 1 deletion docs/source/torch.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --ex

- [Architecture Overview](./torch/arch_overview.md)
- [Adding a New Model](./torch/adding_new_model.md)
- [Examples](../../examples/pytorch/README.md)

## Key Components

Expand All @@ -50,4 +51,4 @@ scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --ex

## Known Issues

- The PyTorch workflow on SBSA is incompatible with bare metal environments like Ubuntu 24.04. Please use the [PyTorch NGC Container (https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for optimal support on SBSA platforms.
- The PyTorch workflow on SBSA is incompatible with bare metal environments like Ubuntu 24.04. Please use the [PyTorch NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for optimal support on SBSA platforms.
Loading