Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO Unstable grad_norm after >10k steps #2980

Closed
5 tasks done
nopepper opened this issue Feb 27, 2025 · 4 comments · Fixed by #2992
Closed
5 tasks done

GRPO Unstable grad_norm after >10k steps #2980

nopepper opened this issue Feb 27, 2025 · 4 comments · Fixed by #2992
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO

Comments

@nopepper
Copy link
Contributor

nopepper commented Feb 27, 2025

Not sure if this is a bug per se, it's possible that I'm missing something here. If this is a known fixable issue, it would be great to note it in the GRPOTrainer docs.

Reproduction

It seems like grad_norm eventually starts to become unstable and spike higher and higher after >10k steps in GRPO.

I've tried both single-GPU and multi-GPU training, various values for beta and batch sizes/learning rates. I have not tested if this problem happens without vLLM, though.

Weirdly enough, this does not affect the reward, which keeps going up, but only the grad_norm and clip_ratio metrics.

Training args (the relevant ones):

warmup_steps = 100
num_train_epochs = 1

per_device_train_batch_size = 16

bf16 = true
optim = "adamw_bnb_8bit"
learning_rate = 3e-6

gradient_checkpointing = false

# --- New fields required by GRPOConfig ---
seed = 42
gradient_accumulation_steps = 4
max_prompt_length = 256
max_completion_length = 128
num_generations = 16
temperature = 1.0
num_iterations = 4
beta = 0.0
use_vllm = true
vllm_gpu_memory_utilization = 0.8
vllm_max_model_len = 4096
vllm_dtype = "bfloat16"
vllm_device = "auto"
vllm_enable_prefix_caching = true
epsilon = 0.2

Image

Image

Image

System Info

  • Platform: Linux-5.10.0-33-cloud-amd64-x86_64-with-glibc2.31
  • Python version: 3.11.9
  • TRL version: 0.16.0.dev0
  • PyTorch version: 2.5.1+cu121
  • CUDA device(s): NVIDIA A100-SXM4-80GB
  • Transformers version: 4.49.0
  • Accelerate version: 1.4.0
  • Accelerate config:
    • compute_environment: LOCAL_MACHINE
    • distributed_type: NO
    • mixed_precision: bf16
    • use_cpu: False
    • debug: False
    • num_processes: 1
    • machine_rank: 0
    • num_machines: 1
    • gpu_ids: all
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • enable_cpu_affinity: True
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []
  • Datasets version: 3.3.2
  • HF Hub version: 0.29.1
  • bitsandbytes version: 0.45.3
  • DeepSpeed version: 0.16.4
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.64.0
  • PEFT version: 0.14.0
  • vLLM version: 0.7.2

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@github-actions github-actions bot added 🏋 GRPO Related to GRPO 🐛 bug Something isn't working labels Feb 27, 2025
@nopepper
Copy link
Contributor Author

I was wrong, I had beta=0.0 in all my experiments. Setting beta=0.001 was enough to prevent the gradient explosion. Perhaps we shouldn't suggest that option in the docs so prominently?

Image

@qgallouedec
Copy link
Member

I suspected that this could produce surprising results on a long run. #2806 (comment)

Would you recommend adding some sort of warning in the documentation?

@nopepper
Copy link
Contributor Author

nopepper commented Feb 28, 2025

Sounds good. Perhaps something like this?

KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training speed, but may be numerically unstable for long training runs.

@qgallouedec
Copy link
Member

Looks good! Are you willing to open a PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 GRPO Related to GRPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants