Skip to content

Add Context Parallel (CP) support for RMSNorm#1076

Open
yukiu00 wants to merge 3 commits intolinkedin:mainfrom
yukiu00:feature/context-parallel-rms-norm
Open

Add Context Parallel (CP) support for RMSNorm#1076
yukiu00 wants to merge 3 commits intolinkedin:mainfrom
yukiu00:feature/context-parallel-rms-norm

Conversation

@yukiu00
Copy link
Contributor

@yukiu00 yukiu00 commented Feb 5, 2026

Summary

Enable LigerRMSNormFunction to efficiently handle Context Parallel (CP) DTensor inputs where the sequence dimension is sharded, in addition to the existing Tensor Parallel (TP) support.

Motivation

RMSNorm normalizes over the hidden dimension independently for each sequence position:

RMS(x) = sqrt(mean(x^2))  # computed over hidden dim only

This means:

  • Tensor Parallel (TP): Hidden dim sharded → requires full_tensor() to gather before computing RMS
  • Context Parallel (CP): Sequence dim sharded → can compute locally since each position is independent

The current implementation treats all DTensor inputs as TP, unnecessarily gathering CP inputs. This addresses the TODO at src/liger_kernel/ops/rms_norm.py:L578:

# TODO: support context parallel in addition to tensor parallel

Changes

src/liger_kernel/ops/rms_norm.py

Add _is_hidden_dim_sharded() helper

  • Checks if DTensor is sharded on the last (hidden) dimension
  • Returns True for TP, False for CP

Update forward()

  • TP path: X.full_tensor() (existing behavior)
  • CP path: X.to_local() + store metadata for backward reconstruction

Update backward()

  • CP path: dY.to_local() for local gradient computation
  • Wrap dX back to DTensor
  • All-reduce dW across all sharded mesh dimensions (supports multi-dimensional meshes like batch + sequence sharding)

Add PyTorch compatibility

  • try-except for Shard import (DTensor requires PyTorch >= 2.0)

test/transformers/test_rms_norm.py

  • Add test_dtensor_rms_norm_context_parallel with Shard(1) (sequence dim)
  • Verify forward output matches non-sharded computation
  • Verify gradients for both input and weight
  • Include non-power-of-2 "weird shapes" test case (2, 3, 6, 17)
  • Add proper destroy_process_group() cleanup

Test Plan

  • ruff check passes
  • ruff format --check passes
  • pytest test/transformers/test_rms_norm.py

Notes

  • Backward compatibility: TP behavior unchanged, only CP path added
  • The all-reduce for dW in CP is necessary because each device only sees a subset of sequence positions, but the weight is shared across all positions
  • Multi-dimensional mesh support: dW is all-reduced across all sharded mesh dimensions, not just the first one

This enables RMSNorm to work efficiently with DTensor inputs that are
sharded on the sequence dimension (Context Parallel), in addition to
the existing Tensor Parallel support.

Key changes:
- Add _is_hidden_dim_sharded() helper to detect TP vs CP sharding
- For CP inputs, compute locally without full_tensor() gathering
- All-reduce dW gradient in backward for CP to aggregate across devices
- Add try-except for Shard import for older PyTorch compatibility
- Add tests for Context Parallel DTensor inputs
Add non-power-of-2 test dimensions (batch=3, seq=6, hidden=17) to catch
edge cases, and add proper process group cleanup with destroy_process_group().
Remove the `break` statement so that dW is all-reduced across all sharded
mesh dimensions, not just the first one. This ensures correct weight
gradients when using multi-dimensional meshes (e.g., batch + sequence
sharding).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant