Commit 751f3a0
authored
Offline data generation for eagle3 training (#157)
This PR includes the entire pipeline of data generation, the PR kinda
grew organically and become gigantic. To help mitigating the stress of
reviewing, here's a guide to facilitate the review process. (Hope this
helps!)
🔴 means most important and contains key logic
🟡 means less important, usually user-facing entry points
🟢 means least important, supporting infrastructure
Here are the details on each part organized based on the dependencies:
## preprocessing
🔴`src/speculators/data_generation/preprocessing.py`
This file contains the main logic of preprocessing, including raw
dataset loading and formatting, loss mask calculation, sample
tokenization, cache key generation and the main function
`load_and_preprocess_dataset`.
🟡`scripts/preprocess_data.py`
This is a standalone simple script that preprocess raw conversational
data and saves them in a cache file specified by users. It's optional,
users have the option to run this script if they want to separate
preprocessing from hidden states generation. The default is running
`data_generation_offline.py` directly, which handles preprocessing,
hidden states generation, caching and savings.
🟢 `src/speculators/data_generation/configs.py`
Configuration registries, defines chat templates for different model
formats and dataset loading configurations.
## Hidden States Generation
🔴 `src/speculators/data_generation/vllm_hidden_states_generator.py`
This file contains the core logic for extracting hidden states from
intermediate layers during prefill using vLLM. Key features include:
- Auto-selection of layers following EAGLE3 pattern: [layer_2,
layer_mid, layer_-3, layer_-1] (first 3 for fusion, last for target
logits)
- Multi-GPU tensor parallelism support via MultiprocExecutor
- Batch processing with variable-length sequences
- Returns dict with input_ids, hidden_state (concatenated layers), and
loss_mask
🔴 `src/speculators/data_generation/custom_worker.py`
Custom vLLM worker extension that hooks into the model's forward pass to
capture hidden states from specified layers during prefill. This
integrates with vLLM's v1 API architecture.
🟡 `scripts/data_generation_offline.py`
Main end-to-end pipeline script that users should use. This script:
- Automatically handles preprocessing (runs it if cache not found, loads
from cache otherwise)
- Uses vLLM to extract hidden states in batches
- Saves each sample as a separate .pt file (to support variable-length
sequences)
- Generates data_config.json with metadata for reproducibility
- Supports auto-resume by detecting existing files
- Optimized for throughput with configurable batch_size and
max_num_batched_tokens
## Vocabulary Mapping
🟡 `src/speculators/data_generation/vocab_mapping.py`
Contains logic for building vocabulary mappings (d2t and t2d) between
draft and target models when vocabularies differ. Maps draft tokens to
target tokens based on frequency distribution.
🟡 `scripts/build_vocab_mapping.py`
Standalone script to generate d2t.npy and t2d.npy files from token
frequency distributions collected during preprocessing.
## Infrastructure & Utilities
🟢 `src/speculators/data_generation/logging_utils.pyCustom`
logger with colored sections, subsections, and config display for better
pipeline visibility.
🟢 `src/speculators/data_generation/__init__.py`
Package exports.
## Tests
🟡 `tests/integration/data_generation/test_vllm_hidden_states.py`
Focused tests for the vLLM hidden states generator component.
## Key Design Decisions
1. Single-file storage: Each training sample is saved as a separate .pt
file to support variable-length sequences efficiently. Training code now
handles batching nicely, but we're planning on moving batching to the
data generation side to avoid the batch -> unbatch -> rebatch overhead.
2. 4-layer extraction, auto-selects 4 layers from the target model if
layer ids are not specified
3. Automatic preprocessing: data_generation_offline.py handles
preprocessing transparently via cache, eliminating manual steps and
preventing parameter mismatches.
4. Metadata tracking: data_config.json saves all generation parameters
for reproducibility.
## Usage Example
```
python scripts/data_generation_offline.py \
--target-model-path meta-llama/Llama-3.1-8B \
--train-data-path sharegpt \
--chat-template llama3 \
--output-dir ./training_data \
--max-samples 5000 \
--batch-size 16
```
```
# Generate vocabulary mappings (if needed)
python scripts/build_vocab_mapping.py \
--token-freq-path ./cache/token_frequencies/xxx_token_freq.pt \
--draft-vocab-size 32000 \
--target-vocab-size 128256 \
--output-path ./training_data/
```
More details can be found in
`src/speculators/data_generation/README.md`. Please feel free to reach
out with any questions.
---------
Signed-off-by: shanjiaz <[email protected]>
Signed-off-by: shanjiaz <[email protected]>1 parent 6c3678f commit 751f3a0
File tree
17 files changed
+2455
-1
lines changed- scripts
- src/speculators/data_generation
- tests
- datagen
- integration
17 files changed
+2455
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
107 | 107 | | |
108 | 108 | | |
109 | 109 | | |
| 110 | + | |
| 111 | + | |
110 | 112 | | |
111 | 113 | | |
112 | 114 | | |
| |||
136 | 138 | | |
137 | 139 | | |
138 | 140 | | |
139 | | - | |
| 141 | + | |
140 | 142 | | |
141 | 143 | | |
142 | 144 | | |
| |||
235 | 237 | | |
236 | 238 | | |
237 | 239 | | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
238 | 257 | | |
239 | 258 | | |
240 | 259 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
0 commit comments