6
6
* the terms of the Apache License 2.0 which accompanies this distribution. *
7
7
******************************************************************************/
8
8
9
- #include " mlir/Bindings/Python/PybindAdaptors.h"
10
-
9
+ #include " py_register_dialects.h"
11
10
#include " cudaq/Optimizer/Builder/Intrinsics.h"
12
- #include " cudaq/Optimizer/CAPI/Dialects.h"
13
11
#include " cudaq/Optimizer/CodeGen/Passes.h"
14
- #include " cudaq/Optimizer/CodeGen/Pipelines.h"
15
- #include " cudaq/Optimizer/Dialect/CC/CCDialect.h"
16
- #include " cudaq/Optimizer/Dialect/CC/CCOps.h"
17
12
#include " cudaq/Optimizer/Dialect/CC/CCTypes.h"
18
- #include " cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
19
13
#include " cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
14
+ #include " cudaq/Optimizer/InitAllDialects.h"
20
15
#include " cudaq/Optimizer/Transforms/Passes.h"
21
- #include " mlir/InitAllDialects.h"
16
+ #include " mlir/Bindings/Python/PybindAdaptors.h"
17
+ #include " mlir/CAPI/IR.h"
18
+ #include " mlir/Transforms/Passes.h"
22
19
#include < fmt/core.h>
23
20
#include < pybind11/complex.h>
24
21
#include < pybind11/stl.h>
@@ -27,33 +24,9 @@ namespace py = pybind11;
27
24
using namespace mlir ::python::adaptors;
28
25
using namespace mlir ;
29
26
30
- namespace cudaq {
31
- static bool registered = false ;
32
-
33
- void registerQuakeDialectAndTypes (py::module &m) {
27
+ static void registerQuakeTypes (py::module &m) {
34
28
auto quakeMod = m.def_submodule (" quake" );
35
29
36
- quakeMod.def (
37
- " register_dialect" ,
38
- [](MlirContext context, bool load) {
39
- MlirDialectHandle handle = mlirGetDialectHandle__quake__ ();
40
- mlirDialectHandleRegisterDialect (handle, context);
41
- if (load) {
42
- mlirDialectHandleLoadDialect (handle, context);
43
- }
44
-
45
- if (!registered) {
46
- cudaq::opt::registerOptCodeGenPasses ();
47
- cudaq::opt::registerOptTransformsPasses ();
48
- cudaq::opt::registerAggressiveEarlyInlining ();
49
- cudaq::opt::registerUnrollingPipeline ();
50
- cudaq::opt::registerTargetPipelines ();
51
- cudaq::opt::registerMappingPipeline ();
52
- registered = true ;
53
- }
54
- },
55
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
56
-
57
30
mlir_type_subclass (quakeMod, " RefType" , [](MlirType type) {
58
31
return unwrap (type).isa <quake::RefType>();
59
32
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -144,21 +117,10 @@ void registerQuakeDialectAndTypes(py::module &m) {
144
117
});
145
118
}
146
119
147
- void registerCCDialectAndTypes (py::module &m) {
120
+ static void registerCCTypes (py::module &m) {
148
121
149
122
auto ccMod = m.def_submodule (" cc" );
150
123
151
- ccMod.def (
152
- " register_dialect" ,
153
- [](MlirContext context, bool load) {
154
- MlirDialectHandle ccHandle = mlirGetDialectHandle__cc__ ();
155
- mlirDialectHandleRegisterDialect (ccHandle, context);
156
- if (load) {
157
- mlirDialectHandleLoadDialect (ccHandle, context);
158
- }
159
- },
160
- py::arg (" context" ) = py::none (), py::arg (" load" ) = true );
161
-
162
124
mlir_type_subclass (ccMod, " CharspanType" , [](MlirType type) {
163
125
return unwrap (type).isa <cudaq::cc::CharspanType>();
164
126
}).def_classmethod (" get" , [](py::object cls, MlirContext ctx) {
@@ -298,10 +260,7 @@ void registerCCDialectAndTypes(py::module &m) {
298
260
});
299
261
}
300
262
301
- void bindRegisterDialects (py::module &mod) {
302
- registerQuakeDialectAndTypes (mod);
303
- registerCCDialectAndTypes (mod);
304
-
263
+ void cudaq::bindRegisterDialects (py::module &mod) {
305
264
mod.def (" load_intrinsic" , [](MlirModule module, std::string name) {
306
265
auto unwrapped = unwrap (module);
307
266
cudaq::IRBuilder builder = IRBuilder::atBlockEnd (unwrapped.getBody ());
@@ -311,14 +270,27 @@ void bindRegisterDialects(py::module &mod) {
311
270
312
271
mod.def (" register_all_dialects" , [](MlirContext context) {
313
272
DialectRegistry registry;
314
- registry. insert <quake::QuakeDialect, cudaq::cc::CCDialect>( );
273
+ cudaq::registerAllDialects (registry );
315
274
cudaq::opt::registerCodeGenDialect (registry);
316
- registerAllDialects (registry);
317
- auto *mlirContext = unwrap (context);
275
+ MLIRContext *mlirContext = unwrap (context);
318
276
mlirContext->appendDialectRegistry (registry);
319
277
mlirContext->loadAllAvailableDialects ();
320
278
});
321
279
280
+ // Register type and passes once, when the module is loaded.
281
+ registerQuakeTypes (mod);
282
+ registerCCTypes (mod);
283
+
284
+ mlir::registerTransformsPasses ();
285
+ cudaq::opt::registerOptCodeGenPasses ();
286
+ cudaq::opt::registerOptTransformsPasses ();
287
+ cudaq::opt::registerAggressiveEarlyInlining ();
288
+ cudaq::opt::registerUnrollingPipeline ();
289
+ cudaq::opt::registerTargetPipelines ();
290
+ cudaq::opt::registerMappingPipeline ();
291
+ cudaq::opt::registerWireSetToProfileQIRPipeline ();
292
+ // cudaq::opt::registerToExecutionManagerCCPipeline();
293
+
322
294
mod.def (" gen_vector_of_complex_constant" , [](MlirLocation loc,
323
295
MlirModule module,
324
296
std::string name,
@@ -330,4 +302,3 @@ void bindRegisterDialects(py::module &mod) {
330
302
builder.genVectorOfConstants (unwrap (loc), modOp, name, newValues);
331
303
});
332
304
}
333
- } // namespace cudaq
0 commit comments