Skip to content

Commit f3884b6

Browse files
authored
AIRSpecializeChannelWrapAndStride: More flexible wrap-and-stride offset canonicalization (Xilinx#791)
* Enable more flexible offset canonicalization, rather than only considering folding to the next dimension * Add a new board test for pack-peel gemm in i32, with 4x4 herd * Fixup comparison types * Remove debug prints * Fixup modulo result 0 cannot be converted to bool false * Wrap-and-stride for loop folding taking into account complex affine maps with both gradients and offset
1 parent 44f0c0b commit f3884b6

File tree

5 files changed

+583
-17
lines changed

5 files changed

+583
-17
lines changed

mlir/lib/Util/Util.cpp

+42-13
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,24 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
840840
builder.getUnknownLoc(), (*const_size) * (*const_size_next));
841841
return true;
842842
};
843+
// For a given offset[i], find the first offset[j] such that stride[j] is
844+
// divisible by stride[i], so that offset[i] can be composed onto offset[j].
845+
auto findFirstComposableOffsetIdx = [](int i, SmallVector<Value> offsets,
846+
SmallVector<Value> strides) {
847+
auto constStrideI = getConstantIntValue(strides[i]);
848+
std::optional<int> output = std::nullopt;
849+
for (int j = i + 1; j < (int)strides.size(); j++) {
850+
if (!getConstantIntValue(offsets[j]))
851+
continue; // Currently unable to compose offset[i] expr onto another
852+
// offset[j] expr.
853+
auto constStrideJ = getConstantIntValue(strides[j]);
854+
if ((*constStrideI) % (*constStrideJ) == 0) {
855+
output = j;
856+
return output;
857+
}
858+
}
859+
return output;
860+
};
843861
for (auto i : erase_dims) {
844862
auto const_offset = getConstantIntValue(offsets[i]);
845863
if (const_offset && *const_offset == 0) {
@@ -855,13 +873,14 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
855873
continue;
856874
auto const_stride = getConstantIntValue(strides[i]);
857875
assert(const_stride && "non-static stride, NYI.");
858-
auto const_offset_next = getConstantIntValue(offsets[i + 1]);
859-
if (!const_offset_next)
876+
auto j = findFirstComposableOffsetIdx(i, offsets, strides);
877+
if (!j)
860878
continue;
861-
auto const_stride_next = getConstantIntValue(strides[i + 1]);
862-
assert(const_stride_next && "non-static stride, NYI.");
879+
auto const_offset_next = getConstantIntValue(offsets[*j]);
880+
auto const_stride_next = getConstantIntValue(strides[*j]);
881+
// Attempting to compose i-th offset onto another offset.
863882
if (const_offset) {
864-
offsets[i + 1] = builder.create<arith::ConstantIndexOp>(
883+
offsets[*j] = builder.create<arith::ConstantIndexOp>(
865884
builder.getUnknownLoc(),
866885
(*const_stride) * (*const_offset) / (*const_stride_next) +
867886
(*const_offset_next));
@@ -912,7 +931,7 @@ LogicalResult eraseWrapNStrideDim(OpBuilder builder,
912931
auto next_offset_map = AffineMap::get(0, 1, offset_expr);
913932
affine_apply.setMap(next_offset_map);
914933
offsets[i] = affine_apply;
915-
offsets[i + 1] = offsets[i];
934+
offsets[*j] = offsets[i];
916935
}
917936
erased |= multiplyAdjWraps(builder, i, sizes);
918937
offsets.erase(offsets.begin() + i);
@@ -1029,6 +1048,12 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
10291048
}
10301049

10311050
std::map<Operation *, int> op_to_count;
1051+
// Evaluate offset from affine map.
1052+
auto evalOffsetFromAffineMap = [&](MLIRContext *ctx, AffineMap map) {
1053+
return air::evaluateConstantsInMap(
1054+
map, SmallVector<std::optional<int64_t>>{std::optional<int64_t>{0}},
1055+
ctx);
1056+
};
10321057
for (auto o : for_loops) {
10331058
int64_t stepSize = -1;
10341059
int loop_lower_bound = 0;
@@ -1067,14 +1092,18 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
10671092
if (iv_is_symbol) {
10681093
auto map = affop.getAffineMap();
10691094
ind_var_factor = *getConstantIntValue(strides[i]);
1070-
ind_var_factor *= air::evaluateConstantsInMap(
1071-
map,
1072-
SmallVector<std::optional<int64_t>>{
1073-
std::optional<int64_t>{stepSize}},
1074-
for_op->getContext())
1075-
.value();
1095+
int64_t map_offset =
1096+
evalOffsetFromAffineMap(for_op->getContext(), map).value();
1097+
int64_t map_gradient = air::evaluateConstantsInMap(
1098+
map,
1099+
SmallVector<std::optional<int64_t>>{
1100+
std::optional<int64_t>{stepSize}},
1101+
for_op->getContext())
1102+
.value() -
1103+
map_offset;
1104+
ind_var_factor *= map_gradient;
10761105
offsets[i] = builder.template create<arith::ConstantIndexOp>(
1077-
loc, loop_lower_bound);
1106+
loc, loop_lower_bound + map_offset);
10781107
break;
10791108
}
10801109
}

mlir/test/Transform/AIRDependencyScheduleOpt/specialize-channel-wrap-and-stride.mlir

+25-4
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ module {
286286
// Offset propagation with wrap-and-stride canonicalization.
287287
// CHECK-LABEL: test9
288288
// CHECK: %[[VAL0:.*]] = affine.apply #map()[%arg1]
289-
// CHECK: put @channel_21[] (%arg0[%c0, %c0, %[[VAL0]], %c0] [%c8, %c2, %c32, %c32] [%c32, %c8192, %c256, %c1]) : (memref<128x256xi32>)
289+
// CHECK: put @channel_21[] (%arg0[%c0, %c0, %[[VAL0]]] [%c8, %c64, %c32] [%c32, %c256, %c1]) : (memref<128x256xi32>)
290290
// CHECK: air.channel.put @channel_22[] (%arg2[%c256, %c0, %c0] [%c8, %c32, %c4] [%c4, %c32, %c1]) : (memref<1x2x32x32xi32, 1 : i32>)
291291
// CHECK: air.channel.put @channel_23[] (%arg3[%c128, %c0, %c0] [%c4, %c32, %c8] [%c8, %c32, %c1]) : (memref<2x1x32x32xi32, 1 : i32>)
292292
// CHECK: %[[VAL1:.*]] = affine.apply
@@ -386,25 +386,32 @@ module {
386386
// Affine.apply with map joining two for loops in a loop nest.
387387
// CHECK-LABEL: test11
388388

389-
// CHECK: air.channel.put async [%{{.*}}] @channel_26[%c0, %c0] (%{{.*}}[%c0, %c0, %c0] [%c4_0, %c18, %c4_0] [%c96, %c16, %c1]) : (memref<1x6x6x16xbf16, 1>)
389+
// CHECK: air.channel.put async {{.*}}@channel_26[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c4{{.*}}, %c18{{.*}}, %c4{{.*}}] [%c96{{.*}}, %c16{{.*}}, %c1{{.*}}]) : (memref<1x6x6x16xbf16, 1>)
390+
// CHECK: air.channel.put async {{.*}}@channel_26[%c0{{.*}}, %c0{{.*}}] (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}, %c12{{.*}}] [%c3{{.*}}, %c3{{.*}}, %c4{{.*}}, %c4{{.*}}] [%c96{{.*}}, %c16{{.*}}, %c16{{.*}}, %c1{{.*}}]) : (memref<1x3x6x16xi32, 1>)
390391

391392
func.func @test11() {
392393
%c3 = arith.constant 3 : index
393394
%c4 = arith.constant 4 : index
394395
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c3, %arg7=%c3, %arg8=%c4) {
395396
%1 = air.segment @segment_0 async {
396397
%c576 = arith.constant 576 : index
398+
%c288 = arith.constant 288 : index
397399
%c96 = arith.constant 96 : index
398400
%c3_0 = arith.constant 3 : index
399401
%c1 = arith.constant 1 : index
400402
%c16 = arith.constant 16 : index
403+
%c12 = arith.constant 12 : index
401404
%c6 = arith.constant 6 : index
402405
%c0 = arith.constant 0 : index
403406
%c4_1 = arith.constant 4 : index
404407
%async_token, %results = air.execute -> (memref<1x6x6x16xbf16, 1>) {
405408
%alloc = memref.alloc() : memref<1x6x6x16xbf16, 1>
406409
air.execute_terminator %alloc : memref<1x6x6x16xbf16, 1>
407410
}
411+
%async_token_23, %results_25 = air.execute -> (memref<1x3x6x16xi32, 1>) {
412+
%alloc = memref.alloc() : memref<1x3x6x16xi32, 1>
413+
air.execute_terminator %alloc : memref<1x3x6x16xi32, 1>
414+
}
408415
%4 = scf.for %arg9 = %c0 to %c4_1 step %c1 iter_args(%arg13 = %async_token) -> (!air.async.token) {
409416
%2 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args(%arg11 = %arg13) -> (!air.async.token) {
410417
%async_token_2, %results_3 = air.execute [%arg11] -> (index) {
@@ -416,6 +423,15 @@ module {
416423
}
417424
scf.yield %2 : !air.async.token
418425
}
426+
scf.for %arg9 = %c0 to %c3_0 step %c1 {
427+
%60 = scf.for %arg10 = %c0 to %c3_0 step %c1 iter_args(%arg13 = %async_token) -> (!air.async.token) {
428+
%async_token_54, %results_55 = air.execute [%arg13] -> (index) {
429+
air.execute_terminator %arg9 : index
430+
}
431+
%61 = air.channel.put async [%async_token_54] @channel_26[%c0, %c0] (%results_25[%c0, %results_55, %arg10, %c12] [%c1, %c1, %c4_1, %c4_1] [%c288, %c96, %c16, %c1]) : (memref<1x3x6x16xi32, 1>)
432+
scf.yield %61 : !air.async.token
433+
}
434+
}
419435
}
420436
}
421437
return
@@ -460,10 +476,11 @@ module {
460476
// CHECK-LABEL: test13
461477

462478
// CHECK: air.channel.put async [%{{.*}}] @channel_14[] (%{{.*}}[%c0, %1, %results, %c0] [%c8, %c2_0, %c32, %c32] [%c32, %c8192, %c256, %c1]) : (memref<2x128x256xi32>)
479+
// CHECK: air.channel.put async [%{{.*}}] @channel_15[%c0, %c0] (%{{.*}}[%c0, %results, %c32768] [%c8, %c32, %c32] [%c32, %c256, %c1]) : (memref<512x512xi32>)
463480

464-
func.func @test13(%arg0: memref<2x128x256xi32>, %arg1: memref<2x256x128xi32>) {
481+
func.func @test13(%arg0: memref<2x128x256xi32>, %arg1: memref<512x512xi32>) {
465482
%c2 = arith.constant 2 : index
466-
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c2, %arg7=%c2, %arg8=%c2) args(%arg10=%arg0, %arg11=%arg1) : memref<2x128x256xi32>, memref<2x256x128xi32> {
483+
%0 = air.launch async (%arg3, %arg4, %arg5) in (%arg6=%c2, %arg7=%c2, %arg8=%c2) args(%arg10=%arg0, %arg11=%arg1) : memref<2x128x256xi32>, memref<512x512xi32> {
467484
%c4096 = arith.constant 4096 : index
468485
%c8 = arith.constant 8 : index
469486
%c16384 = arith.constant 16384 : index
@@ -484,6 +501,10 @@ module {
484501
%7 = air.channel.put async [%arg13, %async_token] @channel_14[] (%arg10[%arg3, %c0, %c0, %results, %arg12] [%c1, %c2_0, %c1, %c32, %c32] [%c32768, %c8192, %c32, %c256, %c1]) : (memref<2x128x256xi32>)
485502
scf.yield %7 : !air.async.token
486503
}
504+
%3 = scf.for %arg12 = %c0 to %c256 step %c32 iter_args(%arg13 = %async_token) -> (!air.async.token) {
505+
%7 = air.channel.put async [%arg13, %async_token] @channel_15[%c0, %c0] (%arg11[%c2_0, %c0, %results, %arg12] [%c1, %c1, %c32, %c32] [%c16384, %c32, %c256, %c1]) {id = 1 : i32} : (memref<512x512xi32>)
506+
scf.yield %7 : !air.async.token
507+
}
487508
}
488509
return
489510
}

0 commit comments

Comments
 (0)