Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Failed to lower TorchInductor generated permute kernel #201

Closed
Nullkooland opened this issue Dec 12, 2024 · 3 comments
Closed

[Bug]: Failed to lower TorchInductor generated permute kernel #201

Nullkooland opened this issue Dec 12, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@Nullkooland
Copy link
Contributor

Nullkooland commented Dec 12, 2024

Triton python code

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[524288], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cpu', index=None, cc='', major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'FFF4DE261050A21E66F9B32F4582F64DC0D8366DAC7E0E4E4BFCCE4F24218A4B', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 409600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)
    x0 = xindex % 128
    x1 = (xindex // 128) % 100
    x2 = (xindex // 12800)
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (128*x2) + (4096*x1)), None)
    tl.store(tl.make_block_ptr(out_ptr0, shape=[409600], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp0, [XBLOCK]).to(tl.float16))

Triton IR

module {
  tt.func public @triton_(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c1_i64 = arith.constant 1 : i64
    %c409600_i64 = arith.constant 409600 : i64
    %cst = arith.constant dense<4096> : tensor<1024xi32>
    %cst_0 = arith.constant dense<12800> : tensor<1024xi32>
    %cst_1 = arith.constant dense<100> : tensor<1024xi32>
    %cst_2 = arith.constant dense<128> : tensor<1024xi32>
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %3 = tt.splat %1 : i32 -> tensor<1024xi32>
    %4 = arith.addi %3, %2 : tensor<1024xi32>
    %5 = arith.remsi %4, %cst_2 : tensor<1024xi32>
    %6 = arith.divsi %4, %cst_2 : tensor<1024xi32>
    %7 = arith.remsi %6, %cst_1 : tensor<1024xi32>
    %8 = arith.divsi %4, %cst_0 : tensor<1024xi32>
    %9 = arith.muli %8, %cst_2 : tensor<1024xi32>
    %10 = arith.addi %5, %9 : tensor<1024xi32>
    %11 = arith.muli %7, %cst : tensor<1024xi32>
    %12 = arith.addi %10, %11 : tensor<1024xi32>
    %13 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<1024x!tt.ptr<f16>>
    %14 = tt.addptr %13, %12 : tensor<1024x!tt.ptr<f16>>, tensor<1024xi32>
    %15 = tt.load %14 : tensor<1024x!tt.ptr<f16>>
    %16 = tt.make_tensor_ptr %arg1, [%c409600_i64], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<1024xf16>>
    tt.store %16, %15 : !tt.ptr<tensor<1024xf16>>
    tt.return
  }
}

Crash log

triton-shared-opt --triton-to-linalg-experimental output/kernels/triton/permute.mlir
PtrAnalysis: encountered addptr operand produced by an unsupported operation
%9 = arith.divsi %5, %cst_0 : tensor<1024xi32>
output/kernels/triton/permute.mlir:24:11: remark: PtrAnalysis: Failed to rewrite AddPtrOp
    %14 = tt.addptr %13, %12 : tensor<1024x!tt.ptr<f16>>, tensor<1024xi32>
          ^
output/kernels/triton/permute.mlir:24:11: note: see current operation: %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr<f16>>, tensor<1024xi32>
output/kernels/triton/permute.mlir:25:11: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
    %15 = tt.load %14 : tensor<1024x!tt.ptr<f16>>
          ^
output/kernels/triton/permute.mlir:25:11: note: see current operation: %16 = tt.load %15 : tensor<1024x!tt.ptr<f16>>
output/kernels/triton/permute.mlir:25:11: remark: PtrAnalysis: Failed to rewrite LoadOp
    %15 = tt.load %14 : tensor<1024x!tt.ptr<f16>>
          ^
