Skip to content

Commit 9c6a575

Browse files
committed
update test
1 parent 8f27555 commit 9c6a575

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

python/examples/test_mm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,10 @@ def mm_kernel(
6262
a = tl.load(
6363
A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
6464
mask=mask_k[None, :],
65-
other=0.0
6665
)
6766
b = tl.load(
6867
B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
6968
mask=mask_k[:, None],
70-
other=0.0
7169
)
7270
if a.dtype != b.dtype:
7371
a = a.to(C.dtype.element_ty)

test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,20 @@ module {
225225
// CHECK: [[VAR_24_2_:%.+]] = arith.maxsi [[VAR_23_2_]], [[VAR_20_5_]] : index
226226
// CHECK-DAG: [[VAR_25_2_:%.+]] = arith.subi [[VAR_24_2_]], [[VAR_20_5_]] : index
227227
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<256xf32>
228+
// CHECK: [[CMPI_:%.+]] = arith.cmpi slt, [[VAR_25_2_]], [[CST_256_1_]] : index
229+
// CHECK: scf.if [[CMPI_]] {
230+
// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_2_]] : memref<256xf32>)
231+
// CHECK: }
228232
// CHECK-NOT: separator of consecutive DAGs
229233
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_2_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
230234
// CHECK-DAG: [[VAR_subview_6_2_:%.+]] = memref.subview [[RES_2_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32> to memref<?xf32, strided<[1]>>
231235
// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_6_2_]] : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
232236
// CHECK-DAG: [[VAR_26_2_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<256xf32>
233237
// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_3_]] to offset: {{.}}[[VAR_20_5_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>>
234238
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref<256xf32>
239+
// CHECK: scf.if [[CMPI_]] {
240+
// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_3_]] : memref<256xf32>)
241+
// CHECK: }
235242
// CHECK-NOT: separator of consecutive DAGs
236243
// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
237244
// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_3_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32> to memref<?xf32, strided<[1]>>
@@ -241,8 +248,7 @@ module {
241248
// CHECK-NOT: separator of consecutive DAGs
242249
// CHECK-DAG: [[VAR_reinterpret_cast_11_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_28_2_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>>
243250
// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() : memref<256xf32>
244-
// CHECK-DAG: [[VAR_29_2_:%.+]] = arith.cmpi slt, [[VAR_25_2_]], [[CST_256_1_]] : index
245-
// CHECK: scf.if [[VAR_29_2_]] {
251+
// CHECK: scf.if [[CMPI_]] {
246252
// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_4_]] : memref<256xf32>)
247253
// CHECK: }
248254
// CHECK-DAG: [[VAR_subview_13_:%.+]] = memref.subview [[VAR_reinterpret_cast_11_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>

0 commit comments

Comments
 (0)