-
Notifications
You must be signed in to change notification settings - Fork 615
fix(ONNX): avoids resizing unsupported dimensions #3945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c0e47ff
bf60cfa
82f6445
50c2529
e4f4425
bd23b93
87f9f54
874827d
1d77eb9
e41fa62
140b628
fff08f2
1f7cdf0
e835f1a
a858e45
3f41467
948a53e
fd20a79
b897a34
01e2274
266a820
c9f2197
6dc3fdf
a230084
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2700,12 +2700,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( | |
}); | ||
patterns.onOp( | ||
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { | ||
Torch::ValueTensorType resultType; | ||
Torch::ValueTensorType outputTensorType; | ||
llvm::SmallVector<Value> operands; | ||
std::string mode, nearest_mode, coordTfMode; | ||
int64_t antialias, exclude_outside; | ||
float extrapolation_value, cubic_coeff_a; | ||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc()); | ||
|
||
if (auto attr = binder.op->getAttr("torch.onnx.axes")) { | ||
return rewriter.notifyMatchFailure( | ||
|
@@ -2720,7 +2719,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( | |
} | ||
|
||
if (binder.tensorOperandsList(operands) || | ||
binder.tensorResultType(resultType) || | ||
binder.tensorResultType(outputTensorType) || | ||
binder.customOpNameStringAttr(mode, "mode", "nearest") || | ||
binder.customOpNameStringAttr( | ||
coordTfMode, "coordinate_transformation_mode", "half_pixel") || | ||
|
@@ -2732,6 +2731,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( | |
"round_prefer_floor") || | ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) | ||
return failure(); | ||
|
||
int64_t const /* */ batchDim = 0; | ||
int64_t const /**/ channelDim = 1; | ||
|
||
SmallVector<int64_t> nonResizableDims{ | ||
batchDim, | ||
channelDim, | ||
}; | ||
|
||
Value inputTensor = operands[0]; | ||
auto inputTensorType = | ||
cast<Torch::BaseTensorType>(inputTensor.getType()); | ||
auto sizesOfInputTensor = inputTensorType.getSizes(); | ||
auto sizesOfOutputTensor = outputTensorType.getSizes(); | ||
|
||
auto unknownSize = Torch::kUnknownSize; | ||
|
||
// Compile-time check for dimensions of static size | ||
for (auto &eachDim : nonResizableDims) { | ||
auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim]; | ||
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim]; | ||
|
||
if (eachSizeOfInputTensor == unknownSize || | ||
eachSizeOfOutputTensor == unknownSize) | ||
continue; | ||
if (eachSizeOfInputTensor == eachSizeOfOutputTensor) | ||
continue; | ||
|
||
auto resizingIntentErrorMessage = | ||
"unsupported: non-trivial intent to resize dimension: " + | ||
std::to_string(eachDim); | ||
|
||
return rewriter.notifyMatchFailure(binder.op, | ||
resizingIntentErrorMessage); | ||
}; | ||
|
||
if (antialias != 0) { | ||
return rewriter.notifyMatchFailure( | ||
binder.op, | ||
|
@@ -2764,35 +2799,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( | |
binder.op, "unimplemented: cubic coeff must be -0.75"); | ||
} | ||
|
||
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType()) | ||
.getSizes() | ||
.size(); | ||
auto loc = binder.getLoc(); | ||
|
||
Value cstFalse = | ||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false); | ||
Value cstTrue = | ||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true); | ||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false); | ||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true); | ||
Value modeStrValue; | ||
|
||
Value scalesValueList = noneVal; | ||
Value sizesValueList = noneVal; | ||
Value alignCorners = | ||
coordTfMode == "align_corners" ? cstTrue : cstFalse; | ||
if (mode == "cubic") { | ||
std::string modeStr = "cubic"; | ||
if (coordTfMode != "half_pixel") | ||
modeStr = modeStr + "_" + coordTfMode; | ||
modeStrValue = | ||
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr); | ||
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr); | ||
} | ||
|
||
auto rankOfInputTensor = sizesOfInputTensor.size(); | ||
|
||
// supported modes: | ||
// bilinear (half_pixel), bilinear with align_corners, | ||
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest | ||
// (asymmetric), nearest with align_corners, nearest_half_pixel, | ||
// nearest_pytorch_half_pixel | ||
if (mode == "linear") { | ||
std::string modeStr; | ||
switch (rank) { | ||
switch (rankOfInputTensor) { | ||
case 3: | ||
modeStr = "linear"; | ||
break; | ||
|
@@ -2809,8 +2840,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( | |
// mode is apparently half_pixel, NOT pytorch_half_pixel | ||
if (coordTfMode != "half_pixel" && coordTfMode != "align_corners") | ||
modeStr = (modeStr + "_") + coordTfMode; | ||
modeStrValue = | ||
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr); | ||
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr); | ||
} | ||
if (mode == "nearest") { | ||
std::string modeStr = "nearest"; | ||
|
@@ -2820,33 +2850,84 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( | |
modeStr = (modeStr + "_") + coordTfMode; | ||
if (nearest_mode != "floor" && nearest_mode != "") | ||
modeStr = modeStr + "," + nearest_mode; | ||
modeStrValue = | ||
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr); | ||
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr); | ||
} | ||
|
||
int64_t assumedForemostSpatialDim = 2; | ||
auto numberOfOperands = operands.size(); | ||
|
||
if (operands.size() < 4) { | ||
Value scaleOperand = operands[2]; | ||
scalesValueList = | ||
createScalarSublist(binder.getLoc(), scaleOperand, | ||
assumedForemostSpatialDim, rewriter); | ||
sizesValueList = noneVal; | ||
} else { | ||
Value sizeOperand = operands[3]; | ||
scalesValueList = noneVal; | ||
sizesValueList = | ||
createScalarSublist(binder.getLoc(), sizeOperand, | ||
assumedForemostSpatialDim, rewriter); | ||
} | ||
if (isa<Torch::NoneType>(scalesValueList.getType()) && | ||
isa<Torch::NoneType>(sizesValueList.getType())) { | ||
Type boolType = rewriter.getType<Torch::BoolType>(); | ||
|
||
int64_t assumedForemostSpatialDim = 1 + nonResizableDims.back(); | ||
|
||
Value supportedScaleFactors; | ||
Value supportedSizes; | ||
|
||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(loc); | ||
|
||
if (numberOfOperands == 3) { | ||
Value proposedScaleFactors = operands[2]; | ||
|
||
Value scaleIdentity = rewriter.create<Torch::ConstantFloatOp>( | ||
loc, rewriter.getF64FloatAttr(1.0)); | ||
Comment on lines
+2870
to
+2871
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI this appears to have caused some test failures downstream in the IREE project on iree-org/iree#19976. I did not bisect to this specific change or line of code, but this looked most relevant. These are the logs: https://github.com/iree-org/iree/actions/runs/13292751088/job/37117771168?pr=19976#step:8:50
By default, IREE demotes f64 to f32 as 64 bits of precision is rarely needed in ML models and many hardware targets either do not support f64 at all or support it with significant performance penalties. The tests there do override that default by setting Is f64 needed here, or would f32 work? I see lots of uses of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More context: some of the tests in the ONNX test suite require f64, which is why we run the tests without f64 to f32 demotion: iree-org/iree#18111. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We dont need f64, this is a small bug with the changes. Will post a quick fix in a minute. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I remember correctly when writing this, using f32 for @zjgarvey Any insights here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, F64 is the correct attr type for constant float ops. I'll take a look at the test failures. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like a simple issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I re-ran the iree tests with #4022 The failing tests go back to passing with that change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, thanks! |
||
|
||
// run-time scale factor check for dynamic sizes | ||
for (auto &eachDim : nonResizableDims) { | ||
Value eachProposedScaleFactor = extractTorchScalar( | ||
loc, eachDim, proposedScaleFactors, rewriter); | ||
|
||
Value eachScaleFactorIsIdentity = | ||
rewriter.create<Torch::AtenEqFloatOp>( | ||
loc, boolType, eachProposedScaleFactor, scaleIdentity); | ||
|
||
auto errorMessageForEachDim = | ||
"Unsupported: non-trivial scale factor for dimension " + | ||
std::to_string(eachDim); | ||
|
||
rewriter.create<Torch::RuntimeAssertOp>( | ||
loc, eachScaleFactorIsIdentity, | ||
rewriter.getStringAttr(errorMessageForEachDim)); | ||
}; | ||
|
||
supportedScaleFactors = createScalarSublist( | ||
loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter); | ||
supportedSizes = noneVal; | ||
} else if (numberOfOperands == 4) { | ||
Value proposedSizes = operands[3]; | ||
|
||
// run-time target size check for dynamic sizes | ||
for (auto &eachDimAsInt : nonResizableDims) { | ||
Value eachDimAsValue = | ||
rewriter.create<Torch::ConstantIntOp>(loc, eachDimAsInt); | ||
|
||
Value eachSizeOfInputTensor = rewriter.create<Torch::AtenSizeIntOp>( | ||
loc, inputTensor, eachDimAsValue); | ||
|
||
Value eachProposedSize = | ||
extractTorchScalar(loc, eachDimAsInt, proposedSizes, rewriter); | ||
|
||
Value eachProposedSizeIsTrivial = | ||
rewriter.create<Torch::AtenEqIntOp>( | ||
loc, boolType, eachProposedSize, eachSizeOfInputTensor); | ||
|
||
auto errorMessageForEachDim = | ||
"Unsupported: non-trivial resizing of dimension " + | ||
std::to_string(eachDimAsInt); | ||
|
||
rewriter.create<Torch::RuntimeAssertOp>( | ||
loc, eachProposedSizeIsTrivial, | ||
rewriter.getStringAttr(errorMessageForEachDim)); | ||
}; | ||
|
||
supportedScaleFactors = noneVal; | ||
supportedSizes = createScalarSublist( | ||
loc, proposedSizes, assumedForemostSpatialDim, rewriter); | ||
} else | ||
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); | ||
} | ||
|
||
rewriter | ||
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>( | ||
binder.op, resultType, operands[0], sizesValueList, | ||
scalesValueList, modeStrValue, | ||
binder.op, outputTensorType, inputTensor, supportedSizes, | ||
supportedScaleFactors, modeStrValue, | ||
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners, | ||
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, | ||
/*Torch_BoolType:$antialias*/ cstFalse); | ||
|
Uh oh!
There was an error while loading. Please reload this page.