|
45 | 45 | #include "llvm/ADT/ScopeExit.h"
|
46 | 46 | #include "llvm/ADT/TypeSwitch.h"
|
47 | 47 | #include "llvm/Support/Debug.h"
|
| 48 | +#include "llvm/Support/LogicalResult.h" |
48 | 49 | #include <type_traits>
|
49 | 50 |
|
50 | 51 | using namespace mlir;
|
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
|
2155 | 2156 | return success();
|
2156 | 2157 | }
|
2157 | 2158 |
|
| 2159 | +//===---------------------------------------------------------------------===// |
| 2160 | +// PadTilingInterfaceOp |
| 2161 | +//===---------------------------------------------------------------------===// |
| 2162 | + |
| 2163 | +void transform::PadTilingInterfaceOp::build(OpBuilder &b, |
| 2164 | + OperationState &result, |
| 2165 | + Value target, |
| 2166 | + ArrayRef<int64_t> paddingDimensions, |
| 2167 | + ArrayRef<int64_t> paddingSizes, |
| 2168 | + bool padToMultipleOf) { |
| 2169 | + auto resultType = transform::AnyOpType::get(b.getContext()); |
| 2170 | + return build(/*builder=*/b, |
| 2171 | + /*result=*/result, |
| 2172 | + /*types=*/TypeRange{resultType, resultType}, |
| 2173 | + /*target=*/target, |
| 2174 | + /*paddingValues=*/ArrayAttr(), // let inference handle this |
| 2175 | + /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions), |
| 2176 | + /*paddingSizes=*/ValueRange{}, |
| 2177 | + /*paddingSizes=*/ |
| 2178 | + (paddingSizes.empty() ? DenseI64ArrayAttr() |
| 2179 | + : b.getDenseI64ArrayAttr(paddingSizes)), |
| 2180 | + /*padToMultipleOf=*/ |
| 2181 | + padToMultipleOf ? b.getUnitAttr() : nullptr); |
| 2182 | +} |
| 2183 | + |
| 2184 | +void transform::PadTilingInterfaceOp::build( |
| 2185 | + OpBuilder &b, OperationState &result, Value target, |
| 2186 | + ArrayRef<int64_t> paddingDimensions, |
| 2187 | + ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) { |
| 2188 | + auto resultType = transform::AnyOpType::get(b.getContext()); |
| 2189 | + SmallVector<int64_t> staticPaddingSizes; |
| 2190 | + SmallVector<Value> dynamicPaddingSizes; |
| 2191 | + dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes, |
| 2192 | + staticPaddingSizes); |
| 2193 | + return build(/*builder=*/b, |
| 2194 | + /*result=*/result, |
| 2195 | + /*types=*/TypeRange{resultType, resultType}, |
| 2196 | + /*target=*/target, |
| 2197 | + /*paddingValues=*/ArrayAttr(), // let inference handle this |
| 2198 | + /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions), |
| 2199 | + /*paddingSizes=*/dynamicPaddingSizes, |
| 2200 | + /*paddingSizes=*/staticPaddingSizes, |
| 2201 | + /*usePrescribedTensorShapes=*/padToMultipleOf); |
| 2202 | +} |
| 2203 | + |
| 2204 | +void transform::PadTilingInterfaceOp::getEffects( |
| 2205 | + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 2206 | + consumesHandle(getTargetMutable(), effects); |
| 2207 | + onlyReadsHandle(getPaddingSizesMutable(), effects); |
| 2208 | + producesHandle(getOperation()->getOpResults(), effects); |
| 2209 | + modifiesPayload(effects); |
| 2210 | +} |
| 2211 | + |
| 2212 | +SmallVector<OpFoldResult> |
| 2213 | +transform::PadTilingInterfaceOp::getMixedPaddingSizes() { |
| 2214 | + Builder b(getContext()); |
| 2215 | + return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b); |
| 2216 | +} |
| 2217 | + |
| 2218 | +DiagnosedSilenceableFailure |
| 2219 | +transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, |
| 2220 | + transform::TransformResults &results, |
| 2221 | + transform::TransformState &state) { |
| 2222 | + SmallVector<Operation *> paddedOps, padOps; |
| 2223 | + |
| 2224 | + for (Operation *target : state.getPayloadOps(getTarget())) { |
| 2225 | + auto targetOp = dyn_cast<TilingInterface>(target); |
| 2226 | + if (!targetOp) { |
| 2227 | + auto diag = emitSilenceableError() << "expected TilingInterface target"; |
| 2228 | + diag.attachNote(target->getLoc()) << "target op"; |
| 2229 | + return diag; |
| 2230 | + } |
| 2231 | + |
| 2232 | + // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand |
| 2233 | + // map / C++ APIs to compute the effect of padding on operands. |
| 2234 | + if (!isa<LinalgOp>(targetOp.getOperation())) { |
| 2235 | + auto diag = emitSilenceableError() << "only LinalgOp supported atm"; |
| 2236 | + diag.attachNote(target->getLoc()) << "target op"; |
| 2237 | + return diag; |
| 2238 | + } |
| 2239 | + |
| 2240 | + // Convert the padding values to attributes. |
| 2241 | + SmallVector<Attribute> paddingValues; |
| 2242 | + for (auto const &[untypedAttr, elementOrTensorType] : |
| 2243 | + llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) { |
| 2244 | + auto attr = dyn_cast<TypedAttr>(untypedAttr); |
| 2245 | + Type elementType = getElementTypeOrSelf(elementOrTensorType); |
| 2246 | + if (!attr) { |
| 2247 | + emitOpError("expects padding values to be typed attributes"); |
| 2248 | + return DiagnosedSilenceableFailure::definiteFailure(); |
| 2249 | + } |
| 2250 | + // Try to parse string attributes to obtain an attribute of element type. |
| 2251 | + if (auto stringAttr = dyn_cast<StringAttr>(attr)) { |
| 2252 | + auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute( |
| 2253 | + stringAttr, getContext(), elementType, |
| 2254 | + /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); |
| 2255 | + if (!parsedAttr || parsedAttr.getType() != elementType) { |
| 2256 | + auto diag = this->emitOpError("expects a padding that parses to ") |
| 2257 | + << elementType << ", got " << attr; |
| 2258 | + diag.attachNote(targetOp.getLoc()) << "when applied to this op"; |
| 2259 | + return DiagnosedSilenceableFailure::definiteFailure(); |
| 2260 | + } |
| 2261 | + paddingValues.push_back(parsedAttr); |
| 2262 | + continue; |
| 2263 | + } |
| 2264 | + // Otherwise, add the attribute directly. |
| 2265 | + if (attr.getType() != elementType) { |
| 2266 | + auto diag = this->emitOpError("expects a padding value of type ") |
| 2267 | + << elementType << ", got " << attr; |
| 2268 | + diag.attachNote(targetOp.getLoc()) << "when applied to this op"; |
| 2269 | + return DiagnosedSilenceableFailure::definiteFailure(); |
| 2270 | + } |
| 2271 | + paddingValues.push_back(attr); |
| 2272 | + } |
| 2273 | + |
| 2274 | + // Set options. |
| 2275 | + TilingInterface paddedOp; |
| 2276 | + PadTilingInterfaceOptions options; |
| 2277 | + options.setPaddingValues(paddingValues) |
| 2278 | + .setPaddingDimensions( |
| 2279 | + extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions())) |
| 2280 | + .setPaddingSizes(getMixedPaddingSizes()) |
| 2281 | + .setPadToMultipleOf(getPadToMultipleOf()); |
| 2282 | + |
| 2283 | + // Apply padding. |
| 2284 | + SmallVector<tensor::PadOp> newPadOps; |
| 2285 | + FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp( |
| 2286 | + rewriter, cast<TilingInterface>(targetOp.getOperation()), options, |
| 2287 | + newPadOps); |
| 2288 | + if (failed(maybePaddedOp)) { |
| 2289 | + auto diag = emitSilenceableError() << "failed to pad op"; |
| 2290 | + diag.attachNote(target->getLoc()) << "target op"; |
| 2291 | + return diag; |
| 2292 | + } |
| 2293 | + |
| 2294 | + // Set transform results. |
| 2295 | + paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation())); |
| 2296 | + padOps.append(newPadOps.begin(), newPadOps.end()); |
| 2297 | + } |
| 2298 | + |
| 2299 | + results.set(cast<OpResult>(getPadded()), paddedOps); |
| 2300 | + results.set(cast<OpResult>(getPad()), padOps); |
| 2301 | + return DiagnosedSilenceableFailure::success(); |
| 2302 | +} |
| 2303 | + |
| 2304 | +LogicalResult transform::PadTilingInterfaceOp::verify() { |
| 2305 | + SmallVector<int64_t> paddingDimensions = |
| 2306 | + extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()); |
| 2307 | + if (any_of(paddingDimensions, |
| 2308 | + [](int64_t paddingDimension) { return paddingDimension < 0; })) { |
| 2309 | + return emitOpError() << "expects padding_dimensions to contain positive " |
| 2310 | + "integers, found " |
| 2311 | + << getPaddingDimensions(); |
| 2312 | + } |
| 2313 | + if (getMixedPaddingSizes().size() != paddingDimensions.size()) { |
| 2314 | + return emitOpError() << "expects as many multiples as padding_dimensions"; |
| 2315 | + } |
| 2316 | + return success(); |
| 2317 | +} |
| 2318 | + |
2158 | 2319 | //===---------------------------------------------------------------------===//
|
2159 | 2320 | // HoistPadOp
|
2160 | 2321 | //===---------------------------------------------------------------------===//
|
|
0 commit comments