Skip to content

Commit f7c4ba4

Browse files
committed
Fixed various bugs in python -> model conversion.
1 parent cacb86e commit f7c4ba4

File tree

7 files changed

+92
-17
lines changed

7 files changed

+92
-17
lines changed

hugr-model/src/v0/ast/python.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ impl<'py> pyo3::FromPyObject<'py> for Term {
3939
let region = term.getattr("region")?.extract()?;
4040
Self::Func(Arc::new(region))
4141
}
42-
_ => return Err(PyTypeError::new_err("Unknown Term type.")),
42+
_ => {
43+
return Err(PyTypeError::new_err(format!(
44+
"Unknown Term type: {}.",
45+
name.to_str()?
46+
)))
47+
}
4348
})
4449
}
4550
}
2.96 MB
Binary file not shown.

hugr-py/src/hugr/hugr/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from hugr.utils import BiMap
2323
from hugr.val import Value
2424
import hugr.model as model
25-
import hugr.model.export as model_export
2625

2726
from .node_port import (
2827
Direction,
@@ -732,7 +731,8 @@ def to_model(self) -> model.Module:
732731
return model.Module(self.to_model_region())
733732

734733
def to_model_region(self) -> model.Region:
735-
export = model_export.Export(self)
734+
from hugr.model.export import ModelExport
735+
export = ModelExport(self)
736736
return export.export_region_module(self.root)
737737

738738
@classmethod

hugr-py/src/hugr/model/export.py

+67-11
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing_extensions import Optional
22
import hugr.model as model
33
from hugr.hugr.base import Hugr, Node
4-
from hugr.ops import DFG, Input, Output, Custom, ExtOp, Conditional, TailLoop, FuncDefn, FuncDecl, Call, CallIndirect, LoadFunc, AliasDecl, AliasDefn, LoadConst, Const, CFG, ExitBlock, DataflowBlock
4+
from hugr.ops import DFG, Input, Output, Custom, AsExtOp, Conditional, TailLoop, FuncDefn, FuncDecl, Call, CallIndirect, LoadFunc, AliasDecl, AliasDefn, LoadConst, Const, CFG, ExitBlock, DataflowBlock, Tag
55
from hugr.hugr.node_port import InPort, OutPort
66
from hugr.tys import FunctionKind, ConstKind, TypeParam, Type, TypeTypeParam, TypeBound
77
from typing import cast, Sequence
88

9-
class Export:
9+
class ModelExport:
1010
def __init__(self, hugr: Hugr):
1111
self.hugr = hugr
1212
self.link_ports: UnionFind[InPort | OutPort] = UnionFind()
@@ -19,11 +19,11 @@ def link_name(self, port):
1919
root = self.link_ports[port]
2020

2121
if root in self.link_names:
22+
return self.link_names[root]
23+
else:
2224
index = str(len(self.link_names))
2325
self.link_names[root] = index
2426
return index
25-
else:
26-
return self.link_names[root]
2727

2828
def export_node(self, node: Node) -> Optional[model.Node]:
2929
node_data = self.hugr[node]
@@ -55,9 +55,9 @@ def export_node(self, node: Node) -> Optional[model.Node]:
5555
outputs = outputs
5656
)
5757

58-
case ExtOp() as op:
58+
case AsExtOp() as op:
5959
name = op.op_def().qualified_name()
60-
args = cast(list[model.Term], [arg.to_model() for arg in op.args])
60+
args = cast(list[model.Term], [arg.to_model() for arg in op.type_args()])
6161
signature = op.outer_signature().to_model()
6262

6363
return model.Node(
@@ -241,6 +241,62 @@ def export_node(self, node: Node) -> Optional[model.Node]:
241241
regions = [region]
242242
)
243243

244+
case DataflowBlock() as op:
245+
region = self.export_region_dfg(node)
246+
247+
input_types = model.List([
248+
model.Apply("core.ctrl", [
249+
model.List([ type.to_model() for type in op.inputs ])
250+
])
251+
])
252+
253+
other_output_types = [ type.to_model() for type in op.other_outputs ]
254+
output_types = model.List([
255+
model.Apply("core.ctrl", [model.List([
256+
*[type.to_model() for type in row],
257+
*other_output_types
258+
])])
259+
for row in op.sum_ty.variant_rows
260+
])
261+
262+
signature = model.Apply("core.fn", [input_types, output_types, model.ExtSet()])
263+
264+
return model.Node(
265+
operation = model.Block(),
266+
inputs = inputs,
267+
outputs = outputs,
268+
regions = [region],
269+
signature = signature
270+
)
271+
272+
case Tag() as op:
273+
variants = model.List([
274+
model.List([type.to_model() for type in row])
275+
for row in op.sum_ty.variant_rows
276+
])
277+
278+
types = model.List([
279+
type.to_model()
280+
for type in op.sum_ty.variant_rows[op.tag]
281+
])
282+
283+
tag = model.Literal(op.tag)
284+
signature = op.outer_signature().to_model()
285+
286+
return model.Node(
287+
operation = model.CustomOp(model.Apply("core.make_adt", [
288+
variants,
289+
types,
290+
tag
291+
])),
292+
inputs = inputs,
293+
outputs = outputs,
294+
signature = signature
295+
)
296+
297+
case op:
298+
raise ValueError(f"Unknown operation: {op}")
299+
244300
def export_region_module(self, node: Node) -> model.Region:
245301
node_data = self.hugr[node]
246302
children = []
@@ -368,13 +424,13 @@ def find_func_input(self, node: Node) -> Optional[str]:
368424
func_node = next(
369425
out_port.node
370426
for (in_port, out_ports) in self.hugr.incoming_links(node)
371-
if isinstance(self.hugr.port_kind, FunctionKind)
427+
if isinstance(self.hugr.port_kind(in_port), FunctionKind)
372428
for out_port in out_ports
373429
)
374430
except StopIteration:
375431
return None
376432

377-
match self.hugr[func_node]:
433+
match self.hugr[func_node].op:
378434
case FuncDecl() as func_op:
379435
return func_op.f_name
380436
case FuncDefn() as func_op:
@@ -387,16 +443,16 @@ def find_const_input(self, node: Node) -> Optional[model.Term]:
387443
const_node = next(
388444
out_port.node
389445
for (in_port, out_ports) in self.hugr.incoming_links(node)
390-
if isinstance(self.hugr.port_kind, ConstKind)
446+
if isinstance(self.hugr.port_kind(in_port), ConstKind)
391447
for out_port in out_ports
392448
)
393449
except StopIteration:
394450
return None
395451

396-
match self.hugr[const_node]:
452+
match self.hugr[const_node].op:
397453
case Const() as op:
398454
return op.val.to_model()
399-
case _:
455+
case op:
400456
return None
401457

402458
class UnionFind[T]:

hugr-py/src/hugr/std/int.py

+6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from hugr import ext, tys, val
1111
from hugr.ops import AsExtOp, DataflowOp, ExtOp, RegisteredOp
1212
from hugr.std import _load_extension
13+
import hugr.model as model
1314

1415
if TYPE_CHECKING:
1516
from hugr.ops import Command, ComWire
@@ -71,6 +72,11 @@ def to_value(self) -> val.Extension:
7172
def __str__(self) -> str:
7273
return f"{self.v}"
7374

75+
def to_model(self) -> model.Term:
76+
return model.Apply("arithmetic.int.const", [
77+
model.Literal(self.width),
78+
model.Literal(self.v)
79+
])
7480

7581
INT_OPS_EXTENSION = _load_extension("arithmetic.int")
7682

hugr-py/src/hugr/tys.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ def __str__(self) -> str:
274274
return f"({comma_sep_str(self.elems)})"
275275

276276
def to_model(self) -> model.Term:
277-
# TODO: Should this be a tuple or a list?
278-
...
277+
# TODO: We should separate lists and tuples. For now we assume that this is a list.
278+
return model.List([elem.to_model() for elem in self.elems])
279279

280280

281281
@dataclass(frozen=True)
@@ -686,7 +686,11 @@ def to_model(self) -> model.Term:
686686
# actual type or a row variable.
687687
args = cast(list[model.Term], args)
688688

689-
return model.Apply(self.type_def.name, args)
689+
extension_name = self.type_def.get_extension().name
690+
type_name = self.type_def.name
691+
name = f"{extension_name}.{type_name}"
692+
693+
return model.Apply(name, args)
690694

691695

692696
def _type_str(name: str, args: Sequence[TypeArg]) -> str:

hugr-py/src/hugr/val.py

+4
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,7 @@ def type_(self) -> tys.Type:
342342

343343
def _to_serial(self) -> sops.CustomValue:
344344
return self.to_value()._to_serial()
345+
346+
def to_model(self) -> model.Term:
347+
# Fallback
348+
return self.to_value().to_model()

0 commit comments

Comments
 (0)