Skip to content

Commit 992ba52

Browse files
committed
[python] Remove unused MLIR components
We don't need to take everything from MLIR for our python bindings. This change cherry picks the upstream components our compiler depends on. The commit also cleans up some unnecessary code that ends up registering dialects more than once, and surfaces the `register_all_dialects` function to a less obscure location. Signed-off-by: boschmitt <[email protected]>
1 parent d893046 commit 992ba52

File tree

6 files changed

+48
-66
lines changed

6 files changed

+48
-66
lines changed

python/cudaq/kernel/ast_bridge.py

-2
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
123123
else:
124124
self.ctx = Context()
125125
register_all_dialects(self.ctx)
126-
quake.register_dialect(self.ctx)
127-
cc.register_dialect(self.ctx)
128126
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
129127
self.loc = Location.unknown(context=self.ctx)
130128
self.module = Module.create(loc=self.loc)

python/cudaq/kernel/kernel_builder.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
# We need static initializers to run in the CAPI `ExecutionEngine`,
3939
# so here we run a simple JIT compile at global scope
40-
with Context():
40+
with Context() as ctx:
41+
register_all_dialects(ctx)
4142
module = Module.parse(r"""
4243
llvm.func @none() {
4344
llvm.return
@@ -246,8 +247,6 @@ class PyKernel(object):
246247
def __init__(self, argTypeList):
247248
self.ctx = Context()
248249
register_all_dialects(self.ctx)
249-
quake.register_dialect(self.ctx)
250-
cc.register_dialect(self.ctx)
251250
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
252251

253252
self.metadata = {'conditionalOnMeasure': False}

python/cudaq/mlir/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
9+
from ._mlir_libs._quakeDialects import register_all_dialects

python/extension/CMakeLists.txt

+15-7
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
119119
RELATIVE_INSTALL_ROOT "../.."
120120
DECLARED_SOURCES
121121
CUDAQuantumPythonSources
122-
# TODO: Remove this in favor of showing fine grained registration once
123-
# available.
124-
MLIRPythonExtension.RegisterEverything
125122
MLIRPythonSources.Core
123+
MLIRPythonSources.Dialects.arith
124+
MLIRPythonSources.Dialects.builtin
125+
MLIRPythonSources.Dialects.cf
126+
MLIRPythonSources.Dialects.complex
127+
MLIRPythonSources.Dialects.func
128+
MLIRPythonSources.Dialects.math
129+
MLIRPythonSources.ExecutionEngine
126130
)
127131

128132
################################################################################
@@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules
134138
INSTALL_PREFIX "cudaq/mlir"
135139
DECLARED_SOURCES
136140
CUDAQuantumPythonSources
137-
# TODO: Remove this in favor of showing fine grained registration once
138-
# available.
139-
MLIRPythonExtension.RegisterEverything
140-
MLIRPythonSources
141+
MLIRPythonSources.Core
142+
MLIRPythonSources.Dialects.arith
143+
MLIRPythonSources.Dialects.builtin
144+
MLIRPythonSources.Dialects.cf
145+
MLIRPythonSources.Dialects.complex
146+
MLIRPythonSources.Dialects.func
147+
MLIRPythonSources.Dialects.math
148+
MLIRPythonSources.ExecutionEngine
141149
COMMON_CAPI_LINK_LIBS
142150
CUDAQuantumPythonCAPI
143151
)

python/runtime/mlir/py_register_dialects.cpp

+20-53
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,15 @@
66
* the terms of the Apache License 2.0 which accompanies this distribution. *
77
******************************************************************************/
88

9-
#include "mlir/Bindings/Python/PybindAdaptors.h"
10-
9+
#include "py_register_dialects.h"
1110
#include "cudaq/Optimizer/Builder/Intrinsics.h"
12-
#include "cudaq/Optimizer/CAPI/Dialects.h"
1311
#include "cudaq/Optimizer/CodeGen/Passes.h"
14-
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
15-
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
16-
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
1712
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
18-
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
1913
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
14+
#include "cudaq/Optimizer/InitAllDialects.h"
2015
#include "cudaq/Optimizer/Transforms/Passes.h"
21-
#include "mlir/InitAllDialects.h"
16+
#include "mlir/Bindings/Python/PybindAdaptors.h"
17+
#include "mlir/CAPI/IR.h"
2218
#include <fmt/core.h>
2319
#include <pybind11/complex.h>
2420
#include <pybind11/stl.h>
@@ -27,33 +23,9 @@ namespace py = pybind11;
2723
using namespace mlir::python::adaptors;
2824
using namespace mlir;
2925

30-
namespace cudaq {
31-
static bool registered = false;
32-
33-
void registerQuakeDialectAndTypes(py::module &m) {
26+
static void registerQuakeTypes(py::module &m) {
3427
auto quakeMod = m.def_submodule("quake");
3528

36-
quakeMod.def(
37-
"register_dialect",
38-
[](MlirContext context, bool load) {
39-
MlirDialectHandle handle = mlirGetDialectHandle__quake__();
40-
mlirDialectHandleRegisterDialect(handle, context);
41-
if (load) {
42-
mlirDialectHandleLoadDialect(handle, context);
43-
}
44-
45-
if (!registered) {
46-
cudaq::opt::registerOptCodeGenPasses();
47-
cudaq::opt::registerOptTransformsPasses();
48-
cudaq::opt::registerAggressiveEarlyInlining();
49-
cudaq::opt::registerUnrollingPipeline();
50-
cudaq::opt::registerTargetPipelines();
51-
cudaq::opt::registerMappingPipeline();
52-
registered = true;
53-
}
54-
},
55-
py::arg("context") = py::none(), py::arg("load") = true);
56-
5729
mlir_type_subclass(quakeMod, "RefType", [](MlirType type) {
5830
return unwrap(type).isa<quake::RefType>();
5931
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -144,21 +116,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
144116
});
145117
}
146118

147-
void registerCCDialectAndTypes(py::module &m) {
119+
static void registerCCTypes(py::module &m) {
148120

149121
auto ccMod = m.def_submodule("cc");
150122

151-
ccMod.def(
152-
"register_dialect",
153-
[](MlirContext context, bool load) {
154-
MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__();
155-
mlirDialectHandleRegisterDialect(ccHandle, context);
156-
if (load) {
157-
mlirDialectHandleLoadDialect(ccHandle, context);
158-
}
159-
},
160-
py::arg("context") = py::none(), py::arg("load") = true);
161-
162123
mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) {
163124
return unwrap(type).isa<cudaq::cc::CharspanType>();
164125
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -298,10 +259,7 @@ void registerCCDialectAndTypes(py::module &m) {
298259
});
299260
}
300261

301-
void bindRegisterDialects(py::module &mod) {
302-
registerQuakeDialectAndTypes(mod);
303-
registerCCDialectAndTypes(mod);
304-
262+
void cudaq::bindRegisterDialects(py::module &mod) {
305263
mod.def("load_intrinsic", [](MlirModule module, std::string name) {
306264
auto unwrapped = unwrap(module);
307265
cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody());
@@ -311,14 +269,24 @@ void bindRegisterDialects(py::module &mod) {
311269

312270
mod.def("register_all_dialects", [](MlirContext context) {
313271
DialectRegistry registry;
314-
registry.insert<quake::QuakeDialect, cudaq::cc::CCDialect>();
272+
cudaq::registerAllDialects(registry);
315273
cudaq::opt::registerCodeGenDialect(registry);
316-
registerAllDialects(registry);
317-
auto *mlirContext = unwrap(context);
274+
MLIRContext *mlirContext = unwrap(context);
318275
mlirContext->appendDialectRegistry(registry);
319276
mlirContext->loadAllAvailableDialects();
320277
});
321278

279+
// Register type and passes once, when the module is loaded.
280+
registerQuakeTypes(mod);
281+
registerCCTypes(mod);
282+
283+
cudaq::opt::registerOptCodeGenPasses();
284+
cudaq::opt::registerOptTransformsPasses();
285+
cudaq::opt::registerAggressiveEarlyInlining();
286+
cudaq::opt::registerUnrollingPipeline();
287+
cudaq::opt::registerTargetPipelines();
288+
cudaq::opt::registerMappingPipeline();
289+
322290
mod.def("gen_vector_of_complex_constant", [](MlirLocation loc,
323291
MlirModule module,
324292
std::string name,
@@ -330,4 +298,3 @@ void bindRegisterDialects(py::module &mod) {
330298
builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues);
331299
});
332300
}
333-
} // namespace cudaq

python/tests/mlir/bare.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
# RUN: PYTHONPATH=../../ python3 %s | FileCheck %s
1010

11+
from cudaq.mlir import register_all_dialects
1112
from cudaq.mlir.ir import *
1213
from cudaq.mlir.dialects import quake
1314
from cudaq.mlir.dialects import builtin, func, arith
1415

1516
with Context() as ctx:
16-
quake.register_dialect()
17+
register_all_dialects(ctx)
1718
m = Module.create(loc=Location.unknown())
1819
with InsertionPoint(m.body), Location.unknown():
1920
f = func.FuncOp('main', ([], []))

0 commit comments

Comments
 (0)