Add Context Parallel (CP) support for RMSNorm#1076
Open
yukiu00 wants to merge 3 commits intolinkedin:mainfrom
Open
Add Context Parallel (CP) support for RMSNorm#1076yukiu00 wants to merge 3 commits intolinkedin:mainfrom
yukiu00 wants to merge 3 commits intolinkedin:mainfrom
Conversation
This enables RMSNorm to work efficiently with DTensor inputs that are sharded on the sequence dimension (Context Parallel), in addition to the existing Tensor Parallel support. Key changes: - Add _is_hidden_dim_sharded() helper to detect TP vs CP sharding - For CP inputs, compute locally without full_tensor() gathering - All-reduce dW gradient in backward for CP to aggregate across devices - Add try-except for Shard import for older PyTorch compatibility - Add tests for Context Parallel DTensor inputs
Add non-power-of-2 test dimensions (batch=3, seq=6, hidden=17) to catch edge cases, and add proper process group cleanup with destroy_process_group().
Remove the `break` statement so that dW is all-reduced across all sharded mesh dimensions, not just the first one. This ensures correct weight gradients when using multi-dimensional meshes (e.g., batch + sequence sharding).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enable
LigerRMSNormFunctionto efficiently handle Context Parallel (CP) DTensor inputs where the sequence dimension is sharded, in addition to the existing Tensor Parallel (TP) support.Motivation
RMSNorm normalizes over the hidden dimension independently for each sequence position:
This means:
full_tensor()to gather before computing RMSThe current implementation treats all DTensor inputs as TP, unnecessarily gathering CP inputs. This addresses the TODO at
src/liger_kernel/ops/rms_norm.py:L578:# TODO: support context parallel in addition to tensor parallelChanges
src/liger_kernel/ops/rms_norm.pyAdd
_is_hidden_dim_sharded()helperTruefor TP,Falsefor CPUpdate
forward()X.full_tensor()(existing behavior)X.to_local()+ store metadata for backward reconstructionUpdate
backward()dY.to_local()for local gradient computationdXback to DTensordWacross all sharded mesh dimensions (supports multi-dimensional meshes like batch + sequence sharding)Add PyTorch compatibility
try-exceptforShardimport (DTensor requires PyTorch >= 2.0)test/transformers/test_rms_norm.pytest_dtensor_rms_norm_context_parallelwithShard(1)(sequence dim)(2, 3, 6, 17)destroy_process_group()cleanupTest Plan
ruff checkpassesruff format --checkpassespytest test/transformers/test_rms_norm.pyNotes
dWin CP is necessary because each device only sees a subset of sequence positions, but the weight is shared across all positionsdWis all-reduced across all sharded mesh dimensions, not just the first one