Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions .github/workflows/example_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
strategy:
fail-fast: false
matrix:
example: [llm_distill, llm_qat, llm_sparsity, speculative_decoding]
example: [llm_distill, llm_qat, llm_sparsity]
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
Expand All @@ -77,7 +77,7 @@ jobs:
strategy:
fail-fast: false
matrix:
example: [llm_distill, llm_qat, llm_sparsity, speculative_decoding]
example: [llm_distill, llm_qat, llm_sparsity]
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
Expand All @@ -86,6 +86,28 @@ jobs:
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-h100-latest-2

##### Speculative Decoding Example Tests (requires 25.08 image) #####
speculative-decoding-pr:
needs: [check-file-changes, wait-checks]
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
example: speculative_decoding
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-l4-latest-1

speculative-decoding-non-pr:
if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }}
uses: ./.github/workflows/_example_tests_runner.yml
secrets: inherit
with:
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
example: speculative_decoding
pip_install_extras: "[hf,dev-test]"
runner: linux-amd64-gpu-h100-latest-2

##### TensorRT-LLM Example Tests #####
trtllm-pr:
needs: [check-file-changes, wait-checks]
Expand Down Expand Up @@ -150,14 +172,15 @@ jobs:
example-pr-required-check:
# Run even if example tests are skipped
if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }}
needs: [check-file-changes, torch-pr, trtllm-pr, onnx-pr]
needs: [check-file-changes, torch-pr, speculative-decoding-pr, trtllm-pr, onnx-pr]
runs-on: ubuntu-latest
steps:
- name: Required GPU tests did not succeed
if: |
needs.check-file-changes.result != 'success' ||
(needs.check-file-changes.outputs.any_changed == 'true' && (
needs.torch-pr.result != 'success' ||
needs.speculative-decoding-pr.result != 'success' ||
needs.trtllm-pr.result != 'success' ||
needs.onnx-pr.result != 'success'
))
Expand Down
8 changes: 3 additions & 5 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,

### Docker

Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.

Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies.

