Skip to content

Commit 7817163

Browse files
authored
[mlir] [presburger] Add IntegerRelation::rangeProduct (#148092)
This is intended to match `isl::map`'s `flat_range_product`. --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent 4635743 commit 7817163

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,19 @@ class IntegerRelation {
707707
/// this for uniformity with `applyDomain`.
708708
void applyRange(const IntegerRelation &rel);
709709

710+
/// Let the relation `this` be R1, and the relation `rel` be R2. Requires
711+
/// R1 and R2 to have the same domain.
712+
///
713+
/// Let R3 be the rangeProduct of R1 and R2. Then x R3 (y, z) iff
714+
/// (x R1 y and x R2 z).
715+
///
716+
/// Example:
717+
///
718+
/// R1: (i, j) -> k : f(i, j, k) = 0
719+
/// R2: (i, j) -> l : g(i, j, l) = 0
720+
/// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
721+
IntegerRelation rangeProduct(const IntegerRelation &rel);
722+
710723
/// Given a relation `other: (A -> B)`, this operation merges the symbol and
711724
/// local variables and then takes the composition of `other` on `this: (B ->
712725
/// C)`. The resulting relation represents tuples of the form: `A -> C`.

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,6 +2481,44 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) {
24812481

24822482
void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
24832483

2484+
IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) {
2485+
/// R1: (i, j) -> k : f(i, j, k) = 0
2486+
/// R2: (i, j) -> l : g(i, j, l) = 0
2487+
/// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
2488+
assert(getNumDomainVars() == rel.getNumDomainVars() &&
2489+
"Range product is only defined for relations with equal domains");
2490+
2491+
// explicit copy of `this`
2492+
IntegerRelation result = *this;
2493+
unsigned relRangeVarStart = rel.getVarKindOffset(VarKind::Range);
2494+
unsigned numThisRangeVars = getNumRangeVars();
2495+
unsigned numNewSymbolVars = result.getNumSymbolVars() - getNumSymbolVars();
2496+
2497+
result.appendVar(VarKind::Range, rel.getNumRangeVars());
2498+
2499+
// Copy each equality from `rel` and update the copy to account for range
2500+
// variables from `this`. The `rel` equality is a list of coefficients of the
2501+
// variables from `rel`, and so the range variables need to be shifted right
2502+
// by the number of `this` range variables and symbols.
2503+
for (unsigned i = 0; i < rel.getNumEqualities(); ++i) {
2504+
SmallVector<DynamicAPInt> copy =
2505+
SmallVector<DynamicAPInt>(rel.getEquality(i));
2506+
copy.insert(copy.begin() + relRangeVarStart,
2507+
numThisRangeVars + numNewSymbolVars, DynamicAPInt(0));
2508+
result.addEquality(copy);
2509+
}
2510+
2511+
for (unsigned i = 0; i < rel.getNumInequalities(); ++i) {
2512+
SmallVector<DynamicAPInt> copy =
2513+
SmallVector<DynamicAPInt>(rel.getInequality(i));
2514+
copy.insert(copy.begin() + relRangeVarStart,
2515+
numThisRangeVars + numNewSymbolVars, DynamicAPInt(0));
2516+
result.addInequality(copy);
2517+
}
2518+
2519+
return result;
2520+
}
2521+
24842522
void IntegerRelation::printSpace(raw_ostream &os) const {
24852523
space.print(os);
24862524
os << getNumConstraints() << " constraints\n";

mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,97 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
608608
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
609609
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
610610
}
611+
612+
TEST(IntegerRelationTest, rangeProduct) {
613+
IntegerRelation r1 = parseRelationFromSet(
614+
"(i, j, k) : (2*i + 3*k == 0, i >= 0, j >= 0, k >= 0)", 2);
615+
IntegerRelation r2 = parseRelationFromSet(
616+
"(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);
617+
618+
IntegerRelation rangeProd = r1.rangeProduct(r2);
619+
IntegerRelation expected =
620+
parseRelationFromSet("(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == "
621+
"0, i >= 0, j >= 0, k >= 0, l >= 0)",
622+
2);
623+
624+
EXPECT_TRUE(expected.isEqual(rangeProd));
625+
}
626+
627+
TEST(IntegerRelationTest, rangeProductMultdimRange) {
628+
IntegerRelation r1 =
629+
parseRelationFromSet("(i, k) : (2*i + 3*k == 0, i >= 0, k >= 0)", 1);
630+
IntegerRelation r2 = parseRelationFromSet(
631+
"(i, l, m) : (4*i + 6*m + 9*l == 0, i >= 0, l >= 0, m >= 0)", 1);
632+
633+
IntegerRelation rangeProd = r1.rangeProduct(r2);
634+
IntegerRelation expected =
635+
parseRelationFromSet("(i, k, l, m) : (2*i + 3*k == 0, 4*i + 6*m + 9*l == "
636+
"0, i >= 0, k >= 0, l >= 0, m >= 0)",
637+
1);
638+
639+
EXPECT_TRUE(expected.isEqual(rangeProd));
640+
}
641+
642+
TEST(IntegerRelationTest, rangeProductMultdimRangeSwapped) {
643+
IntegerRelation r1 = parseRelationFromSet(
644+
"(i, l, m) : (4*i + 6*m + 9*l == 0, i >= 0, l >= 0, m >= 0)", 1);
645+
IntegerRelation r2 =
646+
parseRelationFromSet("(i, k) : (2*i + 3*k == 0, i >= 0, k >= 0)", 1);
647+
648+
IntegerRelation rangeProd = r1.rangeProduct(r2);
649+
IntegerRelation expected =
650+
parseRelationFromSet("(i, l, m, k) : (2*i + 3*k == 0, 4*i + 6*m + 9*l == "
651+
"0, i >= 0, k >= 0, l >= 0, m >= 0)",
652+
1);
653+
654+
EXPECT_TRUE(expected.isEqual(rangeProd));
655+
}
656+
657+
TEST(IntegerRelationTest, rangeProductEmptyDomain) {
658+
IntegerRelation r1 =
659+
parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 0);
660+
IntegerRelation r2 =
661+
parseRelationFromSet("(k, l) : (2*k + 3*l == 0, k >= 0, l >= 0)", 0);
662+
IntegerRelation rangeProd = r1.rangeProduct(r2);
663+
IntegerRelation expected =
664+
parseRelationFromSet("(i, j, k, l) : (2*k + 3*l == 0, 4*i + 9*j == "
665+
"0, i >= 0, j >= 0, k >= 0, l >= 0)",
666+
0);
667+
EXPECT_TRUE(expected.isEqual(rangeProd));
668+
}
669+
670+
TEST(IntegerRelationTest, rangeProductEmptyRange) {
671+
IntegerRelation r1 =
672+
parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 2);
673+
IntegerRelation r2 =
674+
parseRelationFromSet("(i, j) : (2*i + 3*j == 0, i >= 0, j >= 0)", 2);
675+
IntegerRelation rangeProd = r1.rangeProduct(r2);
676+
IntegerRelation expected =
677+
parseRelationFromSet("(i, j) : (2*i + 3*j == 0, 4*i + 9*j == "
678+
"0, i >= 0, j >= 0)",
679+
2);
680+
EXPECT_TRUE(expected.isEqual(rangeProd));
681+
}
682+
683+
TEST(IntegerRelationTest, rangeProductEmptyDomainAndRange) {
684+
IntegerRelation r1 = parseRelationFromSet("() : ()", 0);
685+
IntegerRelation r2 = parseRelationFromSet("() : ()", 0);
686+
IntegerRelation rangeProd = r1.rangeProduct(r2);
687+
IntegerRelation expected = parseRelationFromSet("() : ()", 0);
688+
EXPECT_TRUE(expected.isEqual(rangeProd));
689+
}
690+
691+
TEST(IntegerRelationTest, rangeProductSymbols) {
692+
IntegerRelation r1 = parseRelationFromSet(
693+
"(i, j)[s] : (2*i + 3*j + s == 0, i >= 0, j >= 0)", 1);
694+
IntegerRelation r2 = parseRelationFromSet(
695+
"(i, l)[s] : (3*i + 4*l + s == 0, i >= 0, l >= 0)", 1);
696+
697+
IntegerRelation rangeProd = r1.rangeProduct(r2);
698+
IntegerRelation expected = parseRelationFromSet(
699+
"(i, j, l)[s] : (2*i + 3*j + s == 0, 3*i + 4*l + s == "
700+
"0, i >= 0, j >= 0, l >= 0)",
701+
1);
702+
703+
EXPECT_TRUE(expected.isEqual(rangeProd));
704+
}

0 commit comments

Comments
 (0)