@@ -282,7 +282,7 @@ module {
282282 return %unpack : tensor <2048 xf32 >
283283 }
284284}
285-
285+
286286module attributes {transform.with_named_sequence } {
287287 transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
288288 %slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg1
@@ -343,7 +343,7 @@ module {
343343 return %unpack : tensor <2047 xf32 >
344344 }
345345}
346-
346+
347347module attributes {transform.with_named_sequence } {
348348 transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
349349 %slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg1
@@ -404,7 +404,7 @@ module {
404404 return %pack : tensor <4 x32 x16 xf32 >
405405 }
406406}
407-
407+
408408module attributes {transform.with_named_sequence } {
409409 transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
410410 %slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg1
@@ -610,7 +610,7 @@ module attributes {transform.with_named_sequence} {
610610// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
611611// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
612612// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
613- // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
613+ // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
614614// CHECK-SAME: {
615615// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
616616// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
@@ -676,3 +676,127 @@ module attributes {transform.with_named_sequence} {
676676// CHECK: }
677677// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
678678// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
679+
680+ // -----
681+
682+ module {
683+ func.func @forall_producer_multiple_result_single_consumer (%arg2: tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 > {
684+ %c4 = arith.constant 4 : index
685+ %c64 = arith.constant 64 : index
686+ %c0 = arith.constant 0 : index
687+ %1:2 = scf.forall (%arg3 , %arg4 ) in (2 , 2 ) shared_outs (%arg5 = %arg2 , %arg6 = %arg2 ) -> (tensor <64 x64 xf32 >, tensor <64 x64 xf32 >) {
688+ %outs = tensor.empty () : tensor <32 x32 xf32 >
689+ %extracted_slice = tensor.extract_slice %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <64 x64 xf32 > to tensor <32 x32 xf32 >
690+ %3 = linalg.matmul ins (%extracted_slice , %extracted_slice : tensor <32 x32 xf32 >, tensor <32 x32 xf32 >) outs (%outs : tensor <32 x32 xf32 >) -> tensor <32 x32 xf32 >
691+ scf.forall.in_parallel {
692+ tensor.parallel_insert_slice %3 into %arg6 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <32 x32 xf32 > into tensor <64 x64 xf32 >
693+ tensor.parallel_insert_slice %extracted_slice into %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <32 x32 xf32 > into tensor <64 x64 xf32 >
694+ }
695+ }
696+ %final_out = tensor.empty () : tensor <64 x64 xf32 >
697+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn <add >} ins (%1#0 , %1#1 : tensor <64 x64 xf32 >, tensor <64 x64 xf32 >) outs (%final_out : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
698+ return %2 : tensor <64 x64 xf32 >
699+ }
700+ }
701+
702+ module attributes {transform.with_named_sequence } {
703+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
704+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
705+ %1:2 = transform.split_handle %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
706+ %consumer , %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
707+ transform.yield
708+ }
709+ }
710+
711+ // CHECK-LABEL: func.func @forall_producer_multiple_result_single_consumer(
712+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<64x64xf32>
713+
714+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64x64xf32>
715+ // CHECK: %[[LOOP_RESULT:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (2, 2) shared_outs(%[[SHARED0:.+]] = %[[ARG0]], %[[SHARED1:.+]] = %[[ARG0]], %[[SHARED2:.+]] = %[[INIT]])
716+
717+ // CHECK: %[[TILE_INIT:.+]] = tensor.empty() : tensor<32x32xf32>
718+ // CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
719+ // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[TILE_INIT]] : tensor<32x32xf32>)
720+ // CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
721+ // CHECK: %[[INSERTED_SLICE0:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
722+ // CHECK: %[[EXTRACTED_SLICE1:.+]] = tensor.extract_slice %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
723+ // CHECK: %[[ADD:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%[[EXTRACTED_SLICE]], %[[MATMUL]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[EXTRACTED_SLICE1]] : tensor<32x32xf32>)
724+
725+ // CHECK: scf.forall.in_parallel {
726+ // CHECK: tensor.parallel_insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
727+ // CHECK: tensor.parallel_insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
728+ // CHECK: tensor.parallel_insert_slice %[[ADD]] into %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
729+ // CHECK: }
730+
731+ // CHECK: return %[[LOOP_RESULT]]#2 : tensor<64x64xf32>
732+
733+
734+ // -----
735+
736+ #map = affine_map <(d0 ) -> (d0 )>
737+ module {
738+ func.func @for_producer_producing_multiple_result_single_consumer (%arg0: tensor <32 xf32 >, %arg1: tensor <32 xf32 >, %arg2: tensor <64 xf32 >) -> tensor <64 xf32 > {
739+ %c4 = arith.constant 4 : index
740+ %c64 = arith.constant 64 : index
741+ %c0 = arith.constant 0 : index
742+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args (%arg4 = %arg2 , %arg5 = %arg2 ) -> (tensor <64 xf32 >, tensor <64 xf32 >) {
743+ %extracted_slice = tensor.extract_slice %arg4 [%arg3 ] [32 ] [1 ] : tensor <64 xf32 > to tensor <32 xf32 >
744+ %3 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" ]} ins (%arg0 , %arg1 : tensor <32 xf32 >, tensor <32 xf32 >) outs (%extracted_slice : tensor <32 xf32 >) {
745+ ^bb0 (%in: f32 , %in_16: f32 , %out: f32 ):
746+ %13 = arith.mulf %in , %in_16 : f32
747+ %14 = arith.addf %out , %13 : f32
748+ linalg.yield %14 : f32
749+ } -> tensor <32 xf32 >
750+ %4 = tensor.insert_slice %3 into %arg4 [%arg3 ] [32 ] [1 ] : tensor <32 xf32 > into tensor <64 xf32 >
751+ %5 = tensor.insert_slice %3 into %arg5 [%arg3 ] [32 ] [1 ] : tensor <32 xf32 > into tensor <64 xf32 >
752+ scf.yield %5 , %4 : tensor <64 xf32 >, tensor <64 xf32 >
753+ }
754+ %out_operand = tensor.empty () : tensor <64 xf32 >
755+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn <add >} ins (%1#1 , %1#0 : tensor <64 xf32 >, tensor <64 xf32 >) outs (%out_operand : tensor <64 xf32 >) -> tensor <64 xf32 >
756+ return %2 : tensor <64 xf32 >
757+ }
758+ }
759+
760+ module attributes {transform.with_named_sequence } {
761+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
762+ %0 = transform.structured.match ops {[" tensor.insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
763+ %1:2 = transform.split_handle %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
764+ %consumer , %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
765+ transform.yield
766+ }
767+ }
768+
769+ // CHECK-LABEL: func.func @for_producer_producing_multiple_result_single_consumer(
770+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32xf32>,
771+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32xf32>,
772+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xf32>
773+
774+ // CHECK: %[[C4:.+]] = arith.constant 4 : index
775+ // CHECK: %[[C64:.+]] = arith.constant 64 : index
776+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
777+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64xf32>
778+
779+ // CHECK: %[[LOOP_RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C64]] step %[[C4]]
780+ // CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[ARG2]], %[[ITER1:.+]] = %[[ARG2]], %[[ITER2:.+]] = %[[INIT]])
781+ // CHECK-SAME: -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
782+
783+ // CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV]]] [32] [1]
784+ // CHECK: %[[GENERIC:.+]] = linalg.generic
785+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>)
786+ // CHECK-SAME: outs(%[[EXTRACT_SLICE]] : tensor<32xf32>)
787+ // CHECK: ^{{.*}}(%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %[[OUT:.+]]: f32):
788+ // CHECK: %[[MUL:.+]] = arith.mulf %[[IN0]], %[[IN1]] : f32
789+ // CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f32
790+ // CHECK: linalg.yield %[[ADD]] : f32
791+
792+ // CHECK: %[[INSERT_SLICE0:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER0]][%[[IV]]] [32] [1]
793+ // CHECK: %[[INSERT_SLICE1:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER1]][%[[IV]]] [32] [1]
794+ // CHECK: %[[EXTRACT_SLICE2:.+]] = tensor.extract_slice %[[ITER2]][%[[IV]]] [32] [1]
795+ // CHECK: %[[BINARY:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
796+ // CHECK-SAME: ins(%[[GENERIC]], %[[GENERIC]] : tensor<32xf32>, tensor<32xf32>)
797+ // CHECK-SAME: outs(%[[EXTRACT_SLICE2]] : tensor<32xf32>)
798+ // CHECK: %[[INSERT_SLICE2:.+]] = tensor.insert_slice %[[BINARY]] into %[[ITER2]][%[[IV]]] [32] [1]
799+
800+ // CHECK: scf.yield %[[INSERT_SLICE1]], %[[INSERT_SLICE0]], %[[INSERT_SLICE2]]
801+
802+ // CHECK: return %[[LOOP_RESULT]]#2 : tensor<64xf32>
0 commit comments