Skip to content

Commit 479db23

Browse files
committed
[python] Remove unused MLIR components
We don't need to take everything from MLIR for our python bindings. This change cherry picks the upstream components our compiler depends on. The commit also cleans up some unnecessary code that ends up registering dialects more than once, and surfaces the `register_all_dialects` function to a less obscure location. Signed-off-by: boschmitt <[email protected]>
1 parent 0a8c67e commit 479db23

File tree

6 files changed

+46
-58
lines changed

6 files changed

+46
-58
lines changed

python/cudaq/kernel/ast_bridge.py

-2
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
123123
else:
124124
self.ctx = Context()
125125
register_all_dialects(self.ctx)
126-
quake.register_dialect(self.ctx)
127-
cc.register_dialect(self.ctx)
128126
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
129127
self.loc = Location.unknown(context=self.ctx)
130128
self.module = Module.create(loc=self.loc)

python/cudaq/kernel/kernel_builder.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
# We need static initializers to run in the CAPI `ExecutionEngine`,
3939
# so here we run a simple JIT compile at global scope
40-
with Context():
40+
with Context() as ctx:
41+
register_all_dialects(ctx)
4142
module = Module.parse(r"""
4243
llvm.func @none() {
4344
llvm.return
@@ -246,8 +247,6 @@ class PyKernel(object):
246247
def __init__(self, argTypeList):
247248
self.ctx = Context()
248249
register_all_dialects(self.ctx)
249-
quake.register_dialect(self.ctx)
250-
cc.register_dialect(self.ctx)
251250
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
252251

253252
self.metadata = {'conditionalOnMeasure': False}

python/cudaq/mlir/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
9+
from ._mlir_libs._quakeDialects import register_all_dialects

python/extension/CMakeLists.txt

+15-7
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,14 @@ add_mlir_python_common_capi_library(CUDAQuantumPythonCAPI
119119
RELATIVE_INSTALL_ROOT "../.."
120120
DECLARED_SOURCES
121121
CUDAQuantumPythonSources
122-
# TODO: Remove this in favor of showing fine grained registration once
123-
# available.
124-
MLIRPythonExtension.RegisterEverything
125122
MLIRPythonSources.Core
123+
MLIRPythonSources.Dialects.arith
124+
MLIRPythonSources.Dialects.builtin
125+
MLIRPythonSources.Dialects.cf
126+
MLIRPythonSources.Dialects.complex
127+
MLIRPythonSources.Dialects.func
128+
MLIRPythonSources.Dialects.math
129+
MLIRPythonSources.ExecutionEngine
126130
)
127131

128132
################################################################################
@@ -134,10 +138,14 @@ add_mlir_python_modules(CUDAQuantumPythonModules
134138
INSTALL_PREFIX "cudaq/mlir"
135139
DECLARED_SOURCES
136140
CUDAQuantumPythonSources
137-
# TODO: Remove this in favor of showing fine grained registration once
138-
# available.
139-
MLIRPythonExtension.RegisterEverything
140-
MLIRPythonSources
141+
MLIRPythonSources.Core
142+
MLIRPythonSources.Dialects.arith
143+
MLIRPythonSources.Dialects.builtin
144+
MLIRPythonSources.Dialects.cf
145+
MLIRPythonSources.Dialects.complex
146+
MLIRPythonSources.Dialects.func
147+
MLIRPythonSources.Dialects.math
148+
MLIRPythonSources.ExecutionEngine
141149
COMMON_CAPI_LINK_LIBS
142150
CUDAQuantumPythonCAPI
143151
)

python/runtime/mlir/py_register_dialects.cpp

+18-45
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,15 @@
66
* the terms of the Apache License 2.0 which accompanies this distribution. *
77
******************************************************************************/
88

9+
#include "py_register_dialects.h"
910
#include "cudaq/Optimizer/Builder/Intrinsics.h"
10-
#include "cudaq/Optimizer/CAPI/Dialects.h"
1111
#include "cudaq/Optimizer/CodeGen/Passes.h"
12-
#include "cudaq/Optimizer/CodeGen/Pipelines.h"
13-
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
14-
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
1512
#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
16-
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
1713
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
14+
#include "cudaq/Optimizer/InitAllDialects.h"
1815
#include "cudaq/Optimizer/InitAllPasses.h"
19-
#include "cudaq/Optimizer/Transforms/Passes.h"
2016
#include "mlir/Bindings/Python/PybindAdaptors.h"
21-
#include "mlir/InitAllDialects.h"
17+
#include "mlir/CAPI/IR.h"
2218
#include <fmt/core.h>
2319
#include <pybind11/complex.h>
2420
#include <pybind11/stl.h>
@@ -27,27 +23,9 @@ namespace py = pybind11;
2723
using namespace mlir::python::adaptors;
2824
using namespace mlir;
2925

30-
namespace cudaq {
31-
static bool registered = false;
32-
33-
void registerQuakeDialectAndTypes(py::module &m) {
26+
static void registerQuakeTypes(py::module &m) {
3427
auto quakeMod = m.def_submodule("quake");
3528

36-
quakeMod.def(
37-
"register_dialect",
38-
[](MlirContext context, bool load) {
39-
MlirDialectHandle handle = mlirGetDialectHandle__quake__();
40-
mlirDialectHandleRegisterDialect(handle, context);
41-
if (load)
42-
mlirDialectHandleLoadDialect(handle, context);
43-
44-
if (!registered) {
45-
cudaq::registerCudaqPassesAndPipelines();
46-
registered = true;
47-
}
48-
},
49-
py::arg("context") = py::none(), py::arg("load") = true);
50-
5129
mlir_type_subclass(quakeMod, "RefType", [](MlirType type) {
5230
return unwrap(type).isa<quake::RefType>();
5331
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -138,21 +116,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
138116
});
139117
}
140118

141-
void registerCCDialectAndTypes(py::module &m) {
119+
static void registerCCTypes(py::module &m) {
142120

143121
auto ccMod = m.def_submodule("cc");
144122

145-
ccMod.def(
146-
"register_dialect",
147-
[](MlirContext context, bool load) {
148-
MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__();
149-
mlirDialectHandleRegisterDialect(ccHandle, context);
150-
if (load) {
151-
mlirDialectHandleLoadDialect(ccHandle, context);
152-
}
153-
},
154-
py::arg("context") = py::none(), py::arg("load") = true);
155-
156123
mlir_type_subclass(ccMod, "CharspanType", [](MlirType type) {
157124
return unwrap(type).isa<cudaq::cc::CharspanType>();
158125
}).def_classmethod("get", [](py::object cls, MlirContext ctx) {
@@ -292,10 +259,9 @@ void registerCCDialectAndTypes(py::module &m) {
292259
});
293260
}
294261

295-
void bindRegisterDialects(py::module &mod) {
296-
registerQuakeDialectAndTypes(mod);
297-
registerCCDialectAndTypes(mod);
262+
static bool registered = false;
298263

264+
void cudaq::bindRegisterDialects(py::module &mod) {
299265
mod.def("load_intrinsic", [](MlirModule module, std::string name) {
300266
auto unwrapped = unwrap(module);
301267
cudaq::IRBuilder builder = IRBuilder::atBlockEnd(unwrapped.getBody());
@@ -305,14 +271,22 @@ void bindRegisterDialects(py::module &mod) {
305271

306272
mod.def("register_all_dialects", [](MlirContext context) {
307273
DialectRegistry registry;
308-
registry.insert<quake::QuakeDialect, cudaq::cc::CCDialect>();
274+
cudaq::registerAllDialects(registry);
309275
cudaq::opt::registerCodeGenDialect(registry);
310-
registerAllDialects(registry);
311-
auto *mlirContext = unwrap(context);
276+
MLIRContext *mlirContext = unwrap(context);
312277
mlirContext->appendDialectRegistry(registry);
313278
mlirContext->loadAllAvailableDialects();
314279
});
315280

281+
// Register type and passes once, when the module is loaded.
282+
registerQuakeTypes(mod);
283+
registerCCTypes(mod);
284+
285+
if (!registered) {
286+
cudaq::registerAllPasses();
287+
registered = true;
288+
}
289+
316290
mod.def("gen_vector_of_complex_constant", [](MlirLocation loc,
317291
MlirModule module,
318292
std::string name,
@@ -324,4 +298,3 @@ void bindRegisterDialects(py::module &mod) {
324298
builder.genVectorOfConstants(unwrap(loc), modOp, name, newValues);
325299
});
326300
}
327-
} // namespace cudaq

python/tests/mlir/bare.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
# RUN: PYTHONPATH=../../ python3 %s | FileCheck %s
1010

11+
from cudaq.mlir import register_all_dialects
1112
from cudaq.mlir.ir import *
1213
from cudaq.mlir.dialects import quake
1314
from cudaq.mlir.dialects import builtin, func, arith
1415

1516
with Context() as ctx:
16-
quake.register_dialect()
17+
register_all_dialects(ctx)
1718
m = Module.create(loc=Location.unknown())
1819
with InsertionPoint(m.body), Location.unknown():
1920
f = func.FuncOp('main', ([], []))

0 commit comments

Comments
 (0)