@@ -212,11 +212,16 @@ std::vector<ir::Expr> SplitRelNode::deriveIterBounds(taco::IndexVar indexVar,
212
212
std::vector<ir::Expr> parentBound = parentIterBounds.at (getParentVar ());
213
213
Datatype splitFactorType = parentBound[0 ].type ();
214
214
if (indexVar == getOuterVar ()) {
215
- ir::Expr minBound = ir::Div::make (parentBound[0 ], ir::Literal::make (getSplitFactor (), splitFactorType));
216
- ir::Expr maxBound = ir::Div::make (ir::Add::make (parentBound[1 ], ir::Literal::make (getSplitFactor ()-1 , splitFactorType)), ir::Literal::make (getSplitFactor (), splitFactorType));
215
+ // The outer variable must always range from 0 to the extent of the bounds (chunked up).
216
+ // This is a noop for the common case where all of our loops start at 0. However, it is
217
+ // important when doing a pos split, where the resulting bounds may not start at 0,
218
+ // and instead start at a position like T_pos[i].lo.
219
+ ir::Expr minBound = 0 ;
220
+ auto upper = ir::Sub::make (parentBound[1 ], parentBound[0 ]);
221
+ ir::Expr maxBound = ir::Div::make (ir::Add::make (upper, ir::Literal::make (getSplitFactor () - 1 , splitFactorType)),
222
+ ir::Literal::make (getSplitFactor (), splitFactorType));
217
223
return {minBound, maxBound};
218
- }
219
- else if (indexVar == getInnerVar ()) {
224
+ } else if (indexVar == getInnerVar ()) {
220
225
ir::Expr minBound = 0 ;
221
226
ir::Expr maxBound = ir::Literal::make (getSplitFactor (), splitFactorType);
222
227
return {minBound, maxBound};
@@ -231,7 +236,12 @@ ir::Expr SplitRelNode::recoverVariable(taco::IndexVar indexVar,
231
236
taco_iassert (indexVar == getParentVar ());
232
237
taco_iassert (variableNames.count (getParentVar ()) && variableNames.count (getOuterVar ()) && variableNames.count (getInnerVar ()));
233
238
Datatype splitFactorType = variableNames[getParentVar ()].type ();
234
- return ir::Add::make (ir::Mul::make (variableNames[getOuterVar ()], ir::Literal::make (getSplitFactor (), splitFactorType)), variableNames[getInnerVar ()]);
239
+ // Include the lower bound of the variable being recovered. Normally, this is 0, but
240
+ // in the case of a position split it is not.
241
+ return ir::Add::make (
242
+ ir::Add::make (ir::Mul::make (variableNames[getOuterVar ()], ir::Literal::make (getSplitFactor (), splitFactorType)), variableNames[getInnerVar ()]),
243
+ parentIterBounds[indexVar][0 ]
244
+ );
235
245
}
236
246
237
247
ir::Stmt SplitRelNode::recoverChild (taco::IndexVar indexVar,
@@ -364,11 +374,12 @@ std::vector<ir::Expr> DivideRelNode::deriveIterBounds(taco::IndexVar indexVar,
364
374
ir::Expr minBound = 0 ;
365
375
ir::Expr maxBound = divFactor;
366
376
return {minBound, maxBound};
367
- }
368
- else if (indexVar == getInnerVar ()) {
369
- // The inner loop ranges over a chunk of size parentBound / divFactor.
370
- ir::Expr minBound = ir::Div::make (parentBound[0 ], divFactor);
371
- ir::Expr maxBound = ir::Div::make (ir::Add::make (parentBound[1 ], ir::Literal::make (getDivFactor ()-1 , divFactorType)), divFactor);
377
+ } else if (indexVar == getInnerVar ()) {
378
+ // The inner loop ranges over a chunk of size parentBound / divFactor. Similarly
379
+ // to split, the loop must range from 0 to parentBound.
380
+ ir::Expr minBound = 0 ;
381
+ auto upper = ir::Sub::make (parentBound[1 ], parentBound[0 ]);
382
+ ir::Expr maxBound = ir::Div::make (ir::Add::make (upper, ir::Literal::make (getDivFactor ()-1 , divFactorType)), divFactor);
372
383
return {minBound, maxBound};
373
384
}
374
385
taco_ierror;
@@ -390,7 +401,10 @@ ir::Expr DivideRelNode::recoverVariable(taco::IndexVar indexVar,
390
401
// The bounds for the dimension are adjusted so that dimensions that aren't
391
402
// divisible by divFactor have the last piece included.
392
403
auto bounds = ir::Div::make (ir::Add::make (dimSize, divFactorMinusOne), divFactor);
393
- return ir::Add::make (ir::Mul::make (variableNames[getOuterVar ()], bounds), variableNames[getInnerVar ()]);
404
+ // We multiply this all together and then add on the base of the parentBounds
405
+ // to shift up into the range of the parent. This is normally 0, but for cases
406
+ // like position loops it is not.
407
+ return ir::Add::make (ir::Add::make (ir::Mul::make (variableNames[getOuterVar ()], bounds), variableNames[getInnerVar ()]), parentIterBounds[indexVar][0 ]);
394
408
}
395
409
396
410
ir::Stmt DivideRelNode::recoverChild (taco::IndexVar indexVar,
0 commit comments