-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NFC] Improve readability of AttrHelper usage #135873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-gpu Author: Simon Waters (sjw36) ChangesFull diff: https://github.com/llvm/llvm-project/pull/135873.diff 5 Files Affected:
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index f22ad1fd70db2..1b4ea6b1164ec 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -194,7 +194,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
// Ensure we don't lose information if the function is lowered before its
// surrounding context.
- auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
+ auto *gpuDialect = gpu::GPUDialect::getLoaded(gpuFuncOp);
if (knownBlockSize)
attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
knownBlockSize);
diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index 1f158b271e5c6..d7aa5f70d984a 100644
--- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -74,21 +74,18 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
// 3. Discardable attributes on a surrounding function of any kind
// The below code handles these in reverse order so that more important
// sources overwrite less important ones.
+ auto *gpuDialect = gpu::GPUDialect::getLoaded(op);
DenseI32ArrayAttr funcBounds = nullptr;
if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
switch (indexKind) {
case IndexKind::Block: {
- auto blockHelper =
- gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
- if (blockHelper.isAttrPresent(funcOp))
- funcBounds = blockHelper.getAttr(funcOp);
+ auto blockHelper = gpuDialect->getKnownBlockSizeAttrHelper();
+ funcBounds = blockHelper.getAttr(funcOp);
break;
}
case IndexKind::Grid: {
- auto gridHelper =
- gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
- if (gridHelper.isAttrPresent(funcOp))
- funcBounds = gridHelper.getAttr(funcOp);
+ auto gridHelper = gpuDialect->getKnownGridSizeAttrHelper();
+ funcBounds = gridHelper.getAttr(funcOp);
break;
}
case IndexKind::Other:
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index c6c695b442b4f..4a4c97dfc7bc0 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -326,7 +326,7 @@ struct LowerGpuOpsToROCDLOpsPass final
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
- auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+ auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(getContext());
auto reqdWorkGroupSizeAttrHelper =
rocdlDialect->getReqdWorkGroupSizeAttrHelper();
auto flatWorkGroupSizeAttrHelper =
@@ -374,8 +374,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
using mlir::gpu::amd::Runtime;
- auto *rocdlDialect =
- converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+ auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(converter.getContext());
populateWithGenerated(patterns);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 88a9d4c2a7ef2..abc46bf0e25f1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -76,7 +76,7 @@ class ROCDLDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
- auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+ auto *dialect = ROCDL::ROCDLDialect::getLoaded(op);
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 6cf71d2bb0174..bbd67d274e851 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -110,15 +110,23 @@ tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
/// {2}: The dialect parent class.
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::{2} {
+ typedef {0} DialectType;
explicit {0}(::mlir::MLIRContext *context);
void initialize();
friend class ::mlir::MLIRContext;
public:
~{0}() override;
- static constexpr ::llvm::StringLiteral getDialectNamespace() {
+ static constexpr ::llvm::StringLiteral getDialectNamespace() {{
return ::llvm::StringLiteral("{1}");
}
+ static const DialectType *getLoaded(::mlir::MLIRContext &context) {{
+ return context.getLoadedDialect<DialectType>();
+ }
+ static const DialectType *getLoaded(::mlir::MLIRContext *context) {{
+ return getLoaded(*context);
+ }
+ static const DialectType *getLoaded(::mlir::Operation *operation);
)";
/// Registration for a single dependent dialect: to be inserted in the ctor
@@ -206,28 +214,28 @@ static const char *const discardableAttrHelperDecl = R"(
static constexpr ::llvm::StringLiteral getNameStr() {{
return "{4}.{1}";
}
- constexpr ::mlir::StringAttr getName() {{
+ constexpr ::mlir::StringAttr getName() const {{
return name;
}
{0}AttrHelper(::mlir::MLIRContext *ctx)
: name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
- {2} getAttr(::mlir::Operation *op) {{
- return op->getAttrOfType<{2}>(name);
- }
- void setAttr(::mlir::Operation *op, {2} val) {{
- op->setAttr(name, val);
- }
- bool isAttrPresent(::mlir::Operation *op) {{
- return op->hasAttrOfType<{2}>(name);
- }
- void removeAttr(::mlir::Operation *op) {{
- assert(op->hasAttrOfType<{2}>(name));
- op->removeAttr(name);
- }
+ {2} getAttr(::mlir::Operation *op) const {{
+ return op->getAttrOfType<{2}>(name);
+ }
+ void setAttr(::mlir::Operation *op, {2} val) const {{
+ op->setAttr(name, val);
+ }
+ bool isAttrPresent(::mlir::Operation *op) const {{
+ return op->hasAttrOfType<{2}>(name);
+ }
+ void removeAttr(::mlir::Operation *op) const {{
+ assert(op->hasAttrOfType<{2}>(name));
+ op->removeAttr(name);
+ }
};
- {0}AttrHelper get{0}AttrHelper() {
+ const {0}AttrHelper get{0}AttrHelper() const {
return {3}AttrName;
}
private:
@@ -341,7 +349,17 @@ static const char *const dialectDestructorStr = R"(
{0}::~{0}() = default;
)";
+
+/// The code block to generate a member funcs.
+///
+/// {0}: The name of the dialect class.
+static const char *const dialectStaticMemberDefs = R"(
+const {0} *{0}::getLoaded(::mlir::Operation *operation) {{
+ return getLoaded(*operation->getContext());
+}
+)";
+
static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
@@ -388,6 +406,9 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
discardableAttributesInit);
if (!dialect.hasNonDefaultDestructor())
os << llvm::formatv(dialectDestructorStr, cppClassName);
+
+ // Emit member function definitions.
+ os << llvm::formatv(dialectStaticMemberDefs, cppClassName);
}
static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
/// | ||
/// {0}: The name of the dialect class. | ||
static const char *const dialectStaticMemberDefs = R"( | ||
const {0} *{0}::getLoaded(::mlir::Operation *operation) {{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced of the utility of this method over just casting from the operations dialect. This method is going to be significantly slower than calling getDialect on the operation itself. If we wanted to make this easier, we should rather consider auto-generating a templated version of getDialect
for Ops/Attrs/Types that does the casting for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also seems like an unrelated change compared to the description of this commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dialect of the operation may not be the same dialect that owns the attribute though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method seems ad-hoc to me and does not provide much convenience compared to the current code.
What's wrong with getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing really, just adding for convenience. I was asked on a dialect in a separate project to make usage of AttrHelpers cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This particular method was added in the cpp file gen so Operation.h would not require inclusion by all Dialects. But it is just an extra convenience, happy to remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(marking review)
return ::llvm::StringLiteral("{1}"); | ||
} | ||
static const DialectType *getLoaded(::mlir::MLIRContext &context) {{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not understanding the reasoning behind these changes. When sending a commit for review, please add context as to "why" a change is being proposed.
This change primarily adds some static helper methods to a Dialect to look up the context Loaded version, which can be used to get AttrHelpers.