1
1
from typing_extensions import Optional
2
2
import hugr .model as model
3
3
from hugr .hugr .base import Hugr , Node
4
- from hugr .ops import DFG , Input , Output , Custom , ExtOp , Conditional , TailLoop , FuncDefn , FuncDecl , Call , CallIndirect , LoadFunc , AliasDecl , AliasDefn , LoadConst , Const , CFG , ExitBlock , DataflowBlock
4
+ from hugr .ops import DFG , Input , Output , Custom , AsExtOp , Conditional , TailLoop , FuncDefn , FuncDecl , Call , CallIndirect , LoadFunc , AliasDecl , AliasDefn , LoadConst , Const , CFG , ExitBlock , DataflowBlock , Tag
5
5
from hugr .hugr .node_port import InPort , OutPort
6
6
from hugr .tys import FunctionKind , ConstKind , TypeParam , Type , TypeTypeParam , TypeBound
7
7
from typing import cast , Sequence
8
8
9
- class Export :
9
+ class ModelExport :
10
10
def __init__ (self , hugr : Hugr ):
11
11
self .hugr = hugr
12
12
self .link_ports : UnionFind [InPort | OutPort ] = UnionFind ()
@@ -19,11 +19,11 @@ def link_name(self, port):
19
19
root = self .link_ports [port ]
20
20
21
21
if root in self .link_names :
22
+ return self .link_names [root ]
23
+ else :
22
24
index = str (len (self .link_names ))
23
25
self .link_names [root ] = index
24
26
return index
25
- else :
26
- return self .link_names [root ]
27
27
28
28
def export_node (self , node : Node ) -> Optional [model .Node ]:
29
29
node_data = self .hugr [node ]
@@ -55,9 +55,9 @@ def export_node(self, node: Node) -> Optional[model.Node]:
55
55
outputs = outputs
56
56
)
57
57
58
- case ExtOp () as op :
58
+ case AsExtOp () as op :
59
59
name = op .op_def ().qualified_name ()
60
- args = cast (list [model .Term ], [arg .to_model () for arg in op .args ])
60
+ args = cast (list [model .Term ], [arg .to_model () for arg in op .type_args () ])
61
61
signature = op .outer_signature ().to_model ()
62
62
63
63
return model .Node (
@@ -241,6 +241,62 @@ def export_node(self, node: Node) -> Optional[model.Node]:
241
241
regions = [region ]
242
242
)
243
243
244
+ case DataflowBlock () as op :
245
+ region = self .export_region_dfg (node )
246
+
247
+ input_types = model .List ([
248
+ model .Apply ("core.ctrl" , [
249
+ model .List ([ type .to_model () for type in op .inputs ])
250
+ ])
251
+ ])
252
+
253
+ other_output_types = [ type .to_model () for type in op .other_outputs ]
254
+ output_types = model .List ([
255
+ model .Apply ("core.ctrl" , [model .List ([
256
+ * [type .to_model () for type in row ],
257
+ * other_output_types
258
+ ])])
259
+ for row in op .sum_ty .variant_rows
260
+ ])
261
+
262
+ signature = model .Apply ("core.fn" , [input_types , output_types , model .ExtSet ()])
263
+
264
+ return model .Node (
265
+ operation = model .Block (),
266
+ inputs = inputs ,
267
+ outputs = outputs ,
268
+ regions = [region ],
269
+ signature = signature
270
+ )
271
+
272
+ case Tag () as op :
273
+ variants = model .List ([
274
+ model .List ([type .to_model () for type in row ])
275
+ for row in op .sum_ty .variant_rows
276
+ ])
277
+
278
+ types = model .List ([
279
+ type .to_model ()
280
+ for type in op .sum_ty .variant_rows [op .tag ]
281
+ ])
282
+
283
+ tag = model .Literal (op .tag )
284
+ signature = op .outer_signature ().to_model ()
285
+
286
+ return model .Node (
287
+ operation = model .CustomOp (model .Apply ("core.make_adt" , [
288
+ variants ,
289
+ types ,
290
+ tag
291
+ ])),
292
+ inputs = inputs ,
293
+ outputs = outputs ,
294
+ signature = signature
295
+ )
296
+
297
+ case op :
298
+ raise ValueError (f"Unknown operation: { op } " )
299
+
244
300
def export_region_module (self , node : Node ) -> model .Region :
245
301
node_data = self .hugr [node ]
246
302
children = []
@@ -368,13 +424,13 @@ def find_func_input(self, node: Node) -> Optional[str]:
368
424
func_node = next (
369
425
out_port .node
370
426
for (in_port , out_ports ) in self .hugr .incoming_links (node )
371
- if isinstance (self .hugr .port_kind , FunctionKind )
427
+ if isinstance (self .hugr .port_kind ( in_port ) , FunctionKind )
372
428
for out_port in out_ports
373
429
)
374
430
except StopIteration :
375
431
return None
376
432
377
- match self .hugr [func_node ]:
433
+ match self .hugr [func_node ]. op :
378
434
case FuncDecl () as func_op :
379
435
return func_op .f_name
380
436
case FuncDefn () as func_op :
@@ -387,16 +443,16 @@ def find_const_input(self, node: Node) -> Optional[model.Term]:
387
443
const_node = next (
388
444
out_port .node
389
445
for (in_port , out_ports ) in self .hugr .incoming_links (node )
390
- if isinstance (self .hugr .port_kind , ConstKind )
446
+ if isinstance (self .hugr .port_kind ( in_port ) , ConstKind )
391
447
for out_port in out_ports
392
448
)
393
449
except StopIteration :
394
450
return None
395
451
396
- match self .hugr [const_node ]:
452
+ match self .hugr [const_node ]. op :
397
453
case Const () as op :
398
454
return op .val .to_model ()
399
- case _ :
455
+ case op :
400
456
return None
401
457
402
458
class UnionFind [T ]:
0 commit comments