Skip to content

Commit 6db8cd6

Browse files
committed
[TOSA] Add transposed conv support
Lower aten.conv_transpose2d into tosa.transpose_conv2d. Refresh FX importer TOSA xfails to drop the transpose-conv cases that now pass, and document the weight layout mapping. Change-Id: I709579e40a1ccaf9b9188392c7c78fcb653109ce
1 parent 8b77de9 commit 6db8cd6

File tree

4 files changed

+140
-23
lines changed

4 files changed

+140
-23
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 120 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23042304
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
23052305
return rewriter.notifyMatchFailure(
23062306
op, "Unimplemented: non-constant value for transposed not supported");
2307-
if (transposed)
2308-
return rewriter.notifyMatchFailure(
2309-
op, "Unimplemented: transposed convolution not supported");
23102307

23112308
auto input = adaptor.getInput();
23122309
auto weight = adaptor.getWeight();
@@ -2338,12 +2335,17 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23382335
auto bias = adaptor.getBias();
23392336

23402337
if (isa<Torch::NoneType>(bias.getType())) {
2341-
auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy,
2342-
outputElemTy, weightShape);
2343-
if (failed(bias_result))
2338+
SmallVector<int64_t, 4> biasWeightShape =
2339+
transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0],
2340+
weightShape[2], weightShape[3]}
2341+
: weightShape;
2342+
2343+
auto biasResult = tosa::getConvBiasForNoneType(
2344+
op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
2345+
if (failed(biasResult))
23442346
return rewriter.notifyMatchFailure(
23452347
op, "Failed to create bias tensor for none type.");
2346-
bias = bias_result.value();
2348+
bias = biasResult.value();
23472349
} else {
23482350
if (!isa<RankedTensorType>(bias.getType()))
23492351
return rewriter.notifyMatchFailure(
@@ -2370,8 +2372,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23702372
m_TorchListOfConstantInts(padding_2d)))
23712373
return rewriter.notifyMatchFailure(op,
23722374
"non-const padding list unsupported");
2373-
// TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
2374-
// padding {height, width}. The Torch OFM computation uses 2*pad in each
2375+
// TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
2376+
// padding {height, width}. The PyTorch OFM computation uses 2*pad in each
23752377
// spatial direction, implying the same top=bottom=height and left=right=width
23762378
// values for TOSA.
23772379
SmallVector<int64_t> padding(
@@ -2388,11 +2390,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23882390
return rewriter.notifyMatchFailure(
23892391
op, "failed to get accumulator type for convolution ops");
23902392

2393+
// Weight layout reference:
2394+
// Conv : PyTorch OIHW -> TOSA OHWI
2395+
// Depthwise : PyTorch OIHW* -> TOSA HWIM
2396+
// (PyTorch depthwise uses out_ch=in_ch*depth_multiplier)
2397+
// Grouped : PyTorch O(I/G)HW -> N/A
2398+
// Transposed : PyTorch IOHW -> TOSA OHWI
23912399
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
23922400
// Perform the necessary transformations.
23932401
SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1});
2394-
SmallVector<int64_t> transposedInputShape(
2395-
{inputShape[0], inputShape[2], inputShape[3], inputShape[1]});
2402+
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
2403+
SmallVector<int64_t, 4> transposedInputShape;
2404+
for (int32_t dim : nchwToNhwcDims)
2405+
transposedInputShape.push_back(inputShape[dim]);
23962406
auto transposedInputType = RankedTensorType::get(
23972407
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
23982408
auto transposedInput =
@@ -2403,6 +2413,104 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24032413
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
24042414
.getResult();
24052415

2416+
if (transposed) {
2417+
if (groups != 1)
2418+
return rewriter.notifyMatchFailure(
2419+
op, "Unimplemented: grouped transposed convolution not supported by "
2420+
"TOSA");
2421+
if (dilation[0] != 1 || dilation[1] != 1)
2422+
return rewriter.notifyMatchFailure(
2423+
op, "Unimplemented: dilated transposed convolution not supported by "
2424+
"TOSA");
2425+
2426+
SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0});
2427+
SmallVector<int64_t, 4> ohwiWeightShape;
2428+
for (int32_t dim : iohwToOhwi)
2429+
ohwiWeightShape.push_back(weightShape[dim]);
2430+
auto ohwiWeightType = RankedTensorType::get(
2431+
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
2432+
Value transformedWeight =
2433+
rewriter
2434+
.create<tosa::TransposeOp>(
2435+
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
2436+
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
2437+
.getResult();
2438+
2439+
// TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
2440+
// Map from PyTorch's (padding, output_padding):
2441+
// out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
2442+
// Negative values are allowed and will be handled by the TOSA
2443+
// decomposition.
2444+
SmallVector<int64_t, 2> outPadding2D;
2445+
if (!matchPattern(adaptor.getOutputPadding(),
2446+
m_TorchListOfConstantInts(outPadding2D)))
2447+
return rewriter.notifyMatchFailure(
2448+
op, "non-const output_padding list unsupported for transposed conv");
2449+
2450+
int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0];
2451+
int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1];
2452+
int64_t outPadTop = outPadH / 2;
2453+
int64_t outPadBottom = outPadH - outPadTop;
2454+
int64_t outPadLeft = outPadW / 2;
2455+
int64_t outPadRight = outPadW - outPadLeft;
2456+
SmallVector<int64_t, 4> outPad(
2457+
{outPadTop, outPadBottom, outPadLeft, outPadRight});
2458+
2459+
// Result type is NHWC (we'll transpose back).
2460+
auto outNCHW = makeShapeTorchCompatible(outputTy.getShape());
2461+
SmallVector<int64_t, 4> outNHWC;
2462+
for (int32_t dim : nchwToNhwcDims)
2463+
outNHWC.push_back(outNCHW[dim]);
2464+
auto transConvOpTy =
2465+
RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy);
2466+
2467+
// Zero-points.
2468+
auto zps = tosa::createZPsAsConst(rewriter, input, weight);
2469+
Value inputZp = zps.first ? zps.first
2470+
: tosa::createZeroPointTensor(
2471+
rewriter, op->getLoc(), inputElemTy, 0)
2472+
.value();
2473+
Value weightZp = zps.second ? zps.second
2474+
: tosa::createZeroPointTensor(
2475+
rewriter, op->getLoc(), weightElemTy, 0)
2476+
.value();
2477+
2478+
Value convTOut =
2479+
rewriter
2480+
.create<tosa::TransposeConv2DOp>(
2481+
op->getLoc(), getTypeConverter()->convertType(transConvOpTy),
2482+
transposedInput, transformedWeight, bias, inputZp, weightZp,
2483+
rewriter.getDenseI64ArrayAttr(outPad),
2484+
rewriter.getDenseI64ArrayAttr(stride), accType)
2485+
.getResult();
2486+
2487+
SmallVector<int64_t, 4> transposedOutputShape;
2488+
for (int32_t dim : nhwcToNchwDims)
2489+
transposedOutputShape.push_back(outNHWC[dim]);
2490+
auto transposedOutputType = RankedTensorType::get(
2491+
makeShapeLLVMCompatible(transposedOutputShape), biasElemTy);
2492+
Value transposedOutput =
2493+
rewriter
2494+
.create<tosa::TransposeOp>(
2495+
op->getLoc(),
2496+
getTypeConverter()->convertType(transposedOutputType), convTOut,
2497+
rewriter.getDenseI32ArrayAttr(nhwcToNchwDims))
2498+
.getResult();
2499+
2500+
// Quantized rescale.
2501+
Value rescaledResult = transposedOutput;
2502+
if (isa<quant::QuantizedType>(inputElemTy)) {
2503+
rescaledResult = tosa::buildRescaleOpConvOutput(
2504+
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2505+
}
2506+
2507+
// Final cast to requested output type.
2508+
rewriter.replaceOp(
2509+
op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy)
2510+
.value()});
2511+
return success();
2512+
}
2513+
24062514
SmallVector<int64_t> transformedWeightShape;
24072515
RankedTensorType transformedWeightType;
24082516
Value transformedWeight;
@@ -2487,7 +2595,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24872595
if (remainderHDim != 0) {
24882596
if (remainderHDim > padding[1]) {
24892597
SmallVector<int64_t> startHSlice(inputTy.getRank(), 0);
2490-
SmallVector<int64_t> sizeHSlice(transposedInputShape);
2598+
SmallVector<int64_t, 4> sizeHSlice(transposedInputShape);
24912599
// TOSA uses NHWC, so we will slice dim 1 for Height value
24922600
sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]);
24932601
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
@@ -2583,7 +2691,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25832691
llvm_unreachable("Unhandled convolution type");
25842692
}
25852693

2586-
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
25872694
SmallVector<int64_t> transposedOutputShape(
25882695
{outputShape[0], outputShape[3], outputShape[1], outputShape[2]});
25892696
auto transposedOutputType = RankedTensorType::get(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3588,7 +3588,6 @@
35883588
"AvgPool3dCountIncludePadFalseWithoutPadding_basic",
35893589
"Conv_Transpose1dModule_basic",
35903590
"Conv_Transpose1dStaticModule_basic",
3591-
"Conv_Transpose2dStaticModule_basic",
35923591
"Conv_Transpose3dModule_basic",
35933592
"Conv_Transpose3dStaticModule_basic",
35943593
"IndexPutWithNoneAndBroadcastModule_basic",
@@ -3713,16 +3712,11 @@
37133712
"Conv3dWithValidPaddingModule_basic",
37143713
"ConvTbcModule_basic",
37153714
"ConvTranspose2DQInt8_basic",
3716-
"Conv_Transpose2dModule_basic",
37173715
"ConvolutionBackwardModule2DPadded_basic",
3718-
"ConvolutionBackwardModule2DStatic_basic",
37193716
"ConvolutionBackwardModule2DStrided_basic",
37203717
"ConvolutionBackwardModule2D_basic",
37213718
"ConvolutionModule2DGroups_basic",
37223719
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
3723-
"ConvolutionModule2DTransposeStridedStatic_basic",
3724-
"ConvolutionModule2DTransposeStrided_basic",
3725-
"ConvolutionModule2DTranspose_basic",
37263720
"ConvolutionModule2DGroupedTranspose_basic",
37273721
"ConvolutionModule3DGroups_basic",
37283722
"ConvolutionModule3DGroupsStrided_basic",

projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
# that depend on TOSA as well as TOSA-to-Standard.
3030
"tosa-to-arith",
3131
"tosa-to-scf",
32+
# Required for transposed convolution support (decomposes to conv ops).
33+
"tosa-optional-decompositions",
3234
# Named ops must be legalized prior to general tosa-to-linalg
3335
"tosa-to-linalg-named",
3436
# TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them
Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
1-
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file | FileCheck %s
22

3-
// The following test ensures that a tranposed convolution op is not
4-
// lowered in the torch-to-tosa conversion pass.
3+
// The lowering now legalizes transpose convolutions into the TOSA dialect.
4+
// Verify that we emit tosa.transpose_conv2d with the expected reshapes/
5+
// permutations.
56

7+
// CHECK-LABEL: func.func @forward
8+
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
9+
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[1,64,1,100],f32> -> tensor<1x64x1x100xf32>
10+
// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64x64x3x3xf32>}> : () -> tensor<64x64x3x3xf32>
11+
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32>
12+
// CHECK: %[[TRANS_IN:.*]] = tosa.transpose %[[IN_TENSOR]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x64x1x100xf32>) -> tensor<1x1x100x64xf32>
13+
// CHECK: %[[W_OHWI:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 1, 2, 3, 0>} : (tensor<64x64x3x3xf32>) -> tensor<64x3x3x64xf32>
14+
// CHECK: %[[ZP0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
15+
// CHECK: %[[ZP1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
16+
// CHECK: %[[TCONV:.*]] = tosa.transpose_conv2d %[[TRANS_IN]], %[[W_OHWI]], %[[BIAS]], %[[ZP0]], %[[ZP1]] {acc_type = f32, out_pad = array<i64: 0, -1, 0, -1>, stride = array<i64: 2, 2>} : (tensor<1x1x100x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x200x64xf32>
17+
// CHECK: %[[TRANS_OUT:.*]] = tosa.transpose %[[TCONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x2x200x64xf32>) -> tensor<1x64x2x200xf32>
18+
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[TRANS_OUT]] : tensor<1x64x2x200xf32> -> !torch.vtensor<[1,64,2,200],f32>
19+
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,64,2,200],f32>
20+
// CHECK: }
621
func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
722
%true = torch.constant.bool true
823
%int1 = torch.constant.int 1
@@ -11,7 +26,6 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[
1126
%bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
1227
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
1328
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
14-
// expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}}
1529
%output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>
1630
return %output : !torch.vtensor<[1,64,2,200],f32>
1731
}

0 commit comments

Comments
 (0)