Expand All @@ -56,7 +56,7 @@ See [other-datasets](#other-datasets) section for other dataset options and inst
## Getting Started: Simplified Workflow

```bash
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct
```

This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
Expand All @@ -74,12 +74,11 @@ For small base models that fit in GPU memory, we can collocate them with draft m
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data input_conversations/daring-anteater.jsonl \
--num_gpu $NUM_GPU \
--num_epochs $NUM_EPOCH \
--eagle_config eagle_config.json
```

This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details.
FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`.
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.

## Training Draft Model with Offline Base Model
Expand Down Expand Up @@ -118,7 +117,6 @@ Once we finish dumping hidden states, launch offline training with an extra `--o
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data $DATA \
--num_gpu $NUM_GPU \
--num_epochs $NUM_EPOCH \
--eagle_config eagle_config.json \
--offline-data $HIDDEN_STATES_DIR
Expand Down
149 changes: 149 additions & 0 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import os
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from types import FrameType
from typing import Any

import numpy as np
import torch
import transformers
from datasets import load_dataset
from packaging.version import Version
from PIL import Image
from scripts.ar_validate import validate_ar
from torch.distributed.tensor.experimental._attention import _SDPAMerger
from torch.utils.data import Dataset
from transformers import AutoProcessor, Trainer, TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother

import modelopt
from modelopt.torch.speculative.utils import get_ttt_msk_func
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import is_master

Expand Down Expand Up @@ -566,3 +576,142 @@ def on_step_end(self, args, state, control, **kwargs):
except Exception:
print_rank_0("AR validation not available.")
return control


def get_patched_templated_ring_attn(orig_templated_attn: Callable):
"""
Return patched version of
torch.distributed.tensor.experimental._attention._templated_ring_attention
to support TTT.
"""

def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
"""Get chunk-interleaved TTT mask for current rank.
e.g.:
2 ranks, ttt_step=1;
full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0],
[x, 0, 0, 0, 0, x, 0, 0],
[x, x, 0, 0, 0, 0, x, 0],
[x, x, x, 0, 0, 0, 0, x],

rank 0, step0: [[0, 0, x, 0],
[x, 0, 0, x]]

rank 1, step0: [[0, 0, x, 0],
[x, 0, 0, x]]

rank 0, step1: [[0, 0, 0, 0],
[0, 0, 0, 0]]

rank 1, step1: [[x, x, 0, 0],
[x, x, 0, 0]]

"""
device = torch.cuda.current_device()
q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device)
kv_indices = (
torch.arange(q_len * size * (ttt_step + 1), device=device)
.view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :]
.reshape(-1)
)
msk_func = get_ttt_msk_func(q_len * size, ttt_step)
attn_mask = msk_func(
None,
None,
q_indices.view(1, 1, -1, 1),
kv_indices.view(1, 1, 1, -1),
)
attn_bias = torch.where(
attn_mask,
torch.zeros((), dtype=dtype, device=attn_mask.device),
torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device),
)

return attn_bias

def patched_templated_attn(*args, **kwargs):
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
# Get original attention op
# Sensitive to impl of _templated_ring_attention
original_op = args[2]

# This patch is only enabled for eagle model by context manager, not base model.
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH

if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")

# Unset is_causal to use custom attn mask
if patch_enbabled:
kwargs["is_causal"] = False

def patched_op(*args, **kwargs):
# Inspect the parent frame to get current shard info
# This is sensitive to torch _templated_ring_attention impl
try:
frame: FrameType = inspect.currentframe()
f_back: FrameType = frame.f_back
rank = f_back.f_locals["rank"]
size = f_back.f_locals["size"]
query = f_back.f_locals["query"]
key = f_back.f_locals["key"]
i = f_back.f_locals["i"]
ttt_step = (key.shape[2] // query.shape[2]) - 1
except Exception as e:
raise RuntimeError(
f"Failed to capture loop variables in patched _templated_ring_attention: {e}"
) from e
# Set attn mask to permuted TTT mask
if "attn_bias" in kwargs:
kwargs["attn_bias"] = _get_sharded_ttt_msk(
i, rank, size, query.shape[2], ttt_step, query.dtype
)
# Perform shard attention
return original_op(*args, **kwargs)

return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs)

return patched_templated_attn


def patch_ring_attention_for_ttt():
"""Patch torch ring attention to support context parallelism for TTT."""
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.

if not (
Version(torch.__version__) > Version("2.7.1")
and Version(torch.__version__) < Version("2.9.0")
):
raise RuntimeError(
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
f"Got {torch.__version__}. "
f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1."
)

# 1. Disable load balance, which is designed for causal mask.
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False

# 2. Patch templated ring attention for TTT mask.
original_templated_ring_attention = (
torch.distributed.tensor.experimental._attention._templated_ring_attention
)
original_templated_ring_attention_backward = (
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
)
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
get_patched_templated_ring_attn(original_templated_ring_attention)
)
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
)

# 3. Patch merger to skip the blank shard to avoid difference in output.
original_sdpa_merger_step = _SDPAMerger.step

def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
if lse.sum() <= 0:
return
return original_sdpa_merger_step(self, out, lse, partial)

_SDPAMerger.step = patched_sdpa_merger_step
1 change: 1 addition & 0 deletions examples/speculative_decoding/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"fsdp_version":2}
39 changes: 23 additions & 16 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_CONFIG="${1#*=}"
;;
--fsdp_transformer_layer_cls_to_wrap*)
if [[ "$1" != *=* ]]; then shift; fi
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
;;
--num_gpu*)
if [[ "$1" != *=* ]]; then shift; fi
NUM_GPU="${1#*=}"
;;
--disable_tqdm*)
if [[ "$1" != *=* ]]; then shift; fi
DISABLE_TQDM="${1#*=}"
Expand All @@ -102,6 +94,14 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
AR_VALIDATE_STEPS="${1#*=}"
;;
--cp_size*)
if [[ "$1" != *=* ]]; then shift; fi
CP_SIZE="${1#*=}"
;;
--dp_size*)
if [[ "$1" != *=* ]]; then shift; fi
DP_SHARD_SIZE="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -129,15 +129,15 @@ LR=${LR:-"1e-4"}
TRAIN_BS=${TRAIN_BS:-4}
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
NUM_GPU=${NUM_GPU:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
DISABLE_TQDM=${DISABLE_TQDM:-False}
VLM_PROCESSOR=${VLM_PROCESSOR:-}
VLM_IMG_DIR=${VLM_IMG_DIR:-}
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
ESTIMATE_AR=${ESTIMATE_AR:-False}
CP_SIZE=${CP_SIZE:-1}
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}

if [[ "$MODE" == "medusa" ]]; then
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
Expand All @@ -163,21 +163,25 @@ else
OFFLINE_TRAINING_ARGS=""
fi

if [[ "$NUM_GPU" == 1 ]]; then
MULTI_GPU=""
else
MULTI_GPU="--multi_gpu"
fi

if [[ "$VLM_PROCESSOR" != "" ]]; then
VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR"
else
VLM_ARGS=""
fi

if [[ "$GPU_COUNT" -gt 1 ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
fi


# Disable tokenizers parallelism to avoid warning
export TOKENIZERS_PARALLELISM=False
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
CMD="accelerate launch --mixed_precision bf16 main.py \
--mode $MODE \
--eagle_decoder_type $EAGLE_DECODER_TYPE \
--model_name_or_path $MODEL \
Expand Down Expand Up @@ -206,6 +210,9 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
$SPECULATIVE_ARGS \
$FSDP_ARGS \
--cp_size $CP_SIZE \
--dp_shard_size $DP_SHARD_SIZE \
"

start_time=$(date +%s)
Expand Down
Loading