|
| 1 | +# @package _global_ |
| 2 | +defaults: |
| 3 | + - mode: ${mode:async} # Default to async mode, can be overridden by scripts |
| 4 | + - _self_ |
| 5 | + - override hydra/hydra_logging: disabled |
| 6 | + - override hydra/job_logging: disabled |
| 7 | + |
| 8 | +# Environment configuration |
| 9 | +env: |
| 10 | + dataset: gsm8k # choices: [gsm8k, ifeval] |
| 11 | + # Number of environments to run in parallel. This determines the batch size passed to vLLM. |
| 12 | + # More envs consume more GPU memory. |
| 13 | + num_envs: 8 # Reduced from 8 to save memory |
| 14 | + # Number of times to repeat the same prompt for GRPO. This does not affect the GPU memory usage. |
| 15 | + repeats: 32 |
| 16 | + |
| 17 | +# Base model configuration |
| 18 | +model: |
| 19 | + # A 3B model is sufficient for this task: |
| 20 | + name: Qwen/Qwen2.5-3B |
| 21 | + compile: false |
| 22 | + |
| 23 | +# Base training configuration - will be merged with mode-specific settings |
| 24 | +train: |
| 25 | + # Some fields are defined in mode configs (async.yaml and sync.yaml) |
| 26 | + # The following fields are task-specific: |
| 27 | + exp_name: "grpo-gsm8k" |
| 28 | + |
| 29 | + # Whether to use mixed precision training. |
| 30 | + mixed_precision: true |
| 31 | + |
| 32 | + # Number of top-k rewards to select for training. |
| 33 | + topk_size: 8 |
| 34 | + |
| 35 | + # Total number of dialog turns to collect during training. |
| 36 | + total_dialog_turns: 100_000 |
| 37 | + |
| 38 | + # Number of steps in each batch. Higher values will cause the inference step to be slower, but won't use more GPU memory. |
| 39 | + steps_per_batch: 256 |
| 40 | + |
| 41 | + # Replay buffer size. For a given prompt, we will query the LLM a total of `env.repeats` times. |
| 42 | + # Then, the top-k rewards will be selected from these `env.repeats` rewards. |
| 43 | + # A single batch collected has size `train.steps_per_batch`, and the fraction written to the replay buffer is `train.topk_size / env.repeats`. |
| 44 | + # If `buffer_size` is not set, it will default to `train.steps_per_batch * train.topk_size / env.repeats`. |
| 45 | + buffer_size: |
| 46 | + |
| 47 | + # Number of gradient accumulation steps. Higher values will use less GPU memory (comparing with bigger batches and lower gradient_accumulation_steps), |
| 48 | + # but will make the optimization step slower. |
| 49 | + gradient_accumulation_steps: 16 |
| 50 | + |
| 51 | + # Fields used by both scripts but with different semantics |
| 52 | + checkpoint_frequency: 100 # Save checkpoint every N steps/batches |
| 53 | + |
| 54 | + # Batch size for optimization. Higher values will use more GPU memory. |
| 55 | + optim_batch_size: 2 |
| 56 | + |
| 57 | + # KL coefficients for the KL divergence to the reference and inference policies |
| 58 | + kl_to_ref_coeff: 1e-1 |
| 59 | + |
| 60 | + # Fields used only by grpo-async.py / grpo-sync.py |
| 61 | + logging_frequency: 10 # Log metrics every N steps |
| 62 | + |
| 63 | +# Training model configuration |
| 64 | +train_model: |
| 65 | + gradient_checkpointing: true # Enabled for memory efficiency |
| 66 | + num_devices: 1 # Number of devices to use |
| 67 | + lora: |
| 68 | + enabled: true # Using LoRA for memory efficiency |
| 69 | + r: 8 # LoRA rank - controls capacity of adaptations |
| 70 | + alpha: 16 # LoRA alpha - scales the adaptations |
| 71 | + dropout: 0.1 # Dropout probability for LoRA layers |
| 72 | + quantization: |
| 73 | + enabled: false # Enable 4-bit quantization for base model |
| 74 | + attn_implementation: sdpa # Using flash attention for memory efficiency |
| 75 | + torch_dtype: bfloat16 |
| 76 | + |
| 77 | +# Inference model configuration |
| 78 | +inference_model: |
| 79 | + num_devices: 1 # Number of devices to use |
| 80 | + quantization: |
| 81 | + enabled: false # Enable 4-bit quantization for base model |
| 82 | + attn_implementation: sdpa # Using flash attention for memory efficiency |
| 83 | + torch_dtype: bfloat16 |
| 84 | + gpu_memory_utilization: 0.5 # Limit GPU memory usage |
| 85 | + temperature: 0.8 |
| 86 | + max_tokens: 1024 |
| 87 | + include_stop_str_in_output: true |
| 88 | + enforce_eager: false |
| 89 | + |
| 90 | +# Reference model configuration |
| 91 | +ref_model: |
| 92 | + gradient_checkpointing: false # Always false, no backprop |
| 93 | + num_devices: 1 # Number of devices to use |
| 94 | + lora: |
| 95 | + enabled: true # Using LoRA for memory efficiency |
| 96 | + r: 8 # LoRA rank - controls capacity of adaptations |
| 97 | + alpha: 16 # LoRA alpha - scales the adaptations |
| 98 | + dropout: 0.1 # Dropout probability for LoRA layers |
| 99 | + quantization: |
| 100 | + enabled: false # Enable 4-bit quantization for base model |
| 101 | + attn_implementation: sdpa # Using flash attention for memory efficiency |
| 102 | + torch_dtype: bfloat16 |
| 103 | + |
| 104 | +# Optimizer configuration |
| 105 | +optimizer: |
| 106 | + name: AdamW |
| 107 | + lr: 2e-5 |
| 108 | + clip_grad_norm: 100.0 |
| 109 | + weight_decay: 0.0 |
| 110 | + |
| 111 | +# Ray configuration |
| 112 | +ray: |
| 113 | + init_config: |
| 114 | + num_cpus: 96 # Total available CPUs |
| 115 | + num_gpus: 8 # Explicitly set number of GPUs |
| 116 | + runtime_env: |
| 117 | + working_dir: "." |
| 118 | + _temp_dir: "/tmp/ray_grpo" # Custom temp directory |
| 119 | + _system_config: |
| 120 | + object_spilling_threshold: 0.8 # Spill when 80% full |
| 121 | + max_direct_memory_size: 10 * 1024 * 1024 * 1024 # 10GB limit |
| 122 | + object_store_full_delay_ms: 100 # Delay when store is full |
| 123 | + object_store_full_max_retries: 3 # Max retries when store is full |
| 124 | + collector_config: |
| 125 | + num_cpus: 24 # CPUs for inference and ref model |
| 126 | + train_handler_config: |
| 127 | + num_cpus: 24 # Dedicated CPUs for training |
| 128 | + replay_buffer_config: |
| 129 | + num_cpus: 24 # CPUs for replay buffer |
| 130 | + num_gpus: 0.0 # No GPU needed for replay buffer |
| 131 | + |
| 132 | +# Logging configuration |
| 133 | +logging: |
| 134 | + experiment_name: null # Will be auto-generated if not provided |
| 135 | + checkpoint_dir: "checkpoints" |
| 136 | + checkpoint_frequency: 10 # Save checkpoint every N batches |
| 137 | + |
| 138 | +hydra: |
| 139 | + run: |
| 140 | + dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} |
| 141 | + sweep: |
| 142 | + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} |
| 143 | + subdir: ${hydra.job.num} |
0 commit comments