4
4
from hugr .ops import DFG , Input , Output , Custom , AsExtOp , Conditional , TailLoop , FuncDefn , FuncDecl , Call , CallIndirect , LoadFunc , AliasDecl , AliasDefn , LoadConst , Const , CFG , ExitBlock , DataflowBlock , Tag
5
5
from hugr .hugr .node_port import InPort , OutPort
6
6
from hugr .tys import FunctionKind , ConstKind , TypeParam , Type , TypeTypeParam , TypeBound
7
- from typing import cast , Sequence
7
+ from typing import cast , Sequence , TypeVar , Generic
8
8
9
9
class ModelExport :
10
10
def __init__ (self , hugr : Hugr ):
@@ -127,15 +127,15 @@ def export_node(self, node: Node) -> Optional[model.Node]:
127
127
signature = model .Apply ("core.type" )
128
128
)
129
129
130
- value = cast (model .Term , op .definition .to_model ())
130
+ alias_value = cast (model .Term , op .definition .to_model ())
131
131
132
132
return model .Node (
133
- operation = model .DefineAlias (symbol , value )
133
+ operation = model .DefineAlias (symbol , alias_value )
134
134
)
135
135
136
136
case Call () as op :
137
- input_types = model . List ( [type .to_model () for type in op .instantiation .input ])
138
- output_types = model . List ( [type .to_model () for type in op .instantiation .output ])
137
+ input_types = [type .to_model () for type in op .instantiation .input ]
138
+ output_types = [type .to_model () for type in op .instantiation .output ]
139
139
signature = op .instantiation .to_model ()
140
140
func_args = cast (list [model .Term ], [type .to_model () for type in op .type_args ])
141
141
func_name = self .find_func_input (node )
@@ -146,15 +146,12 @@ def export_node(self, node: Node) -> Optional[model.Node]:
146
146
func = model .Apply (func_name , func_args )
147
147
148
148
return model .Node (
149
- operation = model .CustomOp (model .Apply (
150
- "core.call" ,
151
- [
152
- input_types ,
153
- output_types ,
154
- model .ExtSet (),
155
- func
156
- ]
157
- )),
149
+ operation = model .CustomOp (model .Apply ("core.call" , [
150
+ model .List (input_types ),
151
+ model .List (output_types ),
152
+ model .ExtSet (),
153
+ func
154
+ ])),
158
155
signature = signature ,
159
156
inputs = inputs ,
160
157
outputs = outputs ,
@@ -246,22 +243,26 @@ def export_node(self, node: Node) -> Optional[model.Node]:
246
243
case DataflowBlock () as op :
247
244
region = self .export_region_dfg (node )
248
245
249
- input_types = model . List ( [
246
+ input_types = [
250
247
model .Apply ("core.ctrl" , [
251
248
model .List ([ type .to_model () for type in op .inputs ])
252
249
])
253
- ])
250
+ ]
254
251
255
252
other_output_types = [ type .to_model () for type in op .other_outputs ]
256
- output_types = model . List ( [
253
+ output_types = [
257
254
model .Apply ("core.ctrl" , [model .List ([
258
255
* [type .to_model () for type in row ],
259
256
* other_output_types
260
257
])])
261
258
for row in op .sum_ty .variant_rows
262
- ])
259
+ ]
263
260
264
- signature = model .Apply ("core.fn" , [input_types , output_types , model .ExtSet ()])
261
+ signature = model .Apply ("core.fn" , [
262
+ model .List (input_types ),
263
+ model .List (output_types ),
264
+ model .ExtSet ()
265
+ ])
265
266
266
267
return model .Node (
267
268
operation = model .Block (),
@@ -317,8 +318,8 @@ def export_region_module(self, node: Node) -> model.Region:
317
318
def export_region_dfg (self , node : Node ) -> model .Region :
318
319
node_data = self .hugr [node ]
319
320
children : list [model .Node ] = []
320
- source_types = model .Wildcard ()
321
- target_types = model .Wildcard ()
321
+ source_types : model . Term = model .Wildcard ()
322
+ target_types : model . Term = model .Wildcard ()
322
323
sources = []
323
324
targets = []
324
325
@@ -355,8 +356,8 @@ def export_region_cfg(self, node: Node) -> model.Region:
355
356
356
357
source = None
357
358
targets = []
358
- source_types = model .Wildcard ()
359
- target_types = model .Wildcard ()
359
+ source_types : model . Term = model .Wildcard ()
360
+ target_types : model . Term = model .Wildcard ()
360
361
children = []
361
362
362
363
for child in node_data .children :
@@ -464,8 +465,10 @@ def _mangle_name(node: Node, name: str) -> str:
464
465
# by adding the node id.
465
466
return f"_{ name } _{ node .idx } "
466
467
467
- class UnionFind [T ]:
468
- def __init__ (self ):
468
+ T = TypeVar ('T' )
469
+
470
+ class UnionFind (Generic [T ]):
471
+ def __init__ (self ) -> None :
469
472
self .parents : dict [T , T ] = {}
470
473
self .sizes : dict [T , int ] = {}
471
474
0 commit comments