10
10
dataset : ifeval # choices: [gsm8k, ifeval]
11
11
# Number of environments to run in parallel. This determines the batch size passed to vLLM.
12
12
# More envs consume more GPU memory.
13
- num_envs : 2
13
+ num_envs : 4
14
14
# Number of times to repeat the same prompt for GRPO. This does not affect the GPU memory usage.
15
15
repeats : 16
16
16
17
17
# Base model configuration
18
18
model :
19
- name : Qwen/Qwen2.5-3B
19
+ # A 7B model works well for this task.
20
+ name : Qwen/Qwen2.5-7b
20
21
compile : false
21
22
22
23
# Base training configuration - will be merged with mode-specific settings
23
24
train :
24
- # Fields defined in mode configs (async.yaml and sync.yaml)
25
- # mixed_precision: true # Whether to use mixed precision training
26
- # epochs: 1 # Number of training epochs
27
- # steps_per_batch: 32 # Number of steps per batch
28
- # total_dialog_turns: 1_000_000 # Total number of dialog turns to collect
29
- # optim_batch_size: 2 # Batch size for optimization
30
- # gradient_accumulation_steps: 1 # Number of gradient accumulation steps
31
- # kl_coef_in_loss: true # Whether to include KL coefficient in loss
32
- # sync: false # Default to async, will be overridden by mode configs
33
- # buffer_size: 128 # Size of replay buffer
25
+ # Some fields are defined in mode configs (async.yaml and sync.yaml)
26
+ # The following fields are task-specific:
34
27
exp_name : " grpo-ifeval"
35
28
29
+ # Whether to use mixed precision training.
30
+ mixed_precision : true
31
+
32
+ # Total number of dialog turns to collect during training.
33
+ total_dialog_turns : 100_000
34
+
35
+ # Number of steps in each batch. Higher values will cause the inference step to be slower, but won't use more GPU memory.
36
+ steps_per_batch : 64
37
+
38
+ # Number of gradient accumulation steps. Higher values will use less GPU memory (comparing with bigger batches and lower gradient_accumulation_steps),
39
+ # but will make the optimization step slower.
40
+ gradient_accumulation_steps : 4
41
+
36
42
# Fields used by both scripts but with different semantics
37
43
checkpoint_frequency : 100 # Save checkpoint every N steps/batches
38
44
45
+ # Batch size for optimization. Higher values will use more GPU memory.
46
+ optim_batch_size : 2
47
+
48
+ # Whether to include the KL coefficient in the loss function. Alternatively, the KL ref-to-train will be added to the reward.
49
+ kl_coef_in_loss : false
50
+
39
51
# KL coefficients for the KL divergence to the reference and inference policies
40
- kl_to_ref_coeff : 1e-2
41
- kl_to_inference_coeff : 0.0
52
+ kl_to_ref_coeff : 1e-1
53
+ kl_to_inference_coeff : 1e-1
42
54
entropy_coeff : 0.01
43
55
44
56
# Fields used only by grpo-async.py / grpo-sync.py
45
- logging_frequency : 10 # Log metrics every N steps
57
+ logging_frequency : 1 # Log metrics every N steps - here at each optimization step
46
58
47
59
# Training model configuration
48
60
train_model :
49
61
gradient_checkpointing : true # Enabled for memory efficiency
50
- num_devices : 1 # Number of devices to use
62
+ num_devices : 4 # Number of devices to use
51
63
lora :
52
64
enabled : true # Using LoRA for memory efficiency
53
65
r : 8 # LoRA rank - controls capacity of adaptations
@@ -60,7 +72,7 @@ train_model:
60
72
61
73
# Inference model configuration
62
74
inference_model :
63
- num_devices : 1 # Number of devices to use
75
+ num_devices : 2 # Number of devices to use
64
76
quantization :
65
77
enabled : false # Enable 4-bit quantization for base model
66
78
attn_implementation : sdpa # Using flash attention for memory efficiency
@@ -74,7 +86,7 @@ inference_model:
74
86
# Reference model configuration
75
87
ref_model :
76
88
gradient_checkpointing : false # Always false, no backprop
77
- num_devices : 1 # Number of devices to use
89
+ num_devices : 2 # Number of devices to use
78
90
lora :
79
91
enabled : true # Using LoRA for memory efficiency
80
92
r : 8 # LoRA rank - controls capacity of adaptations
@@ -89,7 +101,7 @@ ref_model:
89
101
optimizer :
90
102
name : AdamW
91
103
lr : 1e-5
92
- clip_grad_norm : 1 .0
104
+ clip_grad_norm : 10 .0
93
105
weight_decay : 0.0
94
106
95
107
# Ray configuration
0 commit comments