Commit 8a2d483
Eagle3 Training (#143)
This pr introduces Eagle3 Model training into the speculators repo. The
implementation is specific to Eagle3 but designed in a way that enables
future generalization to other speculative decoding algorithms.
# Components
<img width="1418" height="734" alt="Eagle3 Training Components"
src="https://github.com/user-attachments/assets/418a7d1f-0078-412a-ae56-a6427b756a05"
/>
## Example training script (~`scripts/train_llama3_8b_drafter.py`~
`scripts/train.py`)
Shows how to setup and run training. ~Currently specific to the
`meta-llama/Llama-3.1-8B-Instruct` model but doesn't require many
changes to run with a different model. Just need to update
`
VERIFIER_MODEL_NAME_OR_PATH = "meta-llama/Llama-3.1-8B-Instruct"
HIDDEN_SIZE = 4096 # Must match the verifier model's hidden size
VERIFIER_VOCAB_SIZE = 128256 # Must match the verifier model's vocab
size
`~
**Update:** I've generalize the training script. It now has a required
cli arg `--verifier_name_or_path` and supports arbitrary verifier
models. Note: this uses
`LlamaConfig.from_pretrained(args.verifier_name_or_path)` under the
hood, which does work for non-llama models (e.g. a Qwen model) but
prints a warning and may not work for every type of verifier.
You will also need to pass in a dataset and `t2d` / `d2t` tensors which
correspond to the verifier you are using.
## Flex Attention
Files:
- `src/speculators/train/eagle3/attention.py`
- `tests/unit/train/test_eagle3_attention.py`
The training code uses Flex attention which provides substantial speed
ups and memory efficiency over the full dense attention operations.
Functions:
- create_combined_mask_mod(lengths, total_seq_len): This function
creates the mask function used by flex attention.
- extend_mask_for_draft_tokens(block_mask): Helper function to extend
the block mask without needed to check each new squares mask value
- block_mask_to_dense_attention_mask: Only used for debugging purposes
- flex_attention_forward: lightweight wrapper around flex attention call
## Data processing
<img width="4531" height="2384" alt="Eagle3 Data Flow"
src="https://github.com/user-attachments/assets/b972ef8c-92d4-4d46-969f-66d33f801ceb"
/>
Files:
- `src/speculators/train/data.py`
Data is currently expected in the format of 1 file per data sample. We
load these samples and perform a shift to align `input_ids,
hidden_states, loss_mask, verifier_last_hidden_state` correctly. We also
automatically collate these samples into batches. Rather than padding
and wasting compute on padded tokens, we instead concatenate the
sequences along the sequence dimension, keeping track of the boundaries
between sequences and setting the attention mask accordingly.
## Batch sampling
Files:
- `src/speculators/train/distributed_batch_sampler.py`
- `src/speculators/train/data.py`
Due to hardware limitations, we set a maximum sequence length for each
batch. We would like each batch of data to be close in size this max
length, so that each batch has a similar number of tokens. The way we
achieve this is through the `MultipackDistributedBatchSamplerV2` taken
from prior work I did on
[instructlab/training](https://github.com/instructlab/training). This
class produces indices of files that when batched together come close to
reaching the max length without exceeding it. It also does this in a
distributed aware manner so that there is no overlap in the data each
rank sees.
To run the packing algorithm, we need to know the lengths of each sample
in the dataset. Unfortunately, this would require opening every file in
the dataset which is expensive, so instead we approximate the lengths
(`_compute_approx_lengths` in `data.py`) using the length of the first
sample and the relative file sizes of samples.
## `Eagle3DraftModel`
Files:
- `src/speculators/train/eagle3/core.py`
The draft model itself. Sets up and loads verifier components, as well
as the draft layers / weights. Contains the model `forward()` pass
which:
- sets up the block mask for the batch
- computes the target logits using the attached `verifier_lm_head`.
Note: this is computed here for data storage efficiency reasons, as
otherwise we would need to save the full logits: `[seq_len, vocab_size]`
instead of the last layer hidden states: `[seq_len, hidden_size]` to
disk. The verifier `vocab_size` is often > 100k whereas `hidden_size`
might be around 4-8k.
- For each ttt step:
- Embeds tokens
- concatenates with hidden_states
- applies decoder layers
- computes logits
- computes loss and step accuracy
- prepares next step tokens
- Updates block mask
## Layer definitions
Files:
- `src/speculators/train/eagle3/model_definitions.py`
Currently just contains model definitions for llama3 style draft models.
Supports `norm_before_residual=True or False`. Attempted to keep
modifications to the original llama models minimal.
## Distributed training via FSDP
Files:
- `src/speculators/train/utils.py`
- `src/speculators/train/checkpointer.py`
- `src/speculators/train/trainer.py` (`setup_model` fn)
Full support for FSDP training by initializing the training script with
`torchrun --nnodes --nproc_per_node=N` where `N` is the number of gpus.
Tested with `N=2,3,4, 8` and all work. FSDP training also enables
Automatic Mixed Precision (AMP) for improved performance.
`checkpointer.py` contains checkpointing logic for FSDP distributed
model weights (gather all weights on rank 0 before saving).
Note: the way distributed works in general is `N` copies of the script
are started and all run the same code but with some env variables
setting which lets each process know its rank. Then explicit
`dist.barrier()` calls or implicit calls within FSDP forward/backwards
hooks force each process to wait until they all reach the same point in
the code, before continuing. It is important that all ranks reach these
operations as it allows them to perform synchronized operations (such as
gathering, reducing, etc). However, we can also limit certain code to
only one rank (rank 0) so that we only log once, or save to checkpoint
once, using simple `if local_rank == 0` statements.
## Logging
Files:
- `src/speculators/train/logger.py`
- `scripts/train.py`: (setup logger calls at start of `main()`)
- `src/speculators/train/trainer.py` and other files: usage of
`metric_logger` and `root_logger`
Another implementation mostly copied from prior work I did on
[instructlab/training](https://github.com/instructlab/training). This
uses python's std library `logging` module and extends it to support
training metric logging. We can log a nested dict of metrics anywhere in
the codebase like so:
```python
# Setup once
import logging
metric_logger = logging.getLogger("speculators.metrics")
# Log call
metric_logger.info(
{"train": {"loss": loss.item(), **acc_values}, "epoch": epoch},
extra={"step": self.global_step},
)
```
And when the user runs the training script they can select one (or
multiple) of `tensorboard`, `wandb`, and `trackio` and the results will
be logged to the respective experiment tracker.
There is also a `root_logger` which can be used for regular update
logging and everything logged to either the `root_logger` or
`metric_logger` will be pretty-printed to console.
## `Trainer`
Files:
- `src/speculators/train/trainer.py`
The `Trainer` class is initialized with the model, data loaders, and a
config and:
- Sets up model / optimizer (loads weights and configures distributed if
needed)
- Contains the training and validation loops (`train_epoch` and
`val_epoch` respectively)
- And the overall training loop which alternatives between training,
validation, and saving checkpoints
Todos:
- [x] Eagle3Draft Model definition with TTT steps and loss calculations
- [x] Patched Decoder layer definitions
- [x] Simple data loading from sample files
- [x] FlexAttention masking and implementation
- [x] Loss Masking
- [x] Training loop
- [x] Train data loader
- [x] `loss.backward()` + optimizer steps
- [x] Distributed loss reduction
- [x] Val data loader
- [x] Metric collection/reporting
- [x] Model checkpointing
- [x] Data batching
- [x] Collate fn
- [x] Batch sampler (dynamic batch size through sample packing)
- [x] Distributed (rank) aware sampling
- [x] Distributed support
- [ ] ~Code relocation / merging with existing definitions (Currently
just have everything under `speculators/train` but this will need to
change)~ FUTURE PR
- [x] Verify correctness of key components (attention masking, data
token alignment, etc).
- [x] General testing
Essential todos (as of 10/22/2025):
- [x] Save checkpoints to safetensors format w/ required config info
- [ ] ~Implement save best or save last logic (currently saving every
epoch)~ FUTURE PR
- [x] Better Verifier `lm_head`, `embed_tokens` loading (requires #144)
- [x] `Eagle3DraftModel.__init__` signature cleanup/better configuration
- [ ] ~Config/argparsing for `scripts/train.py`~ FUTURE PR
- [x] Ensure flex attention impl works with `torch==2.9` and
`torch.compile`
- [x] Fix lint / quality / type errors and pass CI
---------
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
Co-authored-by: Brian Dellabetta <[email protected]>1 parent 751f3a0 commit 8a2d483
File tree
17 files changed
+2788
-3
lines changed- scripts
- src/speculators/train
- eagle3
- tests/unit/train
17 files changed
+2788
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
| 51 | + | |
51 | 52 | | |
52 | 53 | | |
| 54 | + | |
53 | 55 | | |
54 | 56 | | |
55 | 57 | | |
| |||
102 | 104 | | |
103 | 105 | | |
104 | 106 | | |
| 107 | + | |
105 | 108 | | |
106 | 109 | | |
107 | 110 | | |
| |||
211 | 214 | | |
212 | 215 | | |
213 | 216 | | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | 217 | | |
218 | 218 | | |
219 | 219 | | |
| |||
251 | 251 | | |
252 | 252 | | |
253 | 253 | | |
| 254 | + | |
254 | 255 | | |
255 | 256 | | |
256 | 257 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
Whitespace-only changes.
0 commit comments