Skip to content

GRPO unbalanced-memory #2805

@mdy666

Description

@mdy666

Reproduction

the memory in each rank(0-6) is not same, and i find when the train steps increase, the memory will increase much
step 0 use the origin code

Image

Then i write a efficient grpo loss kernel by triton。
step 0

Image

step 5

Image

step 20

Image

System Info

trl = 0.14.0
torch = 2.5.1+cuda12.4
vllm = 0.7.1

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🐛 bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions