Skip to content

Commit cc293e9

Browse files
committed
Make mypy happy
1 parent efd53a9 commit cc293e9

File tree

3 files changed

+48
-43
lines changed

3 files changed

+48
-43
lines changed

hugr-py/src/hugr/model/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def from_str(s: str) -> "Param":
7676
@dataclass
7777
class Symbol:
7878
name: str
79-
params: list[Param] = field(default_factory = list)
80-
constraints: list[Term] = field(default_factory = list)
79+
params: Sequence[Param] = field(default_factory = list)
80+
constraints: Sequence[Term] = field(default_factory = list)
8181
signature: Term = field(default_factory = Wildcard)
8282

8383
def __str__(self):
@@ -150,10 +150,10 @@ class Import(Op):
150150
@dataclass
151151
class Node:
152152
operation: Op = field(default_factory=lambda: InvalidOp())
153-
inputs: list[str] = field(default_factory=list)
154-
outputs: list[str] = field(default_factory=list)
155-
regions: list["Region"] = field(default_factory=list)
156-
meta: list[Term] = field(default_factory=list)
153+
inputs: Sequence[str] = field(default_factory=list)
154+
outputs: Sequence[str] = field(default_factory=list)
155+
regions: Sequence["Region"] = field(default_factory=list)
156+
meta: Sequence[Term] = field(default_factory=list)
157157
signature: Optional[Term] = None
158158

159159
def __str__(self) -> str:
@@ -171,10 +171,10 @@ class ScopeClosure(Enum):
171171
@dataclass
172172
class Region:
173173
kind: RegionKind = RegionKind.DATA_FLOW
174-
sources: list[str] = field(default_factory=list)
175-
targets: list[str] = field(default_factory=list)
176-
children: list[Node] = field(default_factory=list)
177-
meta: list[Term] = field(default_factory=list)
174+
sources: Sequence[str] = field(default_factory=list)
175+
targets: Sequence[str] = field(default_factory=list)
176+
children: Sequence[Node] = field(default_factory=list)
177+
meta: Sequence[Term] = field(default_factory=list)
178178
signature: Optional[Term] = None
179179

180180
def __str__(self):

hugr-py/src/hugr/model/export.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from hugr.ops import DFG, Input, Output, Custom, AsExtOp, Conditional, TailLoop, FuncDefn, FuncDecl, Call, CallIndirect, LoadFunc, AliasDecl, AliasDefn, LoadConst, Const, CFG, ExitBlock, DataflowBlock, Tag
55
from hugr.hugr.node_port import InPort, OutPort
66
from hugr.tys import FunctionKind, ConstKind, TypeParam, Type, TypeTypeParam, TypeBound
7-
from typing import cast, Sequence
7+
from typing import cast, Sequence, TypeVar, Generic
88

99
class ModelExport:
1010
def __init__(self, hugr: Hugr):
@@ -127,15 +127,15 @@ def export_node(self, node: Node) -> Optional[model.Node]:
127127
signature = model.Apply("core.type")
128128
)
129129

130-
value = cast(model.Term, op.definition.to_model())
130+
alias_value = cast(model.Term, op.definition.to_model())
131131

132132
return model.Node(
133-
operation = model.DefineAlias(symbol, value)
133+
operation = model.DefineAlias(symbol, alias_value)
134134
)
135135

136136
case Call() as op:
137-
input_types = model.List([type.to_model() for type in op.instantiation.input ])
138-
output_types = model.List([type.to_model() for type in op.instantiation.output ])
137+
input_types = [type.to_model() for type in op.instantiation.input ]
138+
output_types = [type.to_model() for type in op.instantiation.output ]
139139
signature = op.instantiation.to_model()
140140
func_args = cast(list[model.Term], [type.to_model() for type in op.type_args])
141141
func_name = self.find_func_input(node)
@@ -146,15 +146,12 @@ def export_node(self, node: Node) -> Optional[model.Node]:
146146
func = model.Apply(func_name, func_args)
147147

