6
6
* the terms of the Apache License 2.0 which accompanies this distribution. *
7
7
******************************************************************************/
8
8
9
- #include " mlir/Bindings/Python/PybindAdaptors.h"
10
-
11
9
#include " cudaq/Optimizer/Builder/Intrinsics.h"
12
- #include " cudaq/Optimizer/CAPI/Dialects.h"
13
- #include " cudaq/Optimizer/CodeGen/Passes.h"
14
- #include " cudaq/Optimizer/CodeGen/Pipelines.h"
15
- #include " cudaq/Optimizer/Dialect/CC/CCDialect.h"
16
- #include " cudaq/Optimizer/Dialect/CC/CCOps.h"
17
10
#include " cudaq/Optimizer/Dialect/CC/CCTypes.h"
18
- #include " cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
19
11
#include " cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
20
- #include " cudaq/Optimizer/Transforms/Passes.h"
21
- #include " mlir/InitAllDialects.h"
12
+ #include " cudaq/Optimizer/InitAllDialects.h"
13
+ #include " cudaq/Optimizer/InitAllPasses.h"
14
+ #include " mlir/Bindings/Python/PybindAdaptors.h"
15
+ #include " mlir/CAPI/IR.h"
22
16
#include < fmt/core.h>
23
17
#include < pybind11/complex.h>
24
18
#include < pybind11/stl.h>
@@ -28,32 +22,10 @@ using namespace mlir::python::adaptors;
28
22
using namespace mlir ;
29
23
30
24
namespace cudaq {
31
- static bool registered = false ;
32
25
33
- void registerQuakeDialectAndTypes (py::module &m) {
26
+ void registerQuakeTypes (py::module &m) {
34
27
auto quakeMod = m.def_submodule (" quake" );
35
28
36
- quakeMod.def (
37
- " register_dialect" ,
38
- [](MlirContext context, bool load) {
39
- MlirDialectHandle handle = mlirGetDialectHandle__quake__ ();
40
- mlirDialectHandleRegisterDialect (handle, context);
41
- if (load) {
42
- mlirDialectHandleLoadDialect (handle, context);
43
- }
44
-
45
- if (!registered) {
46
- cudaq::opt::registerOptCodeGenPasses ();
47
- cudaq::opt::registerOptTransformsPasses ();
48
- cudaq::opt::registerAggressiveEarlyInlining ();
49
- cudaq::opt::registerUnrollingPipeline ();
50
- cudaq::opt::registerTargetPipelines ();
51
- cudaq::opt::registerMappingPipeline ();
52
- registered = true ;
53
- }
54
- },
55
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
56
-
57
29
mlir_type_subclass (quakeMod, " RefType" , [](MlirType type) {
58
30
return unwrap (type).isa <quake::RefType>();
59
31
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -143,21 +115,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
143
115
});
144
116
}
145
117
146
- void registerCCDialectAndTypes (py::module &m) {
118
+ void registerCCTypes (py::module &m) {
147
119
148
120
auto ccMod = m.def_submodule (" cc" );
149
121
150
- ccMod.def (
151
- " register_dialect" ,
152
- [](MlirContext context, bool load) {
153
- MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__ ();
154
- mlirDialectHandleRegisterDialect (ccHandle, context);
155
- if (load) {
156
- mlirDialectHandleLoadDialect (ccHandle, context);
157
- }
158
- },
159
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
160
-
161
122
mlir_type_subclass (ccMod, " CharspanType" , [](MlirType type) {
162
123
return unwrap (type).isa <cudaq::cc::CharspanType>();
163
124
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -298,9 +259,6 @@ void registerCCDialectAndTypes(py::module &m) {
298
259
}
299
260
300
261
void bindRegisterDialects (py::module &mod) {
301
- registerQuakeDialectAndTypes (mod);
302
- registerCCDialectAndTypes (mod);
303
-
304
262
mod.def (" load_intrinsic" , [](MlirModule module, std::string name) {
305
263
auto unwrapped = unwrap (module);
306
264
cudaq::IRBuilder builder = IRBuilder::atBlockEnd (unwrapped.getBody ());
@@ -310,14 +268,17 @@ void bindRegisterDialects(py::module &mod) {
310
268
311
269
mod.def (" register_all_dialects" , [](MlirContext context) {
312
270
DialectRegistry registry;
313
- registry.insert <quake::QuakeDialect, cudaq::cc::CCDialect>();
314
- cudaq::opt::registerCodeGenDialect (registry);
315
- registerAllDialects (registry);
316
- auto *mlirContext = unwrap (context);
271
+ cudaq::registerAllDialects (registry);
272
+ MLIRContext *mlirContext = unwrap (context);
317
273
mlirContext->appendDialectRegistry (registry);
318
274
mlirContext->loadAllAvailableDialects ();
319
275
});
320
276
277
+ // Register type as passes once, when the module is loaded.
278
+ registerQuakeTypes (mod);
279
+ registerCCTypes (mod);
280
+ cudaq::registerAllPasses ();
281
+
321
282
mod.def (" gen_vector_of_complex_constant" , [](MlirLocation loc,
322
283
MlirModule module,
323
284
std::string name,
0 commit comments