6
6
* the terms of the Apache License 2.0 which accompanies this distribution. *
7
7
******************************************************************************/
8
8
9
- #include " mlir/Bindings/Python/PybindAdaptors.h"
10
-
9
+ #include " py_register_dialects.h"
11
10
#include " cudaq/Optimizer/Builder/Intrinsics.h"
12
- #include " cudaq/Optimizer/CAPI/Dialects.h"
13
11
#include " cudaq/Optimizer/CodeGen/Passes.h"
14
- #include " cudaq/Optimizer/CodeGen/Pipelines.h"
15
- #include " cudaq/Optimizer/Dialect/CC/CCDialect.h"
16
- #include " cudaq/Optimizer/Dialect/CC/CCOps.h"
17
12
#include " cudaq/Optimizer/Dialect/CC/CCTypes.h"
18
- #include " cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
19
13
#include " cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
14
+ #include " cudaq/Optimizer/InitAllDialects.h"
20
15
#include " cudaq/Optimizer/Transforms/Passes.h"
21
- #include " mlir/InitAllDialects.h"
16
+ #include " mlir/Bindings/Python/PybindAdaptors.h"
17
+ #include " mlir/CAPI/IR.h"
18
+ #include " mlir/Transforms/Passes.h"
22
19
#include < fmt/core.h>
23
20
#include < pybind11/complex.h>
24
21
#include < pybind11/stl.h>
@@ -27,33 +24,9 @@ namespace py = pybind11;
27
24
using namespace mlir ::python::adaptors;
28
25
using namespace mlir ;
29
26
30
- namespace cudaq {
31
- static bool registered = false ;
32
-
33
- void registerQuakeDialectAndTypes (py::module &m) {
27
+ static void registerQuakeTypes (py::module &m) {
34
28
auto quakeMod = m.def_submodule (" quake" );
35
29
36
- quakeMod.def (
37
- " register_dialect" ,
38
- [](MlirContext context, bool load) {
39
- MlirDialectHandle handle = mlirGetDialectHandle__quake__ ();
40
- mlirDialectHandleRegisterDialect (handle, context);
41
- if (load) {
42
- mlirDialectHandleLoadDialect (handle, context);
43
- }
44
-
45
- if (!registered) {
46
- cudaq::opt::registerOptCodeGenPasses ();
47
- cudaq::opt::registerOptTransformsPasses ();
48
- cudaq::opt::registerAggressiveEarlyInlining ();
49
- cudaq::opt::registerUnrollingPipeline ();
50
- cudaq::opt::registerTargetPipelines ();
51
- cudaq::opt::registerMappingPipeline ();
52
- registered = true ;
53
- }
54
- },
55
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
56
-
57
30
mlir_type_subclass (quakeMod, " RefType" , [](MlirType type) {
58
31
return unwrap (type).isa <quake::RefType>();
59
32
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -144,21 +117,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
144
117
});
145
118
}
146
119
147
- void registerCCDialectAndTypes (py::module &m) {
120
+ static void registerCCTypes (py::module &m) {
148
121
149
122
auto ccMod = m.def_submodule (" cc" );
150
123
151
- ccMod.def (
152
- " register_dialect" ,
153
- [](MlirContext context, bool load) {
154
- MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__ ();
155
- mlirDialectHandleRegisterDialect (ccHandle, context);
156
- if (load) {
157
- mlirDialectHandleLoadDialect (ccHandle, context);
158
- }
159
- },
160
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
161
-
162
124
mlir_type_subclass (ccMod, " CharspanType" , [](MlirType type) {
163
125
return unwrap (type).isa <cudaq::cc::CharspanType>();
164
126
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -298,10 +260,9 @@ void registerCCDialectAndTypes(py::module &m) {
298
260
});
299
261
}
300
262
301
- void bindRegisterDialects (py::module &mod) {
302
- registerQuakeDialectAndTypes (mod);
303
- registerCCDialectAndTypes (mod);
263
+ static bool registered = false ;
304
264
265
+ void cudaq::bindRegisterDialects (py::module &mod) {
305
266
mod.def (" load_intrinsic" , [](MlirModule module, std::string name) {
306
267
auto unwrapped = unwrap (module);
307
268
cudaq::IRBuilder builder = IRBuilder::atBlockEnd (unwrapped.getBody ());
@@ -311,14 +272,28 @@ void bindRegisterDialects(py::module &mod) {
311
272
312
273
mod.def (" register_all_dialects" , [](MlirContext context) {
313
274
DialectRegistry registry;
314
- registry. insert <quake::QuakeDialect, cudaq::cc::CCDialect>( );
275
+ cudaq::registerAllDialects (registry );
315
276
cudaq::opt::registerCodeGenDialect (registry);
316
- registerAllDialects (registry);
317
- auto *mlirContext = unwrap (context);
277
+ MLIRContext *mlirContext = unwrap (context);
318
278
mlirContext->appendDialectRegistry (registry);
319
279
mlirContext->loadAllAvailableDialects ();
320
280
});
321
281
282
+ // Register type and passes once, when the module is loaded.
283
+ registerQuakeTypes (mod);
284
+ registerCCTypes (mod);
285
+
286
+ if (!registered) {
287
+ mlir::registerTransformsPasses ();
288
+ cudaq::opt::registerOptCodeGenPasses ();
289
+ cudaq::opt::registerOptTransformsPasses ();
290
+ cudaq::opt::registerAggressiveEarlyInlining ();
291
+ cudaq::opt::registerUnrollingPipeline ();
292
+ cudaq::opt::registerTargetPipelines ();
293
+ cudaq::opt::registerMappingPipeline ();
294
+ registered = true ;
295
+ }
296
+
322
297
mod.def (" gen_vector_of_complex_constant" , [](MlirLocation loc,
323
298
MlirModule module,
324
299
std::string name,
@@ -330,4 +305,3 @@ void bindRegisterDialects(py::module &mod) {
330
305
builder.genVectorOfConstants (unwrap (loc), modOp, name, newValues);
331
306
});
332
307
}
333
- } // namespace cudaq
0 commit comments