Skip to content

[MLIR] Add ComplexTOROCDL pass #144926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

TIFitis
Copy link
Member

@TIFitis TIFitis commented Jun 19, 2025

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.
@TIFitis TIFitis requested review from krzysz00, jsjodin and Copilot June 19, 2025 16:37
@llvmbot llvmbot added mlir flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Jun 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-flang-codegen

@llvm/pr-subscribers-flang-fir-hlfir

Author: Akash Banerjee (TIFitis)

Changes

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.


Full diff: https://github.com/llvm/llvm-project/pull/144926.diff

9 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+4-1)
  • (added) mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h (+19)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+12)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt (+18)
  • (added) mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp (+95)
  • (added) mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir (+13)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..8b4ac18fba527 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToLLVM
   MLIRMathToLibm
   MLIRMathToROCDL
+  MLIRComplexToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..f721b6232b0fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
     // GPU library calls, the rest can be converted to LLVM intrinsics, which
     // is handled in the mathToLLVM conversion. The lowering to libm calls is
     // not needed since all math operations are handled this way.
-    if (isAMDGCN)
+    if (isAMDGCN) {
       mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+      mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+    }
 
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
new file mode 100644
index 0000000000000..ed65be9980408
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -0,0 +1,19 @@
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..67e8f5b99b67b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..8ad2341f93a15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
   let dependentDialects = ["func::FuncDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
+  let summary = "Convert Complex dialect to ROCDL calls";
+  let description = [{
+    This pass converts supported Complex ops to calls to the AMD device library.
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..4ad81553a4fa8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexCommon)
 add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDL)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToSPIRV)
 add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..54607250083d7
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDL
+  ComplexToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRComplexDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
new file mode 100644
index 0000000000000..cdfe2a6dfe874
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -0,0 +1,95 @@
+//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct FloatTypeResolver {
+  std::optional<bool> operator()(Type type) const {
+    auto elementType = cast<FloatType>(type);
+    if (!isa<Float32Type, Float64Type>(elementType))
+      return {};
+    return elementType.getIntOrFloatBitWidth() == 64;
+  }
+};
+
+template <typename Op, typename TypeResolver = FloatTypeResolver>
+struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+  ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+                      StringRef doubleFunc, PatternBenefit benefit)
+      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+        doubleFunc(doubleFunc) {}
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+    auto module = SymbolTable::getNearestSymbolTable(op);
+    auto isDouble = TypeResolver()(op.getType());
+    if (!isDouble.has_value())
+      return failure();
+
+    auto name = *isDouble ? doubleFunc : floatFunc;
+
+    auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+        SymbolTable::lookupSymbolIn(module, name));
+    if (!opFunc) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+      auto funcTy = FunctionType::get(
+          rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+      opFunc =
+          rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+      opFunc.setPrivate();
+    }
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+                                              op->getOperands());
+    return success();
+  }
+
+private:
+  std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
+      patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+}
+
+namespace {
+struct ConvertComplexToROCDLPass
+    : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLPass::runOnOperation() {
+  auto module = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<func::FuncDialect>();
+  target.addIllegalOp<complex::AbsOp>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
new file mode 100644
index 0000000000000..618e9c238378c
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+  // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+  %rf = complex.abs %f : complex<f32>
+  // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+  %rd = complex.abs %d : complex<f64>
+  // CHECK: return %[[RF]], %[[RD]]
+  return %rf, %rd : f32, f64
+}

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a new MLIR pass to convert complex.abs operations into corresponding ROCDL library calls (i.e. __ocml_cabs_f32/__ocml_cabs_f64), along with the necessary test and build system support.

  • Added a test to verify proper lowering in mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
  • Implemented the conversion pass in mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp and integrated it in the MLIR and Flang build systems
  • Updated CMake configurations and pass declarations to support the new pass

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir Adds a test verifying the conversion of complex.abs to ROCDL calls
mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp Implements the conversion pass from Complex dialect to ROCDL calls
mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt Configures build for the new conversion library
mlir/lib/Conversion/CMakeLists.txt Registers the ComplexToROCDL subdirectory
mlir/include/mlir/Conversion/Passes.td Adds the pass definitions and documentation for ComplexToROCDL
mlir/include/mlir/Conversion/Passes.h Updates header inclusions to support the new pass
mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h Declares the conversion population function
flang/lib/Optimizer/CodeGen/CodeGen.cpp Integrates the new conversion pass into the code generation pipeline for AMDGCN targets
flang/lib/Optimizer/CodeGen/CMakeLists.txt Adds MLIRComplexToROCDL as a dependency for code generation

Comment on lines 4110 to +4111
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
Copy link
Preview

Copilot AI Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name 'mathConvertionPM' appears to be misspelled. Consider renaming it to 'mathConversionPM' for clarity.

Suggested change
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
mathConversionPM.addPass(mlir::createConvertMathToROCDL());
mathConversionPM.addPass(mlir::createConvertComplexToROCDL());

Copilot uses AI. Check for mistakes.

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir

Author: Akash Banerjee (TIFitis)

Changes

This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.


Full diff: https://github.com/llvm/llvm-project/pull/144926.diff

9 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+4-1)
  • (added) mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h (+19)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+12)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt (+18)
  • (added) mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp (+95)
  • (added) mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir (+13)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 980307db315d9..8b4ac18fba527 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToLLVM
   MLIRMathToLibm
   MLIRMathToROCDL
+  MLIRComplexToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a3de3ae9d116a..f721b6232b0fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
     // GPU library calls, the rest can be converted to LLVM intrinsics, which
     // is handled in the mathToLLVM conversion. The lowering to libm calls is
     // not needed since all math operations are handled this way.
-    if (isAMDGCN)
+    if (isAMDGCN) {
       mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+      mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
+    }
 
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
diff --git a/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
new file mode 100644
index 0000000000000..ed65be9980408
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h
@@ -0,0 +1,19 @@
+#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Complex to ROCDL
+/// calls.
+void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..67e8f5b99b67b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..8ad2341f93a15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
   let dependentDialects = ["func::FuncDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
+  let summary = "Convert Complex dialect to ROCDL calls";
+  let description = [{
+    This pass converts supported Complex ops to calls to the AMD device library.
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..4ad81553a4fa8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexCommon)
 add_subdirectory(ComplexToLibm)
+add_subdirectory(ComplexToROCDL)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToSPIRV)
 add_subdirectory(ComplexToStandard)
diff --git a/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..54607250083d7
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRComplexToROCDL
+  ComplexToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRComplexDialect
+  MLIRFuncDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
new file mode 100644
index 0000000000000..cdfe2a6dfe874
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp
@@ -0,0 +1,95 @@
+//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct FloatTypeResolver {
+  std::optional<bool> operator()(Type type) const {
+    auto elementType = cast<FloatType>(type);
+    if (!isa<Float32Type, Float64Type>(elementType))
+      return {};
+    return elementType.getIntOrFloatBitWidth() == 64;
+  }
+};
+
+template <typename Op, typename TypeResolver = FloatTypeResolver>
+struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+  ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
+                      StringRef doubleFunc, PatternBenefit benefit)
+      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+        doubleFunc(doubleFunc) {}
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
+    auto module = SymbolTable::getNearestSymbolTable(op);
+    auto isDouble = TypeResolver()(op.getType());
+    if (!isDouble.has_value())
+      return failure();
+
+    auto name = *isDouble ? doubleFunc : floatFunc;
+
+    auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
+        SymbolTable::lookupSymbolIn(module, name));
+    if (!opFunc) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module->getRegion(0).front());
+      auto funcTy = FunctionType::get(
+          rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+      opFunc =
+          rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
+      opFunc.setPrivate();
+    }
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
+                                              op->getOperands());
+    return success();
+  }
+
+private:
+  std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
+      patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
+}
+
+namespace {
+struct ConvertComplexToROCDLPass
+    : public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToROCDLPass::runOnOperation() {
+  auto module = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<func::FuncDialect>();
+  target.addIllegalOp<complex::AbsOp>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
new file mode 100644
index 0000000000000..618e9c238378c
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToROCDL/complex-to-rocdl.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
+
+// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+
+func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+  // CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
+  %rf = complex.abs %f : complex<f32>
+  // CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
+  %rd = complex.abs %d : complex<f64>
+  // CHECK: return %[[RF]], %[[RD]]
+  return %rf, %rd : f32, f64
+}

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the pass should clearly indicate what it does: convert operations to function calls. The current name suggests that it converts operations to operations from the ROCDL dialect.

Comment on lines +10 to +11
LINK_COMPONENTS
Core
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be needed.

using namespace mlir;

namespace {
struct FloatTypeResolver {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document top-level entities.

doubleFunc(doubleFunc) {}

LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
auto module = SymbolTable::getNearestSymbolTable(op);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand auto unless the type is obvious from statement-level context or impossible to spell.

Also avoid using module as a name since it will become a reserved keyword with a C++ version bump.

}
};

template <typename Op, typename TypeResolver = FloatTypeResolver>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The complexity with a template parameter looks unnecessary to perform a dynamic bitwidth check looks unnecessary.

@@ -0,0 +1,13 @@
// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make tests minimal by not running passes that are not being tested, such as the canonicalizer.

@ftynse
Copy link
Member

ftynse commented Jun 19, 2025

Also note that the NVGPU pipeline is doing something similar and it makes sense to align with that and reuse utilities like those from mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants