Skip to content

Add return_predicted_tokens support for cross-entropy kernels#1091

Open
yukiu00 wants to merge 1 commit intolinkedin:mainfrom
yukiu00:feat/return-predicted-tokens
Open

Add return_predicted_tokens support for cross-entropy kernels#1091
yukiu00 wants to merge 1 commit intolinkedin:mainfrom
yukiu00:feat/return-predicted-tokens

Conversation

@yukiu00
Copy link
Contributor

@yukiu00 yukiu00 commented Feb 10, 2026

Summary

Add a return_predicted_tokens flag to LigerCrossEntropyLoss and LigerFusedLinearCrossEntropyLoss that returns per-token argmax predictions (as int64 tensor) without materializing full logits.

Motivation

During training, it is often useful to access the model's predicted tokens (argmax of logits) for logging, visualization, and metric computation — for example, inspecting what the model actually predicts at each position, or tracking prediction distributions over time.

Currently, obtaining predicted tokens requires either:

  1. Materializing full logits and calling .argmax(dim=-1), which defeats the memory savings of FusedLinearCrossEntropy, or
  2. Recomputing the forward pass separately for metrics.

Since the cross-entropy kernel already tracks argmax internally (for return_token_accuracy, introduced in #910), we can return the predicted token indices as a byproduct at near-zero additional cost.

Design

This builds on the return_token_accuracy infrastructure (#910). The existing argmax_idx tracking in the Triton kernel is reused, so:

  • When return_predicted_tokens=False (default), there is zero overhead — the RETURN_PREDICTED_TOKENS constexpr is compiled out.
  • When both return_token_accuracy and return_predicted_tokens are enabled, the argmax computation is shared (no duplicate work).
  • Ignored tokens (ignore_index) return -1 as a sentinel value.

Changes

  • ops/cross_entropy.py, ops/fused_linear_cross_entropy.py: Add RETURN_PREDICTED_TOKENS constexpr to the Triton kernel; store argmax_idx for non-ignored tokens, -1 for ignored tokens.
  • transformers/cross_entropy.py, transformers/fused_linear_cross_entropy.py, transformers/functional.py: Propagate return_predicted_tokens through module and functional APIs. Return CrossEntropyOutput when any extra output is requested.
  • transformers/model/loss_utils.py: Thread return_predicted_tokens through LigerForCausalLMLossfixed_fused_linear_cross_entropy.
  • transformers/model/output_classes.py: Add predicted_tokens field to all Liger*CausalLMOutputWithPast dataclasses.
  • transformers/model/*.py (32 model files): Unpack and forward predicted_tokens in both tuple and dict return paths, following the same pattern as token_accuracy.

Usage

# Standalone
loss_fn = LigerCrossEntropyLoss(return_predicted_tokens=True)
result = loss_fn(logits, target)  # logits: (B*T, V), target: (B*T,)
result.loss              # scalar loss
result.predicted_tokens  # (B*T,) int64 tensor, -1 for ignored tokens

# Fused (no logits materialization)
loss_fn = LigerFusedLinearCrossEntropyLoss(return_predicted_tokens=True)
result = loss_fn(lm_head_weight, hidden_states, target)  # hidden_states: (B*T, H)
result.predicted_tokens  # (B*T,) int64 tensor

# Can combine with token_accuracy
loss_fn = LigerCrossEntropyLoss(
    return_token_accuracy=True,
    return_predicted_tokens=True,
)
result = loss_fn(logits, target)
result.token_accuracy    # scalar
result.predicted_tokens  # (B*T,) int64 tensor

Note: predicted_tokens is returned as a flat (B*T,) tensor, matching the input shape convention of the cross-entropy API (which expects (B*T, V) logits and (B*T,) targets, consistent with torch.nn.CrossEntropyLoss). Reshape as needed:

result.predicted_tokens.view(B, T)

Testing Done

  • Hardware Type: NVIDIA GPU
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

New/updated tests:

  • test_correctness_with_predicted_tokens (cross-entropy): Verifies predicted tokens match reference argmax, ignored tokens are -1, backward works. Tests multiple dtypes, shapes, and ignore indices.
  • test_correctness_with_predicted_tokens (fused linear cross-entropy): Same coverage with logit-value comparison (handles chunked bfloat16 matmul tie-breaking).
  • test_liger_cross_entropy_structured_output: Extended to parametrize return_predicted_tokens across all 8 combinations of (return_z_loss, return_token_accuracy, return_predicted_tokens). Includes consistency check between predicted_tokens and token_accuracy when both are enabled.

…r cross-entropy kernels

Enable returning per-token argmax predictions (as int64 tensor) without materializing
full logits, reusing the existing argmax tracking shared with token_accuracy.
Propagate the new flag through ops, transformers wrappers, functional API, loss_utils,
all model forward functions, and output classes. Add tests for correctness, combined
flags (return_predicted_tokens + return_token_accuracy), and backward pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant