[NPU]: fused_add_rms_norm kernel distinguish the chunking strategy#1100
[NPU]: fused_add_rms_norm kernel distinguish the chunking strategy#1100TianHao324 wants to merge 1 commit intolinkedin:mainfrom
Conversation
|
benchmark: |
|
@Tcc0403 would you mind having a preview? |
|
What's the key difference between no tiling and tiling kernels? Isn't it just a special case when num_chunks==1? i.e., looping over col dimension but only 1 iteration |
No, the former means that all the data is accessed only once. The latter, however, seems to behave the same as the former when num_chunks = 1. But by observing the code, it can be noticed that it has two loops. Regardless of the value of num_chunk, there will be situations where the data is accessed multiple times, that is, the values accessed in the first loop will be accessed again in the second loop. |
|
Thanks for the clarification, I totally missed that part. However, if we set n_cols as a constexpr, I believe That being said, it will require checking compilation results to verify this assumption. It's easier to just make a new function for it to prevent mutiple access as this PR does. I'll take a look. |
| # Grid-stride loop over rows | ||
| for row_idx in tl.range(pid, n_rows, num_progs, num_stages=NUM_STAGES): | ||
| Y_row_ptr = Y_ptr + row_idx * Y_row_stride | ||
| S_row_ptr = S_ptr + row_idx * S_row_stride | ||
| X_row_ptr = X_ptr + row_idx * X_row_stride | ||
| R_row_ptr = R_ptr + row_idx * R_row_stride | ||
| RSTD_row_ptr = RSTD_ptr + row_idx * RSTD_row_stride | ||
|
|
||
| # Load entire rows and compute S = X + R (keep in registers) | ||
| X_row = tl.load(X_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first") | ||
| R_row = tl.load(R_row_ptr + col_offsets, mask=mask, other=0.0, eviction_policy="evict_first") | ||
| S_row = X_row + R_row | ||
|
|
||
| # Compute sum_square | ||
| if casting_mode == _CASTING_MODE_LLAMA or casting_mode == _CASTING_MODE_GEMMA: | ||
| S_row = S_row.to(tl.float32) | ||
|
|
||
| sum_square = tl.sum(tl.where(mask, S_row * S_row, 0.0)) | ||
| # Compute rstd for this row | ||
| mean_square = sum_square / n_cols | ||
| rstd = rsqrt(mean_square + eps) | ||
|
|
||
| tl.store(S_row_ptr + col_offsets, S_row, mask=mask, cache_modifier=".cg") | ||
| tl.store(RSTD_row_ptr, rstd) | ||
|
|
||
| # Load W_row (while stores are in flight) | ||
| W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) | ||
|
|
||
| # Normalize and apply weight - optimized for each casting mode | ||
| if casting_mode == _CASTING_MODE_GEMMA: | ||
| Y_row = ((S_row * rstd) * (offset + W_row)).to(X_DTYPE) | ||
| elif casting_mode == _CASTING_MODE_LLAMA: | ||
| S_normalized = (S_row * rstd).to(X_DTYPE) | ||
| Y_row = S_normalized * (offset + W_row) | ||
| else: | ||
| Y_row = (S_row * rstd) * (offset + W_row) | ||
|
|
||
| # Store result | ||
| tl.store(Y_row_ptr + col_offsets, Y_row, mask=mask) |
There was a problem hiding this comment.
Since the tile size is smaller, maybe we can process mulitple rows at once for each programs? E.g. (row_offsets, col_soffsets) = (1, 2048), (2, 1024), (4, 512)
There was a problem hiding this comment.
Do you mean to change the shape of the grid to what you mentioned (row_offsets, col_soffsets)? First of all, this actually has almost the same performance. Moreover, the NPU is more inclined to take the number of NPU cores or a multiple of it as the size of the grid.
There was a problem hiding this comment.
Do you mean to change the shape of the grid to what you mentioned (row_offsets, col_soffsets)?
I mean changing the tile size while keeping the grid shape, so grid loop would take less iterations.
Moreover, the NPU is more inclined to take the number of NPU cores or a multiple of it as the size of the grid.
Yes, keep the grid size but change the loop stride. We can pass BLOCK_SIZE_M and consider it when designing the loop.
grid_stride = num_progs * BLOCK_SIZE_M
# ceil_div(n_rows, grid_stride)
num_iterations = (n_rows + grid_stride - 1) // grid_stride
col_offsets = tl.arange(BLOCK_SIZE_N)
col_mask = col_offsetse < n_cols
row_offsets = tl.arange(BLOCK_SIZE_M)
for i in tl.range(num_iterations, num_stages=NUM_STAGES):
row_idx = i * grid_stride + row_offsets
row_mask = row_idx < n_rows
block_mask = row_mask[:, None] & col_mask[None, :]
# Load multiple rows at once
X_rows = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=block_mask, other=0.0)
...
# Store multiple rows at once
tl.store(Y_row_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_rows, mask=mask)Does it make sense?
There was a problem hiding this comment.
This suggestion comes from seeing that benchmark results show the kernel runtime isn't lower with smaller n_cols, which is due to same degree of parallelism. If one block can allocate tiles with shape (1, 2048), it can certainly allocate (2, 1024), (4, 512) as well. Processing multiple rows at once --> less iterations --> shorter runtime
There was a problem hiding this comment.
Indeed, I didn't consider this aspect. It makes perfect sense.
src/liger_kernel/ops/backends/_ascend/ops/fused_add_rms_norm.py
Outdated
Show resolved
Hide resolved
OK. Thank you very much. |
5451549 to
d253e49
Compare
d253e49 to
a012dce
Compare
|
@Tcc0403 I have completed the revisions. I sincerely wish you a happy Lunar New Year in advance. |
Summary
Based on #1070
Because the original kernel uses n_cols as BLOCK_SIZE, and n_cols is smaller in the test, the test can pass normally. However, in the benchmark, n_cols is larger, and when running on the NPU, an ub overflow occurs. Therefore, for each row, we process it in chunks of BLOCK_SIZE. Maintain high performance even when using a smaller hidden size in most models, and also ensure support in cases where a larger hidden size is used.
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence