Skip to content

Commit c7e34fa

Browse files
committed
[Algorithm] Expert Iteration and SFT
1 parent e36d562 commit c7e34fa

File tree

20 files changed

+2523
-27
lines changed

20 files changed

+2523
-27
lines changed

docs/source/reference/llms.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ LLM post training require some appropriate versions of the losses implemented in
256256
GRPO
257257
~~~~
258258

259+
The :class:`~torchrl.objectives.llm.GRPOLoss` class is a thin wrapper around the :class:`~torchrl.objectives.PPOLoss` class
260+
that codes the LLM-specific functionnalities.
261+
259262
.. currentmodule:: torchrl.objectives.llm
260263

261264
.. autosummary::
@@ -265,3 +268,24 @@ GRPO
265268
GRPOLoss
266269
GRPOLossOutput
267270
MCAdvantage
271+
272+
273+
SFT
274+
~~~
275+
276+
.. currentmodule:: torchrl.objectives.llm
277+
278+
.. autosummary::
279+
:toctree: generated/
280+
:template: rl_template.rst
281+
282+
SFTLoss
283+
SFTLossOutput
284+
285+
.. currentmodule:: torchrl.data.llm
286+
287+
.. autosummary::
288+
:toctree: generated/
289+
:template: rl_template.rst
290+
291+
TopKRewardSelector
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# @package _global_
2+
defaults:
3+
- mode: ${mode:async} # Default to async mode, can be overridden by scripts
4+
- _self_
5+
- override hydra/hydra_logging: disabled
6+
- override hydra/job_logging: disabled
7+
8+
# Environment configuration
9+
env:
10+
dataset: gsm8k # choices: [gsm8k, ifeval]
11+
# Number of environments to run in parallel. This determines the batch size passed to vLLM.
12+
# More envs consume more GPU memory.
13+
num_envs: 8 # Reduced from 8 to save memory
14+
# Number of times to repeat the same prompt for GRPO. This does not affect the GPU memory usage.
15+
repeats: 32
16+
17+
# Base model configuration
18+
model:
19+
# A 3B model is sufficient for this task:
20+
name: Qwen/Qwen2.5-3B
21+
compile: false
22+
23+
# Base training configuration - will be merged with mode-specific settings
24+
train:
25+
# Some fields are defined in mode configs (async.yaml and sync.yaml)
26+
# The following fields are task-specific:
27+
exp_name: "grpo-gsm8k"
28+
29+
# Whether to use mixed precision training.
30+
mixed_precision: true
31+
32+
# Number of top-k rewards to select for training.
33+
topk_size: 8
34+
35+
# Total number of dialog turns to collect during training.
36+
total_dialog_turns: 100_000
37+
38+
# Number of steps in each batch. Higher values will cause the inference step to be slower, but won't use more GPU memory.
39+
steps_per_batch: 256
40+
41+
# Replay buffer size. For a given prompt, we will query the LLM a total of `env.repeats` times.
42+
# Then, the top-k rewards will be selected from these `env.repeats` rewards.
43+
# A single batch collected has size `train.steps_per_batch`, and the fraction written to the replay buffer is `train.topk_size / env.repeats`.
44+
# If `buffer_size` is not set, it will default to `train.steps_per_batch * train.topk_size / env.repeats`.
45+
buffer_size:
46+
47+
# Number of gradient accumulation steps. Higher values will use less GPU memory (comparing with bigger batches and lower gradient_accumulation_steps),
48+
# but will make the optimization step slower.
49+
gradient_accumulation_steps: 16
50+
51+
# Fields used by both scripts but with different semantics
52+
checkpoint_frequency: 100 # Save checkpoint every N steps/batches
53+
54+
# Batch size for optimization. Higher values will use more GPU memory.
55+
optim_batch_size: 2
56+
57+
# KL coefficients for the KL divergence to the reference and inference policies
58+
kl_to_ref_coeff: 1e-1
59+
60+
# Fields used only by grpo-async.py / grpo-sync.py
61+
logging_frequency: 10 # Log metrics every N steps
62+
63+
# Training model configuration
64+
train_model:
65+
gradient_checkpointing: true # Enabled for memory efficiency
66+
num_devices: 1 # Number of devices to use
67+
lora:
68+
enabled: true # Using LoRA for memory efficiency
69+
r: 8 # LoRA rank - controls capacity of adaptations
70+
alpha: 16 # LoRA alpha - scales the adaptations
71+
dropout: 0.1 # Dropout probability for LoRA layers
72+
quantization:
73+
enabled: false # Enable 4-bit quantization for base model
74+
attn_implementation: sdpa # Using flash attention for memory efficiency
75+
torch_dtype: bfloat16
76+
77+
# Inference model configuration
78+
inference_model:
79+
num_devices: 1 # Number of devices to use
80+
quantization:
81+
enabled: false # Enable 4-bit quantization for base model
82+
attn_implementation: sdpa # Using flash attention for memory efficiency
83+
torch_dtype: bfloat16
84+
gpu_memory_utilization: 0.5 # Limit GPU memory usage
85+
temperature: 0.8
86+
max_tokens: 1024
87+
include_stop_str_in_output: true
88+
enforce_eager: false
89+
90+
# Reference model configuration
91+
ref_model:
92+
gradient_checkpointing: false # Always false, no backprop
93+
num_devices: 1 # Number of devices to use
94+
lora:
95+
enabled: true # Using LoRA for memory efficiency
96+
r: 8 # LoRA rank - controls capacity of adaptations
97+
alpha: 16 # LoRA alpha - scales the adaptations
98+
dropout: 0.1 # Dropout probability for LoRA layers
99+
quantization:
100+
enabled: false # Enable 4-bit quantization for base model
101+
attn_implementation: sdpa # Using flash attention for memory efficiency
102+
torch_dtype: bfloat16
103+
104+
# Optimizer configuration
105+
optimizer:
106+
name: AdamW
107+
lr: 2e-5
108+
clip_grad_norm: 100.0
109+
weight_decay: 0.0
110+
111+
# Ray configuration
112+
ray:
113+
init_config:
114+
num_cpus: 96 # Total available CPUs
115+
num_gpus: 8 # Explicitly set number of GPUs
116+
runtime_env:
117+
working_dir: "."
118+
_temp_dir: "/tmp/ray_grpo" # Custom temp directory
119+
_system_config:
120+
object_spilling_threshold: 0.8 # Spill when 80% full
121+
max_direct_memory_size: 10 * 1024 * 1024 * 1024 # 10GB limit
122+
object_store_full_delay_ms: 100 # Delay when store is full
123+
object_store_full_max_retries: 3 # Max retries when store is full
124+
collector_config:
125+
num_cpus: 24 # CPUs for inference and ref model
126+
train_handler_config:
127+
num_cpus: 24 # Dedicated CPUs for training
128+
replay_buffer_config:
129+
num_cpus: 24 # CPUs for replay buffer
130+
num_gpus: 0.0 # No GPU needed for replay buffer
131+
132+
# Logging configuration
133+
logging:
134+
experiment_name: null # Will be auto-generated if not provided
135+
checkpoint_dir: "checkpoints"
136+
checkpoint_frequency: 10 # Save checkpoint every N batches
137+
138+
hydra:
139+
run:
140+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
141+
sweep:
142+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
143+
subdir: ${hydra.job.num}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# @package _global_
2+
defaults:
3+
- mode: ${mode:async} # Default to async mode, can be overridden by scripts
4+
- _self_
5+
- override hydra/hydra_logging: disabled
6+
- override hydra/job_logging: disabled
7+
8+
# Environment configuration
9+
env:
10+
dataset: ifeval # choices: [gsm8k, ifeval]
11+
# Number of environments to run in parallel. This determines the batch size passed to vLLM.
12+
# More envs consume more GPU memory.
13+
num_envs: 4
14+
# Number of times to repeat the same prompt for GRPO. This does not affect the GPU memory usage.
15+
repeats: 32
16+
17+
# Base model configuration
18+
model:
19+
# A 7B model works well for this task.
20+
name: Qwen/Qwen2.5-7b
21+
compile: false
22+
23+
# Base training configuration - will be merged with mode-specific settings
24+
train:
25+
# Some fields are defined in mode configs (async.yaml and sync.yaml)
26+
# The following fields are task-specific:
27+
exp_name: "grpo-ifeval"
28+
29+
# Whether to use mixed precision training.
30+
mixed_precision: true
31+
32+
# Number of top-k rewards to select for training.
33+
topk_size: 8
34+
35+
# Total number of dialog turns to collect during training.
36+
total_dialog_turns: 100_000
37+
38+
# Number of steps in each batch. Higher values will cause the inference step to be slower, but won't use more GPU memory.
39+
steps_per_batch: 64
40+
41+
# Number of gradient accumulation steps. Higher values will use less GPU memory (comparing with bigger batches and lower gradient_accumulation_steps),
42+
# but will make the optimization step slower.
43+
gradient_accumulation_steps: 4
44+
45+
# Fields used by both scripts but with different semantics
46+
checkpoint_frequency: 100 # Save checkpoint every N steps/batches
47+
48+
# Batch size for optimization. Higher values will use more GPU memory.
49+
optim_batch_size: 2
50+
51+
# KL coefficients for the KL divergence to the reference and inference policies
52+
kl_to_ref_coeff:
53+
54+
# Fields used only by grpo-async.py / grpo-sync.py
55+
logging_frequency: 1 # Log metrics every N steps - here at each optimization step
56+
57+
# Training model configuration
58+
train_model:
59+
gradient_checkpointing: true # Enabled for memory efficiency
60+
num_devices: 4 # Number of devices to use
61+
lora:
62+
enabled: true # Using LoRA for memory efficiency
63+
r: 8 # LoRA rank - controls capacity of adaptations
64+
alpha: 16 # LoRA alpha - scales the adaptations
65+
dropout: 0.1 # Dropout probability for LoRA layers
66+
quantization:
67+
enabled: false # Enable 4-bit quantization for base model
68+
attn_implementation: sdpa # Using flash attention for memory efficiency
69+
torch_dtype: bfloat16
70+
71+
# Inference model configuration
72+
inference_model:
73+
num_devices: 2 # Number of devices to use
74+
quantization:
75+
enabled: false # Enable 4-bit quantization for base model
76+
attn_implementation: sdpa # Using flash attention for memory efficiency
77+
torch_dtype: bfloat16
78+
gpu_memory_utilization: 0.5 # Limit GPU memory usage
79+
temperature: 0.8
80+
max_tokens: 2048
81+
include_stop_str_in_output: true
82+
enforce_eager: false
83+
84+
# Reference model configuration
85+
ref_model:
86+
gradient_checkpointing: false # Always false, no backprop
87+
num_devices: 2 # Number of devices to use
88+
lora:
89+
enabled: true # Using LoRA for memory efficiency
90+
r: 8 # LoRA rank - controls capacity of adaptations
91+
alpha: 16 # LoRA alpha - scales the adaptations
92+
dropout: 0.1 # Dropout probability for LoRA layers
93+
quantization:
94+
enabled: false # Enable 4-bit quantization for base model
95+
attn_implementation: sdpa # Using flash attention for memory efficiency
96+
torch_dtype: bfloat16
97+
98+
# Optimizer configuration
99+
optimizer:
100+
name: AdamW
101+
lr: 1e-5
102+
clip_grad_norm: 10.0
103+
weight_decay: 0.0
104+
105+
# Ray configuration
106+
ray:
107+
init_config:
108+
num_cpus: 96 # Total available CPUs
109+
num_gpus: 8 # Explicitly set number of GPUs
110+
runtime_env:
111+
working_dir: "."
112+
_temp_dir: "/tmp/ray_grpo" # Custom temp directory
113+
_system_config:
114+
object_spilling_threshold: 0.8 # Spill when 80% full
115+
max_direct_memory_size: 10 * 1024 * 1024 * 1024 # 10GB limit
116+
object_store_full_delay_ms: 100 # Delay when store is full
117+
object_store_full_max_retries: 3 # Max retries when store is full
118+
collector_config:
119+
num_cpus: 24 # CPUs for inference and ref model (co-located)
120+
train_handler_config:
121+
num_cpus: 24 # Dedicated CPUs for training
122+
replay_buffer_config:
123+
num_cpus: 24 # CPUs for replay buffer
124+
num_gpus: 0.0 # No GPU needed for replay buffer
125+
126+
# Logging configuration
127+
logging:
128+
experiment_name: null # Will be auto-generated if not provided
129+
checkpoint_dir: "checkpoints"
130+
checkpoint_frequency: 10 # Save checkpoint every N batches
131+
132+
hydra:
133+
run:
134+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
135+
sweep:
136+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
137+
subdir: ${hydra.job.num}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# @package _global_
2+
train:
3+
# Mode-specific setting
4+
sync: false # Force asynchronous mode
5+
6+
# Number of epochs to train for, every time a batch is collected. Per se, not directly used in async - aside from computing the total number of steps.
7+
epochs: 1
8+
# Replay buffer size.
9+
buffer_size: 128
10+
# Update policy weights every N steps - can be set to any positive integer in async mode
11+
weight_update_frequency: 10
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# @package _global_
2+
train:
3+
# Mode-specific setting
4+
sync: true # Force synchronous mode
5+
6+
# Number of epochs to train for, every time a batch is collected.
7+
epochs: 1
8+
# Update policy weights every N steps - must be left empty in sync mode
9+
weight_update_frequency:

0 commit comments

Comments
 (0)