6
6
* the terms of the Apache License 2.0 which accompanies this distribution. *
7
7
******************************************************************************/
8
8
9
+ #include " py_register_dialects.h"
9
10
#include " cudaq/Optimizer/Builder/Intrinsics.h"
10
- #include " cudaq/Optimizer/CAPI/Dialects.h"
11
11
#include " cudaq/Optimizer/CodeGen/Passes.h"
12
- #include " cudaq/Optimizer/CodeGen/Pipelines.h"
13
- #include " cudaq/Optimizer/Dialect/CC/CCDialect.h"
14
- #include " cudaq/Optimizer/Dialect/CC/CCOps.h"
15
12
#include " cudaq/Optimizer/Dialect/CC/CCTypes.h"
16
- #include " cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
17
13
#include " cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
14
+ #include " cudaq/Optimizer/InitAllDialects.h"
18
15
#include " cudaq/Optimizer/InitAllPasses.h"
19
- #include " cudaq/Optimizer/Transforms/Passes.h"
20
16
#include " mlir/Bindings/Python/PybindAdaptors.h"
21
- #include " mlir/InitAllDialects .h"
17
+ #include " mlir/CAPI/IR .h"
22
18
#include < fmt/core.h>
23
19
#include < pybind11/complex.h>
24
20
#include < pybind11/stl.h>
@@ -27,27 +23,9 @@ namespace py = pybind11;
27
23
using namespace mlir ::python::adaptors;
28
24
using namespace mlir ;
29
25
30
- namespace cudaq {
31
- static bool registered = false ;
32
-
33
- void registerQuakeDialectAndTypes (py::module &m) {
26
+ static void registerQuakeTypes (py::module &m) {
34
27
auto quakeMod = m.def_submodule (" quake" );
35
28
36
- quakeMod.def (
37
- " register_dialect" ,
38
- [](MlirContext context, bool load) {
39
- MlirDialectHandle handle = mlirGetDialectHandle__quake__ ();
40
- mlirDialectHandleRegisterDialect (handle, context);
41
- if (load)
42
- mlirDialectHandleLoadDialect (handle, context);
43
-
44
- if (!registered) {
45
- cudaq::registerCudaqPassesAndPipelines ();
46
- registered = true ;
47
- }
48
- },
49
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
50
-
51
29
mlir_type_subclass (quakeMod, " RefType" , [](MlirType type) {
52
30
return unwrap (type).isa <quake::RefType>();
53
31
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -138,21 +116,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
138
116
});
139
117
}
140
118
141
- void registerCCDialectAndTypes (py::module &m) {
119
+ static void registerCCTypes (py::module &m) {
142
120
143
121
auto ccMod = m.def_submodule (" cc" );
144
122
145
- ccMod.def (
146
- " register_dialect" ,
147
- [](MlirContext context, bool load) {
148
- MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__ ();
149
- mlirDialectHandleRegisterDialect (ccHandle, context);
150
- if (load) {
151
- mlirDialectHandleLoadDialect (ccHandle, context);
152
- }
153
- },
154
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
155
-
156
123
mlir_type_subclass (ccMod, " CharspanType" , [](MlirType type) {
157
124
return unwrap (type).isa <cudaq::cc::CharspanType>();
158
125
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -292,10 +259,9 @@ void registerCCDialectAndTypes(py::module &m) {
292
259
});
293
260
}
294
261
295
- void bindRegisterDialects (py::module &mod) {
296
- registerQuakeDialectAndTypes (mod);
297
- registerCCDialectAndTypes (mod);
262
+ static bool registered = false ;
298
263
264
+ void cudaq::bindRegisterDialects (py::module &mod) {
299
265
mod.def (" load_intrinsic" , [](MlirModule module, std::string name) {
300
266
auto unwrapped = unwrap (module);
301
267
cudaq::IRBuilder builder = IRBuilder::atBlockEnd (unwrapped.getBody ());
@@ -305,14 +271,22 @@ void bindRegisterDialects(py::module &mod) {
305
271
306
272
mod.def (" register_all_dialects" , [](MlirContext context) {
307
273
DialectRegistry registry;
308
- registry. insert <quake::QuakeDialect, cudaq::cc::CCDialect>( );
274
+ cudaq::registerAllDialects (registry );
309
275
cudaq::opt::registerCodeGenDialect (registry);
310
- registerAllDialects (registry);
311
- auto *mlirContext = unwrap (context);
276
+ MLIRContext *mlirContext = unwrap (context);
312
277
mlirContext->appendDialectRegistry (registry);
313
278
mlirContext->loadAllAvailableDialects ();
314
279
});
315
280
281
+ // Register type and passes once, when the module is loaded.
282
+ registerQuakeTypes (mod);
283
+ registerCCTypes (mod);
284
+
285
+ if (!registered) {
286
+ cudaq::registerAllPasses ();
287
+ registered = true ;
288
+ }
289
+
316
290
mod.def (" gen_vector_of_complex_constant" , [](MlirLocation loc,
317
291
MlirModule module,
318
292
std::string name,
@@ -324,4 +298,3 @@ void bindRegisterDialects(py::module &mod) {
324
298
builder.genVectorOfConstants (unwrap (loc), modOp, name, newValues);
325
299
});
326
300
}
327
- } // namespace cudaq
0 commit comments