Skip to content

Commit a98f691

Browse files
fix correctness test
1 parent 096dd4a commit a98f691

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

csrc/cutlass_extensions/torch_utils.hpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
6868
name, ".stride(", idx, ") to be ", StrideEle::value);
6969
return StrideEle{};
7070
} else {
71-
return tensor.stride(idx);
71+
if (tensor.size(idx) == 1) {
72+
// use 0 stride for dim with size 1, this is easier for
73+
// cute/cutlass to optimize (helps the TMA code flatten dims)
74+
return StrideEle{0};
75+
} else {
76+
return tensor.stride(idx);
77+
}
7278
}
7379
} else {
7480
// Extra strides are assumed to be 0 or 1

csrc/quantization/machete/machete_mm_launcher.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
7171
auto arguments = MacheteKernel::create_arguments(
7272
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
7373
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
74-
args.group_size.value_or(K));
74+
args.group_size);
7575
TORCH_CHECK(MacheteKernel::can_implement(arguments),
7676
"Machete kernel cannot be run with these arguments");
7777

vllm/_custom_ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,8 @@ def machete_gemm_fake(
389389
@torch.library.register_fake("_C::machete_prepack_B")
390390
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
391391
b_type: ScalarType) -> torch.Tensor:
392-
return torch.empty_like(b_q_weight)
392+
return torch.empty_like(b_q_weight,
393+
memory_format=torch.contiguous_format)
393394

394395
@torch.library.register_fake("_C::causal_conv1d_fwd")
395396
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,

0 commit comments

Comments
 (0)