Skip to content

[NPU] FIX fused_linear_jsd ub overflow and OOM on NPU#1043

Open
MAYUNHUI666 wants to merge 3 commits intolinkedin:mainfrom
MAYUNHUI666:main
Open

[NPU] FIX fused_linear_jsd ub overflow and OOM on NPU#1043
MAYUNHUI666 wants to merge 3 commits intolinkedin:mainfrom
MAYUNHUI666:main

Conversation

@MAYUNHUI666
Copy link

Summary

Distinguish the memory limits and the maximum supported shape across different hardware scenarios

Testing Done

  • Hardware Type: Ascend NPU 910B2
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 29, 2026

Given the device-specific shapes currently scattered across the codebase, I opened #1051 to discuss this issues and a possible path toward standardization. Feedback is very welcome!

# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
MAX_FUSED_SIZE = 4096 if infer_device() == "npu" else 65536 // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

append instead of replace

Suggested change
MAX_FUSED_SIZE = 4096 if infer_device() == "npu" else 65536 // 2
MAX_FUSED_SIZE = 4096 if infer_device() in ["npu", "xpu"] else 65536 // 2

Comment on lines +237 to +249
gpu_memory_gbs = get_total_gpu_memory()
if gpu_memory_gbs >= 69:
vocab_size = 128256
else:
vocab_size = 65536

common_configs = {
"kernel_name": "fused_linear_jsd",
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(10, 14)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
"extra_benchmark_configs": [{"H": 4096, "V": vocab_size, "mode": "forward", "dtype": torch.bfloat16}],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's lower the upper bound of x_values instead of vocab_size for now.

We can discuss what configs should be scalable if there's memory constraint, see #1051

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants