Skip to content

Commit a156bdc

Browse files
committed
[python] Remove unused MLIR components
We don't need to take everything from MLIR for our python bindings. This change cherry picks the upstream components our compiler depends on. The commit also cleans up some unnecessary code that ends up registering dialects more than once, and surfaces the `register_all_dialects` function to a less obscure location. Signed-off-by: boschmitt <[email protected]>
1 parent 3b0f04c commit a156bdc

File tree

6 files changed

+54
-65
lines changed

6 files changed

+54
-65
lines changed

python/cudaq/kernel/ast_bridge.py

-2
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
123123
else:
124124
self.ctx = Context()
125125
register_all_dialects(self.ctx)
126-
quake.register_dialect(self.ctx)
127-
cc.register_dialect(self.ctx)
128126
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
129127
self.loc = Location.unknown(context=self.ctx)
130128
self.module = Module.create(loc=self.loc)

python/cudaq/kernel/kernel_builder.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
# We need static initializers to run in the CAPI `ExecutionEngine`,
3939
# so here we run a simple JIT compile at global scope
40-
with Context():
40+
with Context() as ctx:
41+
register_all_dialects(ctx)
4142
module = Module.parse(r"""
4243
llvm.func @none() {
4344
llvm.return
@@ -246,8 +247,6 @@ class PyKernel(object):
246247
def __init__(self, argTypeList):
247248
self.ctx = Context()
248249
register_all_dialects(self.ctx)
249-
quake.register_dialect(self.ctx)
250-
cc.register_dialect(self.ctx)
251250
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
252251

253252
self.metadata = {'conditionalOnMeasure': False}

python/cudaq/mlir/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
9+
from ._mlir_libs._quakeDialects import register_all_dialects

python/extension/CMakeLists.txt

+15-7
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
119119
RELATIVE_INSTALL_ROOT "../.."
120120
DECLARED_SOURCES
121121
CUDAQuantumPythonSources
122-
# TODO: Remove this in favor of showing fine grained registration once
123-
# available.
124-
MLIRPythonExtension.RegisterEverything
125122
MLIRPythonSources.Core
123+
MLIRPythonSources.Dialects.arith
124+
MLIRPythonSources.Dialects.builtin
125+
MLIRPythonSources.Dialects.cf
126+
MLIRPythonSources.Dialects.complex
127+
MLIRPythonSources.Dialects.func
128+
MLIRPythonSources.Dialects.math
129+
MLIRPythonSources.ExecutionEngine
126130
)
127131

128132
################################################################################
@@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules
134138
INSTALL_PREFIX "cudaq/mlir"
135139
DECLARED_SOURCES
136140
CUDAQuantumPythonSources
137-
# TODO: Remove this in favor of showing fine grained registration once
138-
# available.
139-
MLIRPythonExtension.RegisterEverything
140-
MLIRPythonSources
141+
MLIRPythonSources.Core
142+
MLIRPythonSources.Dialects.arith
143+
MLIRPythonSources.Dialects.builtin
144+
MLIRPythonSources.Dialects.cf
145+
MLIRPythonSources.Dialects.complex
146+
MLIRPythonSources.Dialects.func
147+
MLIRPythonSources.Dialects.math
148+
MLIRPythonSources.ExecutionEngine
141149
COMMON_CAPI_LINK_LIBS
142150
CUDAQuantumPythonCAPI
143151
)

python/runtime/mlir/py_register_dialects.cpp

+26-52
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66
* the terms of the Apache License 2.0 which accompanies this distribution. *
77
******************************************************************************/
88

9-
#include "mlir/Bindings/Python/PybindAdaptors.h"
10-
9+
#include "py_register_dialects.h"
1110
#include "cudaq/Optimizer/Builder/Intrinsics.h"
12-
#include "cudaq/Optimizer/CAPI/Dialects.h"
1311
#include "cudaq/Optimizer/CodeGen/Passes.h"
14-
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
15-
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
16-
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
1712
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
18-
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
1913
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
14+
#include "cudaq/Optimizer/InitAllDialects.h"
2015
#include "cudaq/Optimizer/Transforms/Passes.h"
21-
#include "mlir/InitAllDialects.h"
16+
#include "mlir/Bindings/Python/PybindAdaptors.h"
17+
#include "mlir/CAPI/IR.h"
18+
#include "mlir/Transforms/Passes.h"
2219
#include <fmt/core.h>
2320
#include <pybind11/complex.h>
2421
#include <pybind11/stl.h>
@@ -27,33 +24,9 @@ namespace py = pybind11;
2724
using namespace mlir::python::adaptors;
2825
using namespace mlir;
2926

30-
namespace cudaq {
31-
static bool registered = false;
32-
33-
void registerQuakeDialectAndTypes(py::module &m) {
27+
static void registerQuakeTypes(py::module &m) {
3428
auto quakeMod = m.def_submodule("quake");
3529

36-
quakeMod.def(
37-
"register_dialect",
38-
[](MlirContext context, bool load) {
39-
MlirDialectHandle handle = mlirGetDialectHandle__quake__();
40-
mlirDialectHandleRegisterDialect(handle, context);
41-
if (load) {
42-
mlirDialectHandleLoadDialect(handle, context);
43-
}
44-
45-
if (!registered) {
46-
cudaq::opt::registerOptCodeGenPasses();
47-
cudaq::opt::registerOptTransformsPasses();
48-
cudaq::opt::registerAggressiveEarlyInlining();
49-
cudaq::opt::registerUnrollingPipeline();
50-
cudaq::opt::registerTargetPipelines();
51-
cudaq::opt::registerMappingPipeline();
52-
registered = true;
53-
}
54-
},
55-
py::arg("context") = py::none(), py::arg("load") = true);
56-
5730
mlir_type_subclass(quakeMod, "RefType", [](MlirType type) {
5831
return unwrap(type).isa<quake::RefType>();
5932
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -144,21 +117,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
144117
});
145118
}
146119

147-
void registerCCDialectAndTypes(py::module &m) {
120+
static void registerCCTypes(py::module &m) {
148121

149122
auto ccMod = m.def_submodule("cc");
150123

151-
ccMod.def(
152-
"register_dialect",
153-
[](MlirContext context, bool load) {
154-
MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__();
155-
mlirDialectHandleRegisterDialect(ccHandle, context);
156-
if (load) {
157-
mlirDialectHandleLoadDialect(ccHandle, context);
158-
}
159-
},
160-
py::arg("context") = py::none(), py::arg("load") = true);
161-
162124
mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) {
163125
return unwrap(type).isa<cudaq::cc::CharspanType>();
164126
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -298,10 +260,9 @@ void registerCCDialectAndTypes(py::module &m) {
298260
});
299261
}
300262

301-
void bindRegisterDialects(py::module &mod) {
302-
registerQuakeDialectAndTypes(mod);
303-
registerCCDialectAndTypes(mod);
263+
static bool registered = false;
304264

265+
void cudaq::bindRegisterDialects(py::module &mod) {
305266
mod.def("load_intrinsic", [](MlirModule module, std::string name) {
306267
auto unwrapped = unwrap(module);
307268
cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody());
@@ -311,14 +272,28 @@ void bindRegisterDialects(py::module &mod) {
311272

312273
mod.def("register_all_dialects", [](MlirContext context) {
313274
DialectRegistry registry;
314-
registry.insert<quake::QuakeDialect, cudaq::cc::CCDialect>();
275+
cudaq::registerAllDialects(registry);
315276
cudaq::opt::registerCodeGenDialect(registry);
316-
registerAllDialects(registry);
317-
auto *mlirContext = unwrap(context);
277+
MLIRContext *mlirContext = unwrap(context);
318278
mlirContext->appendDialectRegistry(registry);
319279
mlirContext->loadAllAvailableDialects();
320280
});
321281

282+
// Register type and passes once, when the module is loaded.
283+
registerQuakeTypes(mod);
284+
registerCCTypes(mod);
285+
286+
if (!registered) {
287+
mlir::registerTransformsPasses();
288+
cudaq::opt::registerOptCodeGenPasses();
289+
cudaq::opt::registerOptTransformsPasses();
290+
cudaq::opt::registerAggressiveEarlyInlining();
291+
cudaq::opt::registerUnrollingPipeline();
292+
cudaq::opt::registerTargetPipelines();
293+
cudaq::opt::registerMappingPipeline();
294+
registered = true;
295+
}
296+
322297
mod.def("gen_vector_of_complex_constant", [](MlirLocation loc,
323298
MlirModule module,
324299
std::string name,
@@ -330,4 +305,3 @@ void bindRegisterDialects(py::module &mod) {
330305
builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues);
331306
});
332307
}
333-
} // namespace cudaq

python/tests/mlir/bare.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
# RUN: PYTHONPATH=../../ python3 %s | FileCheck %s
1010

11+
from cudaq.mlir import register_all_dialects
1112
from cudaq.mlir.ir import *
1213
from cudaq.mlir.dialects import quake
1314
from cudaq.mlir.dialects import builtin, func, arith
1415

1516
with Context() as ctx:
16-
quake.register_dialect()
17+
register_all_dialects(ctx)
1718
m = Module.create(loc=Location.unknown())
1819
with InsertionPoint(m.body), Location.unknown():
1920
f = func.FuncOp('main', ([], []))

0 commit comments

Comments
 (0)