Skip to content

Commit 5b4059b

Browse files
committed
[python] Remove unused MLIR components
We don't need to take everything from MLIR for our python bindings. This change cherry picks the upstream components our compiler depends on. The commit also cleans up some unnecessary code that ends up registering dialects more than once, and surfaces the `register_all_dialects` function to a less obscure location. Signed-off-by: boschmitt <[email protected]>
1 parent c9d0e4c commit 5b4059b

File tree

6 files changed

+41
-65
lines changed

6 files changed

+41
-65
lines changed

python/cudaq/kernel/ast_bridge.py

-2
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
123123
else:
124124
self.ctx = Context()
125125
register_all_dialects(self.ctx)
126-
quake.register_dialect(self.ctx)
127-
cc.register_dialect(self.ctx)
128126
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
129127
self.loc = Location.unknown(context=self.ctx)
130128
self.module = Module.create(loc=self.loc)

python/cudaq/kernel/kernel_builder.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
# We need static initializers to run in the CAPI `ExecutionEngine`,
3939
# so here we run a simple JIT compile at global scope
40-
with Context():
40+
with Context() as ctx:
41+
register_all_dialects(ctx)
4142
module = Module.parse(r"""
4243
llvm.func @none() {
4344
llvm.return
@@ -246,8 +247,6 @@ class PyKernel(object):
246247
def __init__(self, argTypeList):
247248
self.ctx = Context()
248249
register_all_dialects(self.ctx)
249-
quake.register_dialect(self.ctx)
250-
cc.register_dialect(self.ctx)
251250
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
252251

253252
self.metadata = {'conditionalOnMeasure': False}

python/cudaq/mlir/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
9+
from ._mlir_libs._quakeDialects import register_all_dialects

python/extension/CMakeLists.txt

+15-7
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
119119
RELATIVE_INSTALL_ROOT "../.."
120120
DECLARED_SOURCES
121121
CUDAQuantumPythonSources
122-
# TODO: Remove this in favor of showing fine grained registration once
123-
# available.
124-
MLIRPythonExtension.RegisterEverything
125122
MLIRPythonSources.Core
123+
MLIRPythonSources.Dialects.arith
124+
MLIRPythonSources.Dialects.builtin
125+
MLIRPythonSources.Dialects.cf
126+
MLIRPythonSources.Dialects.complex
127+
MLIRPythonSources.Dialects.func
128+
MLIRPythonSources.Dialects.math
129+
MLIRPythonSources.ExecutionEngine
126130
)
127131

128132
################################################################################
@@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules
134138
INSTALL_PREFIX "cudaq/mlir"
135139
DECLARED_SOURCES
136140
CUDAQuantumPythonSources
137-
# TODO: Remove this in favor of showing fine grained registration once
138-
# available.
139-
MLIRPythonExtension.RegisterEverything
140-
MLIRPythonSources
141+
MLIRPythonSources.Core
142+
MLIRPythonSources.Dialects.arith
143+
MLIRPythonSources.Dialects.builtin
144+
MLIRPythonSources.Dialects.cf
145+
MLIRPythonSources.Dialects.complex
146+
MLIRPythonSources.Dialects.func
147+
MLIRPythonSources.Dialects.math
148+
MLIRPythonSources.ExecutionEngine
141149
COMMON_CAPI_LINK_LIBS
142150
CUDAQuantumPythonCAPI
143151
)

python/runtime/mlir/py_register_dialects.cpp

+13-52
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,13 @@
66
* the terms of the Apache License 2.0 which accompanies this distribution. *
77
******************************************************************************/
88

9-
#include "mlir/Bindings/Python/PybindAdaptors.h"
10-
119
#include "cudaq/Optimizer/Builder/Intrinsics.h"
12-
#include "cudaq/Optimizer/CAPI/Dialects.h"
13-
#include "cudaq/Optimizer/CodeGen/Passes.h"
14-
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
15-
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
16-
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
1710
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
18-
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
1911
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
20-
#include "cudaq/Optimizer/Transforms/Passes.h"
21-
#include "mlir/InitAllDialects.h"
12+
#include "cudaq/Optimizer/InitAllDialects.h"
13+
#include "cudaq/Optimizer/InitAllPasses.h"
14+
#include "mlir/Bindings/Python/PybindAdaptors.h"
15+
#include "mlir/CAPI/IR.h"
2216
#include <fmt/core.h>
2317
#include <pybind11/complex.h>
2418
#include <pybind11/stl.h>
@@ -28,32 +22,10 @@ using namespace mlir::python::adaptors;
2822
using namespace mlir;
2923

3024
namespace cudaq {
31-
static bool registered = false;
3225

33-
void registerQuakeDialectAndTypes(py::module &m) {
26+
void registerQuakeTypes(py::module &m) {
3427
auto quakeMod = m.def_submodule("quake");
3528

36-
quakeMod.def(
37-
"register_dialect",
38-
[](MlirContext context, bool load) {
39-
MlirDialectHandle handle = mlirGetDialectHandle__quake__();
40-
mlirDialectHandleRegisterDialect(handle, context);
41-
if (load) {
42-
mlirDialectHandleLoadDialect(handle, context);
43-
}
44-
45-
if (!registered) {
46-
cudaq::opt::registerOptCodeGenPasses();
47-
cudaq::opt::registerOptTransformsPasses();
48-
cudaq::opt::registerAggressiveEarlyInlining();
49-
cudaq::opt::registerUnrollingPipeline();
50-
cudaq::opt::registerTargetPipelines();
51-
cudaq::opt::registerMappingPipeline();
52-
registered = true;
53-
}
54-
},
55-
py::arg("context") = py::none(), py::arg("load") = true);
56-
5729
mlir_type_subclass(quakeMod, "RefType", [](MlirType type) {
5830
return unwrap(type).isa<quake::RefType>();
5931
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -143,21 +115,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
143115
});
144116
}
145117

146-
void registerCCDialectAndTypes(py::module &m) {
118+
void registerCCTypes(py::module &m) {
147119

148120
auto ccMod = m.def_submodule("cc");
149121

150-
ccMod.def(
151-
"register_dialect",
152-
[](MlirContext context, bool load) {
153-
MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__();
154-
mlirDialectHandleRegisterDialect(ccHandle, context);
155-
if (load) {
156-
mlirDialectHandleLoadDialect(ccHandle, context);
157-
}
158-
},
159-
py::arg("context") = py::none(), py::arg("load") = true);
160-
161122
mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) {
162123
return unwrap(type).isa<cudaq::cc::CharspanType>();
163124
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -298,9 +259,6 @@ void registerCCDialectAndTypes(py::module &m) {
298259
}
299260

300261
void bindRegisterDialects(py::module &mod) {
301-
registerQuakeDialectAndTypes(mod);
302-
registerCCDialectAndTypes(mod);
303-
304262
mod.def("load_intrinsic", [](MlirModule module, std::string name) {
305263
auto unwrapped = unwrap(module);
306264
cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody());
@@ -310,14 +268,17 @@ void bindRegisterDialects(py::module &mod) {
310268

311269
mod.def("register_all_dialects", [](MlirContext context) {
312270
DialectRegistry registry;
313-
registry.insert<quake::QuakeDialect, cudaq::cc::CCDialect>();
314-
cudaq::opt::registerCodeGenDialect(registry);
315-
registerAllDialects(registry);
316-
auto *mlirContext = unwrap(context);
271+
cudaq::registerAllDialects(registry);
272+
MLIRContext *mlirContext = unwrap(context);
317273
mlirContext->appendDialectRegistry(registry);
318274
mlirContext->loadAllAvailableDialects();
319275
});
320276

277+
// Register type as passes once, when the module is loaded.
278+
registerQuakeTypes(mod);
279+
registerCCTypes(mod);
280+
cudaq::registerAllPasses();
281+
321282
mod.def("gen_vector_of_complex_constant", [](MlirLocation loc,
322283
MlirModule module,
323284
std::string name,

python/tests/mlir/bare.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
# RUN: PYTHONPATH=../../ python3 %s | FileCheck %s
1010

11+
from cudaq.mlir import register_all_dialects
1112
from cudaq.mlir.ir import *
1213
from cudaq.mlir.dialects import quake
1314
from cudaq.mlir.dialects import builtin, func, arith
1415

1516
with Context() as ctx:
16-
quake.register_dialect()
17+
register_all_dialects(ctx)
1718
m = Module.create(loc=Location.unknown())
1819
with InsertionPoint(m.body), Location.unknown():
1920
f = func.FuncOp('main', ([], []))

0 commit comments

Comments
 (0)