Add return_predicted_tokens support for cross-entropy kernels#1091
Open
yukiu00 wants to merge 1 commit intolinkedin:mainfrom
Open
Add return_predicted_tokens support for cross-entropy kernels#1091yukiu00 wants to merge 1 commit intolinkedin:mainfrom
yukiu00 wants to merge 1 commit intolinkedin:mainfrom
Conversation
…r cross-entropy kernels Enable returning per-token argmax predictions (as int64 tensor) without materializing full logits, reusing the existing argmax tracking shared with token_accuracy. Propagate the new flag through ops, transformers wrappers, functional API, loss_utils, all model forward functions, and output classes. Add tests for correctness, combined flags (return_predicted_tokens + return_token_accuracy), and backward pass.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add a
return_predicted_tokensflag toLigerCrossEntropyLossandLigerFusedLinearCrossEntropyLossthat returns per-token argmax predictions (asint64tensor) without materializing full logits.Motivation
During training, it is often useful to access the model's predicted tokens (argmax of logits) for logging, visualization, and metric computation — for example, inspecting what the model actually predicts at each position, or tracking prediction distributions over time.
Currently, obtaining predicted tokens requires either:
.argmax(dim=-1), which defeats the memory savings ofFusedLinearCrossEntropy, orSince the cross-entropy kernel already tracks
argmaxinternally (forreturn_token_accuracy, introduced in #910), we can return the predicted token indices as a byproduct at near-zero additional cost.Design
This builds on the
return_token_accuracyinfrastructure (#910). The existingargmax_idxtracking in the Triton kernel is reused, so:return_predicted_tokens=False(default), there is zero overhead — theRETURN_PREDICTED_TOKENSconstexpr is compiled out.return_token_accuracyandreturn_predicted_tokensare enabled, the argmax computation is shared (no duplicate work).ignore_index) return-1as a sentinel value.Changes
ops/cross_entropy.py,ops/fused_linear_cross_entropy.py: AddRETURN_PREDICTED_TOKENSconstexpr to the Triton kernel; storeargmax_idxfor non-ignored tokens,-1for ignored tokens.transformers/cross_entropy.py,transformers/fused_linear_cross_entropy.py,transformers/functional.py: Propagatereturn_predicted_tokensthrough module and functional APIs. ReturnCrossEntropyOutputwhen any extra output is requested.transformers/model/loss_utils.py: Threadreturn_predicted_tokensthroughLigerForCausalLMLoss→fixed_fused_linear_cross_entropy.transformers/model/output_classes.py: Addpredicted_tokensfield to allLiger*CausalLMOutputWithPastdataclasses.transformers/model/*.py(32 model files): Unpack and forwardpredicted_tokensin both tuple and dict return paths, following the same pattern astoken_accuracy.Usage
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergenceNew/updated tests:
test_correctness_with_predicted_tokens(cross-entropy): Verifies predicted tokens match reference argmax, ignored tokens are-1, backward works. Tests multiple dtypes, shapes, and ignore indices.test_correctness_with_predicted_tokens(fused linear cross-entropy): Same coverage with logit-value comparison (handles chunked bfloat16 matmul tie-breaking).test_liger_cross_entropy_structured_output: Extended to parametrizereturn_predicted_tokensacross all 8 combinations of(return_z_loss, return_token_accuracy, return_predicted_tokens). Includes consistency check betweenpredicted_tokensandtoken_accuracywhen both are enabled.