@@ -286,7 +286,7 @@ module {
286
286
// Offset propagation with wrap-and-stride canonicalization.
287
287
// CHECK-LABEL: test9
288
288
// CHECK: %[[VAL0:.*]] = affine.apply #map()[%arg1]
289
- // CHECK: put @channel_21[] (%arg0[%c0, %c0, %[[VAL0]], %c0 ] [%c8, %c2 , %c32, %c32 ] [%c32, %c8192 , %c256, %c1]) : (memref<128x256xi32>)
289
+ // CHECK: put @channel_21[] (%arg0[%c0, %c0, %[[VAL0]]] [%c8, %c64 , %c32] [%c32, %c256, %c1]) : (memref<128x256xi32>)
290
290
// CHECK: air.channel.put @channel_22[] (%arg2[%c256, %c0, %c0] [%c8, %c32, %c4] [%c4, %c32, %c1]) : (memref<1x2x32x32xi32, 1 : i32>)
291
291
// CHECK: air.channel.put @channel_23[] (%arg3[%c128, %c0, %c0] [%c4, %c32, %c8] [%c8, %c32, %c1]) : (memref<2x1x32x32xi32, 1 : i32>)
292
292
// CHECK: %[[VAL1:.*]] = affine.apply
@@ -386,25 +386,32 @@ module {
386
386
// Affine.apply with map joining two for loops in a loop nest.
387
387
// CHECK-LABEL: test11
388
388
389
- // CHECK: air.channel.put async [%{{.*}}] @channel_26[%c0, %c0] (%{{.*}}[%c0, %c0, %c0] [%c4_0, %c18, %c4_0] [%c96, %c16, %c1]) : (memref<1x6x6x16xbf16, 1>)
389
+ // CHECK: air.channel.put async {{.*}}@channel_26[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c4{{.*}}, %c18{{.*}}, %c4{{.*}}] [%c96{{.*}}, %c16{{.*}}, %c1{{.*}}]) : (memref<1x6x6x16xbf16, 1>)
390
+ // CHECK: air.channel.put async {{.*}}@channel_26[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}, %c12{{.*}}] [%c3{{.*}}, %c3{{.*}}, %c4{{.*}}, %c4{{.*}}] [%c96{{.*}}, %c16{{.*}}, %c16{{.*}}, %c1{{.*}}]) : (memref<1x3x6x16xi32, 1>)
390
391
391
392
func.func @test11 () {
392
393
%c3 = arith.constant 3 : index
393
394
%c4 = arith.constant 4 : index
394
395
%0 = air.launch async (%arg3 , %arg4 , %arg5 ) in (%arg6 =%c3 , %arg7 =%c3 , %arg8 =%c4 ) {
395
396
%1 = air.segment @segment_0 async {
396
397
%c576 = arith.constant 576 : index
398
+ %c288 = arith.constant 288 : index
397
399
%c96 = arith.constant 96 : index
398
400
%c3_0 = arith.constant 3 : index
399
401
%c1 = arith.constant 1 : index
400
402
%c16 = arith.constant 16 : index
403
+ %c12 = arith.constant 12 : index
401
404
%c6 = arith.constant 6 : index
402
405
%c0 = arith.constant 0 : index
403
406
%c4_1 = arith.constant 4 : index
404
407
%async_token , %results = air.execute -> (memref <1 x6 x6 x16 xbf16 , 1 >) {
405
408
%alloc = memref.alloc () : memref <1 x6 x6 x16 xbf16 , 1 >
406
409
air.execute_terminator %alloc : memref <1 x6 x6 x16 xbf16 , 1 >
407
410
}
411
+ %async_token_23 , %results_25 = air.execute -> (memref <1 x3 x6 x16 xi32 , 1 >) {
412
+ %alloc = memref.alloc () : memref <1 x3 x6 x16 xi32 , 1 >
413
+ air.execute_terminator %alloc : memref <1 x3 x6 x16 xi32 , 1 >
414
+ }
408
415
%4 = scf.for %arg9 = %c0 to %c4_1 step %c1 iter_args (%arg13 = %async_token ) -> (!air.async.token ) {
409
416
%2 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args (%arg11 = %arg13 ) -> (!air.async.token ) {
410
417
%async_token_2 , %results_3 = air.execute [%arg11 ] -> (index ) {
@@ -416,6 +423,15 @@ module {
416
423
}
417
424
scf.yield %2 : !air.async.token
418
425
}
426
+ scf.for %arg9 = %c0 to %c3_0 step %c1 {
427
+ %60 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args (%arg13 = %async_token ) -> (!air.async.token ) {
428
+ %async_token_54 , %results_55 = air.execute [%arg13 ] -> (index ) {
429
+ air.execute_terminator %arg9 : index
430
+ }
431
+ %61 = air.channel.put async [%async_token_54 ] @channel_26 [%c0 , %c0 ] (%results_25 [%c0 , %results_55 , %arg10 , %c12 ] [%c1 , %c1 , %c4_1 , %c4_1 ] [%c288 , %c96 , %c16 , %c1 ]) : (memref <1 x3 x6 x16 xi32 , 1 >)
432
+ scf.yield %61 : !air.async.token
433
+ }
434
+ }
419
435
}
420
436
}
421
437
return
@@ -460,10 +476,11 @@ module {
460
476
// CHECK-LABEL: test13
461
477
462
478
// CHECK: air.channel.put async [%{{.*}}] @channel_14[] (%{{.*}}[%c0, %1, %results, %c0] [%c8, %c2_0, %c32, %c32] [%c32, %c8192, %c256, %c1]) : (memref<2x128x256xi32>)
479
+ // CHECK: air.channel.put async [%{{.*}}] @channel_15[%c0, %c0] (%{{.*}}[%c0, %results, %c32768] [%c8, %c32, %c32] [%c32, %c256, %c1]) : (memref<512x512xi32>)
463
480
464
- func.func @test13 (%arg0: memref <2 x128 x256 xi32 >, %arg1: memref <2 x 256 x 128 x i32 >) {
481
+ func.func @test13 (%arg0: memref <2 x128 x256 xi32 >, %arg1: memref <512 x 512 x i32 >) {
465
482
%c2 = arith.constant 2 : index
466
- %0 = air.launch async (%arg3 , %arg4 , %arg5 ) in (%arg6 =%c2 , %arg7 =%c2 , %arg8 =%c2 ) args (%arg10 =%arg0 , %arg11 =%arg1 ) : memref <2 x128 x256 xi32 >, memref <2 x 256 x 128 x i32 > {
483
+ %0 = air.launch async (%arg3 , %arg4 , %arg5 ) in (%arg6 =%c2 , %arg7 =%c2 , %arg8 =%c2 ) args (%arg10 =%arg0 , %arg11 =%arg1 ) : memref <2 x128 x256 xi32 >, memref <512 x 512 x i32 > {
467
484
%c4096 = arith.constant 4096 : index
468
485
%c8 = arith.constant 8 : index
469
486
%c16384 = arith.constant 16384 : index
@@ -484,6 +501,10 @@ module {
484
501
%7 = air.channel.put async [%arg13 , %async_token ] @channel_14 [] (%arg10 [%arg3 , %c0 , %c0 , %results , %arg12 ] [%c1 , %c2_0 , %c1 , %c32 , %c32 ] [%c32768 , %c8192 , %c32 , %c256 , %c1 ]) : (memref <2 x128 x256 xi32 >)
485
502
scf.yield %7 : !air.async.token
486
503
}
504
+ %3 = scf.for %arg12 = %c0 to %c256 step %c32 iter_args (%arg13 = %async_token ) -> (!air.async.token ) {
505
+ %7 = air.channel.put async [%arg13 , %async_token ] @channel_15 [%c0 , %c0 ] (%arg11 [%c2_0 , %c0 , %results , %arg12 ] [%c1 , %c1 , %c32 , %c32 ] [%c16384 , %c32 , %c256 , %c1 ]) {id = 1 : i32 } : (memref <512 x512 xi32 >)
506
+ scf.yield %7 : !air.async.token
507
+ }
487
508
}
488
509
return
489
510
}
0 commit comments