Skip to content

Commit 8a2d483

Browse files
Eagle3 Training (#143)
This pr introduces Eagle3 Model training into the speculators repo. The implementation is specific to Eagle3 but designed in a way that enables future generalization to other speculative decoding algorithms. # Components <img width="1418" height="734" alt="Eagle3 Training Components" src="https://github.com/user-attachments/assets/418a7d1f-0078-412a-ae56-a6427b756a05" /> ## Example training script (~`scripts/train_llama3_8b_drafter.py`~ `scripts/train.py`) Shows how to setup and run training. ~Currently specific to the `meta-llama/Llama-3.1-8B-Instruct` model but doesn't require many changes to run with a different model. Just need to update ` VERIFIER_MODEL_NAME_OR_PATH = "meta-llama/Llama-3.1-8B-Instruct" HIDDEN_SIZE = 4096 # Must match the verifier model's hidden size VERIFIER_VOCAB_SIZE = 128256 # Must match the verifier model's vocab size `~ **Update:** I've generalize the training script. It now has a required cli arg `--verifier_name_or_path` and supports arbitrary verifier models. Note: this uses `LlamaConfig.from_pretrained(args.verifier_name_or_path)` under the hood, which does work for non-llama models (e.g. a Qwen model) but prints a warning and may not work for every type of verifier. You will also need to pass in a dataset and `t2d` / `d2t` tensors which correspond to the verifier you are using. ## Flex Attention Files: - `src/speculators/train/eagle3/attention.py` - `tests/unit/train/test_eagle3_attention.py` The training code uses Flex attention which provides substantial speed ups and memory efficiency over the full dense attention operations. Functions: - create_combined_mask_mod(lengths, total_seq_len): This function creates the mask function used by flex attention. - extend_mask_for_draft_tokens(block_mask): Helper function to extend the block mask without needed to check each new squares mask value - block_mask_to_dense_attention_mask: Only used for debugging purposes - flex_attention_forward: lightweight wrapper around flex attention call ## Data processing <img width="4531" height="2384" alt="Eagle3 Data Flow" src="https://github.com/user-attachments/assets/b972ef8c-92d4-4d46-969f-66d33f801ceb" /> Files: - `src/speculators/train/data.py` Data is currently expected in the format of 1 file per data sample. We load these samples and perform a shift to align `input_ids, hidden_states, loss_mask, verifier_last_hidden_state` correctly. We also automatically collate these samples into batches. Rather than padding and wasting compute on padded tokens, we instead concatenate the sequences along the sequence dimension, keeping track of the boundaries between sequences and setting the attention mask accordingly. ## Batch sampling Files: - `src/speculators/train/distributed_batch_sampler.py` - `src/speculators/train/data.py` Due to hardware limitations, we set a maximum sequence length for each batch. We would like each batch of data to be close in size this max length, so that each batch has a similar number of tokens. The way we achieve this is through the `MultipackDistributedBatchSamplerV2` taken from prior work I did on [instructlab/training](https://github.com/instructlab/training). This class produces indices of files that when batched together come close to reaching the max length without exceeding it. It also does this in a distributed aware manner so that there is no overlap in the data each rank sees. To run the packing algorithm, we need to know the lengths of each sample in the dataset. Unfortunately, this would require opening every file in the dataset which is expensive, so instead we approximate the lengths (`_compute_approx_lengths` in `data.py`) using the length of the first sample and the relative file sizes of samples. ## `Eagle3DraftModel` Files: - `src/speculators/train/eagle3/core.py` The draft model itself. Sets up and loads verifier components, as well as the draft layers / weights. Contains the model `forward()` pass which: - sets up the block mask for the batch - computes the target logits using the attached `verifier_lm_head`. Note: this is computed here for data storage efficiency reasons, as otherwise we would need to save the full logits: `[seq_len, vocab_size]` instead of the last layer hidden states: `[seq_len, hidden_size]` to disk. The verifier `vocab_size` is often > 100k whereas `hidden_size` might be around 4-8k. - For each ttt step: - Embeds tokens - concatenates with hidden_states - applies decoder layers - computes logits - computes loss and step accuracy - prepares next step tokens - Updates block mask ## Layer definitions Files: - `src/speculators/train/eagle3/model_definitions.py` Currently just contains model definitions for llama3 style draft models. Supports `norm_before_residual=True or False`. Attempted to keep modifications to the original llama models minimal. ## Distributed training via FSDP Files: - `src/speculators/train/utils.py` - `src/speculators/train/checkpointer.py` - `src/speculators/train/trainer.py` (`setup_model` fn) Full support for FSDP training by initializing the training script with `torchrun --nnodes --nproc_per_node=N` where `N` is the number of gpus. Tested with `N=2,3,4, 8` and all work. FSDP training also enables Automatic Mixed Precision (AMP) for improved performance. `checkpointer.py` contains checkpointing logic for FSDP distributed model weights (gather all weights on rank 0 before saving). Note: the way distributed works in general is `N` copies of the script are started and all run the same code but with some env variables setting which lets each process know its rank. Then explicit `dist.barrier()` calls or implicit calls within FSDP forward/backwards hooks force each process to wait until they all reach the same point in the code, before continuing. It is important that all ranks reach these operations as it allows them to perform synchronized operations (such as gathering, reducing, etc). However, we can also limit certain code to only one rank (rank 0) so that we only log once, or save to checkpoint once, using simple `if local_rank == 0` statements. ## Logging Files: - `src/speculators/train/logger.py` - `scripts/train.py`: (setup logger calls at start of `main()`) - `src/speculators/train/trainer.py` and other files: usage of `metric_logger` and `root_logger` Another implementation mostly copied from prior work I did on [instructlab/training](https://github.com/instructlab/training). This uses python's std library `logging` module and extends it to support training metric logging. We can log a nested dict of metrics anywhere in the codebase like so: ```python # Setup once import logging metric_logger = logging.getLogger("speculators.metrics") # Log call metric_logger.info( {"train": {"loss": loss.item(), **acc_values}, "epoch": epoch}, extra={"step": self.global_step}, ) ``` And when the user runs the training script they can select one (or multiple) of `tensorboard`, `wandb`, and `trackio` and the results will be logged to the respective experiment tracker. There is also a `root_logger` which can be used for regular update logging and everything logged to either the `root_logger` or `metric_logger` will be pretty-printed to console. ## `Trainer` Files: - `src/speculators/train/trainer.py` The `Trainer` class is initialized with the model, data loaders, and a config and: - Sets up model / optimizer (loads weights and configures distributed if needed) - Contains the training and validation loops (`train_epoch` and `val_epoch` respectively) - And the overall training loop which alternatives between training, validation, and saving checkpoints Todos: - [x] Eagle3Draft Model definition with TTT steps and loss calculations - [x] Patched Decoder layer definitions - [x] Simple data loading from sample files - [x] FlexAttention masking and implementation - [x] Loss Masking - [x] Training loop - [x] Train data loader - [x] `loss.backward()` + optimizer steps - [x] Distributed loss reduction - [x] Val data loader - [x] Metric collection/reporting - [x] Model checkpointing - [x] Data batching - [x] Collate fn - [x] Batch sampler (dynamic batch size through sample packing) - [x] Distributed (rank) aware sampling - [x] Distributed support - [ ] ~Code relocation / merging with existing definitions (Currently just have everything under `speculators/train` but this will need to change)~ FUTURE PR - [x] Verify correctness of key components (attention masking, data token alignment, etc). - [x] General testing Essential todos (as of 10/22/2025): - [x] Save checkpoints to safetensors format w/ required config info - [ ] ~Implement save best or save last logic (currently saving every epoch)~ FUTURE PR - [x] Better Verifier `lm_head`, `embed_tokens` loading (requires #144) - [x] `Eagle3DraftModel.__init__` signature cleanup/better configuration - [ ] ~Config/argparsing for `scripts/train.py`~ FUTURE PR - [x] Ensure flex attention impl works with `torch==2.9` and `torch.compile` - [x] Fix lint / quality / type errors and pass CI --------- Signed-off-by: Fynn Schmitt-Ulms <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]>
1 parent 751f3a0 commit 8a2d483

File tree

17 files changed

+2788
-3
lines changed

17 files changed

+2788
-3
lines changed

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ dependencies = [
4848
"pydantic>=2.0.0",
4949
"pydantic-settings>=2.0.0",
5050
"pyyaml>=6.0.0",
51+
"rich",
5152
"safetensors",
5253
"torch",
54+
"tqdm",
5355
"transformers",
5456
"typer-slim>=0.12.0",
5557
]
@@ -102,6 +104,7 @@ dev = [
102104
"types-PyYAML~=6.0.1",
103105
"types-requests~=2.32.0",
104106
"types-toml",
107+
"types-tqdm",
105108

106109
# link checking
107110
"mkdocs-linkcheck~=1.0.6",
@@ -211,9 +214,6 @@ select = [
211214
"UP", # pyupgrade: automatically upgrades syntax for newer versions of Python
212215
"W", # Warning: provides warnings about potential issues in the code
213216
"YTT", # flake8-2020: identifies code that will break with future Python releases
214-
215-
# Code Documentation
216-
"FIX", # flake8-fixme: detects FIXMEs and other temporary comments that should be resolved
217217
]
218218

219219
[tool.ruff.lint.extend-per-file-ignores]
@@ -251,6 +251,7 @@ select = [
251251
"PTH", # os.path is acceptable in data generation
252252
]
253253
"scripts/**/*.py" = [
254+
"INP001", # allow implicit namespace packages in scripts
254255
"PTH", # os.path is acceptable in scripts
255256
]
256257

scripts/TRAINING.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Eagle3 Training
2+
3+
`scripts/train.py` provides the main entry point for training Eagle3 models.
4+
5+
## Running the training script
6+
7+
To run in a multi-node distributed training setup with FSDP, the scripts should be launched with `torchrun`:
8+
```bash
9+
torchrun --nnodes=1 --nproc_per_node=<num_gpus> scripts/train.py
10+
```
11+
12+
For single GPU training (useful for debugging), the script can be run directly:
13+
```bash
14+
python scripts/train.py
15+
```
16+
17+
## Arguments
18+
The scripts has one required argument: `--verifier-name-or-path`, which is the name or path of the verifier model to use.
19+
20+
The scripts has the following optional arguments:
21+
- `--data-path`: The path to the data directory. Defaults to `./data`. The script will collect all `.pt` files in this directory or its subdirectories and use them as training data.
22+
- `--save-path`: The path to save the checkpoints. Defaults to `./checkpoints`. The script will create subdirectories for each epoch to save the model weights and optimizer states. e.g. `./checkpoints/0/`
23+
- `--epochs`: The number of epochs to train for. Defaults to 20.
24+
- `--lr`: The learning rate to use. Defaults to 1e-4.
25+
- `--no-resume-from-checkpoint`: If set, the script will not resume from the last checkpoint if it exists, and will instead start from scratch and overwrite existing checkpoints.
26+
- `--logger`: The logger to use. Defaults to empty string, which means no logging. Supported loggers are `trackio`, `wandb`, and `tensorboard`.
27+
- `--total-seq-len`: The total sequence length to use. Defaults to 8192.
28+
- `--data-format-version`: The version of the data format to use. Defaults to 1. The structure of the data to train on. `1` is the default and is the structure produced by Speculators generation scripts. `0` exists for backwards compatibility with the old data format.
29+
- `--log-dir`: The path to save the logs. Defaults to `./logs`.
30+
- `--run-name`: The name of the run. Defaults to None.
31+
- `--num-layers`: The number of layers to use. Defaults to 1.
32+
- `--d2t-path`: The path to the d2t tensor. Defaults to `d2t.npy`.
33+
- `--t2d-path`: The path to the t2d tensor. Defaults to `t2d.npy`.
34+
- `--ttt-steps`: The number of TTT steps to use. Defaults to 3.
35+
- `--ttt-step-loss-decay`: The loss decay factor to use for the TTT steps. Defaults to 1.0.
36+
37+
## Example run command
38+
```bash
39+
torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py \
40+
--verifier-name-or-path "meta-llama/Llama-3.1-8B" \
41+
--data-path "./data/llama-3.1-8b_sharegpt/gen/" \
42+
--save-path "./checkpoints/llama-3.1-8b.eagle3" \
43+
--epochs 10 \
44+
--lr 1e-4 \
45+
--no-resume-from-checkpoint \
46+
--logger "tensorboard" \
47+
--total-seq-len 8192 \
48+
--data-format-version 1 \
49+
--log-dir "./logs/llama-3.1-8b.eagle3" \
50+
--run-name "llama-3.1-8b.eagle3" \
51+
--num-layers 1 \
52+
--d2t-path "./data/llama-3.1-8b_sharegpt/d2t.npy" \
53+
--t2d-path "./data/llama-3.1-8b_sharegpt/t2d.npy" \
54+
--ttt-steps 3 \
55+
--ttt-step-loss-decay 1.0
56+
```

scripts/train.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import argparse
2+
3+
import numpy as np
4+
import torch
5+
from torch.utils.data import DataLoader
6+
from transformers import LlamaConfig
7+
8+
from speculators.config import SpeculatorsConfig, VerifierConfig
9+
from speculators.models.eagle3 import Eagle3SpeculatorConfig
10+
from speculators.proposals.greedy import GreedyTokenProposalConfig
11+
from speculators.train.data import (
12+
Eagle3SampleFileDataset,
13+
create_collate_fn,
14+
split_files,
15+
standardize_data_v0,
16+
standardize_data_v1,
17+
)
18+
from speculators.train.distributed_batch_sampler import (
19+
MultipackDistributedBatchSamplerV2,
20+
)
21+
from speculators.train.eagle3.core import Eagle3DraftModel
22+
from speculators.train.logger import setup_metric_logger, setup_root_logger
23+
from speculators.train.noise_transforms import AddUniformNoise
24+
from speculators.train.trainer import Trainer, TrainerConfig
25+
from speculators.train.utils import maybe_destroy_distributed, maybe_setup_distributed
26+
27+
# DRAFTER MODEL HYPARAMETERS
28+
NORM_BEFORE_RESIDUAL = True
29+
30+
# Dataloader
31+
NUM_WORKERS = 12
32+
PREFETCH_FACTOR = 4
33+
NOISE_STD = 0.05
34+
35+
36+
def setup_dataloader(
37+
file_list: list[str],
38+
world_size: int,
39+
local_rank: int,
40+
add_noise: bool = True,
41+
data_format_version: int = 1,
42+
) -> DataLoader:
43+
"""Setup dataloader for training.
44+
Args:
45+
file_list: List of file paths to load data from.
46+
world_size: Number of processes in the distributed training.
47+
local_rank: Rank of the current process.
48+
add_noise: Whether to add noise to the data.
49+
data_format_version: Version of the data format. Default is 1.
50+
Returns:
51+
DataLoader: Dataloader for training.
52+
"""
53+
if add_noise:
54+
noise_transform = AddUniformNoise(
55+
std=NOISE_STD, tensors=("hidden_states", "verifier_last_hidden_states")
56+
)
57+
else:
58+
noise_transform = None
59+
60+
standardize_fn = (
61+
standardize_data_v1 if data_format_version == 1 else standardize_data_v0
62+
)
63+
64+
dataset = Eagle3SampleFileDataset(
65+
file_list=file_list,
66+
max_len=args.total_seq_len,
67+
transform=noise_transform,
68+
standardize_fn=standardize_fn,
69+
)
70+
batch_sampler = MultipackDistributedBatchSamplerV2(
71+
batch_max_length=args.total_seq_len,
72+
lengths=dataset.approx_lengths,
73+
num_replicas=world_size,
74+
rank=local_rank,
75+
)
76+
return DataLoader(
77+
dataset,
78+
batch_sampler=batch_sampler,
79+
num_workers=NUM_WORKERS,
80+
prefetch_factor=PREFETCH_FACTOR,
81+
pin_memory=True,
82+
collate_fn=create_collate_fn(args.total_seq_len),
83+
persistent_workers=True,
84+
)
85+
86+
87+
def main(args: argparse.Namespace):
88+
# Setup logging
89+
setup_root_logger()
90+
setup_metric_logger(
91+
loggers=args.logger, run_name=args.run_name, output_dir=args.log_dir
92+
)
93+
94+
# Setup distributed training
95+
local_rank, world_size, rank, is_distributed = maybe_setup_distributed()
96+
device = torch.device(local_rank)
97+
98+
# Setup speculator config
99+
llama_config = LlamaConfig.from_pretrained(args.verifier_name_or_path)
100+
llama_config.num_hidden_layers = args.num_layers
101+
llama_config.model_type = "llama" # reset to llama (handles non-llama verifiers)
102+
llama_config._attn_implementation = "simple_flex_attention" # noqa: SLF001
103+
104+
# Load t2d and d2t tensors
105+
d2t = torch.from_numpy(np.load(args.d2t_path)).to(device)
106+
t2d = torch.from_numpy(np.load(args.t2d_path)).to(device)
107+
draft_vocab_size = d2t.shape[0]
108+
109+
speculator_config = Eagle3SpeculatorConfig(
110+
transformer_layer_config=llama_config,
111+
draft_vocab_size=draft_vocab_size,
112+
norm_before_residual=NORM_BEFORE_RESIDUAL,
113+
speculators_config=SpeculatorsConfig(
114+
algorithm="eagle3",
115+
proposal_methods=[
116+
GreedyTokenProposalConfig(
117+
proposal_type="greedy",
118+
speculative_tokens=args.ttt_steps,
119+
)
120+
],
121+
default_proposal_method="greedy",
122+
verifier=VerifierConfig(
123+
name_or_path=args.verifier_name_or_path,
124+
architectures=["LlamaForCausalLM"],
125+
),
126+
),
127+
)
128+
129+
# Setup draft model
130+
draft_model = Eagle3DraftModel(config=speculator_config, t2d=t2d, d2t=d2t)
131+
132+
# Setup dataloaders
133+
train_files, val_files = split_files(args.data_path, ratio=0.9)
134+
train_loader = setup_dataloader(
135+
train_files,
136+
world_size,
137+
local_rank,
138+
add_noise=True,
139+
data_format_version=args.data_format_version,
140+
)
141+
val_loader = setup_dataloader(
142+
val_files,
143+
world_size,
144+
local_rank,
145+
add_noise=False,
146+
data_format_version=args.data_format_version,
147+
)
148+
149+
# Setup trainer
150+
trainer_config = TrainerConfig(
151+
num_epochs=args.epochs,
152+
save_path=args.save_path,
153+
lr=args.lr,
154+
resume_from_checkpoint=not args.no_resume_from_checkpoint,
155+
is_distributed=is_distributed,
156+
local_rank=local_rank,
157+
train_call_kwargs={
158+
"use_off_policy_tokens": False,
159+
"ttt_steps": args.ttt_steps,
160+
"ttt_step_loss_decay": args.ttt_step_loss_decay,
161+
},
162+
val_call_kwargs={
163+
"use_off_policy_tokens": False,
164+
"ttt_steps": args.ttt_steps,
165+
"ttt_step_loss_decay": args.ttt_step_loss_decay,
166+
},
167+
)
168+
trainer = Trainer(draft_model, trainer_config, train_loader, val_loader)
169+
170+
# Run training
171+
trainer.run_training()
172+
173+
# Cleanup
174+
maybe_destroy_distributed()
175+
176+
177+
def parse_args():
178+
parser = argparse.ArgumentParser()
179+
parser.add_argument("--verifier-name-or-path", type=str, required=True)
180+
parser.add_argument("--data-path", type=str, default="./data")
181+
parser.add_argument("--save-path", type=str, default="./checkpoints")
182+
parser.add_argument("--epochs", type=int, default=20)
183+
parser.add_argument("--lr", type=float, default=1e-4)
184+
parser.add_argument("--no-resume-from-checkpoint", action="store_true")
185+
parser.add_argument(
186+
"--logger",
187+
type=str,
188+
default="",
189+
help="One of 'trackio', 'wandb', 'tensorboard' or comma separated list of them",
190+
)
191+
parser.add_argument("--total-seq-len", type=int, default=8192)
192+
parser.add_argument("--data-format-version", type=int, default=1)
193+
parser.add_argument("--log-dir", type=str, default="./logs")
194+
parser.add_argument("--run-name", type=str, default=None)
195+
parser.add_argument("--num-layers", type=int, default=1)
196+
parser.add_argument("--d2t-path", type=str, default="d2t.npy")
197+
parser.add_argument("--t2d-path", type=str, default="t2d.npy")
198+
parser.add_argument("--ttt-steps", type=int, default=3)
199+
parser.add_argument("--ttt-step-loss-decay", type=float, default=1.0)
200+
return parser.parse_args()
201+
202+
203+
if __name__ == "__main__":
204+
args = parse_args()
205+
main(args)
206+
207+
208+
# RUN WITH:
209+
# torchrun --nnodes=1 --nproc_per_node=<num_gpus> scripts/train.py
210+
# for FSDP training
211+
# OR
212+
# python scripts/train.py
213+
# for single GPU training

src/speculators/train/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)