Skip to content

Commit 1f103b1

Browse files
committed
src,test: fix some bugs around ivar recovery and pos splits
This commit fixes several bugs around the recovery of index variables when splits/divides of the position spaces are involved. In particular, the old code would emit position loops like ``` for (int jposo = A2_pos[i] / chunkSize; jposo < (A2_pos[i+1] + (chunkSize - 1)) / chunkSize; jposo++) { for (int jposi = 0; jposi < chunkSize; jposi++) { int jposA = jposo * chunkSize + jposi; ... } } ``` This does not correctly iterate over the position space. A correct code is: ``` for (int jposo = 0; jposo < ((A2_pos[i+1] - A2_pos[i]) + (chunkSize - 1)) / chunkSize; jposo++) { for (int jposi = 0; jposi < chunkSize; jposi++) { int jposA = jposo * chunkSize + jposi + A2_pos[i]; ... } } ```
1 parent 0ede002 commit 1f103b1

File tree

3 files changed

+78
-12
lines changed

3 files changed

+78
-12
lines changed

src/index_notation/provenance_graph.cpp

+25-11
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,16 @@ std::vector<ir::Expr> SplitRelNode::deriveIterBounds(taco::IndexVar indexVar,
212212
std::vector<ir::Expr> parentBound = parentIterBounds.at(getParentVar());
213213
Datatype splitFactorType = parentBound[0].type();
214214
if (indexVar == getOuterVar()) {
215-
ir::Expr minBound = ir::Div::make(parentBound[0], ir::Literal::make(getSplitFactor(), splitFactorType));
216-
ir::Expr maxBound = ir::Div::make(ir::Add::make(parentBound[1], ir::Literal::make(getSplitFactor()-1, splitFactorType)), ir::Literal::make(getSplitFactor(), splitFactorType));
215+
// The outer variable must always range from 0 to the extent of the bounds (chunked up).
216+
// This is a noop for the common case where all of our loops start at 0. However, it is
217+
// important when doing a pos split, where the resulting bounds may not start at 0,
218+
// and instead start at a position like T_pos[i].lo.
219+
ir::Expr minBound = 0;
220+
auto upper = ir::Sub::make(parentBound[1], parentBound[0]);
221+
ir::Expr maxBound = ir::Div::make(ir::Add::make(upper, ir::Literal::make(getSplitFactor() - 1, splitFactorType)),
222+
ir::Literal::make(getSplitFactor(), splitFactorType));
217223
return {minBound, maxBound};
218-
}
219-
else if (indexVar == getInnerVar()) {
224+
} else if (indexVar == getInnerVar()) {
220225
ir::Expr minBound = 0;
221226
ir::Expr maxBound = ir::Literal::make(getSplitFactor(), splitFactorType);
222227
return {minBound, maxBound};
@@ -231,7 +236,12 @@ ir::Expr SplitRelNode::recoverVariable(taco::IndexVar indexVar,
231236
taco_iassert(indexVar == getParentVar());
232237
taco_iassert(variableNames.count(getParentVar()) && variableNames.count(getOuterVar()) && variableNames.count(getInnerVar()));
233238
Datatype splitFactorType = variableNames[getParentVar()].type();
234-
return ir::Add::make(ir::Mul::make(variableNames[getOuterVar()], ir::Literal::make(getSplitFactor(), splitFactorType)), variableNames[getInnerVar()]);
239+
// Include the lower bound of the variable being recovered. Normally, this is 0, but
240+
// in the case of a position split it is not.
241+
return ir::Add::make(
242+
ir::Add::make(ir::Mul::make(variableNames[getOuterVar()], ir::Literal::make(getSplitFactor(), splitFactorType)), variableNames[getInnerVar()]),
243+
parentIterBounds[indexVar][0]
244+
);
235245
}
236246

237247
ir::Stmt SplitRelNode::recoverChild(taco::IndexVar indexVar,
@@ -364,11 +374,12 @@ std::vector<ir::Expr> DivideRelNode::deriveIterBounds(taco::IndexVar indexVar,
364374
ir::Expr minBound = 0;
365375
ir::Expr maxBound = divFactor;
366376
return {minBound, maxBound};
367-
}
368-
else if (indexVar == getInnerVar()) {
369-
// The inner loop ranges over a chunk of size parentBound / divFactor.
370-
ir::Expr minBound = ir::Div::make(parentBound[0], divFactor);
371-
ir::Expr maxBound = ir::Div::make(ir::Add::make(parentBound[1], ir::Literal::make(getDivFactor()-1, divFactorType)), divFactor);
377+
} else if (indexVar == getInnerVar()) {
378+
// The inner loop ranges over a chunk of size parentBound / divFactor. Similarly
379+
// to split, the loop must range from 0 to parentBound.
380+
ir::Expr minBound = 0;
381+
auto upper = ir::Sub::make(parentBound[1], parentBound[0]);
382+
ir::Expr maxBound = ir::Div::make(ir::Add::make(upper, ir::Literal::make(getDivFactor()-1, divFactorType)), divFactor);
372383
return {minBound, maxBound};
373384
}
374385
taco_ierror;
@@ -390,7 +401,10 @@ ir::Expr DivideRelNode::recoverVariable(taco::IndexVar indexVar,
390401
// The bounds for the dimension are adjusted so that dimensions that aren't
391402
// divisible by divFactor have the last piece included.
392403
auto bounds = ir::Div::make(ir::Add::make(dimSize, divFactorMinusOne), divFactor);
393-
return ir::Add::make(ir::Mul::make(variableNames[getOuterVar()], bounds), variableNames[getInnerVar()]);
404+
// We multiply this all together and then add on the base of the parentBounds
405+
// to shift up into the range of the parent. This is normally 0, but for cases
406+
// like position loops it is not.
407+
return ir::Add::make(ir::Add::make(ir::Mul::make(variableNames[getOuterVar()], bounds), variableNames[getInnerVar()]), parentIterBounds[indexVar][0]);
394408
}
395409

396410
ir::Stmt DivideRelNode::recoverChild(taco::IndexVar indexVar,

src/lower/lowerer_impl_imperative.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -695,11 +695,15 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
695695
// Find the iteration bounds of the inner variable -- that is the size
696696
// that the outer loop was broken into.
697697
auto bounds = this->provGraph.deriveIterBounds(inner, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
698+
auto parentBounds = this->provGraph.deriveIterBounds(varToRecover, this->definedIndexVarsOrdered, this->underivedBounds, this->indexVarToExprMap, this->iterators);
698699
// Use the difference between the bounds to find the size of the loop.
699700
auto dimLen = ir::Sub::make(bounds[1], bounds[0]);
700701
// For a variable f divided into into f1 and f2, the guard ensures that
701702
// for iteration f, f should be within f1 * dimLen and (f1 + 1) * dimLen.
702-
auto guard = ir::Gte::make(this->indexVarToExprMap[varToRecover], ir::Mul::make(ir::Add::make(this->indexVarToExprMap[outer], 1), dimLen));
703+
// Additionally, similarly to the recovery of variables in the DivideRelNode,
704+
// we also need to include the lower bound of the original variable here.
705+
auto upper = ir::Add::make(ir::Mul::make(ir::Add::make(this->indexVarToExprMap[outer], 1), dimLen), parentBounds[0]);
706+
auto guard = ir::Gte::make(this->indexVarToExprMap[varToRecover], ir::simplify(upper));
703707
recoverySteps.push_back(IfThenElse::make(guard, ir::Continue::make()));
704708
}
705709
}

test/tests-scheduling.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -990,3 +990,51 @@ TEST(scheduling, divide) {
990990
return stmt.fuse(i, j, f).pos(f, fpos, A(i, j)).divide(fpos, f0, f1, 4).split(f1, i1, i2, 16).split(i2, i3, i4, 8);
991991
});
992992
}
993+
994+
TEST(scheduling, posSplitAndDivide) {
995+
int dim = 10;
996+
Tensor<int> A("A", {dim, dim}, {Dense, Sparse});
997+
Tensor<int> x("x", {dim}, Dense);
998+
999+
auto sparsity = 0.5;
1000+
srand(59393);
1001+
for (int i = 0; i < dim; i++) {
1002+
x.insert({i}, i);
1003+
for (int j = 0; j < dim; j++) {
1004+
auto rand_float = (float)rand()/(float)(RAND_MAX);
1005+
if (rand_float < sparsity) {
1006+
A.insert({i, j},((int)(rand_float * 10 / sparsity)));
1007+
}
1008+
}
1009+
}
1010+
A.pack();
1011+
x.pack();
1012+
1013+
IndexVar i("i"), j("j"), ipos("ipos"), iposo("iposo"), iposi("iposi");
1014+
auto test = [&](std::function<IndexStmt(IndexStmt)> f) {
1015+
Tensor<int> y("y", {dim}, Dense);
1016+
y(i) = A(i, j) * x(j);
1017+
auto stmt = f(y.getAssignment().concretize());
1018+
y.compile(stmt);
1019+
y.evaluate();
1020+
Tensor<int> expected("expected", {dim}, Dense);
1021+
expected(i) = A(i, j) * x(j);
1022+
expected.evaluate();
1023+
ASSERT_TRUE(equals(expected, y)) << expected << endl << y << endl;
1024+
};
1025+
1026+
// TODO (rohany): The old split code did not work prior to this commit on large
1027+
// problem instances from suitesparse, but I'm not able to reproduce the bug on
1028+
// small test cases here.
1029+
// test([&](IndexStmt stmt) {
1030+
// return stmt.pos(j, ipos, A(i, j))
1031+
// .split(ipos, iposo, iposi, 12)
1032+
// ;
1033+
// });
1034+
1035+
test([&](IndexStmt stmt) {
1036+
return stmt.pos(j, ipos, A(i, j))
1037+
.divide(ipos, iposo, iposi, 8)
1038+
;
1039+
});
1040+
}

0 commit comments

Comments
 (0)