Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: failed to legalize operation 'tt.splat' marked as erased #65

Open
yuanfz98 opened this issue Nov 27, 2023 · 1 comment
Open

[Bug]: failed to legalize operation 'tt.splat' marked as erased #65

yuanfz98 opened this issue Nov 27, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@yuanfz98
Copy link
Contributor

Triton python code

def triton_(in_out_ptr0, in_ptr0, in_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    rnumel = 2
    RBLOCK: tl.constexpr = 2
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r0 = rindex
    tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0)
    tmp5 = tl.load(in_ptr1 + (r0), rmask, other=0)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp3 = tl.where(rmask, tmp1, 0)
    tmp4 = tl.sum(tmp3, 1)[:, None]
    tmp6 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
    tmp8 = tl.where(rmask, tmp6, 0)
    tmp9 = tl.sum(tmp8, 1)[:, None]
    tmp10 = tmp9.to(tl.float32)
    tmp11 = tmp4 / tmp10
    tl.debug_barrier()
    tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp11, None)

Triton IR

module {
  tt.func public @triton__0d1d2d34(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32) attributes {noinline = false} {
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<0> : tensor<1x2xi64>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x2xf32>
    %cst_1 = arith.constant dense<2> : tensor<1x2xi32>
    %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32>
    %1 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<2xi32>) -> tensor<1x2xi32>
    %2 = arith.cmpi slt, %1, %cst_1 : tensor<1x2xi32>
    %3 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<1x2x!tt.ptr<f32, 1>>
    %4 = tt.addptr %3, %1 : tensor<1x2x!tt.ptr<f32, 1>>, tensor<1x2xi32>
    %5 = tt.load %4, %2, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2xf32>
    %6 = tt.splat %arg2 : (!tt.ptr<i64, 1>) -> tensor<1x2x!tt.ptr<i64, 1>>
    %7 = tt.addptr %6, %1 : tensor<1x2x!tt.ptr<i64, 1>>, tensor<1x2xi32>
    %8 = tt.load %7, %2, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x2xi64>
    %9 = arith.select %2, %5, %cst_0 : tensor<1x2xi1>, tensor<1x2xf32>
    %10 = "tt.reduce"(%9) <{axis = 1 : i32}> ({
    ^bb0(%arg5: f32, %arg6: f32):
      %19 = arith.addf %arg5, %arg6 : f32
      tt.reduce.return %19 : f32
    }) : (tensor<1x2xf32>) -> tensor<1xf32>
    %11 = tt.expand_dims %10 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
    %12 = arith.select %2, %8, %cst : tensor<1x2xi1>, tensor<1x2xi64>
    %13 = "tt.reduce"(%12) <{axis = 1 : i32}> ({
    ^bb0(%arg5: i64, %arg6: i64):
      %19 = arith.addi %arg5, %arg6 : i64
      tt.reduce.return %19 : i64
    }) : (tensor<1x2xi64>) -> tensor<1xi64>
    %14 = tt.expand_dims %13 {axis = 1 : i32} : (tensor<1xi64>) -> tensor<1x1xi64>
    %15 = arith.sitofp %14 : tensor<1x1xi64> to tensor<1x1xf32>
    %16 = arith.divf %11, %15 : tensor<1x1xf32>
    gpu.barrier
    %17 = tt.addptr %arg0, %c0_i32 : !tt.ptr<f32, 1>, i32
    %18 = tt.splat %17 : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
    tt.store %18, %16 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32>
    tt.return
  }
}

Crash log

/workspace/hongjing/temp/jolwo38w/triton_.ttir:34:11: error: failed to legalize operation 'tt.splat' marked as erased
    %18 = tt.splat %17 : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
          ^
/workspace/hongjing/temp/jolwo38w/triton_.ttir:34:11: note: see current operation: %93 = "tt.splat"(%92) {MetaUse} : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
/workspace/hongjing/temp/jolwo38w/triton_.ttir:35:5: note: found live user of result #0: "memref.tensor_store"(%88, %93) : (tensor<1x1xf32>, tensor<1x1x!tt.ptr<f32, 1>>) -> ()
    tt.store %18, %16 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32>

Additional information

No response

@yuanfz98 yuanfz98 added the bug Something isn't working label Nov 27, 2023
@yuanfz98
Copy link
Contributor Author

yuanfz98 commented Nov 27, 2023

@nhat-nguyen
This error is from misplacement of tt.splat. We have MetaOpConverter which erase tt.splat op first, then AddPtrConverter + StoreConverter (which uses adaptor to get the mutated ptr of AddPtrConverter). So a tt.splat in the middle is inconsistant as it has been erased.
We may:

  1. in canonicalizer pass, move tt.splat before addptr
  2. bottom-top exploring in StoreConverter/LoadConverter

Refs to #62.
Thanks for your reply !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant