Skip to content

[mlir][bufferization] Return BufferLikeType in BufferizableOpInterface #144867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
/*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
/*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);

FailureOr<BaseMemRefType> getBufferType(
FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
Expand Down Expand Up @@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [

bool isWritable(Value value, const AnalysisState &state);

FailureOr<BaseMemRefType> getBufferType(
FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
return getBuffer().getType();
}
}];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand Down Expand Up @@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (!bufferType)
return op->emitOpError("could not infer buffer type of block argument");

return bufferType;
return cast<BufferLikeType>(bufferType);
}

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ struct SelectOpInterface
return success();
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand All @@ -196,17 +196,17 @@ struct SelectOpInterface
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
return *trueType;
return cast<BufferLikeType>(*trueType);
if (trueType->getMemorySpace() != falseType->getMemorySpace())
Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getMemorySpace() is the main problem here why I couldn't drop bufferization::detail::asMemRefType() (and thus work with BufferLikeType objects directly).

@matthias-springer is it reasonable to assume all buffers have an associated memory space? I guess I could follow this up by another patch that extends the BufferLikeType and streamlines multiple places around the code-base.

return op->emitError("inconsistent memory space on true/false operands");

// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
return getMemRefTypeWithFullyDynamicLayout(
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
memrefType.getMemorySpace());
memrefType.getMemorySpace()));
}
};

Expand Down
15 changes: 9 additions & 6 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,16 +945,18 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
return AliasingOpOperandList(std::move(result));
}

FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
auto tensorType = cast<TensorType>(value.getType());

// No further analysis is possible for a block argument.
if (llvm::isa<BlockArgument>(value))
return bufferization::getMemRefType(tensorType, options);
if (llvm::isa<BlockArgument>(value)) {
return cast<BufferLikeType>(
bufferization::getMemRefType(tensorType, options));
}

// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
Expand All @@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
return asMemRefType(getBufferType(equivalentOperand, options,
bufferizationState, invocationStack));
return getBufferType(equivalentOperand, options, bufferizationState,
invocationStack);
}

// If we do not know the memory space and there is no default memory space,
Expand All @@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");

return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
return cast<BufferLikeType>(
getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}

bool bufferization::detail::defaultIsRepetitiveRegion(
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
return {};
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
Expand All @@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}

return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
return cast<BufferLikeType>(
Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: unfortunately, i didn't seem to manage to avoid this even with an addition of an implicit ctor (tried locally but didn't add it in this patch): BufferLikeType(BaseMemRefType). I think this is due to also a FailureOr<> wrapper. I wonder yet again whether it makes sense to push a patch to extend FailureOr<> implicit conversion semantics (out of scope of this PR though).

getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}

LogicalResult AllocTensorOp::verify() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ struct CallOpInterface
return result;
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand All @@ -229,12 +229,13 @@ struct CallOpInterface
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
return bufferizedType;
return cast<BufferLikeType>(bufferizedType);

// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorType>(resultType);
return options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
return cast<BufferLikeType>(options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
options));
}

/// All function arguments are writable. It is the responsibility of the
Expand Down Expand Up @@ -396,7 +397,7 @@ struct FuncOpInterface
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand All @@ -405,8 +406,8 @@ struct FuncOpInterface

// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
options);
return cast<BufferLikeType>(
getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));

return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
Expand Down
61 changes: 28 additions & 33 deletions mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ struct IfOpInterface
return success();
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand Down Expand Up @@ -313,15 +313,15 @@ struct IfOpInterface

// Best case: Both branches have the exact same buffer type.
if (thenBufferType == elseBufferType)
return thenBufferType;
return cast<BufferLikeType>(thenBufferType);

// Memory space mismatch.
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
return op->emitError("inconsistent memory space on then/else branches");

// Layout maps are different: Promote to fully dynamic layout map.
return getMemRefTypeWithFullyDynamicLayout(
cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
}
};

Expand Down Expand Up @@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
return success();
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand Down Expand Up @@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
}

return bufferType;
return cast<BufferLikeType>(bufferType);
}
};

Expand Down Expand Up @@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// If both buffer types are equal, no casts are needed the computed buffer type
/// can be used directly. Otherwise, the buffer types can only differ in their
/// layout map and a cast must be inserted.
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
const BufferizationOptions &options, const BufferizationState &state,
SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
auto initArgBufferType = bufferization::detail::asMemRefType(
bufferization::getBufferType(initArg, options, state, invocationStack));
auto initArgBufferType =
bufferization::getBufferType(initArg, options, state, invocationStack);
if (failed(initArgBufferType))
return failure();

Expand All @@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
}

// Compute the buffer type of the yielded value.
BaseMemRefType yieldedValueBufferType;
BufferLikeType yieldedValueBufferType;
if (isa<BaseMemRefType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
auto maybeBufferType =
bufferization::detail::asMemRefType(bufferization::getBufferType(
yieldedValue, options, state, invocationStack));
auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
state, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
Expand Down Expand Up @@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same shape");
}
#endif // NDEBUG
return getMemRefTypeWithFullyDynamicLayout(
iterTensorType, yieldedBufferType.getMemorySpace());
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
iterTensorType, yieldedBufferType.getMemorySpace()));
}

/// Return `true` if the given loop may have 0 iterations.
Expand Down Expand Up @@ -708,7 +707,7 @@ struct ForOpInterface
return success();
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand All @@ -719,12 +718,8 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
auto bufferType =
bufferization::getBufferType(bbArg, options, state, invocationStack);
if (failed(bufferType))
return failure();
assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
return cast<BaseMemRefType>(*bufferType);
return bufferization::getBufferType(bbArg, options, state,
invocationStack);
}

// Compute result/argument number.
Expand Down Expand Up @@ -1047,7 +1042,7 @@ struct WhileOpInterface
return success();
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand Down Expand Up @@ -1081,10 +1076,10 @@ struct WhileOpInterface
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
if (!isa<TensorType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
return cast<BaseMemRefType>(conditionYieldedVal.getType());
return cast<BufferLikeType>(conditionYieldedVal.getType());
}
return bufferization::detail::asMemRefType(bufferization::getBufferType(
conditionYieldedVal, options, state, invocationStack));
return bufferization::getBufferType(conditionYieldedVal, options, state,
invocationStack);
}

/// Assert that yielded values of an scf.while op are equivalent to their
Expand Down Expand Up @@ -1303,7 +1298,7 @@ struct ForallOpInterface
return success();
}

FailureOr<BaseMemRefType>
FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
Expand All @@ -1312,15 +1307,15 @@ struct ForallOpInterface
if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
return bufferization::detail::asMemRefType(
bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
options, state, invocationStack));
return bufferization::getBufferType(
forallOp.getTiedOpOperand(bbArg)->get(), options, state,
invocationStack);

// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
return bufferization::detail::asMemRefType(bufferization::getBufferType(
return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
state, invocationStack));
state, invocationStack);
}

bool isRepetitiveRegion(Operation *op, unsigned index) const {
Expand Down
Loading
Loading