Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build failure in issues #178 #187

Merged
merged 15 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ set(TRITON_SHARED_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files
include_directories(${Python3_INCLUDE_DIR})
include_directories(${pybind11_INCLUDE_DIR})

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
Expand Down
4 changes: 4 additions & 0 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@ class CPUOptions:
extern_libs = None
cluster_dims: tuple = (1, 1, 1)
shared: bool = False
# Disable FP8 here since this is a sample CPU backend.
# Target specific backends can eanble it with supported types.
supported_fp8_dtypes: Tuple[str] = ()
allow_fp8e4nv: bool = False
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
sanitize_overflow: bool = True

def __post_init__(self):
pass
Expand Down
17 changes: 15 additions & 2 deletions backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
def _ty_to_cpp(ty):
if ty[0] == '*':
return "void*"
if ty == "constexpr":
return "PyObject*"
return {
"i1": "int32_t",
"i8": "int8_t",
Expand All @@ -37,11 +39,14 @@ def _ty_to_cpp(ty):
def _extracted_type(ty):
if ty[0] == '*':
return "PyObject*"
if ty == "constexpr":
return "PyObject*"
return _ty_to_cpp(ty)

def _format_of(ty):
return {
"PyObject*": "O",
"constexpr": "O",
"float": "f",
"double": "d",
"long": "l",
Expand All @@ -61,10 +66,10 @@ def _generate_launcher(constants, signature, kernel_name):
format = "iiiOOOO" + args_format
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''

kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants)
kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if ty != "constexpr")
kernel_arg_decls += ', ' if kernel_arg_decls else ''

kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if i not in constants)
kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if ty != "constexpr")
kernel_parameters += ', ' if kernel_parameters else ''

return f"""
Expand Down Expand Up @@ -347,6 +352,10 @@ def __init__(self):
def is_active():
return False

def get_benchmarker(self):
from triton.testing import do_bench
return do_bench

def get_device_capability(self):
return ("cpu", 0)

Expand All @@ -365,5 +374,9 @@ def set_current_device(self, device):
def get_current_target(self):
return GPUTarget("cpu", 0, 0)

def get_active_torch_device(self):
import torch
return torch.device("cpu")

def assemble_tensormap_to_arg(self, tensormaps_info, args):
return args
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,7 @@ struct AssertConverter : public OpConversionPattern<triton::AssertOp> {
}

auto assertMessage =
llvm::formatv("{0}.py:{1}: {2} Assertion `{3}` failed", op.getFile(),
op.getLine(), op.getFunc(), op.getMessage());
llvm::formatv("Assertion `{0}` failed", op.getMessage());
rewriter.create<mlir::cf::AssertOp>(op.getLoc(), condVal,
assertMessage.str());

Expand Down
11 changes: 6 additions & 5 deletions lib/Conversion/StructuredToMemref/StructuredToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ struct MakeTensorPtrConverter
/* result shape */
SmallVector<int64_t>{

// Row stays the same
resultShape[0],
// Row stays the same, but mlir doesn't allow this anymore. Put
// dynamic.
ShapedType::kDynamic,

// Column is dynamic, in most cases, this
// should be the same as the original column.
Expand Down Expand Up @@ -286,9 +287,9 @@ struct MakeTensorPtrConverter
// around.
ShapedType::kDynamic,

// Col stays the same.
resultShape[1],
});
// Col stays the same, which is resultShape[1], but mlir doesn't
// allow this anymore. So we put dynamic instead.
ShapedType::kDynamic});

Value rowSize = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(op.getSizes()[0]));
Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class LoopTypeConverter : public TypeConverter {
// reinterpret_cast.
addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
auto reinterpretCast =
inputs[0].getDefiningOp<memref::ReinterpretCastOp>();
if (!reinterpretCast) {
Expand All @@ -99,14 +99,14 @@ class LoopTypeConverter : public TypeConverter {

addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addArgumentMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
Expand All @@ -123,7 +123,7 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
addTargetMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class TritonFunctionSignatureConverter : public TypeConverter {

auto createUnrealizedCast = [&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TritonToStructuredPass
RewritePatternSet patterns(&getContext());

auto context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
converter.addConversion([](Type type) { return type; });

// We are doing a 1->1 type conversion here, where a triton pointer type
Expand Down Expand Up @@ -145,10 +145,10 @@ class TritonToStructuredPass
// Compute the target materialization, given a value with the pointer type,
// convert that value to a tuple type.
converter.addTargetMaterialization(
[](OpBuilder &builder, TypeRange resultTypes, Value input,
Location loc) -> std::optional<SmallVector<Value>> {
[](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc) -> SmallVector<Value> {
return builder
.create<UnrealizedConversionCastOp>(loc, resultTypes, input)
.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs.front())
->getResults();
});

Expand All @@ -172,7 +172,7 @@ class TritonToStructuredPass
auto moduleOp = getOperation();

auto context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
converter.addConversion([](Type type) { return type; });

// We are doing a 1->N type conversion here, where a pointer tuple type
Expand Down Expand Up @@ -208,10 +208,10 @@ class TritonToStructuredPass
// At the end of pointer analysis, we will use the PtrState to create the
// correct offsets, strides, and remove these ops.
converter.addTargetMaterialization([](OpBuilder &builder,
TypeRange resultTypes, Value input,
TypeRange resultTypes, ValueRange inputs,
Location loc) {
auto placeholder = builder.create<tts::GetStructuredStateOp>(
loc, input.getDefiningOp()->getOperand(0));
loc, inputs.front().getDefiningOp()->getOperand(0));
assert(llvm::equal(placeholder.getResultTypes(), resultTypes));
return placeholder.getResults();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
addTargetMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
Expand Down
24 changes: 22 additions & 2 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,32 @@ def device(request):
"test_if",
"test_if_call",
"test_convert2d",
"test_convertmma2mma",
"test_dot_max_num_imprecise_acc",
"test_propagate_nan",
"test_clamp_symmetric",
"test_temp_var_in_loop",
"test_math_extern"
"test_math_extern",
# attribute 'launch_cooperative_grid' not supported
"test_load_scope_sem_coop_grid_cta_one",
# fp8 support on CPUs is unclear
"test_scaled_dot",
# triton-shared-opt failures:
# PtrAnalysis: encountered addptr operand produced by an unsupported operation
"test_chained_reductions",
# failed to legalize unresolved materialization
"test_constexpr_if_return",
"test_unroll_attr",
# Dialect `ub' not found for custom op 'ub.poison'
"test_poison_return",
# tt.gather not supported yet
"test_gather",
"test_gather_warp_shuffle",
# device 'cpu' does not have 'index
"test_zero_strided_tensors",
# hard-coded with 'ttg' attributes
"test_convert_mma2mma",
"test_local_load_store",
"test_local_load_store_mma"
}

# probably different version of MLIR on the nightly build machine is complaining
Expand Down
8 changes: 6 additions & 2 deletions python/examples/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ def test(device):
# TODO: need to check some conditions otherwise the code below does not make any difference for the test
src = triton.compiler.ASTSource(
fn=reduce_kernel_2d,
signature="*fp32,*fp32,i32,i32",
constants={"BLOCK_SIZE": 32}
signature={"x_ptr": "*fp32",
"output_ptr": "*fp32",
"stride": "i32",
"n_elements": "i32",
"BLOCK_SIZE": "constexpr"},
constexprs={"BLOCK_SIZE": 32}
)
ret = triton.compile(
src,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ module {
%8 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
%9 = tt.addptr %8, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
%10 = tt.load %9 : tensor<32x!tt.ptr<f32>>
%11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32>
%11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32>
%12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32>
%13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32>
%13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32>
%14 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<2048x!tt.ptr<f32>>
%15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr<f32>>, tensor<2048xi32>
tt.store %15, %13 : tensor<2048x!tt.ptr<f32>>
Expand Down
8 changes: 8 additions & 0 deletions test/Conversion/StructuredToMemref/get_num_programs.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
// XFAIL: *
// Note: PtrAnalysis pass can create a tts.makeptr for below pattern:
// %3 = arith.constant 0 : index
// %6 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>>
// %7 = tt.addptr %6, %3 : tensor<1x!tt.ptr<i32>>, tensor<1xi32>
// But not if creating constant 0 and add it to a pointer is optimized away:
// %6 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>>
// A patch that rewrites tt.splat in such case will be sent separately.
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s

module {
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/StructuredToMemref/triton_assert.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ tt.func public @assert_lol(%arg0: i32) {
%c0_i32 = arith.constant 0 : i32
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
%1 = tt.splat %0 : i1 -> tensor<1xi1>
tt.assert %1, "lol", "", "", 0 : tensor<1xi1>
tt.assert %1, "lol" : tensor<1xi1>
tt.return
}

Expand All @@ -12,6 +12,6 @@ tt.func public @assert_lol(%arg0: i32) {
// CHECK-SAME: ([[PARAM_0_:%.+]]: i32, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : i32
// CHECK: [[VAR_0_:%.+]] = arith.cmpi sgt, [[PARAM_0_]], [[CST_0_]] : i32
// CHECK: cf.assert [[VAR_0_]], ".py:0: Assertion `lol` failed"
// CHECK: cf.assert [[VAR_0_]], "Assertion `lol` failed"
// CHECK: return
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,17 @@ module {
// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_14_]], [[CST_4_]] : index
// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_16_]], [[VAR_5_]] : index
// CHECK: [[VAR_18_:%.+]] = arith.subi [[VAR_17_]], [[VAR_14_]] : index
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_18_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_13_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_18_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_19_:%.+]] = arith.subi [[CST_4_]], [[VAR_18_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_19_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_19_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32>
// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>)
// CHECK: [[VAR_20_:%.+]] = arith.minsi [[VAR_18_]], [[CST_4_]] : index
// CHECK-DAG: [[VAR_21_:%.+]] = arith.subi [[CST_4_]], [[VAR_20_]] : index
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_21_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_21_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>>
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_20_]]{{.}} [2, [[VAR_21_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>>
// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>>
Expand Down
8 changes: 4 additions & 4 deletions test/Conversion/StructuredToMemref/wraparound_stacked.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ module {
// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_3_]], [[VAR_12_]] : index
// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_13_]], [[VAR_11_]] : index
// CHECK: [[VAR_15_:%.+]] = arith.divsi [[VAR_14_]], [[VAR_1_]] : index
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_15_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x4xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_15_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_16_:%.+]] = arith.subi [[CST_4_]], [[VAR_15_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x4xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32>
// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>)
// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_15_]], [[CST_4_]] : index
// CHECK-DAG: [[VAR_18_:%.+]] = arith.subi [[CST_4_]], [[VAR_17_]] : index
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<?x4xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1]>>
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_17_]], 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1], offset: ?>>
// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<?x3xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[4, 1]>>
Expand Down
Loading
Loading