148148
return model.Node(
149-
operation = model.CustomOp(model.Apply(
150-
"core.call",
151-
[
152-
input_types,
153-
output_types,
154-
model.ExtSet(),
155-
func
156-
]
157-
)),
149+
operation = model.CustomOp(model.Apply("core.call", [
150+
model.List(input_types),
151+
model.List(output_types),
152+
model.ExtSet(),
153+
func
154+
])),
158155
signature = signature,
159156
inputs = inputs,
160157
outputs = outputs,
@@ -246,22 +243,26 @@ def export_node(self, node: Node) -> Optional[model.Node]:
246243
case DataflowBlock() as op:
247244
region = self.export_region_dfg(node)
248245

249-
input_types = model.List([
246+
input_types = [
250247
model.Apply("core.ctrl", [
251248
model.List([ type.to_model() for type in op.inputs ])
252249
])
253-
])
250+
]
254251

255252
other_output_types = [ type.to_model() for type in op.other_outputs ]
256-
output_types = model.List([
253+
output_types = [
257254
model.Apply("core.ctrl", [model.List([
258255
*[type.to_model() for type in row],
259256
*other_output_types
260257
])])
261258
for row in op.sum_ty.variant_rows
262-
])
259+
]
263260

264-
signature = model.Apply("core.fn", [input_types, output_types, model.ExtSet()])
261+
signature = model.Apply("core.fn", [
262+
model.List(input_types),
263+
model.List(output_types),
264+
model.ExtSet()
265+
])
265266

266267
return model.Node(
267268
operation = model.Block(),
@@ -317,8 +318,8 @@ def export_region_module(self, node: Node) -> model.Region:
317318
def export_region_dfg(self, node: Node) -> model.Region:
318319
node_data = self.hugr[node]
319320
children: list[model.Node] = []
320-
source_types = model.Wildcard()
321-
target_types = model.Wildcard()
321+
source_types: model.Term = model.Wildcard()
322+
target_types: model.Term = model.Wildcard()
322323
sources = []
323324
targets = []
324325

@@ -355,8 +356,8 @@ def export_region_cfg(self, node: Node) -> model.Region:
355356

356357
source = None
357358
targets = []
358-
source_types = model.Wildcard()
359-
target_types = model.Wildcard()
359+
source_types: model.Term = model.Wildcard()
360+
target_types: model.Term = model.Wildcard()
360361
children = []
361362

362363
for child in node_data.children:
@@ -464,8 +465,10 @@ def _mangle_name(node: Node, name: str) -> str:
464465
# by adding the node id.
465466
return f"_{name}_{node.idx}"
466467

467-
class UnionFind[T]:
468-
def __init__(self):
468+
T = TypeVar('T')
469+
470+
class UnionFind(Generic[T]):
471+
def __init__(self) -> None:
469472
self.parents: dict[T, T] = {}
470473
self.sizes: dict[T, int] = {}
471474

hugr-py/src/hugr/tys.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -682,11 +682,12 @@ def __eq__(self, value):
682682
return super().__eq__(value)
683683

684684
def to_model(self) -> model.Term:
685-
args = [arg.to_model() for arg in self.args]
686-
687-
# TODO: This cast is only neccessary because `Type` can both be an
685+
# This cast is only neccessary because `Type` can both be an
688686
# actual type or a row variable.
689-
args = cast(list[model.Term], args)
687+
args = [
688+
cast(model.Term, arg.to_model())
689+
for arg in self.args
690+
]
690691

691692
extension_name = self.type_def.get_extension().name
692693
type_name = self.type_def.name
@@ -739,11 +740,12 @@ def __str__(self) -> str:
739740
return _type_str(self.id, self.args)
740741

741742
def to_model(self) -> model.Term:
742-
args = [arg.to_model() for arg in self.args]
743-
744-
# TODO: This cast is only neccessary because `Type` can both be an
743+
# This cast is only neccessary because `Type` can both be an
745744
# actual type or a row variable.
746-
args = cast(list[model.Term], args)
745+
args = [
746+
cast(model.Term, arg.to_model())
747+
for arg in self.args
748+
]
747749

748750
return model.Apply(self.id, args)
749751

0 commit comments

Comments
 (0)