output/kernels/triton/permute.mlir:25:11: note: see current operation: %16 = tt.load %15 : tensor<1024x!tt.ptr<f16>>
triton-shared-opt: triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:186: std::optional<SmallVector<Value>> (anonymous namespace)::buildCastAndOffsetOps(OpBuilder &, TypeRange, Value, Location): Assertion `castOp && "Unexpected defining op for input of type tt.ptr"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: triton-shared-opt --triton-to-linalg-experimental output/kernels/triton/permute.mlir
 #0 0x000055ab44813a9d llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) llvm_project/llvm/lib/Support/Unix/Signals.inc:723:11
 #1 0x000055ab44813f8b PrintStackTraceSignalHandler(void*) llvm_project/llvm/lib/Support/Unix/Signals.inc:798:1
 #2 0x000055ab448122a6 llvm::sys::RunSignalHandlers() llvm_project/llvm/lib/Support/Signals.cpp:105:5
 #3 0x000055ab44814595 SignalHandler(int) llvm_project/llvm/lib/Support/Unix/Signals.inc:413:1
 #4 0x00007f8075bbc520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #5 0x00007f8075c109fc pthread_kill (/lib/x86_64-linux-gnu/libc.so.6+0x969fc)
 #6 0x00007f8075bbc476 gsignal (/lib/x86_64-linux-gnu/libc.so.6+0x42476)
 #7 0x00007f8075ba27f3 abort (/lib/x86_64-linux-gnu/libc.so.6+0x287f3)
 #8 0x00007f8075ba271b (/lib/x86_64-linux-gnu/libc.so.6+0x2871b)
 #9 0x00007f8075bb3e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#10 0x000055ab3fb73d7f (anonymous namespace)::buildCastAndOffsetOps(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location) triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:189:24
#11 0x000055ab3fb811dd decltype(std::declval<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*&)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location)>()(std::declval<mlir::OpBuilder&>(), std::declval<mlir::TypeRange>(), std::declval<mlir::Value>(), std::declval<mlir::Location>())) std::__1::__invoke[abi:nn180100]<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*&)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location), mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location>(std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*&)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location), mlir::OpBuilder&, mlir::TypeRange&&, mlir::Value&&, mlir::Location&&) /usr/lib/llvm-18/bin/../include/c++/v1/__type_traits/invoke.h:344:25
#12 0x000055ab3fb81140 std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> std::__1::__invoke_void_return_wrapper<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>>, false>::__call[abi:nn180100]<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*&)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location), mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location>(std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*&)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location), mlir::OpBuilder&, mlir::TypeRange&&, mlir::Value&&, mlir::Location&&) /usr/lib/llvm-18/bin/../include/c++/v1/__type_traits/invoke.h:411:12
#13 0x000055ab3fb810f0 std::__1::__function::__alloc_func<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location), std::__1::allocator<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location)>, std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location)>::operator()[abi:nn180100](mlir::OpBuilder&, mlir::TypeRange&&, mlir::Value&&, mlir::Location&&) /usr/lib/llvm-18/bin/../include/c++/v1/__functional/function.h:169:12
#14 0x000055ab3fb80584 std::__1::__function::__func<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location), std::__1::allocator<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (*)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location)>, std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location)>::operator()(mlir::OpBuilder&, mlir::TypeRange&&, mlir::Value&&, mlir::Location&&) /usr/lib/llvm-18/bin/../include/c++/v1/__functional/function.h:311:10
#15 0x000055ab43e63af5 std::__1::__function::__value_func<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location)>::operator()[abi:nn180100](mlir::OpBuilder&, mlir::TypeRange&&, mlir::Value&&, mlir::Location&&) const /usr/lib/llvm-18/bin/../include/c++/v1/__functional/function.h:428:12
#16 0x000055ab43e61c08 std::__1::function<std::__1::optional<llvm::SmallVector<mlir::Value, 6u>> (mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location)>::operator()(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location) const /usr/lib/llvm-18/bin/../include/c++/v1/__functional/function.h:981:10
#17 0x000055ab43e5bc57 mlir::OneToNTypeConverter::materializeTargetConversion(mlir::OpBuilder&, mlir::Location, mlir::TypeRange, mlir::Value) const llvm_project/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp:27:43
#18 0x000055ab43e5db2e mlir::applyPartialOneToNConversion(mlir::Operation*, mlir::OneToNTypeConverter&, mlir::FrozenRewritePatternSet const&) llvm_project/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp:376:12
#19 0x000055ab3fb6cb6d (anonymous namespace)::StructuredToMemrefPass::convertAddPtrToReinterpretCast() triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:345:16
#20 0x000055ab3fb6c3d5 (anonymous namespace)::StructuredToMemrefPass::runOnOperation() triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:367:16
#21 0x000055ab43c5f0ab mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const llvm_project/mlir/lib/Pass/Pass.cpp:0:17
#22 0x000055ab43c5f045 void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) llvm_project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:5
#23 0x000055ab43c6bd89 llvm::function_ref<void ()>::operator()() const llvm_project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:5
#24 0x000055ab43c63225 void mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) llvm_project/mlir/include/mlir/IR/MLIRContext.h:276:3
#25 0x000055ab43c5bc23 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) llvm_project/mlir/lib/Pass/Pass.cpp:533:17
#26 0x000055ab43c5c1c4 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) llvm_project/mlir/lib/Pass/Pass.cpp:593:16
#27 0x000055ab43c5f021 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_0::operator()(mlir::OpPassManager&, mlir::Operation*) const llvm_project/mlir/lib/Pass/Pass.cpp:510:12
#28 0x000055ab43c5ed95 llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (mlir::OpPassManager&, mlir::Operation*)>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_0>(long, mlir::OpPassManager&, mlir::Operation*) llvm_project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#29 0x000055ab3fb1d8e9 llvm::function_ref<llvm::LogicalResult (mlir::OpPassManager&, mlir::Operation*)>::operator()(mlir::OpPassManager&, mlir::Operation*) const llvm_project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#30 0x000055ab3fb1a8a5 mlir::Pass::runPipeline(mlir::OpPassManager&, mlir::Operation*) llvm_project/mlir/include/mlir/Pass/Pass.h:200:12
#31 0x000055ab3fb30a66 (anonymous namespace)::TritonToLinalgExperimentalPass::runOnOperation() triton_shared/lib/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimentalPass.cpp:59:16
#32 0x000055ab43c5f0ab mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const llvm_project/mlir/lib/Pass/Pass.cpp:0:17
#33 0x000055ab43c5f045 void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) llvm_project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:5
#34 0x000055ab43c6bd89 llvm::function_ref<void ()>::operator()() const llvm_project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:5
#35 0x000055ab43c63225 void mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) llvm_project/mlir/include/mlir/IR/MLIRContext.h:276:3
#36 0x000055ab43c5bc23 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) llvm_project/mlir/lib/Pass/Pass.cpp:533:17
#37 0x000055ab43c5c1c4 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) llvm_project/mlir/lib/Pass/Pass.cpp:593:16
#38 0x000055ab43c5dbf8 mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) llvm_project/mlir/lib/Pass/Pass.cpp:904:10
#39 0x000055ab43c5db22 mlir::PassManager::run(mlir::Operation*) llvm_project/mlir/lib/Pass/Pass.cpp:884:60
#40 0x000055ab43bcdfd2 performActions(llvm::raw_ostream&, std::__1::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) llvm_project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:413:17
#41 0x000055ab43bcdc08 processBuffer(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPoolInterface*) llvm_project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:478:12
#42 0x000055ab43bcd9ec mlir::MlirOptMain(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0::operator()(std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) const llvm_project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:561:12
#43 0x000055ab43bcd986 llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) llvm_project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#44 0x000055ab446d2692 llvm::function_ref<llvm::LogicalResult (std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::operator()(std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) const llvm_project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#45 0x000055ab446d1c7c mlir::splitAndProcessBuffer(std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) llvm_project/mlir/lib/Support/ToolUtilities.cpp:27:12
#46 0x000055ab43bc65b8 mlir::MlirOptMain(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) llvm_project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:564:10
#47 0x000055ab43bc6979 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) llvm_project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:605:14
#48 0x000055ab43bc6b4f mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) llvm_project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:621:10
#49 0x000055ab3fbe6351 main triton_shared/tools/triton-shared-opt/triton-shared-opt.cpp:16:33
#50 0x00007f8075ba3d90 (/lib/x86_64-linux-gnu/libc.so.6+0x29d90)
#51 0x00007f8075ba3e40 __libc_start_main (/lib/x86_64-linux-gnu/libc.so.6+0x29e40)
#52 0x000055ab3f4e46a5 _start (triton_shared/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2b366a5)
zsh: IOT instruction (core dumped)  triton-shared-opt --triton-to-linalg-experimental

Additional information

The permute kernel is the codegen result of torch-inductor, from:

x = torch.randn(
    size=(100, 32, 128),
    dtype=torch.float16,
    device=device
)

@torch.compile(fullgraph=True)
def test_func(x: torch.Tensor) -> torch.Tensor:
    out = torch.permute(x, dims=(1, 0, 2)).contiguous()
    return out

This issue maybe related to #16 and #138, about representing a N-D block ptrs on a 1-D flatten ptr with div and modulo arithmetic operations on offsets. Looks like torch-inductor failed to pattern match and generate proper tt.make_tensor_ptr when there is non-power-of-two tensor dim (100 in my case), and there's arith.divsi op in the ptr offset arithmetic ops, causing PtrAnalysis to fail.

@Nullkooland Nullkooland added the bug Something isn't working label Dec 12, 2024
@nhat-nguyen
Copy link
Collaborator

There is a work-in-progress fallback mode for when PtrAnalysis fails which will help these cases. The mode will help with compilation, but the codegen will not be as efficient because we will end up having to load each individual element into a tensor. The fallback mode should be ready in the coming weeks.

@nhat-nguyen
Copy link
Collaborator

@Nullkooland can you check with the latest triton-shared to see if this works now? thank you.

@nhat-nguyen
Copy link
Collaborator

I verified that the latest triton works with this case now. Let's close this. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants