Skip to content

Commit 8ed48b9

Browse files
authored
Make loop splitting respect dependency that goes through air.wait_all (Xilinx#788)
1 parent 13fd3c6 commit 8ed48b9

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

mlir/lib/Util/Dependency.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,11 @@ bool areAsyncDependent(Operation *a, Operation *b) {
614614
for (auto dep : dep_b)
615615
if (dep == token_a)
616616
return true;
617+
// Deep async dependency tracing through air.wait_all.
618+
if (isAsyncDependent(a, b))
619+
return true;
620+
if (isAsyncDependent(b, a))
621+
return true;
617622

618623
auto chanA = dyn_cast<air::ChannelInterface>(a);
619624
auto chanB = dyn_cast<air::ChannelInterface>(b);

mlir/test/Transform/AIRDependencyScheduleOpt/isolate_async_dma_loop_nest.mlir

+62
Original file line numberDiff line numberDiff line change
@@ -616,3 +616,65 @@ module {
616616
return
617617
}
618618
}
619+
620+
// -----
621+
622+
// Deep dependency tracing through air.wait_all.
623+
624+
// CHECK: scf.for
625+
// CHECK: air.channel.get
626+
// CHECK: air.channel.get
627+
// CHECK: scf.for
628+
// CHECK: air.channel.put
629+
// CHECK: air.channel.put
630+
// CHECK: scf.yield
631+
// CHECK: scf.yield
632+
633+
#map = affine_map<()[s0] -> (s0 * 96)>
634+
#map1 = affine_map<()[s0] -> (s0 * 3)>
635+
module {
636+
air.channel @channel_0 []
637+
air.channel @channel_1 []
638+
func.func @func7() {
639+
%c1 = arith.constant 1 : index
640+
%0 = air.launch async (%arg5, %arg6) in (%arg7=%c1, %arg8=%c1) attributes {id = 2 : i32} {
641+
%1 = air.segment @segment_0 async attributes {id = 1 : i32} {
642+
%c96 = arith.constant 96 : index
643+
%c1_0 = arith.constant 1 : index
644+
%c3 = arith.constant 3 : index
645+
%c0 = arith.constant 0 : index
646+
%async_token, %results = air.execute -> (memref<288xi8, 1 : i32>) {
647+
%alloc = memref.alloc() : memref<288xi8, 1 : i32>
648+
air.execute_terminator %alloc : memref<288xi8, 1 : i32>
649+
} {id = 1 : i32}
650+
%async_token_1, %results_2 = air.execute -> (memref<9xf32, 1 : i32>) {
651+
%alloc = memref.alloc() : memref<9xf32, 1 : i32>
652+
air.execute_terminator %alloc : memref<9xf32, 1 : i32>
653+
} {id = 2 : i32}
654+
%2 = air.wait_all async [%async_token, %async_token_1] {id = 4 : i32}
655+
%3 = scf.for %arg9 = %c0 to %c3 step %c1_0 iter_args(%arg10 = %2) -> (!air.async.token) {
656+
%4 = air.channel.get async [%arg10] @channel_0[] (%results[] [] []) {id = 1 : i32} : (memref<288xi8, 1 : i32>)
657+
%5 = air.channel.get async [%arg10] @channel_0[] (%results_2[] [] []) {id = 2 : i32} : (memref<9xf32, 1 : i32>)
658+
%6 = air.wait_all async [%4, %5] {id = 2 : i32}
659+
%7 = scf.for %arg11 = %c0 to %c3 step %c1_0 iter_args(%arg12 = %6) -> (!air.async.token) {
660+
%async_token_3, %results_4 = air.execute [%arg12] -> (index) {
661+
%12 = affine.apply #map()[%arg11]
662+
air.execute_terminator %12 : index
663+
} {id = 5 : i32}
664+
%9 = air.channel.put async [%async_token_3] @channel_1[] (%results[%results_4] [%c96] [%c1_0]) {id = 3 : i32} : (memref<288xi8, 1 : i32>)
665+
%async_token_5, %results_6 = air.execute [%arg12] -> (index) {
666+
%12 = affine.apply #map1()[%arg11]
667+
air.execute_terminator %12 : index
668+
} {id = 6 : i32}
669+
%10 = air.channel.put async [%async_token_5] @channel_1[] (%results_2[%results_6] [%c3] [%c1_0]) {id = 4 : i32} : (memref<9xf32, 1 : i32>)
670+
%11 = air.wait_all async [%arg12, %9, %10] {id = 1 : i32}
671+
scf.yield %11 : !air.async.token
672+
}
673+
%8 = air.wait_all async [%arg10, %7] {id = 3 : i32}
674+
scf.yield %8 : !air.async.token
675+
}
676+
}
677+
}
678+
return
679+
}
680+
}

0 commit comments

Comments
 (0)