Skip to content

Commit

Permalink
Swap the operands of arith.add op in matmul converter,
Browse files Browse the repository at this point in the history
the tt.dot with accumulator will lower to linalg.matmul and arith.add,
and the arith.add will further lower to linalg.generic,
generic will take the lhs of add as the DPS init, so the lhs should be
the matmul accumulator. This is a temporary fix for issue microsoft#196.
  • Loading branch information
MercuryChen committed Dec 5, 2024
1 parent d5b7bee commit f186c1e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1162,9 +1162,9 @@ struct MatmulConverter : public OpConversionPattern<triton::DotOp> {

if (!skipC) {
if (integers) {
res = rewriter.create<arith::AddIOp>(loc, res, opc);
res = rewriter.create<arith::AddIOp>(loc, opc, res);
} else {
res = rewriter.create<arith::AddFOp>(loc, res, opc);
res = rewriter.create<arith::AddFOp>(loc, opc, res);
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/StructuredToMemref/dot.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ module {
// CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16>
// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_4_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_0_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_6_]], [[VAR_3_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_6_]] : tensor<128x256xbf16>) {
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_3_]], [[VAR_6_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_3_]] : tensor<128x256xbf16>) {
// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
// CHECK: [[VAR_8_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : bf16
// CHECK: linalg.yield [[VAR_8_]] : bf16
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/TritonArithToLinalg/dot.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ module {
// CHECK-DAG: [[VAR_45_:%.+]] = tensor.empty() : tensor<128x256xbf16>
// CHECK: [[VAR_46_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_45_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_47_:%.+]] = linalg.matmul ins([[LOAD_VAR_34_MEM_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_46_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_48_:%.+]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_47_]], [[LOAD_VAR_43_MEM_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_47_]] : tensor<128x256xbf16>) {
// CHECK: ^bb0([[in_]]: bf16, [[in_1:.+]]: bf16, [[out_]]: bf16):
// CHECK: [[VAR_49_13_:%.+]] = arith.addf [[in_]], [[in_1]] : bf16
// CHECK: [[VAR_48_:%.+]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins([[LOAD_VAR_43_MEM_]], [[VAR_47_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[LOAD_VAR_43_MEM_]] : tensor<128x256xbf16>) {
// CHECK: ^bb0([[in_]]: bf16, [[in_]]_6: bf16, [[out_]]: bf16):
// CHECK: [[VAR_49_13_:%.+]] = arith.addf [[in_]], [[in_]]_6 : bf16
// CHECK: linalg.yield [[VAR_49_13_]] : bf16
// CHECK: } -> tensor<128x256xbf16>
// CHECK: tt.store [[VAR_43_]], [[VAR_48_]] : tensor<128x256x!tt.ptr<bf16>>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TritonToLinalg/dot.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ module {
// CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16>
// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_4_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_0_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_6_]], [[VAR_3_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_6_]] : tensor<128x256xbf16>) {
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_3_]], [[VAR_6_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_3_]] : tensor<128x256xbf16>) {
// CHECK: ^bb0([[in_:.+]]: bf16, [[in_1:.+]]: bf16, [[out_:.+]]: bf16):
// CHECK: [[VAR_8_:%.+]] = arith.addf [[in_]], [[in_1]] : bf16
// CHECK: linalg.yield [[VAR_8_]] : bf16
Expand Down

0 comments on commit f186c1e

Please sign in to comment.