Skip to content

Feature optimize#1077

Open
Umang-projects wants to merge 2 commits intolinkedin:mainfrom
Umang-projects:feature-optimize
Open

Feature optimize#1077
Umang-projects wants to merge 2 commits intolinkedin:mainfrom
Umang-projects:feature-optimize

Conversation

@Umang-projects
Copy link

🚀 Optimization Description

This PR introduces an inference_mode argument to the rms_norm_forward wrapper and kernels. When enabled (True), the kernel skips storing the Reciprocal Standard Deviation (RSTD) to global memory.

🧠 Motivation

RMSNorm is significantly memory-bound.

  • Training: RSTD is required for the backward pass, so it must be stored.
  • Inference: RSTD is not needed after normalization. Writing it to HBM consumes write bandwidth unnecessarily.

By skipping this write operation, we reduce memory traffic. This results in significant throughput gains for smaller and medium hidden dimensions (critical for Speculative Decoding and MoE Architectures) while saving VRAM bandwidth on larger dimensions.

🛠️ Implementation Details

  • Added STORE_RSTD constexpr to both standard and blocked Triton kernels.
  • Updated rms_norm_forward wrapper to accept inference_mode.
  • If inference_mode=True, RSTD allocation is skipped (empty tensor) and the store operation is compiled out via Triton's JIT.
  • Backward Compatibility: LigerRMSNormFunction (Autograd) default behavior is preserved (Training mode).

📊 Benchmarks

Hardware: NVIDIA RTX 4050 (Ada Lovelace) | Precision: BFloat16

Benchmark_rms

(Note: Gains are highest at lower dimensions where the RSTD write overhead is proportionally larger. Performance may vary slightly due to thermal throttling on consumer hardware.)

✅ Verification

  • Verified numerical correctness against PyTorch native RMSNorm.
  • Verified that RSTD is not written to memory using Nsight Compute (NCU) analysis.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant