Skip to content

Commit 49fd266

Browse files
committed
Fixed various bugs in python -> model conversion.
1 parent cacb86e commit 49fd266

File tree

7 files changed

+99
-17
lines changed

7 files changed

+99
-17
lines changed

hugr-model/src/v0/ast/python.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ impl<'py> pyo3::FromPyObject<'py> for Term {
3939
let region = term.getattr("region")?.extract()?;
4040
Self::Func(Arc::new(region))
4141
}
42-
_ => return Err(PyTypeError::new_err("Unknown Term type.")),
42+
_ => {
43+
return Err(PyTypeError::new_err(format!(
44+
"Unknown Term type: {}.",
45+
name.to_str()?
46+
)))
47+
}
4348
})
4449
}
4550
}
2.96 MB
Binary file not shown.

hugr-py/src/hugr/hugr/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from hugr.utils import BiMap
2323
from hugr.val import Value
2424
import hugr.model as model
25-
import hugr.model.export as model_export
2625

2726
from .node_port import (
2827
Direction,
@@ -732,7 +731,8 @@ def to_model(self) -> model.Module:
732731
return model.Module(self.to_model_region())
733732

734733
def to_model_region(self) -> model.Region:
735-
export = model_export.Export(self)
734+
from hugr.model.export import ModelExport
735+
export = ModelExport(self)
736736
return export.export_region_module(self.root)
737737

738738
@classmethod

hugr-py/src/hugr/model/export.py

+74-11
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,43 @@
11
from typing_extensions import Optional
22
import hugr.model as model
33
from hugr.hugr.base import Hugr, Node
4-
from hugr.ops import DFG, Input, Output, Custom, ExtOp, Conditional, TailLoop, FuncDefn, FuncDecl, Call, CallIndirect, LoadFunc, AliasDecl, AliasDefn, LoadConst, Const, CFG, ExitBlock, DataflowBlock
4+
from hugr.ops import DFG, Input, Output, Custom, AsExtOp, Conditional, TailLoop, FuncDefn, FuncDecl, Call, CallIndirect, LoadFunc, AliasDecl, AliasDefn, LoadConst, Const, CFG, ExitBlock, DataflowBlock, Tag
55
from hugr.hugr.node_port import InPort, OutPort
66
from hugr.tys import FunctionKind, ConstKind, TypeParam, Type, TypeTypeParam, TypeBound
77
from typing import cast, Sequence
88

9-
class Export:
9+
class ModelExport:
1010
def __init__(self, hugr: Hugr):
1111
self.hugr = hugr
1212
self.link_ports: UnionFind[InPort | OutPort] = UnionFind()
1313
self.link_names: dict[InPort | OutPort, str] = {}
1414

1515
for (a, b) in self.hugr.links():
16+
print(f"Link {a} and {b}")
1617
self.link_ports.union(a, b)
1718

1819
def link_name(self, port):
1920
root = self.link_ports[port]
21+
# print(f"Port {port} -> {root}")
2022

2123
if root in self.link_names:
24+
# print(f"Known name {self.link_names[root]} for {root}")
25+
return self.link_names[root]
26+
else:
2227
index = str(len(self.link_names))
28+
# print(f"New name {index} for {root}")
2329
self.link_names[root] = index
2430
return index
25-
else:
26-
return self.link_names[root]
2731

2832
def export_node(self, node: Node) -> Optional[model.Node]:
2933
node_data = self.hugr[node]
3034

3135
inputs = [self.link_name(InPort(node, i)) for i in range(node_data._num_inps)]
3236
outputs = [self.link_name(OutPort(node, i)) for i in range(node_data._num_outs)]
3337

38+
print(f"node {node} {inputs} {outputs}")
39+
print(node_data.op)
40+
3441
match node_data.op:
3542
case DFG() as op:
3643
region = self.export_region_dfg(node)
@@ -55,9 +62,9 @@ def export_node(self, node: Node) -> Optional[model.Node]:
5562
outputs = outputs
5663
)
5764

58-
case ExtOp() as op:
65+
case AsExtOp() as op:
5966
name = op.op_def().qualified_name()
60-
args = cast(list[model.Term], [arg.to_model() for arg in op.args])
67+
args = cast(list[model.Term], [arg.to_model() for arg in op.type_args()])
6168
signature = op.outer_signature().to_model()
6269

6370
return model.Node(
@@ -241,6 +248,62 @@ def export_node(self, node: Node) -> Optional[model.Node]:
241248
regions = [region]
242249
)
243250

251+
case DataflowBlock() as op:
252+
region = self.export_region_dfg(node)
253+
254+
input_types = model.List([
255+
model.Apply("core.ctrl", [
256+
model.List([ type.to_model() for type in op.inputs ])
257+
])
258+
])
259+
260+
other_output_types = [ type.to_model() for type in op.other_outputs ]
261+
output_types = model.List([
262+
model.Apply("core.ctrl", [model.List([
263+
*[type.to_model() for type in row],
264+
*other_output_types
265+
])])
266+
for row in op.sum_ty.variant_rows
267+
])
268+
269+
signature = model.Apply("core.fn", [input_types, output_types, model.ExtSet()])
270+
271+
return model.Node(
272+
operation = model.Block(),
273+
inputs = inputs,
274+
outputs = outputs,
275+
regions = [region],
276+
signature = signature
277+
)
278+
279+
case Tag() as op:
280+
variants = model.List([
281+
model.List([type.to_model() for type in row])
282+
for row in op.sum_ty.variant_rows
283+
])
284+
285+
types = model.List([
286+
type.to_model()
287+
for type in op.sum_ty.variant_rows[op.tag]
288+
])
289+
290+
tag = model.Literal(op.tag)
291+
signature = op.outer_signature().to_model()
292+
293+
return model.Node(
294+
operation = model.CustomOp(model.Apply("core.make_adt", [
295+
variants,
296+
types,
297+
tag
298+
])),
299+
inputs = inputs,
300+
outputs = outputs,
301+
signature = signature
302+
)
303+
304+
case op:
305+
raise ValueError(f"Unknown operation: {op}")
306+
244307
def export_region_module(self, node: Node) -> model.Region:
245308
node_data = self.hugr[node]
246309
children = []
@@ -368,13 +431,13 @@ def find_func_input(self, node: Node) -> Optional[str]:
368431
func_node = next(
369432
out_port.node
370433
for (in_port, out_ports) in self.hugr.incoming_links(node)
371-
if isinstance(self.hugr.port_kind, FunctionKind)
434+
if isinstance(self.hugr.port_kind(in_port), FunctionKind)
372435
for out_port in out_ports
373436
)
374437
except StopIteration:
375438
return None
376439

377-
match self.hugr[func_node]:
440+
match self.hugr[func_node].op:
378441
case FuncDecl() as func_op:
379442
return func_op.f_name
380443
case FuncDefn() as func_op:
@@ -387,16 +450,16 @@ def find_const_input(self, node: Node) -> Optional[model.Term]:
387450
const_node = next(
388451
out_port.node
389452
for (in_port, out_ports) in self.hugr.incoming_links(node)
390-
if isinstance(self.hugr.port_kind, ConstKind)
453+
if isinstance(self.hugr.port_kind(in_port), ConstKind)
391454
for out_port in out_ports
392455
)
393456
except StopIteration:
394457
return None
395458

396-
match self.hugr[const_node]:
459+
match self.hugr[const_node].op:
397460
case Const() as op:
398461
return op.val.to_model()
399-
case _:
462+
case op:
400463
return None
401464

402465
class UnionFind[T]:

hugr-py/src/hugr/std/int.py

+6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from hugr import ext, tys, val
1111
from hugr.ops import AsExtOp, DataflowOp, ExtOp, RegisteredOp
1212
from hugr.std import _load_extension
13+
import hugr.model as model
1314

1415
if TYPE_CHECKING:
1516
from hugr.ops import Command, ComWire
@@ -71,6 +72,11 @@ def to_value(self) -> val.Extension:
7172
def __str__(self) -> str:
7273
return f"{self.v}"
7374

75+
def to_model(self) -> model.Term:
76+
return model.Apply("arithmetic.int.const", [
77+
model.Literal(self.width),
78+
model.Literal(self.v)
79+
])
7480

7581
INT_OPS_EXTENSION = _load_extension("arithmetic.int")
7682

hugr-py/src/hugr/tys.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ def __str__(self) -> str:
274274
return f"({comma_sep_str(self.elems)})"
275275

276276
def to_model(self) -> model.Term:
277-
# TODO: Should this be a tuple or a list?
278-
...
277+
# TODO: We should separate lists and tuples. For now we assume that this is a list.
278+
return model.List([elem.to_model() for elem in self.elems])
279279

280280

281281
@dataclass(frozen=True)
@@ -686,7 +686,11 @@ def to_model(self) -> model.Term:
686686
# actual type or a row variable.
687687
args = cast(list[model.Term], args)
688688

689-
return model.Apply(self.type_def.name, args)
689+
extension_name = self.type_def.get_extension().name
690+
type_name = self.type_def.name
691+
name = f"{extension_name}.{type_name}"
692+
693+
return model.Apply(name, args)
690694

691695

692696
def _type_str(name: str, args: Sequence[TypeArg]) -> str:

hugr-py/src/hugr/val.py

+4
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,7 @@ def type_(self) -> tys.Type:
342342

343343
def _to_serial(self) -> sops.CustomValue:
344344
return self.to_value()._to_serial()
345+
346+
def to_model(self) -> model.Term:
347+
# Fallback
348+
return self.to_value().to_model()

0 commit comments

Comments
 (0)