[Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss#551
[Verl] Add entropy loss to cross_entropy_loss and fused_linear_cross_entropy_loss#551hongpeng-guo wants to merge 34 commits intomainfrom
cross_entropy_loss and fused_linear_cross_entropy_loss#551Conversation
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
cross_entropy_loss and fused_linear_cross_entropy_losscross_entropy_loss and fused_linear_cross_entropy_loss
cross_entropy_loss and fused_linear_cross_entropy_losscross_entropy_loss and fused_linear_cross_entropy_loss
|
Please add a unit test with |
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
cross_entropy_loss and fused_linear_cross_entropy_losscross_entropy_loss and fused_linear_cross_entropy_loss
|
Update: Met some numerical unstable issue, inverstigating |
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
There was a problem hiding this comment.
Thanks for the efforts! Let's try to test thoroughly on both accuracy (numerical stability) and speed (including old ones) before checking in. Considering we may fused more and more losses such as the existing Z loss and the added entropy loss, api outputs kind of diverged and also make the loss quite heavy with multiple branches coupling together (like label smoothing, target weights, etc) We probably need to refactor a bit to make it cleaner to dev later. cc @ByronHsu @shivam15s @Tcc0403 @shimizust
| # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d | ||
| lse = m + tl.log(d) | ||
|
|
||
| # 3.5 Calculate the entropy loss |
There was a problem hiding this comment.
can probably put an equation in the PR description and also a simple one in a comment here to demonstrate how the entropy_loss is calculated (especially on the reuse of m and d computed in the first pass online softmax
| # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. | ||
| z_loss = z_loss / n_non_ignore | ||
| # TODO: Implement weighted entropy loss. Currently, entropy loss is not scaled by weight. | ||
| entropy_loss = entropy_loss / n_non_ignore |
There was a problem hiding this comment.
It seems you had done the implementation of weight provided case already above? dX_entropy_block = dX_entropy_block / sum_non_ignore_weight Did I misunderstand anything? If this is not the right equation for the weighted case, please use dX_entropy_block = dX_entropy_block / n_non_ignore above and also list a comment above as an TODO item.
There was a problem hiding this comment.
Thanks for catching this. I think this is a bug in my program. I just fixed it. But it seems the numerical problem is still there. Maybe we need to take a deeper look.
There was a problem hiding this comment.
BTW, it seems the CI stops running for this PR for some reason.
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
|
close this PR for now. |
Summary
In RLHF workflows, such as verl, the actor forward function usually generates both losses of
cross_entropy_loss (-log_probs)andentropy_loss, the later was used to encourage the policy to be not over-deterministic.There is a real needs for a kernel that will generates both the two losses, without materializing the huge
logitstensor. Liger-kernel'sfused_linear_cross_entropy_lossalready works well to generate thecross_entropy_loss, but only calculating the second part of the loss, i.e., the entropy loss.This PR adds the
entropyloss option to the existing FLCE loss, and work as one important step to support verl.cross_entropy.py::liger_cross_entropy_kernel, both the loss and its gradient subject to input are calculated and stored;fused_linear_cross_entropy.py,Testing Done
Made existing unit tests working; Adding new unittest WIP.
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence