[NPU] FIX fused_linear_jsd ub overflow and OOM on NPU#1043
Open
MAYUNHUI666 wants to merge 3 commits intolinkedin:mainfrom
Open
[NPU] FIX fused_linear_jsd ub overflow and OOM on NPU#1043MAYUNHUI666 wants to merge 3 commits intolinkedin:mainfrom
MAYUNHUI666 wants to merge 3 commits intolinkedin:mainfrom
Conversation
Collaborator
|
Given the device-specific shapes currently scattered across the codebase, I opened #1051 to discuss this issues and a possible path toward standardization. Feedback is very welcome! |
Tcc0403
reviewed
Jan 29, 2026
| # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling | ||
| # The optimal maximum block size depends on your hardware, your kernel, and your dtype | ||
| MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 | ||
| MAX_FUSED_SIZE = 4096 if infer_device() == "npu" else 65536 // 2 |
Collaborator
There was a problem hiding this comment.
append instead of replace
Suggested change
| MAX_FUSED_SIZE = 4096 if infer_device() == "npu" else 65536 // 2 | |
| MAX_FUSED_SIZE = 4096 if infer_device() in ["npu", "xpu"] else 65536 // 2 |
Comment on lines
+237
to
+249
| gpu_memory_gbs = get_total_gpu_memory() | ||
| if gpu_memory_gbs >= 69: | ||
| vocab_size = 128256 | ||
| else: | ||
| vocab_size = 65536 | ||
|
|
||
| common_configs = { | ||
| "kernel_name": "fused_linear_jsd", | ||
| "x_name": "BT", | ||
| "x_label": "B x T", | ||
| "x_values": [2**i for i in range(10, 14)], | ||
| "kernel_providers": ["liger", "torch"], | ||
| "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}], | ||
| "extra_benchmark_configs": [{"H": 4096, "V": vocab_size, "mode": "forward", "dtype": torch.bfloat16}], |
Collaborator
There was a problem hiding this comment.
Let's lower the upper bound of x_values instead of vocab_size for now.
We can discuss what configs should be scalable if there's memory constraint, see #1051
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Distinguish the memory limits and the maximum supported shape across different hardware scenarios
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence