1
1
from typing_extensions import Optional
2
2
import hugr .model as model
3
3
from hugr .hugr .base import Hugr , Node
4
- from hugr .ops import DFG , Input , Output , Custom , ExtOp , Conditional , TailLoop , FuncDefn , FuncDecl , Call , CallIndirect , LoadFunc , AliasDecl , AliasDefn , LoadConst , Const , CFG , ExitBlock , DataflowBlock
4
+ from hugr .ops import DFG , Input , Output , Custom , AsExtOp , Conditional , TailLoop , FuncDefn , FuncDecl , Call , CallIndirect , LoadFunc , AliasDecl , AliasDefn , LoadConst , Const , CFG , ExitBlock , DataflowBlock , Tag
5
5
from hugr .hugr .node_port import InPort , OutPort
6
6
from hugr .tys import FunctionKind , ConstKind , TypeParam , Type , TypeTypeParam , TypeBound
7
7
from typing import cast , Sequence
8
8
9
- class Export :
9
+ class ModelExport :
10
10
def __init__ (self , hugr : Hugr ):
11
11
self .hugr = hugr
12
12
self .link_ports : UnionFind [InPort | OutPort ] = UnionFind ()
13
13
self .link_names : dict [InPort | OutPort , str ] = {}
14
14
15
15
for (a , b ) in self .hugr .links ():
16
+ print (f"Link { a } and { b } " )
16
17
self .link_ports .union (a , b )
17
18
18
19
def link_name (self , port ):
19
20
root = self .link_ports [port ]
21
+ # print(f"Port {port} -> {root}")
20
22
21
23
if root in self .link_names :
24
+ # print(f"Known name {self.link_names[root]} for {root}")
25
+ return self .link_names [root ]
26
+ else :
22
27
index = str (len (self .link_names ))
28
+ # print(f"New name {index} for {root}")
23
29
self .link_names [root ] = index
24
30
return index
25
- else :
26
- return self .link_names [root ]
27
31
28
32
def export_node (self , node : Node ) -> Optional [model .Node ]:
29
33
node_data = self .hugr [node ]
30
34
31
35
inputs = [self .link_name (InPort (node , i )) for i in range (node_data ._num_inps )]
32
36
outputs = [self .link_name (OutPort (node , i )) for i in range (node_data ._num_outs )]
33
37
38
+ print (f"node { node } { inputs } { outputs } " )
39
+ print (node_data .op )
40
+
34
41
match node_data .op :
35
42
case DFG () as op :
36
43
region = self .export_region_dfg (node )
@@ -55,9 +62,9 @@ def export_node(self, node: Node) -> Optional[model.Node]:
55
62
outputs = outputs
56
63
)
57
64
58
- case ExtOp () as op :
65
+ case AsExtOp () as op :
59
66
name = op .op_def ().qualified_name ()
60
- args = cast (list [model .Term ], [arg .to_model () for arg in op .args ])
67
+ args = cast (list [model .Term ], [arg .to_model () for arg in op .type_args () ])
61
68
signature = op .outer_signature ().to_model ()
62
69
63
70
return model .Node (
@@ -241,6 +248,62 @@ def export_node(self, node: Node) -> Optional[model.Node]:
241
248
regions = [region ]
242
249
)
243
250
251
+ case DataflowBlock () as op :
252
+ region = self .export_region_dfg (node )
253
+
254
+ input_types = model .List ([
255
+ model .Apply ("core.ctrl" , [
256
+ model .List ([ type .to_model () for type in op .inputs ])
257
+ ])
258
+ ])
259
+
260
+ other_output_types = [ type .to_model () for type in op .other_outputs ]
261
+ output_types = model .List ([
262
+ model .Apply ("core.ctrl" , [model .List ([
263
+ * [type .to_model () for type in row ],
264
+ * other_output_types
265
+ ])])
266
+ for row in op .sum_ty .variant_rows
267
+ ])
268
+
269
+ signature = model .Apply ("core.fn" , [input_types , output_types , model .ExtSet ()])
270
+
271
+ return model .Node (
272
+ operation = model .Block (),
273
+ inputs = inputs ,
274
+ outputs = outputs ,
275
+ regions = [region ],
276
+ signature = signature
277
+ )
278
+
279
+ case Tag () as op :
280
+ variants = model .List ([
281
+ model .List ([type .to_model () for type in row ])
282
+ for row in op .sum_ty .variant_rows
283
+ ])
284
+
285
+ types = model .List ([
286
+ type .to_model ()
287
+ for type in op .sum_ty .variant_rows [op .tag ]
288
+ ])
289
+
290
+ tag = model .Literal (op .tag )
291
+ signature = op .outer_signature ().to_model ()
292
+
293
+ return model .Node (
294
+ operation = model .CustomOp (model .Apply ("core.make_adt" , [
295
+ variants ,
296
+ types ,
297
+ tag
298
+ ])),
299
+ inputs = inputs ,
300
+ outputs = outputs ,
301
+ signature = signature
302
+ )
303
+
304
+ case op :
305
+ raise ValueError (f"Unknown operation: { op } " )
306
+
244
307
def export_region_module (self , node : Node ) -> model .Region :
245
308
node_data = self .hugr [node ]
246
309
children = []
@@ -368,13 +431,13 @@ def find_func_input(self, node: Node) -> Optional[str]:
368
431
func_node = next (
369
432
out_port .node
370
433
for (in_port , out_ports ) in self .hugr .incoming_links (node )
371
- if isinstance (self .hugr .port_kind , FunctionKind )
434
+ if isinstance (self .hugr .port_kind ( in_port ) , FunctionKind )
372
435
for out_port in out_ports
373
436
)
374
437
except StopIteration :
375
438
return None
376
439
377
- match self .hugr [func_node ]:
440
+ match self .hugr [func_node ]. op :
378
441
case FuncDecl () as func_op :
379
442
return func_op .f_name
380
443
case FuncDefn () as func_op :
@@ -387,16 +450,16 @@ def find_const_input(self, node: Node) -> Optional[model.Term]:
387
450
const_node = next (
388
451
out_port .node
389
452
for (in_port , out_ports ) in self .hugr .incoming_links (node )
390
- if isinstance (self .hugr .port_kind , ConstKind )
453
+ if isinstance (self .hugr .port_kind ( in_port ) , ConstKind )
391
454
for out_port in out_ports
392
455
)
393
456
except StopIteration :
394
457
return None
395
458
396
- match self .hugr [const_node ]:
459
+ match self .hugr [const_node ]. op :
397
460
case Const () as op :
398
461
return op .val .to_model ()
399
- case _ :
462
+ case op :
400
463
return None
401
464
402
465
class UnionFind [T ]:
0 commit comments