@@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23042304 if (!matchPattern (op.getTransposed (), m_TorchConstantBool (&transposed)))
23052305 return rewriter.notifyMatchFailure (
23062306 op, " Unimplemented: non-constant value for transposed not supported" );
2307- if (transposed)
2308- return rewriter.notifyMatchFailure (
2309- op, " Unimplemented: transposed convolution not supported" );
23102307
23112308 auto input = adaptor.getInput ();
23122309 auto weight = adaptor.getWeight ();
@@ -2338,12 +2335,17 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23382335 auto bias = adaptor.getBias ();
23392336
23402337 if (isa<Torch::NoneType>(bias.getType ())) {
2341- auto bias_result = tosa::getConvBiasForNoneType (op, rewriter, inputElemTy,
2342- outputElemTy, weightShape);
2343- if (failed (bias_result))
2338+ SmallVector<int64_t , 4 > biasWeightShape =
2339+ transposed ? SmallVector<int64_t , 4 >{weightShape[1 ], weightShape[0 ],
2340+ weightShape[2 ], weightShape[3 ]}
2341+ : weightShape;
2342+
2343+ auto biasResult = tosa::getConvBiasForNoneType (
2344+ op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
2345+ if (failed (biasResult))
23442346 return rewriter.notifyMatchFailure (
23452347 op, " Failed to create bias tensor for none type." );
2346- bias = bias_result .value ();
2348+ bias = biasResult .value ();
23472349 } else {
23482350 if (!isa<RankedTensorType>(bias.getType ()))
23492351 return rewriter.notifyMatchFailure (
@@ -2370,8 +2372,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23702372 m_TorchListOfConstantInts (padding_2d)))
23712373 return rewriter.notifyMatchFailure (op,
23722374 " non-const padding list unsupported" );
2373- // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
2374- // padding {height, width}. The Torch OFM computation uses 2*pad in each
2375+ // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
2376+ // padding {height, width}. The PyTorch OFM computation uses 2*pad in each
23752377 // spatial direction, implying the same top=bottom=height and left=right=width
23762378 // values for TOSA.
23772379 SmallVector<int64_t > padding (
@@ -2388,11 +2390,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23882390 return rewriter.notifyMatchFailure (
23892391 op, " failed to get accumulator type for convolution ops" );
23902392
2393+ // Weight layout reference:
2394+ // Conv : PyTorch OIHW -> TOSA OHWI
2395+ // Depthwise : PyTorch OIHW* -> TOSA HWIM
2396+ // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier)
2397+ // Grouped : PyTorch O(I/G)HW -> N/A
2398+ // Transposed : PyTorch IOHW -> TOSA OHWI
23912399 // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
23922400 // Perform the necessary transformations.
23932401 SmallVector<int32_t > nchwToNhwcDims ({0 , 2 , 3 , 1 });
2394- SmallVector<int64_t > transposedInputShape (
2395- {inputShape[0 ], inputShape[2 ], inputShape[3 ], inputShape[1 ]});
2402+ SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
2403+ SmallVector<int64_t , 4 > transposedInputShape;
2404+ for (int32_t dim : nchwToNhwcDims)
2405+ transposedInputShape.push_back (inputShape[dim]);
23962406 auto transposedInputType = RankedTensorType::get (
23972407 makeShapeLLVMCompatible (transposedInputShape), inputElemTy);
23982408 auto transposedInput =
@@ -2403,6 +2413,104 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24032413 rewriter.getDenseI32ArrayAttr (nchwToNhwcDims))
24042414 .getResult ();
24052415
2416+ if (transposed) {
2417+ if (groups != 1 )
2418+ return rewriter.notifyMatchFailure (
2419+ op, " Unimplemented: grouped transposed convolution not supported by "
2420+ " TOSA" );
2421+ if (dilation[0 ] != 1 || dilation[1 ] != 1 )
2422+ return rewriter.notifyMatchFailure (
2423+ op, " Unimplemented: dilated transposed convolution not supported by "
2424+ " TOSA" );
2425+
2426+ SmallVector<int32_t > iohwToOhwi ({1 , 2 , 3 , 0 });
2427+ SmallVector<int64_t , 4 > ohwiWeightShape;
2428+ for (int32_t dim : iohwToOhwi)
2429+ ohwiWeightShape.push_back (weightShape[dim]);
2430+ auto ohwiWeightType = RankedTensorType::get (
2431+ makeShapeLLVMCompatible (ohwiWeightShape), weightElemTy);
2432+ Value transformedWeight =
2433+ rewriter
2434+ .create <tosa::TransposeOp>(
2435+ op->getLoc (), getTypeConverter ()->convertType (ohwiWeightType),
2436+ weight, rewriter.getDenseI32ArrayAttr (iohwToOhwi))
2437+ .getResult ();
2438+
2439+ // TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
2440+ // Map from PyTorch's (padding, output_padding):
2441+ // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
2442+ // Negative values are allowed and will be handled by the TOSA
2443+ // decomposition.
2444+ SmallVector<int64_t , 2 > outPadding2D;
2445+ if (!matchPattern (adaptor.getOutputPadding (),
2446+ m_TorchListOfConstantInts (outPadding2D)))
2447+ return rewriter.notifyMatchFailure (
2448+ op, " non-const output_padding list unsupported for transposed conv" );
2449+
2450+ int64_t outPadH = outPadding2D[0 ] - 2 * padding_2d[0 ];
2451+ int64_t outPadW = outPadding2D[1 ] - 2 * padding_2d[1 ];
2452+ int64_t outPadTop = outPadH / 2 ;
2453+ int64_t outPadBottom = outPadH - outPadTop;
2454+ int64_t outPadLeft = outPadW / 2 ;
2455+ int64_t outPadRight = outPadW - outPadLeft;
2456+ SmallVector<int64_t , 4 > outPad (
2457+ {outPadTop, outPadBottom, outPadLeft, outPadRight});
2458+
2459+ // Result type is NHWC (we'll transpose back).
2460+ auto outNCHW = makeShapeTorchCompatible (outputTy.getShape ());
2461+ SmallVector<int64_t , 4 > outNHWC;
2462+ for (int32_t dim : nchwToNhwcDims)
2463+ outNHWC.push_back (outNCHW[dim]);
2464+ auto transConvOpTy =
2465+ RankedTensorType::get (makeShapeLLVMCompatible (outNHWC), biasElemTy);
2466+
2467+ // Zero-points.
2468+ auto zps = tosa::createZPsAsConst (rewriter, input, weight);
2469+ Value inputZp = zps.first ? zps.first
2470+ : tosa::createZeroPointTensor (
2471+ rewriter, op->getLoc (), inputElemTy, 0 )
2472+ .value ();
2473+ Value weightZp = zps.second ? zps.second
2474+ : tosa::createZeroPointTensor (
2475+ rewriter, op->getLoc (), weightElemTy, 0 )
2476+ .value ();
2477+
2478+ Value convTOut =
2479+ rewriter
2480+ .create <tosa::TransposeConv2DOp>(
2481+ op->getLoc (), getTypeConverter ()->convertType (transConvOpTy),
2482+ transposedInput, transformedWeight, bias, inputZp, weightZp,
2483+ rewriter.getDenseI64ArrayAttr (outPad),
2484+ rewriter.getDenseI64ArrayAttr (stride), accType)
2485+ .getResult ();
2486+
2487+ SmallVector<int64_t , 4 > transposedOutputShape;
2488+ for (int32_t dim : nhwcToNchwDims)
2489+ transposedOutputShape.push_back (outNHWC[dim]);
2490+ auto transposedOutputType = RankedTensorType::get (
2491+ makeShapeLLVMCompatible (transposedOutputShape), biasElemTy);
2492+ Value transposedOutput =
2493+ rewriter
2494+ .create <tosa::TransposeOp>(
2495+ op->getLoc (),
2496+ getTypeConverter ()->convertType (transposedOutputType), convTOut,
2497+ rewriter.getDenseI32ArrayAttr (nhwcToNchwDims))
2498+ .getResult ();
2499+
2500+ // Quantized rescale.
2501+ Value rescaledResult = transposedOutput;
2502+ if (isa<quant::QuantizedType>(inputElemTy)) {
2503+ rescaledResult = tosa::buildRescaleOpConvOutput (
2504+ rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2505+ }
2506+
2507+ // Final cast to requested output type.
2508+ rewriter.replaceOp (
2509+ op, {tosa::tosaCastTensorToType (rewriter, rescaledResult, outputTy)
2510+ .value ()});
2511+ return success ();
2512+ }
2513+
24062514 SmallVector<int64_t > transformedWeightShape;
24072515 RankedTensorType transformedWeightType;
24082516 Value transformedWeight;
@@ -2487,7 +2595,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24872595 if (remainderHDim != 0 ) {
24882596 if (remainderHDim > padding[1 ]) {
24892597 SmallVector<int64_t > startHSlice (inputTy.getRank (), 0 );
2490- SmallVector<int64_t > sizeHSlice (transposedInputShape);
2598+ SmallVector<int64_t , 4 > sizeHSlice (transposedInputShape);
24912599 // TOSA uses NHWC, so we will slice dim 1 for Height value
24922600 sizeHSlice[1 ] = inputHDim - (remainderHDim - padding[1 ]);
24932601 transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
@@ -2583,7 +2691,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25832691 llvm_unreachable (" Unhandled convolution type" );
25842692 }
25852693
2586- SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
25872694 SmallVector<int64_t > transposedOutputShape (
25882695 {outputShape[0 ], outputShape[3 ], outputShape[1 ], outputShape[2 ]});
25892696 auto transposedOutputType = RankedTensorType::get (
0 commit comments