Skip to content

Commit c59ec6d

Browse files
docs: Expand training customization examples
Resolves #4379 - Add custom callbacks example for logging and monitoring - Add custom evaluation metrics example - Add mixed precision training example (bf16/fp16) - Add gradient accumulation example - Add custom data collator example - Update introduction for better clarity
1 parent 41c8ca1 commit c59ec6d

File tree

1 file changed

+150
-1
lines changed

1 file changed

+150
-1
lines changed

docs/source/customization.md

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Training customization
22

3-
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
3+
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques. Note: Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL.
44

55
## Use different optimizers and schedulers
66

@@ -117,3 +117,152 @@ When training large models, you should better handle the accelerator cache by it
117117
```python
118118
training_args = DPOConfig(..., optimize_device_cache=True)
119119
```
120+
121+
## Add custom callbacks
122+
123+
You can customize the training loop by adding callbacks for logging, monitoring, or early stopping. Callbacks allow you to execute custom code at specific points during training.
124+
125+
```python
126+
from datasets import load_dataset
127+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
128+
from trl import DPOConfig, DPOTrainer
129+
130+
131+
class CustomLoggingCallback(TrainerCallback):
132+
def on_log(self, args, state, control, logs=None, **kwargs):
133+
if logs is not None:
134+
print(f"Step {state.global_step}: {logs}")
135+
136+
137+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
138+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
139+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
140+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
141+
142+
trainer = DPOTrainer(
143+
model=model,
144+
args=training_args,
145+
train_dataset=dataset,
146+
tokenizer=tokenizer,
147+
callbacks=[CustomLoggingCallback()],
148+
)
149+
trainer.train()
150+
```
151+
152+
## Add custom evaluation metrics
153+
154+
You can define custom evaluation metrics to track during training. This is useful for monitoring model performance on specific tasks.
155+
156+
```python
157+
from datasets import load_dataset
158+
from transformers import AutoModelForCausalLM, AutoTokenizer
159+
from trl import DPOConfig, DPOTrainer
160+
161+
162+
def compute_metrics(eval_preds):
163+
# Custom metric computation
164+
logits, labels = eval_preds
165+
# Add your metric computation here
166+
return {"custom_metric": 0.0}
167+
168+
169+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
170+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
171+
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
172+
eval_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test[:10%]")
173+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO", eval_strategy="steps", eval_steps=100)
174+
175+
trainer = DPOTrainer(
176+
model=model,
177+
args=training_args,
178+
train_dataset=train_dataset,
179+
eval_dataset=eval_dataset,
180+
tokenizer=tokenizer,
181+
compute_metrics=compute_metrics,
182+
)
183+
trainer.train()
184+
```
185+
186+
## Use mixed precision training
187+
188+
Mixed precision training can significantly speed up training and reduce memory usage. You can enable it by setting `bf16=True` or `fp16=True` in the training config.
189+
190+
```python
191+
from datasets import load_dataset
192+
from transformers import AutoModelForCausalLM, AutoTokenizer
193+
from trl import DPOConfig, DPOTrainer
194+
195+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
196+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
197+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
198+
199+
# Use bfloat16 precision (recommended for modern GPUs)
200+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO", bf16=True)
201+
202+
trainer = DPOTrainer(
203+
model=model,
204+
args=training_args,
205+
train_dataset=dataset,
206+
tokenizer=tokenizer,
207+
)
208+
trainer.train()
209+
```
210+
211+
Note: Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs.
212+
213+
## Use gradient accumulation
214+
215+
When training with limited GPU memory, gradient accumulation allows you to simulate larger batch sizes by accumulating gradients over multiple steps before updating weights.
216+
217+
```python
218+
from datasets import load_dataset
219+
from transformers import AutoModelForCausalLM, AutoTokenizer
220+
from trl import DPOConfig, DPOTrainer
221+
222+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
223+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
224+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
225+
226+
# Simulate a batch size of 32 with per_device_train_batch_size=4 and gradient_accumulation_steps=8
227+
training_args = DPOConfig(
228+
output_dir="Qwen2.5-0.5B-DPO",
229+
per_device_train_batch_size=4,
230+
gradient_accumulation_steps=8,
231+
)
232+
233+
trainer = DPOTrainer(
234+
model=model,
235+
args=training_args,
236+
train_dataset=dataset,
237+
tokenizer=tokenizer,
238+
)
239+
trainer.train()
240+
```
241+
242+
## Use a custom data collator
243+
244+
You can provide a custom data collator to handle special data preprocessing or padding strategies.
245+
246+
```python
247+
from datasets import load_dataset
248+
from transformers import AutoModelForCausalLM, AutoTokenizer
249+
from trl import DPOConfig, DPOTrainer
250+
from trl.trainer.dpo_trainer import DataCollatorForPreference
251+
252+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
253+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
254+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
255+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
256+
257+
# Create a custom data collator with specific padding token
258+
data_collator = DataCollatorForPreference(pad_token_id=tokenizer.pad_token_id)
259+
260+
trainer = DPOTrainer(
261+
model=model,
262+
args=training_args,
263+
train_dataset=dataset,
264+
tokenizer=tokenizer,
265+
data_collator=data_collator,
266+
)
267+
trainer.train()
268+
```

0 commit comments

Comments
 (0)