Skip to content

[NPU]: fused_add_rms_norm kernel distinguish the chunking strategy#1100

Open
TianHao324 wants to merge 1 commit intolinkedin:mainfrom
TianHao324:add_rms_ext
Open

[NPU]: fused_add_rms_norm kernel distinguish the chunking strategy#1100
TianHao324 wants to merge 1 commit intolinkedin:mainfrom
TianHao324:add_rms_ext

Conversation

@TianHao324
Copy link
Contributor

@TianHao324 TianHao324 commented Feb 12, 2026

Summary

Based on #1070
Because the original kernel uses n_cols as BLOCK_SIZE, and n_cols is smaller in the test, the test can pass normally. However, in the benchmark, n_cols is larger, and when running on the NPU, an ub overflow occurs. Therefore, for each row, we process it in chunks of BLOCK_SIZE. Maintain high performance even when using a smaller hidden size in most models, and also ensure support in cases where a larger hidden size is used.

Testing Done

image
  • Hardware Type: Atlas 800I A2
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@TianHao324
Copy link
Contributor Author

TianHao324 commented Feb 12, 2026

benchmark:

**************************************
     BENCHMARKING SPEED for FUSED_ADD_RMS_NORM
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "fused_add_rms_norm",
    "kernel_provider": "liger_fused_add_rms_norm",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "H",
    "x_label": "hidden size",
    "x_values": [
      512,
      1024,
      2048
    ],
    "y_values_50": [
      0.3016600012779236,
      0.30347001552581787,
      0.3464900255203247
    ],
    "y_values_20": [
      0.2922600209712982,
      0.2950280010700226,
      0.34233999252319336
    ],
    "y_values_80": [
      0.30614399909973145,
      0.3140000104904175,
      0.3537440001964569
    ],
    "timestamp": "2026-02-14 02:32:04",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"M\": 2048, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "fused_add_rms_norm",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "H",
    "x_label": "hidden size",
    "x_values": [
      512,
      1024,
      2048
    ],
    "y_values_50": [
      0.16526000201702118,
      0.1593399941921234,
      0.2619200050830841
    ],
    "y_values_20": [
      0.15649600327014923,
      0.15703999996185303,
      0.26043999195098877
    ],
    "y_values_80": [
      0.17308400571346283,
      0.17042000591754913,
      0.2635200023651123
    ],
    "timestamp": "2026-02-14 02:32:06",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"M\": 2048, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "fused_add_rms_norm",
    "kernel_provider": "liger_fused_add_rms_norm",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "H",
    "x_label": "hidden size",
    "x_values": [
      512,
      1024,
      2048
    ],
    "y_values_50": [
      1.0353000164031982,
      0.9589799642562866,
      1.0795600414276123
    ],
    "y_values_20": [
      1.0069040060043335,
      0.9506640434265137,
      0.9830079674720764
    ],
    "y_values_80": [
      1.0679240226745605,
      0.9769719839096069,
      1.1022840738296509
    ],
    "timestamp": "2026-02-14 02:32:08",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"M\": 2048, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "fused_add_rms_norm",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "H",
    "x_label": "hidden size",
    "x_values": [
      512,
      1024,
      2048
    ],
    "y_values_50": [
      0.735759973526001,
      0.7554799914360046,
      0.814520001411438
    ],
    "y_values_20": [
      0.7252560257911682,
      0.7480800151824951,
      0.8116359710693359
    ],
    "y_values_80": [
      0.7504920363426208,
      0.7660800218582153,
      0.8187599778175354
    ],
    "timestamp": "2026-02-14 02:32:09",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"M\": 2048, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "fused_add_rms_norm",
    "kernel_provider": "liger_fused_add_rms_norm",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "H",
    "x_label": "hidden size",
    "x_values": [
      512,
      1024,
      2048
    ],
    "y_values_50": [
      0.5300999879837036,
      0.5419399738311768,
      0.6003699898719788
    ],
    "y_values_20": [
      0.5045040249824524,
      0.519320011138916,
      0.5943999886512756
    ],
    "y_values_80": [
      0.5404000282287598,
      0.552836000919342,
      0.6061080098152161
    ],
    "timestamp": "2026-02-14 02:32:11",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"M\": 2048, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "fused_add_rms_norm",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "H",
    "x_label": "hidden size",
    "x_values": [
      512,
      1024,
      2048
    ],
    "y_values_50": [
      0.3912000060081482,
      0.40658000111579895,
      0.6460199952125549
    ],
    "y_values_20": [
      0.38254398107528687,
      0.4007920026779175,
      0.6433959603309631
    ],
    "y_values_80": [
      0.40358400344848633,
      0.4165160059928894,
      0.6494320034980774
    ],
    "timestamp": "2026-02-14 02:32:13",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"M\": 2048, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.0.0"
  }
]
**************************************
     BENCHMARKING MEMORY for FUSED_ADD_RMS_NORM
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "fused_add_rms_norm",
    "kernel_provider": "liger_fused_add_rms_norm",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "H",
    "x_label": "hidden size",
    "x_values": [
      512,
      1024,
      2048
    ],
    "y_values_50": [
      36.185546875,
      72.353515625,
      144.689453125
    ],
    "y_values_20": [
      36.185546875,
      72.353515625,
      144.689453125
    ],
    "y_values_80": [
      36.185546875,
      72.353515625,
      144.689453125
    ],
    "timestamp": "2026-02-14 02:32:13",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"M\": 2048, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.0.0"
  },
  {
    "kernel_name": "fused_add_rms_norm",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "H",
    "x_label": "hidden size",
    "x_values": [
      512,
      1024,
      2048
    ],
    "y_values_50": [
      52.0185546875,
      104.0283203125,
      208.0478515625
    ],
    "y_values_20": [
      52.0185546875,
      104.0283203125,
      208.0478515625
    ],
    "y_values_80": [
      52.0185546875,
      104.0283203125,
      208.0478515625
    ],
    "timestamp": "2026-02-14 02:32:13",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"M\": 2048, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.0.0"
  }
]

@TianHao324
Copy link
Contributor Author

@Tcc0403 would you mind having a preview?

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Feb 12, 2026

What's the key difference between no tiling and tiling kernels? Isn't it just a special case when num_chunks==1? i.e., looping over col dimension but only 1 iteration

@TianHao324
Copy link
Contributor Author

What's the key difference between no tiling and tiling kernels? Isn't it just a special case when num_chunks==1? i.e., looping over col dimension but only 1 iteration

No, the former means that all the data is accessed only once. The latter, however, seems to behave the same as the former when num_chunks = 1. But by observing the code, it can be noticed that it has two loops. Regardless of the value of num_chunk, there will be situations where the data is accessed multiple times, that is, the values accessed in the first loop will be accessed again in the second loop.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Feb 12, 2026

Thanks for the clarification, I totally missed that part. However, if we set n_cols as a constexpr, I believe num_col_blocks would be a compile-time known value, and loops and multiple global memory access can be optimized by the compiler as well.

That being said, it will require checking compilation results to verify this assumption. It's easier to just make a new function for it to prevent mutiple access as this PR does. I'll take a look.

Comment on lines 74 to 112
# Grid-stride loop over rows
for row_idx in tl.range(pid, n_rows, num_progs, num_stages=NUM_STAGES):
Y_row_ptr = Y_ptr + row_idx * Y_row_stride
S_row_ptr = S_ptr + row_idx * S_row_stride
X_row_ptr = X_ptr + row_idx * X_row_stride
R_row_ptr = R_ptr + row_idx * R_row_stride
RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride

# Load entire rows and compute S = X + R (keep in registers)
X_row = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first")
R_row = tl.load(R_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first")
S_row = X_row + R_row

# Compute sum_square
if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA:
S_row = S_row.to(tl.float32)

sum_square = tl.sum(tl.where(mask, S_row * S_row, 0.0))
# Compute rstd for this row
mean_square = sum_square / n_cols
rstd = rsqrt(mean_square + eps)

tl.store(S_row_ptr + col_offsets, S_row, mask=mask, cache_modifier=".cg")
tl.store(RSTD_row_ptr, rstd)

# Load W_row (while stores are in flight)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)

# Normalize and apply weight - optimized for each casting mode
if casting_mode == _CASTING_MODE_GEMMA:
Y_row = ((S_row * rstd) * (offset + W_row)).to(X_DTYPE)
elif casting_mode == _CASTING_MODE_LLAMA:
S_normalized = (S_row * rstd).to(X_DTYPE)
Y_row = S_normalized * (offset + W_row)
else:
Y_row = (S_row * rstd) * (offset + W_row)

# Store result
tl.store(Y_row_ptr + col_offsets, Y_row, mask=mask)
Copy link
Collaborator

@Tcc0403 Tcc0403 Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the tile size is smaller, maybe we can process mulitple rows at once for each programs? E.g. (row_offsets, col_soffsets) = (1, 2048), (2, 1024), (4, 512)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to change the shape of the grid to what you mentioned (row_offsets, col_soffsets)? First of all, this actually has almost the same performance. Moreover, the NPU is more inclined to take the number of NPU cores or a multiple of it as the size of the grid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to change the shape of the grid to what you mentioned (row_offsets, col_soffsets)?

I mean changing the tile size while keeping the grid shape, so grid loop would take less iterations.

Moreover, the NPU is more inclined to take the number of NPU cores or a multiple of it as the size of the grid.

Yes, keep the grid size but change the loop stride. We can pass BLOCK_SIZE_M and consider it when designing the loop.

grid_stride = num_progs * BLOCK_SIZE_M
# ceil_div(n_rows, grid_stride)
num_iterations = (n_rows + grid_stride - 1) // grid_stride
col_offsets = tl.arange(BLOCK_SIZE_N)
col_mask = col_offsetse < n_cols
row_offsets = tl.arange(BLOCK_SIZE_M)
for i in tl.range(num_iterations, num_stages=NUM_STAGES):
    row_idx = i * grid_stride + row_offsets
    row_mask = row_idx < n_rows
    block_mask = row_mask[:, None] & col_mask[None, :]
    # Load multiple rows at once
    X_rows = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=block_mask, other=0.0)
    ...
    # Store multiple rows at once
    tl.store(Y_row_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_rows, mask=mask)

Does it make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggestion comes from seeing that benchmark results show the kernel runtime isn't lower with smaller n_cols, which is due to same degree of parallelism. If one block can allocate tiles with shape (1, 2048), it can certainly allocate (2, 1024), (4, 512) as well. Processing multiple rows at once --> less iterations --> shorter runtime

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I didn't consider this aspect. It makes perfect sense.

@TianHao324
Copy link
Contributor Author

That being said, it will require checking compilation results to verify this assumption. It's easier to just make a new function for it to prevent mutiple access as this PR does. I'll take a look.

OK. Thank you very much.

@TianHao324
Copy link
Contributor Author

@Tcc0403 I have completed the revisions. I sincerely wish you a happy Lunar New Year in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants