-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[MLIR] Add ComplexTOROCDL pass #144926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR] Add ComplexTOROCDL pass #144926
Conversation
This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.
@llvm/pr-subscribers-flang-codegen @llvm/pr-subscribers-flang-fir-hlfir Author: Akash Banerjee (TIFitis) ChangesThis patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls. Full diff: https://github.com/llvm/llvm-project/pull/144926.diff 9 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..8b4ac18fba527 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
MLIRMathToLLVM
MLIRMathToLibm
MLIRMathToROCDL
+ MLIRComplexToROCDL
MLIROpenMPToLLVM
MLIROpenACCDialect
MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..f721b6232b0fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
// GPU library calls, the rest can be converted to LLVM intrinsics, which
// is handled in the mathToLLVM conversion. The lowering to libm calls is
// not needed since all math operations are handled this way.
- if (isAMDGCN)
+ if (isAMDGCN) {
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+ mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+ }
// Convert math::FPowI operations to inline implementation
// only if the exponent's width is greater than 32, otherwise,
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
new file mode 100644
index 0000000000000..ed65be9980408
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -0,0 +1,19 @@
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..67e8f5b99b67b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..8ad2341f93a15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
let dependentDialects = ["func::FuncDialect"];
}
+//===----------------------------------------------------------------------===//
+// ComplexToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
+ let summary = "Convert Complex dialect to ROCDL calls";
+ let description = [{
+ This pass converts supported Complex ops to calls to the AMD device library.
+ }];
+ let dependentDialects = ["func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ComplexToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..4ad81553a4fa8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexCommon)
add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDL)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToSPIRV)
add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..54607250083d7
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDL
+ ComplexToROCDL.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRComplexDialect
+ MLIRFuncDialect
+ MLIRPass
+ MLIRTransformUtils
+ )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
new file mode 100644
index 0000000000000..cdfe2a6dfe874
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -0,0 +1,95 @@
+//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct FloatTypeResolver {
+ std::optional<bool> operator()(Type type) const {
+ auto elementType = cast<FloatType>(type);
+ if (!isa<Float32Type, Float64Type>(elementType))
+ return {};
+ return elementType.getIntOrFloatBitWidth() == 64;
+ }
+};
+
+template <typename Op, typename TypeResolver = FloatTypeResolver>
+struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+ StringRef doubleFunc, PatternBenefit benefit)
+ : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+ doubleFunc(doubleFunc) {}
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+ auto module = SymbolTable::getNearestSymbolTable(op);
+ auto isDouble = TypeResolver()(op.getType());
+ if (!isDouble.has_value())
+ return failure();
+
+ auto name = *isDouble ? doubleFunc : floatFunc;
+
+ auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+ SymbolTable::lookupSymbolIn(module, name));
+ if (!opFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+ auto funcTy = FunctionType::get(
+ rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+ opFunc =
+ rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+ opFunc.setPrivate();
+ }
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+ op->getOperands());
+ return success();
+ }
+
+private:
+ std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
+ patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+}
+
+namespace {
+struct ConvertComplexToROCDLPass
+ : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLPass::runOnOperation() {
+ auto module = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<func::FuncDialect>();
+ target.addIllegalOp<complex::AbsOp>();
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
new file mode 100644
index 0000000000000..618e9c238378c
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+ %rf = complex.abs %f : complex<f32>
+ // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+ %rd = complex.abs %d : complex<f64>
+ // CHECK: return %[[RF]], %[[RD]]
+ return %rf, %rd : f32, f64
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new MLIR pass to convert complex.abs operations into corresponding ROCDL library calls (i.e. __ocml_cabs_f32/__ocml_cabs_f64), along with the necessary test and build system support.
- Added a test to verify proper lowering in mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
- Implemented the conversion pass in mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp and integrated it in the MLIR and Flang build systems
- Updated CMake configurations and pass declarations to support the new pass
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir | Adds a test verifying the conversion of complex.abs to ROCDL calls |
mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp | Implements the conversion pass from Complex dialect to ROCDL calls |
mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt | Configures build for the new conversion library |
mlir/lib/Conversion/CMakeLists.txt | Registers the ComplexToROCDL subdirectory |
mlir/include/mlir/Conversion/Passes.td | Adds the pass definitions and documentation for ComplexToROCDL |
mlir/include/mlir/Conversion/Passes.h | Updates header inclusions to support the new pass |
mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h | Declares the conversion population function |
flang/lib/Optimizer/CodeGen/CodeGen.cpp | Integrates the new conversion pass into the code generation pipeline for AMDGCN targets |
flang/lib/Optimizer/CodeGen/CMakeLists.txt | Adds MLIRComplexToROCDL as a dependency for code generation |
mathConvertionPM.addPass(mlir::createConvertMathToROCDL()); | ||
mathConvertionPM.addPass(mlir::createConvertComplexToROCDL()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name 'mathConvertionPM' appears to be misspelled. Consider renaming it to 'mathConversionPM' for clarity.
mathConvertionPM.addPass(mlir::createConvertMathToROCDL()); | |
mathConvertionPM.addPass(mlir::createConvertComplexToROCDL()); | |
mathConversionPM.addPass(mlir::createConvertMathToROCDL()); | |
mathConversionPM.addPass(mlir::createConvertComplexToROCDL()); |
Copilot uses AI. Check for mistakes.
@llvm/pr-subscribers-mlir Author: Akash Banerjee (TIFitis) ChangesThis patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls. Full diff: https://github.com/llvm/llvm-project/pull/144926.diff 9 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..8b4ac18fba527 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
MLIRMathToLLVM
MLIRMathToLibm
MLIRMathToROCDL
+ MLIRComplexToROCDL
MLIROpenMPToLLVM
MLIROpenACCDialect
MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..f721b6232b0fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
// GPU library calls, the rest can be converted to LLVM intrinsics, which
// is handled in the mathToLLVM conversion. The lowering to libm calls is
// not needed since all math operations are handled this way.
- if (isAMDGCN)
+ if (isAMDGCN) {
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+ mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+ }
// Convert math::FPowI operations to inline implementation
// only if the exponent's width is greater than 32, otherwise,
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
new file mode 100644
index 0000000000000..ed65be9980408
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -0,0 +1,19 @@
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..67e8f5b99b67b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..8ad2341f93a15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
let dependentDialects = ["func::FuncDialect"];
}
+//===----------------------------------------------------------------------===//
+// ComplexToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
+ let summary = "Convert Complex dialect to ROCDL calls";
+ let description = [{
+ This pass converts supported Complex ops to calls to the AMD device library.
+ }];
+ let dependentDialects = ["func::FuncDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ComplexToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..4ad81553a4fa8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexCommon)
add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDL)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToSPIRV)
add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..54607250083d7
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDL
+ ComplexToROCDL.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRComplexDialect
+ MLIRFuncDialect
+ MLIRPass
+ MLIRTransformUtils
+ )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
new file mode 100644
index 0000000000000..cdfe2a6dfe874
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -0,0 +1,95 @@
+//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct FloatTypeResolver {
+ std::optional<bool> operator()(Type type) const {
+ auto elementType = cast<FloatType>(type);
+ if (!isa<Float32Type, Float64Type>(elementType))
+ return {};
+ return elementType.getIntOrFloatBitWidth() == 64;
+ }
+};
+
+template <typename Op, typename TypeResolver = FloatTypeResolver>
+struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+ StringRef doubleFunc, PatternBenefit benefit)
+ : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+ doubleFunc(doubleFunc) {}
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+ auto module = SymbolTable::getNearestSymbolTable(op);
+ auto isDouble = TypeResolver()(op.getType());
+ if (!isDouble.has_value())
+ return failure();
+
+ auto name = *isDouble ? doubleFunc : floatFunc;
+
+ auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+ SymbolTable::lookupSymbolIn(module, name));
+ if (!opFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+ auto funcTy = FunctionType::get(
+ rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+ opFunc =
+ rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+ opFunc.setPrivate();
+ }
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+ op->getOperands());
+ return success();
+ }
+
+private:
+ std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
+ patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+}
+
+namespace {
+struct ConvertComplexToROCDLPass
+ : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLPass::runOnOperation() {
+ auto module = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<func::FuncDialect>();
+ target.addIllegalOp<complex::AbsOp>();
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
new file mode 100644
index 0000000000000..618e9c238378c
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+ %rf = complex.abs %f : complex<f32>
+ // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+ %rd = complex.abs %d : complex<f64>
+ // CHECK: return %[[RF]], %[[RD]]
+ return %rf, %rd : f32, f64
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name of the pass should clearly indicate what it does: convert operations to function calls. The current name suggests that it converts operations to operations from the ROCDL dialect.
LINK_COMPONENTS | ||
Core |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be needed.
using namespace mlir; | ||
|
||
namespace { | ||
struct FloatTypeResolver { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document top-level entities.
doubleFunc(doubleFunc) {} | ||
|
||
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final { | ||
auto module = SymbolTable::getNearestSymbolTable(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expand auto
unless the type is obvious from statement-level context or impossible to spell.
Also avoid using module
as a name since it will become a reserved keyword with a C++ version bump.
} | ||
}; | ||
|
||
template <typename Op, typename TypeResolver = FloatTypeResolver> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The complexity with a template parameter looks unnecessary to perform a dynamic bitwidth check looks unnecessary.
@@ -0,0 +1,13 @@ | |||
// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make tests minimal by not running passes that are not being tested, such as the canonicalizer.
Also note that the NVGPU pipeline is doing something similar and it makes sense to align with that and reuse utilities like those from mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h. |
This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.