Skip to content

[Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss#551

Closed
hongpeng-guo wants to merge 34 commits intomainfrom
hpguo/lce_add_entropy_loss
Closed

[Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss#551
hongpeng-guo wants to merge 34 commits intomainfrom
hpguo/lce_add_entropy_loss

Conversation

@hongpeng-guo
Copy link
Contributor

@hongpeng-guo hongpeng-guo commented Jan 30, 2025

Summary

In RLHF workflows, such as verl, the actor forward function usually generates both losses of cross_entropy_loss (-log_probs) and entropy_loss, the later was used to encourage the policy to be not over-deterministic.

There is a real needs for a kernel that will generates both the two losses, without materializing the huge logits tensor. Liger-kernel's fused_linear_cross_entropy_loss already works well to generate the cross_entropy_loss, but only calculating the second part of the loss, i.e., the entropy loss.

This PR adds the entropy loss option to the existing FLCE loss, and work as one important step to support verl.

  1. Adding the entropy calculation in the second pass of online softmax in cross_entropy.py::liger_cross_entropy_kernel, both the loss and its gradient subject to input are calculated and stored;
  2. Propagate the changes to relevant modules in fused_linear_cross_entropy.py,
  3. Propagate relavent changes to other functional modules in PyTorch interface.

Testing Done

Made existing unit tests working; Adding new unittest WIP.

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
@hongpeng-guo hongpeng-guo marked this pull request as draft January 30, 2025 04:38
@hongpeng-guo hongpeng-guo changed the title [Feature] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss [WIP][Feature][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss Jan 30, 2025
@hongpeng-guo hongpeng-guo changed the title [WIP][Feature][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss [WIP][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss Jan 30, 2025
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
@hongpeng-guo hongpeng-guo requested a review from ByronHsu January 30, 2025 09:56
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 30, 2025

Please add a unit test with return_entropy_loss. You can write a new pytorch implementation like CrossEntropyWithZLoss, or return_entropy_loss functionality on top of it.

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
@hongpeng-guo hongpeng-guo changed the base branch from main to hpguo/ruff_style_check February 3, 2025 04:41
@hongpeng-guo hongpeng-guo marked this pull request as ready for review February 3, 2025 05:17
@hongpeng-guo hongpeng-guo changed the title [WIP][Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss [Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss Feb 3, 2025
@hongpeng-guo
Copy link
Contributor Author

Update: Met some numerical unstable issue, inverstigating

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the efforts! Let's try to test thoroughly on both accuracy (numerical stability) and speed (including old ones) before checking in. Considering we may fused more and more losses such as the existing Z loss and the added entropy loss, api outputs kind of diverged and also make the loss quite heavy with multiple branches coupling together (like label smoothing, target weights, etc) We probably need to refactor a bit to make it cleaner to dev later. cc @ByronHsu @shivam15s @Tcc0403 @shimizust

# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
lse = m + tl.log(d)

# 3.5 Calculate the entropy loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can probably put an equation in the PR description and also a simple one in a comment here to demonstrate how the entropy_loss is calculated (especially on the reuse of m and d computed in the first pass online softmax

# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
z_loss = z_loss / n_non_ignore
# TODO: Implement weighted entropy loss. Currently, entropy loss is not scaled by weight.
entropy_loss = entropy_loss / n_non_ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you had done the implementation of weight provided case already above? dX_entropy_block = dX_entropy_block / sum_non_ignore_weight Did I misunderstand anything? If this is not the right equation for the weighted case, please use dX_entropy_block = dX_entropy_block / n_non_ignore above and also list a comment above as an TODO item.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I think this is a bug in my program. I just fixed it. But it seems the numerical problem is still there. Maybe we need to take a deeper look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, it seems the CI stops running for this PR for some reason.

hongpeng-guo and others added 3 commits February 3, 2025 09:48
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
@hongpeng-guo hongpeng-guo changed the base branch from hpguo/ruff_style_check to main February 3, 2025 09:54
@hongpeng-guo hongpeng-guo changed the base branch from main to hpguo/ruff_style_check February 3, 2025 09:55
@hongpeng-guo hongpeng-guo changed the base branch from hpguo/ruff_style_check to main February 5, 2025 22:15
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
@hongpeng-guo
Copy link
Contributor Author

close this PR for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants