Skip to content

[NFC] Improve readability of AttrHelper usage #135873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
// Ensure we don't lose information if the function is lowered before its
// surrounding context.
auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
auto *gpuDialect = gpu::GPUDialect::getLoaded(gpuFuncOp);
if (knownBlockSize)
attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
knownBlockSize);
Expand Down
13 changes: 5 additions & 8 deletions mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,18 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
// 3. Discardable attributes on a surrounding function of any kind
// The below code handles these in reverse order so that more important
// sources overwrite less important ones.
auto *gpuDialect = gpu::GPUDialect::getLoaded(op);
DenseI32ArrayAttr funcBounds = nullptr;
if (auto funcOp = op->template getParentOfType<FunctionOpInterface>()) {
switch (indexKind) {
case IndexKind::Block: {
auto blockHelper =
gpu::GPUDialect::KnownBlockSizeAttrHelper(op.getContext());
if (blockHelper.isAttrPresent(funcOp))
funcBounds = blockHelper.getAttr(funcOp);
auto blockHelper = gpuDialect->getKnownBlockSizeAttrHelper();
funcBounds = blockHelper.getAttr(funcOp);
break;
}
case IndexKind::Grid: {
auto gridHelper =
gpu::GPUDialect::KnownGridSizeAttrHelper(op.getContext());
if (gridHelper.isAttrPresent(funcOp))
funcBounds = gridHelper.getAttr(funcOp);
auto gridHelper = gpuDialect->getKnownGridSizeAttrHelper();
funcBounds = gridHelper.getAttr(funcOp);
break;
}
case IndexKind::Other:
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ struct LowerGpuOpsToROCDLOpsPass final
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(getContext());
auto reqdWorkGroupSizeAttrHelper =
rocdlDialect->getReqdWorkGroupSizeAttrHelper();
auto flatWorkGroupSizeAttrHelper =
Expand Down Expand Up @@ -374,8 +374,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
using mlir::gpu::amd::Runtime;
auto *rocdlDialect =
converter.getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
auto *rocdlDialect = ROCDL::ROCDLDialect::getLoaded(converter.getContext());
populateWithGenerated(patterns);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ROCDLDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
auto *dialect = ROCDL::ROCDLDialect::getLoaded(op);
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
Expand Down
53 changes: 37 additions & 16 deletions mlir/tools/mlir-tblgen/DialectGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,23 @@ tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
/// {2}: The dialect parent class.
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::{2} {
typedef {0} DialectType;
explicit {0}(::mlir::MLIRContext *context);

void initialize();
friend class ::mlir::MLIRContext;
public:
~{0}() override;
static constexpr ::llvm::StringLiteral getDialectNamespace() {
static constexpr ::llvm::StringLiteral getDialectNamespace() {{
return ::llvm::StringLiteral("{1}");
}
static const DialectType *getLoaded(::mlir::MLIRContext &context) {{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not understanding the reasoning behind these changes. When sending a commit for review, please add context as to "why" a change is being proposed.

return context.getLoadedDialect<DialectType>();
}
static const DialectType *getLoaded(::mlir::MLIRContext *context) {{
return getLoaded(*context);
}
static const DialectType *getLoaded(::mlir::Operation *operation);
)";

/// Registration for a single dependent dialect: to be inserted in the ctor
Expand Down Expand Up @@ -206,28 +214,28 @@ static const char *const discardableAttrHelperDecl = R"(
static constexpr ::llvm::StringLiteral getNameStr() {{
return "{4}.{1}";
}
constexpr ::mlir::StringAttr getName() {{
constexpr ::mlir::StringAttr getName() const {{
return name;
}

{0}AttrHelper(::mlir::MLIRContext *ctx)
: name(::mlir::StringAttr::get(ctx, getNameStr())) {{}

{2} getAttr(::mlir::Operation *op) {{
return op->getAttrOfType<{2}>(name);
}
void setAttr(::mlir::Operation *op, {2} val) {{
op->setAttr(name, val);
}
bool isAttrPresent(::mlir::Operation *op) {{
return op->hasAttrOfType<{2}>(name);
}
void removeAttr(::mlir::Operation *op) {{
assert(op->hasAttrOfType<{2}>(name));
op->removeAttr(name);
}
{2} getAttr(::mlir::Operation *op) const {{
return op->getAttrOfType<{2}>(name);
}
void setAttr(::mlir::Operation *op, {2} val) const {{
op->setAttr(name, val);
}
bool isAttrPresent(::mlir::Operation *op) const {{
return op->hasAttrOfType<{2}>(name);
}
void removeAttr(::mlir::Operation *op) const {{
assert(op->hasAttrOfType<{2}>(name));
op->removeAttr(name);
}
};
{0}AttrHelper get{0}AttrHelper() {
const {0}AttrHelper get{0}AttrHelper() const {
return {3}AttrName;
}
private:
Expand Down Expand Up @@ -342,6 +350,16 @@ static const char *const dialectDestructorStr = R"(

)";

/// The code block to generate a member funcs.
///
/// {0}: The name of the dialect class.
static const char *const dialectStaticMemberDefs = R"(
const {0} *{0}::getLoaded(::mlir::Operation *operation) {{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced of the utility of this method over just casting from the operations dialect. This method is going to be significantly slower than calling getDialect on the operation itself. If we wanted to make this easier, we should rather consider auto-generating a templated version of getDialect for Ops/Attrs/Types that does the casting for you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also seems like an unrelated change compared to the description of this commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dialect of the operation may not be the same dialect that owns the attribute though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems ad-hoc to me and does not provide much convenience compared to the current code.
What's wrong with getContext().getLoadedDialect<ROCDL::ROCDLDialect>(); ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing really, just adding for convenience. I was asked on a dialect in a separate project to make usage of AttrHelpers cleaner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This particular method was added in the cpp file gen so Operation.h would not require inclusion by all Dialects. But it is just an extra convenience, happy to remove.

return getLoaded(*operation->getContext());
}

)";

static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
Expand Down Expand Up @@ -388,6 +406,9 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
discardableAttributesInit);
if (!dialect.hasNonDefaultDestructor())
os << llvm::formatv(dialectDestructorStr, cppClassName);

// Emit member function definitions.
os << llvm::formatv(dialectStaticMemberDefs, cppClassName);
}

static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {
Expand Down
Loading