Skip to content

Commit f6a8985

Browse files
committed
[python] Remove unused MLIR components
We don't need to take everything from MLIR for our python bindings. This change cherry picks the upstream components our compiler depends on. The commit also cleans up some unnecessary code that ends up registering dialects more than once, and surfaces the `register_all_dialects` function to a less obscure location. Signed-off-by: boschmitt <[email protected]>
1 parent 7903c8b commit f6a8985

File tree

6 files changed

+43
-69
lines changed

6 files changed

+43
-69
lines changed

python/cudaq/kernel/ast_bridge.py

-2
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
123123
else:
124124
self.ctx = Context()
125125
register_all_dialects(self.ctx)
126-
quake.register_dialect(self.ctx)
127-
cc.register_dialect(self.ctx)
128126
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
129127
self.loc = Location.unknown(context=self.ctx)
130128
self.module = Module.create(loc=self.loc)

python/cudaq/kernel/kernel_builder.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
# We need static initializers to run in the CAPI `ExecutionEngine`,
3939
# so here we run a simple JIT compile at global scope
40-
with Context():
40+
with Context() as ctx:
41+
register_all_dialects(ctx)
4142
module = Module.parse(r"""
4243
llvm.func @none() {
4344
llvm.return
@@ -246,8 +247,6 @@ class PyKernel(object):
246247
def __init__(self, argTypeList):
247248
self.ctx = Context()
248249
register_all_dialects(self.ctx)
249-
quake.register_dialect(self.ctx)
250-
cc.register_dialect(self.ctx)
251250
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
252251

253252
self.metadata = {'conditionalOnMeasure': False}

python/cudaq/mlir/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
9+
from ._mlir_libs._quakeDialects import register_all_dialects

python/extension/CMakeLists.txt

+15-7
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
119119
RELATIVE_INSTALL_ROOT "../.."
120120
DECLARED_SOURCES
121121
CUDAQuantumPythonSources
122-
# TODO: Remove this in favor of showing fine grained registration once
123-
# available.
124-
MLIRPythonExtension.RegisterEverything
125122
MLIRPythonSources.Core
123+
MLIRPythonSources.Dialects.arith
124+
MLIRPythonSources.Dialects.builtin
125+
MLIRPythonSources.Dialects.cf
126+
MLIRPythonSources.Dialects.complex
127+
MLIRPythonSources.Dialects.func
128+
MLIRPythonSources.Dialects.math
129+
MLIRPythonSources.ExecutionEngine
126130
)
127131

128132
################################################################################
@@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules
134138
INSTALL_PREFIX "cudaq/mlir"
135139
DECLARED_SOURCES
136140
CUDAQuantumPythonSources
137-
# TODO: Remove this in favor of showing fine grained registration once
138-
# available.
139-
MLIRPythonExtension.RegisterEverything
140-
MLIRPythonSources
141+
MLIRPythonSources.Core
142+
MLIRPythonSources.Dialects.arith
143+
MLIRPythonSources.Dialects.builtin
144+
MLIRPythonSources.Dialects.cf
145+
MLIRPythonSources.Dialects.complex
146+
MLIRPythonSources.Dialects.func
147+
MLIRPythonSources.Dialects.math
148+
MLIRPythonSources.ExecutionEngine
141149
COMMON_CAPI_LINK_LIBS
142150
CUDAQuantumPythonCAPI
143151
)

python/runtime/mlir/py_register_dialects.cpp

+15-56
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,14 @@
66
* the terms of the Apache License 2.0 which accompanies this distribution. *
77
******************************************************************************/
88

9-
#include "mlir/Bindings/Python/PybindAdaptors.h"
10-
9+
#include "py_register_dialects.h"
1110
#include "cudaq/Optimizer/Builder/Intrinsics.h"
12-
#include "cudaq/Optimizer/CAPI/Dialects.h"
13-
#include "cudaq/Optimizer/CodeGen/Passes.h"
14-
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
15-
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
16-
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
1711
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
18-
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
1912
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
20-
#include "cudaq/Optimizer/Transforms/Passes.h"
21-
#include "mlir/InitAllDialects.h"
13+
#include "cudaq/Optimizer/InitAllDialects.h"
14+
#include "cudaq/Optimizer/InitAllPasses.h"
15+
#include "mlir/Bindings/Python/PybindAdaptors.h"
16+
#include "mlir/CAPI/IR.h"
2217
#include <fmt/core.h>
2318
#include <pybind11/complex.h>
2419
#include <pybind11/stl.h>
@@ -27,33 +22,9 @@ namespace py = pybind11;
2722
using namespace mlir::python::adaptors;
2823
using namespace mlir;
2924

30-
namespace cudaq {
31-
static bool registered = false;
32-
33-
void registerQuakeDialectAndTypes(py::module &m) {
25+
static void registerQuakeTypes(py::module &m) {
3426
auto quakeMod = m.def_submodule("quake");
3527

36-
quakeMod.def(
37-
"register_dialect",
38-
[](MlirContext context, bool load) {
39-
MlirDialectHandle handle = mlirGetDialectHandle__quake__();
40-
mlirDialectHandleRegisterDialect(handle, context);
41-
if (load) {
42-
mlirDialectHandleLoadDialect(handle, context);
43-
}
44-
45-
if (!registered) {
46-
cudaq::opt::registerOptCodeGenPasses();
47-
cudaq::opt::registerOptTransformsPasses();
48-
cudaq::opt::registerAggressiveEarlyInlining();
49-
cudaq::opt::registerUnrollingPipeline();
50-
cudaq::opt::registerTargetPipelines();
51-
cudaq::opt::registerMappingPipeline();
52-
registered = true;
53-
}
54-
},
55-
py::arg("context") = py::none(), py::arg("load") = true);
56-
5728
mlir_type_subclass(quakeMod, "RefType", [](MlirType type) {
5829
return unwrap(type).isa<quake::RefType>();
5930
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -144,21 +115,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
144115
});
145116
}
146117

147-
void registerCCDialectAndTypes(py::module &m) {
118+
static void registerCCTypes(py::module &m) {
148119

149120
auto ccMod = m.def_submodule("cc");
150121

151-
ccMod.def(
152-
"register_dialect",
153-
[](MlirContext context, bool load) {
154-
MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__();
155-
mlirDialectHandleRegisterDialect(ccHandle, context);
156-
if (load) {
157-
mlirDialectHandleLoadDialect(ccHandle, context);
158-
}
159-
},
160-
py::arg("context") = py::none(), py::arg("load") = true);
161-
162122
mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) {
163123
return unwrap(type).isa<cudaq::cc::CharspanType>();
164124
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -298,10 +258,7 @@ void registerCCDialectAndTypes(py::module &m) {
298258
});
299259
}
300260

301-
void bindRegisterDialects(py::module &mod) {
302-
registerQuakeDialectAndTypes(mod);
303-
registerCCDialectAndTypes(mod);
304-
261+
void cudaq::bindRegisterDialects(py::module &mod) {
305262
mod.def("load_intrinsic", [](MlirModule module, std::string name) {
306263
auto unwrapped = unwrap(module);
307264
cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody());
@@ -311,14 +268,17 @@ void bindRegisterDialects(py::module &mod) {
311268

312269
mod.def("register_all_dialects", [](MlirContext context) {
313270
DialectRegistry registry;
314-
registry.insert<quake::QuakeDialect, cudaq::cc::CCDialect>();
315-
cudaq::opt::registerCodeGenDialect(registry);
316-
registerAllDialects(registry);
317-
auto *mlirContext = unwrap(context);
271+
cudaq::registerAllDialects(registry);
272+
MLIRContext *mlirContext = unwrap(context);
318273
mlirContext->appendDialectRegistry(registry);
319274
mlirContext->loadAllAvailableDialects();
320275
});
321276

277+
// Register type as passes once, when the module is loaded.
278+
registerQuakeTypes(mod);
279+
registerCCTypes(mod);
280+
cudaq::registerAllPasses();
281+
322282
mod.def("gen_vector_of_complex_constant", [](MlirLocation loc,
323283
MlirModule module,
324284
std::string name,
@@ -330,4 +290,3 @@ void bindRegisterDialects(py::module &mod) {
330290
builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues);
331291
});
332292
}
333-
} // namespace cudaq

python/tests/mlir/bare.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
# RUN: PYTHONPATH=../../ python3 %s | FileCheck %s
1010

11+
from cudaq.mlir import register_all_dialects
1112
from cudaq.mlir.ir import *
1213
from cudaq.mlir.dialects import quake
1314
from cudaq.mlir.dialects import builtin, func, arith
1415

1516
with Context() as ctx:
16-
quake.register_dialect()
17+
register_all_dialects(ctx)
1718
m = Module.create(loc=Location.unknown())
1819
with InsertionPoint(m.body), Location.unknown():
1920
f = func.FuncOp('main', ([], []))

0 commit comments

Comments
 (0)