-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][bufferization] Return BufferLikeType in BufferizableOpInterface #144867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize into custom types. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType. Affected implementors of the interface are update accordingly.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Andrei Golubev (andrey-golubev) ChangesSupport custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType. Affected implementors of the interface are updated accordingly. Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+ return getBuffer().getType();
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (!bufferType)
return op->emitOpError("could not infer buffer type of block argument");
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
- return *trueType;
+ return cast<BufferLikeType>(*trueType);
if (trueType->getMemorySpace() != falseType->getMemorySpace())
return op->emitError("inconsistent memory space on true/false operands");
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
- memrefType.getMemorySpace());
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
return AliasingOpOperandList(std::move(result));
}
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
auto tensorType = cast<TensorType>(value.getType());
// No further analysis is possible for a block argument.
- if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(tensorType, options);
+ if (llvm::isa<BlockArgument>(value)) {
+ return cast<BufferLikeType>(
+ bufferization::getMemRefType(tensorType, options));
+ }
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return asMemRefType(getBufferType(equivalentOperand, options,
- bufferizationState, invocationStack));
+ return getBufferType(equivalentOperand, options, bufferizationState,
+ invocationStack);
}
// If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+ return cast<BufferLikeType>(
+ getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
return {};
}
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
- return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+ return cast<BufferLikeType>(
+ getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
return result;
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
- return bufferizedType;
+ return cast<BufferLikeType>(bufferizedType);
// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorType>(resultType);
- return options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+ return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+ options));
}
/// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
- return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
- options);
+ return cast<BufferLikeType>(
+ getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
// Best case: Both branches have the exact same buffer type.
if (thenBufferType == elseBufferType)
- return thenBufferType;
+ return cast<BufferLikeType>(thenBufferType);
// Memory space mismatch.
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
return op->emitError("inconsistent memory space on then/else branches");
// Layout maps are different: Promote to fully dynamic layout map.
- return getMemRefTypeWithFullyDynamicLayout(
- cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
}
};
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
}
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
};
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// If both buffer types are equal, no casts are needed the computed buffer type
/// can be used directly. Otherwise, the buffer types can only differ in their
/// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
const BufferizationOptions &options, const BufferizationState &state,
SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
- auto initArgBufferType = bufferization::detail::asMemRefType(
- bufferization::getBufferType(initArg, options, state, invocationStack));
+ auto initArgBufferType =
+ bufferization::getBufferType(initArg, options, state, invocationStack);
if (failed(initArgBufferType))
return failure();
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
}
// Compute the buffer type of the yielded value.
- BaseMemRefType yieldedValueBufferType;
+ BufferLikeType yieldedValueBufferType;
if (isa<BaseMemRefType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
- yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+ yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
- auto maybeBufferType =
- bufferization::detail::asMemRefType(bufferization::getBufferType(
- yieldedValue, options, state, invocationStack));
+ auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+ state, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same shape");
}
#endif // NDEBUG
- return getMemRefTypeWithFullyDynamicLayout(
- iterTensorType, yieldedBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ iterTensorType, yieldedBufferType.getMemorySpace()));
}
/// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
- auto bufferType =
- bufferization::getBufferType(bbArg, options, state, invocationStack);
- if (failed(bufferType))
- return failure();
- assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
- return cast<BaseMemRefType>(*bufferType);
+ return bufferization::getBufferType(bbArg, options, state,
+ invocationStack);
}
// Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
if (!isa<TensorType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
- return cast<BaseMemRefType>(conditionYieldedVal.getType());
+ return cast<BufferLikeType>(conditionYieldedVal.getType());
}
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
- conditionYieldedVal, options, state, invocationStack));
+ return bufferization::getBufferType(conditionYieldedVal, options, state,
+ invocationStack);
}
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
- return bufferization::detail::asMemRefType(
- bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
- options, state, invocationStack));
+ return bufferization::getBufferType(
+ forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+ invocationStack);
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
+ return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
- state, invocationStack));
+ state, invocationStack);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
return {{op->getResult(0), BufferRelation::Equivalent}};
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Opera...
[truncated]
|
@llvm/pr-subscribers-mlir-tensor Author: Andrei Golubev (andrey-golubev) ChangesSupport custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType. Affected implementors of the interface are updated accordingly. Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+ return getBuffer().getType();
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (!bufferType)
return op->emitOpError("could not infer buffer type of block argument");
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
- return *trueType;
+ return cast<BufferLikeType>(*trueType);
if (trueType->getMemorySpace() != falseType->getMemorySpace())
return op->emitError("inconsistent memory space on true/false operands");
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
- memrefType.getMemorySpace());
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
return AliasingOpOperandList(std::move(result));
}
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
auto tensorType = cast<TensorType>(value.getType());
// No further analysis is possible for a block argument.
- if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(tensorType, options);
+ if (llvm::isa<BlockArgument>(value)) {
+ return cast<BufferLikeType>(
+ bufferization::getMemRefType(tensorType, options));
+ }
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return asMemRefType(getBufferType(equivalentOperand, options,
- bufferizationState, invocationStack));
+ return getBufferType(equivalentOperand, options, bufferizationState,
+ invocationStack);
}
// If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+ return cast<BufferLikeType>(
+ getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
return {};
}
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
- return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+ return cast<BufferLikeType>(
+ getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
return result;
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
- return bufferizedType;
+ return cast<BufferLikeType>(bufferizedType);
// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorType>(resultType);
- return options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+ return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+ options));
}
/// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
- return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
- options);
+ return cast<BufferLikeType>(
+ getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
// Best case: Both branches have the exact same buffer type.
if (thenBufferType == elseBufferType)
- return thenBufferType;
+ return cast<BufferLikeType>(thenBufferType);
// Memory space mismatch.
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
return op->emitError("inconsistent memory space on then/else branches");
// Layout maps are different: Promote to fully dynamic layout map.
- return getMemRefTypeWithFullyDynamicLayout(
- cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
}
};
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
}
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
};
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// If both buffer types are equal, no casts are needed the computed buffer type
/// can be used directly. Otherwise, the buffer types can only differ in their
/// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
const BufferizationOptions &options, const BufferizationState &state,
SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
- auto initArgBufferType = bufferization::detail::asMemRefType(
- bufferization::getBufferType(initArg, options, state, invocationStack));
+ auto initArgBufferType =
+ bufferization::getBufferType(initArg, options, state, invocationStack);
if (failed(initArgBufferType))
return failure();
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
}
// Compute the buffer type of the yielded value.
- BaseMemRefType yieldedValueBufferType;
+ BufferLikeType yieldedValueBufferType;
if (isa<BaseMemRefType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
- yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+ yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
- auto maybeBufferType =
- bufferization::detail::asMemRefType(bufferization::getBufferType(
- yieldedValue, options, state, invocationStack));
+ auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+ state, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same shape");
}
#endif // NDEBUG
- return getMemRefTypeWithFullyDynamicLayout(
- iterTensorType, yieldedBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ iterTensorType, yieldedBufferType.getMemorySpace()));
}
/// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
- auto bufferType =
- bufferization::getBufferType(bbArg, options, state, invocationStack);
- if (failed(bufferType))
- return failure();
- assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
- return cast<BaseMemRefType>(*bufferType);
+ return bufferization::getBufferType(bbArg, options, state,
+ invocationStack);
}
// Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
if (!isa<TensorType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
- return cast<BaseMemRefType>(conditionYieldedVal.getType());
+ return cast<BufferLikeType>(conditionYieldedVal.getType());
}
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
- conditionYieldedVal, options, state, invocationStack));
+ return bufferization::getBufferType(conditionYieldedVal, options, state,
+ invocationStack);
}
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
- return bufferization::detail::asMemRefType(
- bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
- options, state, invocationStack));
+ return bufferization::getBufferType(
+ forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+ invocationStack);
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
+ return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
- state, invocationStack));
+ state, invocationStack);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
return {{op->getResult(0), BufferRelation::Equivalent}};
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Opera...
[truncated]
|
@llvm/pr-subscribers-mlir-bufferization Author: Andrei Golubev (andrey-golubev) ChangesSupport custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType. Affected implementors of the interface are updated accordingly. Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+ return getBuffer().getType();
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (!bufferType)
return op->emitOpError("could not infer buffer type of block argument");
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
- return *trueType;
+ return cast<BufferLikeType>(*trueType);
if (trueType->getMemorySpace() != falseType->getMemorySpace())
return op->emitError("inconsistent memory space on true/false operands");
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
- memrefType.getMemorySpace());
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
return AliasingOpOperandList(std::move(result));
}
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
auto tensorType = cast<TensorType>(value.getType());
// No further analysis is possible for a block argument.
- if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(tensorType, options);
+ if (llvm::isa<BlockArgument>(value)) {
+ return cast<BufferLikeType>(
+ bufferization::getMemRefType(tensorType, options));
+ }
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return asMemRefType(getBufferType(equivalentOperand, options,
- bufferizationState, invocationStack));
+ return getBufferType(equivalentOperand, options, bufferizationState,
+ invocationStack);
}
// If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+ return cast<BufferLikeType>(
+ getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
return {};
}
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
- return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+ return cast<BufferLikeType>(
+ getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
return result;
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
- return bufferizedType;
+ return cast<BufferLikeType>(bufferizedType);
// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorType>(resultType);
- return options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+ return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+ options));
}
/// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
- return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
- options);
+ return cast<BufferLikeType>(
+ getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
// Best case: Both branches have the exact same buffer type.
if (thenBufferType == elseBufferType)
- return thenBufferType;
+ return cast<BufferLikeType>(thenBufferType);
// Memory space mismatch.
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
return op->emitError("inconsistent memory space on then/else branches");
// Layout maps are different: Promote to fully dynamic layout map.
- return getMemRefTypeWithFullyDynamicLayout(
- cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
}
};
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
}
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
};
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// If both buffer types are equal, no casts are needed the computed buffer type
/// can be used directly. Otherwise, the buffer types can only differ in their
/// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
const BufferizationOptions &options, const BufferizationState &state,
SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
- auto initArgBufferType = bufferization::detail::asMemRefType(
- bufferization::getBufferType(initArg, options, state, invocationStack));
+ auto initArgBufferType =
+ bufferization::getBufferType(initArg, options, state, invocationStack);
if (failed(initArgBufferType))
return failure();
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
}
// Compute the buffer type of the yielded value.
- BaseMemRefType yieldedValueBufferType;
+ BufferLikeType yieldedValueBufferType;
if (isa<BaseMemRefType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
- yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+ yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
- auto maybeBufferType =
- bufferization::detail::asMemRefType(bufferization::getBufferType(
- yieldedValue, options, state, invocationStack));
+ auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+ state, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same shape");
}
#endif // NDEBUG
- return getMemRefTypeWithFullyDynamicLayout(
- iterTensorType, yieldedBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ iterTensorType, yieldedBufferType.getMemorySpace()));
}
/// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
- auto bufferType =
- bufferization::getBufferType(bbArg, options, state, invocationStack);
- if (failed(bufferType))
- return failure();
- assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
- return cast<BaseMemRefType>(*bufferType);
+ return bufferization::getBufferType(bbArg, options, state,
+ invocationStack);
}
// Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
if (!isa<TensorType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
- return cast<BaseMemRefType>(conditionYieldedVal.getType());
+ return cast<BufferLikeType>(conditionYieldedVal.getType());
}
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
- conditionYieldedVal, options, state, invocationStack));
+ return bufferization::getBufferType(conditionYieldedVal, options, state,
+ invocationStack);
}
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
- return bufferization::detail::asMemRefType(
- bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
- options, state, invocationStack));
+ return bufferization::getBufferType(
+ forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+ invocationStack);
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
+ return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
- state, invocationStack));
+ state, invocationStack);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
return {{op->getResult(0), BufferRelation::Equivalent}};
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Opera...
[truncated]
|
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, | |||
return getOperation()->emitError("could not infer memory space"); | |||
} | |||
|
|||
return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace); | |||
return cast<BufferLikeType>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: unfortunately, i didn't seem to manage to avoid this even with an addition of an implicit ctor (tried locally but didn't add it in this patch): BufferLikeType(BaseMemRefType)
. I think this is due to also a FailureOr<> wrapper. I wonder yet again whether it makes sense to push a patch to extend FailureOr<> implicit conversion semantics (out of scope of this PR though).
@@ -196,17 +196,17 @@ struct SelectOpInterface | |||
if (failed(trueType) || failed(falseType)) | |||
return failure(); | |||
if (*trueType == *falseType) | |||
return *trueType; | |||
return cast<BufferLikeType>(*trueType); | |||
if (trueType->getMemorySpace() != falseType->getMemorySpace()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getMemorySpace()
is the main problem here why I couldn't drop bufferization::detail::asMemRefType()
(and thus work with BufferLikeType objects directly).
@matthias-springer is it reasonable to assume all buffers have an associated memory space? I guess I could follow this up by another patch that extends the BufferLikeType and streamlines multiple places around the code-base.
Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.
Affected implementors of the interface are updated accordingly.
Relates to ee070d0.