Skip to content

Commit 751f3a0

Browse files
authored
Offline data generation for eagle3 training (#157)
This PR includes the entire pipeline of data generation, the PR kinda grew organically and become gigantic. To help mitigating the stress of reviewing, here's a guide to facilitate the review process. (Hope this helps!) 🔴 means most important and contains key logic 🟡 means less important, usually user-facing entry points 🟢 means least important, supporting infrastructure Here are the details on each part organized based on the dependencies: ## preprocessing 🔴`src/speculators/data_generation/preprocessing.py` This file contains the main logic of preprocessing, including raw dataset loading and formatting, loss mask calculation, sample tokenization, cache key generation and the main function `load_and_preprocess_dataset`. 🟡`scripts/preprocess_data.py` This is a standalone simple script that preprocess raw conversational data and saves them in a cache file specified by users. It's optional, users have the option to run this script if they want to separate preprocessing from hidden states generation. The default is running `data_generation_offline.py` directly, which handles preprocessing, hidden states generation, caching and savings. 🟢 `src/speculators/data_generation/configs.py` Configuration registries, defines chat templates for different model formats and dataset loading configurations. ## Hidden States Generation 🔴 `src/speculators/data_generation/vllm_hidden_states_generator.py` This file contains the core logic for extracting hidden states from intermediate layers during prefill using vLLM. Key features include: - Auto-selection of layers following EAGLE3 pattern: [layer_2, layer_mid, layer_-3, layer_-1] (first 3 for fusion, last for target logits) - Multi-GPU tensor parallelism support via MultiprocExecutor - Batch processing with variable-length sequences - Returns dict with input_ids, hidden_state (concatenated layers), and loss_mask 🔴 `src/speculators/data_generation/custom_worker.py` Custom vLLM worker extension that hooks into the model's forward pass to capture hidden states from specified layers during prefill. This integrates with vLLM's v1 API architecture. 🟡 `scripts/data_generation_offline.py` Main end-to-end pipeline script that users should use. This script: - Automatically handles preprocessing (runs it if cache not found, loads from cache otherwise) - Uses vLLM to extract hidden states in batches - Saves each sample as a separate .pt file (to support variable-length sequences) - Generates data_config.json with metadata for reproducibility - Supports auto-resume by detecting existing files - Optimized for throughput with configurable batch_size and max_num_batched_tokens ## Vocabulary Mapping 🟡 `src/speculators/data_generation/vocab_mapping.py` Contains logic for building vocabulary mappings (d2t and t2d) between draft and target models when vocabularies differ. Maps draft tokens to target tokens based on frequency distribution. 🟡 `scripts/build_vocab_mapping.py` Standalone script to generate d2t.npy and t2d.npy files from token frequency distributions collected during preprocessing. ## Infrastructure & Utilities 🟢 `src/speculators/data_generation/logging_utils.pyCustom` logger with colored sections, subsections, and config display for better pipeline visibility. 🟢 `src/speculators/data_generation/__init__.py` Package exports. ## Tests 🟡 `tests/integration/data_generation/test_vllm_hidden_states.py` Focused tests for the vLLM hidden states generator component. ## Key Design Decisions 1. Single-file storage: Each training sample is saved as a separate .pt file to support variable-length sequences efficiently. Training code now handles batching nicely, but we're planning on moving batching to the data generation side to avoid the batch -> unbatch -> rebatch overhead. 2. 4-layer extraction, auto-selects 4 layers from the target model if layer ids are not specified 3. Automatic preprocessing: data_generation_offline.py handles preprocessing transparently via cache, eliminating manual steps and preventing parameter mismatches. 4. Metadata tracking: data_config.json saves all generation parameters for reproducibility. ## Usage Example ``` python scripts/data_generation_offline.py \ --target-model-path meta-llama/Llama-3.1-8B \ --train-data-path sharegpt \ --chat-template llama3 \ --output-dir ./training_data \ --max-samples 5000 \ --batch-size 16 ``` ``` # Generate vocabulary mappings (if needed) python scripts/build_vocab_mapping.py \ --token-freq-path ./cache/token_frequencies/xxx_token_freq.pt \ --draft-vocab-size 32000 \ --target-vocab-size 128256 \ --output-path ./training_data/ ``` More details can be found in `src/speculators/data_generation/README.md`. Please feel free to reach out with any questions. --------- Signed-off-by: shanjiaz <[email protected]> Signed-off-by: shanjiaz <[email protected]>
1 parent 6c3678f commit 751f3a0

17 files changed

+2455
-1
lines changed

pyproject.toml

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ dev = [
107107
"mkdocs-linkcheck~=1.0.6",
108108
]
109109

110+
datagen = ["vllm~=0.11.0"]
111+
110112
[project.entry-points.console_scripts]
111113
speculators = "speculators.__main__:app"
112114

@@ -136,7 +138,7 @@ exclude = ["venv", ".tox", "build", "dist"]
136138
follow_imports = 'silent'
137139

138140
[[tool.mypy.overrides]]
139-
module = ["datasets.*", "transformers.*", "setuptools.*", "setuptools_git_versioning.*"]
141+
module = ["datasets.*", "transformers.*", "setuptools.*", "setuptools_git_versioning.*", "vllm.*"]
140142
ignore_missing_imports=true
141143

142144
[tool.ruff]
@@ -235,6 +237,23 @@ select = [
235237
"BLE001", # allow catching Exception for conversion errors
236238
]
237239

240+
"src/speculators/data_generation/**/*.py" = [
241+
"S106", # false positives for chat template tokens
242+
"S324", # MD5 is used for cache keys, not security
243+
"FIX002", # TODOs are tracked
244+
"C901", # complexity is acceptable for data processing
245+
"BLE001", # catching Exception in __del__ is acceptable
246+
"S110", # try-except-pass in __del__ is acceptable
247+
"SIM105", # contextlib.suppress not needed for __del__
248+
"S101", # assert in worker is acceptable
249+
"SIM102", # nested if is clearer
250+
"ERA001", # comment is documentation, not code
251+
"PTH", # os.path is acceptable in data generation
252+
]
253+
"scripts/**/*.py" = [
254+
"PTH", # os.path is acceptable in scripts
255+
]
256+
238257
[tool.ruff.lint.isort]
239258
known-first-party = ["speculators", "tests"]
240259

scripts/DATAGEN.md

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# EAGLE Data Generation Pipeline
2+
3+
This module provides a complete pipeline for generating training data for EAGLE-style speculative decoding models.
4+
5+
## Overview
6+
7+
The pipeline consists of two main stages:
8+
9+
1. **Preprocessing**: Tokenize conversational data and create loss masks
10+
2. **Hidden State Extraction**: Use vLLM to extract intermediate layer hidden states from a target model
11+
12+
## Quick Start
13+
14+
### Basic Usage
15+
16+
Generate training data from ShareGPT using Llama 3.1 8B:
17+
18+
```bash
19+
python scripts/data_generation_offline.py \
20+
--target-model-path meta-llama/Llama-3.1-8B \
21+
--train-data-path sharegpt \
22+
--output-dir ./training_data \
23+
--max-samples 5000
24+
```
25+
26+
The script automatically uses the tokenizer's built-in chat template via `apply_chat_template`.
27+
28+
### Advanced Usage
29+
30+
With custom settings and multi-GPU:
31+
32+
```bash
33+
python scripts/data_generation_offline.py \
34+
--target-model-path meta-llama/Llama-3.1-70B \
35+
--train-data-path ./my_data.jsonl \
36+
--seq-length 4096 \
37+
--cache-dir ./cache \
38+
--output-dir ./training_data \
39+
--layer-ids 2 28 54 \
40+
--tensor-parallel-size 4 \
41+
--batch-size 16 \
42+
--max-samples 10000
43+
```
44+
45+
## Architecture
46+
47+
### Core Components
48+
49+
#### 1. VllmHiddenStatesGenerator
50+
51+
Extracts hidden states from intermediate layers during prefill using vLLM's efficient engine.
52+
53+
```python
54+
from speculators.data_generation import VllmHiddenStatesGenerator
55+
56+
generator = VllmHiddenStatesGenerator(
57+
model_path="meta-llama/Llama-3.1-8B",
58+
layer_ids=[2, 14, 24], # or None for auto-select
59+
tensor_parallel_size=1,
60+
)
61+
62+
token_ids = [[1, 234, 567, 890]] # batch of sequences
63+
results = generator.generate(token_ids=token_ids)
64+
```
65+
66+
**Features:**
67+
68+
- Auto-selects layers using EAGLE3 pattern if not specified
69+
- Supports multi-GPU tensor parallelism
70+
- Prefill-only mode (no decode overhead)
71+
- Validates layer indices at initialization
72+
73+
#### 2. Preprocessing Pipeline
74+
75+
Tokenizes conversations and creates loss masks to identify trainable tokens.
76+
77+
```python
78+
from speculators.data_generation.preprocessing import load_and_preprocess_dataset
79+
80+
dataset, tokenizer = load_and_preprocess_dataset(
81+
target_model_path="meta-llama/Llama-3.1-8B",
82+
train_data_path="sharegpt",
83+
seq_length=2048,
84+
max_samples=1000,
85+
token_freq_path="./token_freq.pt",
86+
cache_dir="/path/to/cache", # Optional
87+
)
88+
```
89+
90+
**Features:**
91+
92+
- Uses tokenizer's built-in chat template via `apply_chat_template`
93+
- Automatically creates loss masks for assistant responses
94+
- ShareGPT format datasets
95+
- HuggingFace datasets (sharegpt, ultrachat)
96+
- Local JSON/JSONL files
97+
98+
#### 3. Custom Worker Extension
99+
100+
vLLM worker extension that captures hidden states during model forward pass.
101+
102+
**Features:**
103+
104+
- Minimal overhead - only captures target layers
105+
- TP rank 0 only (prevents duplicate captures)
106+
- Automatic batching across sequences
107+
108+
### Configuration
109+
110+
#### Dataset Configs
111+
112+
Built-in datasets in `configs.py`:
113+
114+
- `sharegpt` - ShareGPT Vicuna unfiltered
115+
- `ultrachat` - HuggingFace UltraChat 200k
116+
117+
Add custom datasets by extending `DATASET_CONFIGS`.
118+
119+
## Output Format
120+
121+
Each training sample is saved as a `.pt` file containing:
122+
123+
```python
124+
{
125+
'input_ids': torch.Tensor, # [seq_len]
126+
'hidden_state': torch.Tensor, # [seq_len, hidden_dim * num_layers]
127+
'loss_mask': torch.Tensor, # [seq_len] - 1 for trainable tokens
128+
}
129+
```
130+
131+
## Performance Optimization
132+
133+
### Memory Usage
134+
135+
The pipeline has a TODO to optimize KV cache allocation for prefill-only workloads:
136+
137+
```python
138+
# TODO at vllm_hidden_states_generator.py:133
139+
# Currently allocating based on available memory, but we only need minimal cache
140+
# since we abort after prefill. Could reduce to: min_blocks = (max_num_batched_tokens // block_size) + 1
141+
```
142+
143+
### Batch Size Tuning
144+
145+
- **Small models (7-8B)**: `--batch-size 16-32`
146+
- **Medium models (13-30B)**: `--batch-size 8-16`
147+
- **Large models (70B+)**: `--batch-size 4-8`
148+
149+
Adjust based on GPU memory and sequence length.
150+
151+
### Caching
152+
153+
Preprocessing is automatically cached by HuggingFace datasets using fingerprint-based cache invalidation. The cache automatically updates when:
154+
155+
- Tokenizer changes
156+
- Preprocessing parameters change (seq_length, etc.)
157+
- Dataset changes
158+
159+
**Cache Location:**
160+
161+
- Default: `~/.cache/huggingface/datasets`
162+
- Custom: Set `HF_DATASETS_CACHE` environment variable
163+
164+
```bash
165+
# Example: Use custom cache directory
166+
export HF_DATASETS_CACHE=/path/to/your/cache
167+
python scripts/data_generation_offline.py ...
168+
```
169+
170+
Or set it per-command:
171+
172+
```bash
173+
HF_DATASETS_CACHE=./my_cache python scripts/data_generation_offline.py ...
174+
```
175+
176+
## Module Structure
177+
178+
```
179+
data_generation/
180+
├── __init__.py # Exports VllmHiddenStatesGenerator
181+
├── vllm_hidden_states_generator.py # Main hidden states extraction
182+
├── custom_worker.py # vLLM worker extension
183+
├── preprocessing.py # Dataset preprocessing
184+
├── configs.py # Chat templates & dataset configs
185+
├── vocab_mapping.py # Vocabulary mapping utilities
186+
└── logging_utils.py # Clean logging utilities
187+
```
188+
189+
## API Reference
190+
191+
### VllmHiddenStatesGenerator
192+
193+
```python
194+
class VllmHiddenStatesGenerator:
195+
def __init__(
196+
self,
197+
model_path: str,
198+
layer_ids: List[int] = None, # Auto-select if None
199+
max_model_len: int = 2048,
200+
gpu_memory_utilization: float = 0.8,
201+
tensor_parallel_size: int = 1,
202+
)
203+
204+
def generate(
205+
self,
206+
token_ids: Union[List[int], List[List[int]], torch.Tensor]
207+
) -> List[Dict]
208+
```
209+
210+
### Preprocessing Functions
211+
212+
```python
213+
def load_and_preprocess_dataset(
214+
target_model_path: str,
215+
train_data_path: str,
216+
seq_length: int,
217+
build_dataset_num_proc: int = 8,
218+
seed: int = 0,
219+
max_samples: Optional[int] = None,
220+
token_freq_path: str = "./token_freq.pt",
221+
cache_dir: Optional[str] = None,
222+
) -> Tuple[HFDataset, PreTrainedTokenizer]
223+
224+
def build_eagle3_dataset(
225+
dataset: HFDataset,
226+
tokenizer: PreTrainedTokenizer,
227+
max_length: int = 2048,
228+
num_proc: int = 8,
229+
) -> HFDataset
230+
```
231+
232+
**Note:** Both functions now use the tokenizer's built-in chat template via `apply_chat_template`.
233+
234+
## Troubleshooting
235+
236+
### Common Issues
237+
238+
**Issue**: Out of memory during hidden state extraction
239+
240+
- Reduce `--batch-size`
241+
- Reduce `--seq-length`
242+
- Increase `--tensor-parallel-size`
243+
244+
**Issue**: Layer index out of bounds
245+
246+
- Check model's actual number of layers
247+
- Auto-selection uses: `[2, num_layers // 2, num_layers - 3]`
248+
249+
**Issue**: No assistant response spans found
250+
251+
- Ensure tokenizer has a chat template (supports `apply_chat_template`)
252+
- Check that conversations have assistant responses in correct format (role/content keys)
253+
254+
**Issue**: Cache invalidation
255+
256+
- Delete cache directory if changing preprocessing parameters
257+
- Ensure `--seed` matches between runs for reproducibility
258+
259+
## Development
260+
261+
### Adding a New Dataset
262+
263+
Edit `configs.py`:
264+
265+
```python
266+
DATASET_CONFIGS["my_dataset"] = DatasetConfig(
267+
name="my_dataset",
268+
hf_path="username/dataset-name",
269+
split="train",
270+
normalize_fn=_my_normalize_fn, # Optional
271+
)
272+
```

0 commit comments

Comments
 (0)