6
6
* the terms of the Apache License 2.0 which accompanies this distribution. *
7
7
******************************************************************************/
8
8
9
- #include " mlir/Bindings/Python/PybindAdaptors.h"
10
-
9
+ #include " py_register_dialects.h"
11
10
#include " cudaq/Optimizer/Builder/Intrinsics.h"
12
- #include " cudaq/Optimizer/CAPI/Dialects.h"
13
11
#include " cudaq/Optimizer/CodeGen/Passes.h"
14
- #include " cudaq/Optimizer/CodeGen/Pipelines.h"
15
- #include " cudaq/Optimizer/Dialect/CC/CCDialect.h"
16
- #include " cudaq/Optimizer/Dialect/CC/CCOps.h"
17
12
#include " cudaq/Optimizer/Dialect/CC/CCTypes.h"
18
- #include " cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
19
13
#include " cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
14
+ #include " cudaq/Optimizer/InitAllDialects.h"
20
15
#include " cudaq/Optimizer/Transforms/Passes.h"
21
- #include " mlir/InitAllDialects.h"
16
+ #include " mlir/Bindings/Python/PybindAdaptors.h"
17
+ #include " mlir/CAPI/IR.h"
22
18
#include < fmt/core.h>
23
19
#include < pybind11/complex.h>
24
20
#include < pybind11/stl.h>
@@ -27,33 +23,9 @@ namespace py = pybind11;
27
23
using namespace mlir ::python::adaptors;
28
24
using namespace mlir ;
29
25
30
- namespace cudaq {
31
- static bool registered = false ;
32
-
33
- void registerQuakeDialectAndTypes (py::module &m) {
26
+ static void registerQuakeTypes (py::module &m) {
34
27
auto quakeMod = m.def_submodule (" quake" );
35
28
36
- quakeMod.def (
37
- " register_dialect" ,
38
- [](MlirContext context, bool load) {
39
- MlirDialectHandle handle = mlirGetDialectHandle__quake__ ();
40
- mlirDialectHandleRegisterDialect (handle, context);
41
- if (load) {
42
- mlirDialectHandleLoadDialect (handle, context);
43
- }
44
-
45
- if (!registered) {
46
- cudaq::opt::registerOptCodeGenPasses ();
47
- cudaq::opt::registerOptTransformsPasses ();
48
- cudaq::opt::registerAggressiveEarlyInlining ();
49
- cudaq::opt::registerUnrollingPipeline ();
50
- cudaq::opt::registerTargetPipelines ();
51
- cudaq::opt::registerMappingPipeline ();
52
- registered = true ;
53
- }
54
- },
55
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
56
-
57
29
mlir_type_subclass (quakeMod, " RefType" , [](MlirType type) {
58
30
return unwrap (type).isa <quake::RefType>();
59
31
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -144,21 +116,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
144
116
});
145
117
}
146
118
147
- void registerCCDialectAndTypes (py::module &m) {
119
+ static void registerCCTypes (py::module &m) {
148
120
149
121
auto ccMod = m.def_submodule (" cc" );
150
122
151
- ccMod.def (
152
- " register_dialect" ,
153
- [](MlirContext context, bool load) {
154
- MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__ ();
155
- mlirDialectHandleRegisterDialect (ccHandle, context);
156
- if (load) {
157
- mlirDialectHandleLoadDialect (ccHandle, context);
158
- }
159
- },
160
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
161
-
162
123
mlir_type_subclass (ccMod, " CharspanType" , [](MlirType type) {
163
124
return unwrap (type).isa <cudaq::cc::CharspanType>();
164
125
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -298,10 +259,7 @@ void registerCCDialectAndTypes(py::module &m) {
298
259
});
299
260
}
300
261
301
- void bindRegisterDialects (py::module &mod) {
302
- registerQuakeDialectAndTypes (mod);
303
- registerCCDialectAndTypes (mod);
304
-
262
+ void cudaq::bindRegisterDialects (py::module &mod) {
305
263
mod.def (" load_intrinsic" , [](MlirModule module, std::string name) {
306
264
auto unwrapped = unwrap (module);
307
265
cudaq::IRBuilder builder = IRBuilder::atBlockEnd (unwrapped.getBody ());
@@ -311,14 +269,24 @@ void bindRegisterDialects(py::module &mod) {
311
269
312
270
mod.def (" register_all_dialects" , [](MlirContext context) {
313
271
DialectRegistry registry;
314
- registry. insert <quake::QuakeDialect, cudaq::cc::CCDialect>( );
272
+ cudaq::registerAllDialects (registry );
315
273
cudaq::opt::registerCodeGenDialect (registry);
316
- registerAllDialects (registry);
317
- auto *mlirContext = unwrap (context);
274
+ MLIRContext *mlirContext = unwrap (context);
318
275
mlirContext->appendDialectRegistry (registry);
319
276
mlirContext->loadAllAvailableDialects ();
320
277
});
321
278
279
+ // Register type and passes once, when the module is loaded.
280
+ registerQuakeTypes (mod);
281
+ registerCCTypes (mod);
282
+
283
+ cudaq::opt::registerOptCodeGenPasses ();
284
+ cudaq::opt::registerOptTransformsPasses ();
285
+ cudaq::opt::registerAggressiveEarlyInlining ();
286
+ cudaq::opt::registerUnrollingPipeline ();
287
+ cudaq::opt::registerTargetPipelines ();
288
+ cudaq::opt::registerMappingPipeline ();
289
+
322
290
mod.def (" gen_vector_of_complex_constant" , [](MlirLocation loc,
323
291
MlirModule module,
324
292
std::string name,
@@ -330,4 +298,3 @@ void bindRegisterDialects(py::module &mod) {
330
298
builder.genVectorOfConstants (unwrap (loc), modOp, name, newValues);
331
299
});
332
300
}
333
- } // namespace cudaq
0 commit comments