Skip to content

[mlir][bufferization] Return BufferLikeType in BufferizableOpInterface #144867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

andrey-golubev
Copy link
Contributor

@andrey-golubev andrey-golubev commented Jun 19, 2025

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.

Relates to ee070d0.

Support custom types (2/N): allow value-owning operations (e.g.
allocation ops) to bufferize into custom types. This requires
BufferizableOpInterface::getBufferType() to return BufferLikeType
instead of BaseMemRefType.

Affected implementors of the interface are update accordingly.
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Andrei Golubev (andrey-golubev)

Changes

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.


Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+2-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-7)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-33)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+28-24)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+21-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+53)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const BufferizationState &state,
                      SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state,
         SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+      return getBuffer().getType();
     }
   }];
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     if (!bufferType)
       return op->emitOpError("could not infer buffer type of block argument");
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 
 protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
-      return *trueType;
+      return cast<BufferLikeType>(*trueType);
     if (trueType->getMemorySpace() != falseType->getMemorySpace())
       return op->emitError("inconsistent memory space on true/false operands");
 
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
-        memrefType.getMemorySpace());
+        memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
-  if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(tensorType, options);
+  if (llvm::isa<BlockArgument>(value)) {
+    return cast<BufferLikeType>(
+        bufferization::getMemRefType(tensorType, options));
+  }
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return asMemRefType(getBufferType(equivalentOperand, options,
-                                      bufferizationState, invocationStack));
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+  return cast<BufferLikeType>(
+      getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return bufferizedType;
+      return cast<BufferLikeType>(bufferizedType);
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorType>(resultType);
-    return options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
-                                          options);
+      return cast<BufferLikeType>(
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
 
     // Best case: Both branches have the exact same buffer type.
     if (thenBufferType == elseBufferType)
-      return thenBufferType;
+      return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
   }
 };
 
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
     }
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 };
 
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only differ in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType = bufferization::detail::asMemRefType(
-      bufferization::getBufferType(initArg, options, state, invocationStack));
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
+  BufferLikeType yieldedValueBufferType;
   if (isa<BaseMemRefType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::detail::asMemRefType(bufferization::getBufferType(
-            yieldedValue, options, state, invocationStack));
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+      iterTensorType, yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      auto bufferType =
-          bufferization::getBufferType(bbArg, options, state, invocationStack);
-      if (failed(bufferType))
-        return failure();
-      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
-      return cast<BaseMemRefType>(*bufferType);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
     if (!isa<TensorType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
-        conditionYieldedVal, options, state, invocationStack));
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::detail::asMemRefType(
-          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
-                                       options, state, invocationStack));
+      return bufferization::getBufferType(
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
+    return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack));
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
     return {{op->getResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Opera...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir-tensor

Author: Andrei Golubev (andrey-golubev)

Changes

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.


Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+2-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-7)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-33)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+28-24)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+21-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+53)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const BufferizationState &state,
                      SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state,
         SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+      return getBuffer().getType();
     }
   }];
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     if (!bufferType)
       return op->emitOpError("could not infer buffer type of block argument");
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 
 protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
-      return *trueType;
+      return cast<BufferLikeType>(*trueType);
     if (trueType->getMemorySpace() != falseType->getMemorySpace())
       return op->emitError("inconsistent memory space on true/false operands");
 
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
-        memrefType.getMemorySpace());
+        memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
-  if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(tensorType, options);
+  if (llvm::isa<BlockArgument>(value)) {
+    return cast<BufferLikeType>(
+        bufferization::getMemRefType(tensorType, options));
+  }
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return asMemRefType(getBufferType(equivalentOperand, options,
-                                      bufferizationState, invocationStack));
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+  return cast<BufferLikeType>(
+      getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return bufferizedType;
+      return cast<BufferLikeType>(bufferizedType);
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorType>(resultType);
-    return options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
-                                          options);
+      return cast<BufferLikeType>(
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
 
     // Best case: Both branches have the exact same buffer type.
     if (thenBufferType == elseBufferType)
-      return thenBufferType;
+      return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
   }
 };
 
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
     }
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 };
 
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only differ in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType = bufferization::detail::asMemRefType(
-      bufferization::getBufferType(initArg, options, state, invocationStack));
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
+  BufferLikeType yieldedValueBufferType;
   if (isa<BaseMemRefType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::detail::asMemRefType(bufferization::getBufferType(
-            yieldedValue, options, state, invocationStack));
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+      iterTensorType, yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      auto bufferType =
-          bufferization::getBufferType(bbArg, options, state, invocationStack);
-      if (failed(bufferType))
-        return failure();
-      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
-      return cast<BaseMemRefType>(*bufferType);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
     if (!isa<TensorType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
-        conditionYieldedVal, options, state, invocationStack));
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::detail::asMemRefType(
-          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
-                                       options, state, invocationStack));
+      return bufferization::getBufferType(
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
+    return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack));
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
     return {{op->getResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Opera...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir-bufferization

Author: Andrei Golubev (andrey-golubev)

Changes

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.


Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+2-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-7)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-33)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+28-24)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+21-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+53)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const BufferizationState &state,
                      SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state,
         SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+      return getBuffer().getType();
     }
   }];
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     if (!bufferType)
       return op->emitOpError("could not infer buffer type of block argument");
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 
 protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
-      return *trueType;
+      return cast<BufferLikeType>(*trueType);
     if (trueType->getMemorySpace() != falseType->getMemorySpace())
       return op->emitError("inconsistent memory space on true/false operands");
 
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
-        memrefType.getMemorySpace());
+        memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
-  if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(tensorType, options);
+  if (llvm::isa<BlockArgument>(value)) {
+    return cast<BufferLikeType>(
+        bufferization::getMemRefType(tensorType, options));
+  }
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return asMemRefType(getBufferType(equivalentOperand, options,
-                                      bufferizationState, invocationStack));
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+  return cast<BufferLikeType>(
+      getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return bufferizedType;
+      return cast<BufferLikeType>(bufferizedType);
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorType>(resultType);
-    return options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
-                                          options);
+      return cast<BufferLikeType>(
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
 
     // Best case: Both branches have the exact same buffer type.
     if (thenBufferType == elseBufferType)
-      return thenBufferType;
+      return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
   }
 };
 
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
     }
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 };
 
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only differ in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType = bufferization::detail::asMemRefType(
-      bufferization::getBufferType(initArg, options, state, invocationStack));
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
+  BufferLikeType yieldedValueBufferType;
   if (isa<BaseMemRefType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::detail::asMemRefType(bufferization::getBufferType(
-            yieldedValue, options, state, invocationStack));
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+      iterTensorType, yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      auto bufferType =
-          bufferization::getBufferType(bbArg, options, state, invocationStack);
-      if (failed(bufferType))
-        return failure();
-      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
-      return cast<BaseMemRefType>(*bufferType);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
     if (!isa<TensorType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
-        conditionYieldedVal, options, state, invocationStack));
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::detail::asMemRefType(
-          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
-                                       options, state, invocationStack));
+      return bufferization::getBufferType(
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
+    return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack));
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
     return {{op->getResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Opera...
[truncated]

@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}

return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
return cast<BufferLikeType>(
Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: unfortunately, i didn't seem to manage to avoid this even with an addition of an implicit ctor (tried locally but didn't add it in this patch): BufferLikeType(BaseMemRefType). I think this is due to also a FailureOr<> wrapper. I wonder yet again whether it makes sense to push a patch to extend FailureOr<> implicit conversion semantics (out of scope of this PR though).

@@ -196,17 +196,17 @@ struct SelectOpInterface
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
return *trueType;
return cast<BufferLikeType>(*trueType);
if (trueType->getMemorySpace() != falseType->getMemorySpace())
Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getMemorySpace() is the main problem here why I couldn't drop bufferization::detail::asMemRefType() (and thus work with BufferLikeType objects directly).

@matthias-springer is it reasonable to assume all buffers have an associated memory space? I guess I could follow this up by another patch that extends the BufferLikeType and streamlines multiple places around the code-base.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants