Skip to content

Commit 7765f55

Browse files
committed
feat: Add TypeArg/Type::used_extensions
1 parent fe98ba1 commit 7765f55

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

hugr-py/src/hugr/ext.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,14 @@ class ExtensionExists(Exception):
415415

416416
extension_id: ExtensionId
417417

418+
def ids(self) -> set[ExtensionId]:
419+
"""Get the set of extension IDs in the registry.
420+
421+
Returns:
422+
Set of extension IDs.
423+
"""
424+
return set(self.extensions.keys())
425+
418426
def add_extension(self, extension: Extension) -> Extension:
419427
"""Add an extension to the registry.
420428
@@ -448,3 +456,20 @@ def get_extension(self, name: ExtensionId) -> Extension:
448456
return self.extensions[name]
449457
except KeyError as e:
450458
raise self.ExtensionNotFound(name) from e
459+
460+
def extend(self, other: ExtensionRegistry) -> None:
461+
"""Add a registry of extensions to this registry.
462+
463+
If an extension with the same name already exists, the one with the
464+
higher version is kept.
465+
466+
Args:
467+
other: The extension registry to add.
468+
"""
469+
for name, ext in other.extensions.items():
470+
if name in self.extensions and self.extensions[name].version >= ext.version:
471+
continue
472+
self.extensions[name] = ext
473+
474+
def __str__(self) -> str:
475+
return "ExtensionRegistry(" + ", ".join(self.extensions.keys()) + ")"

hugr-py/src/hugr/tys.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections.abc import Iterable, Sequence
1414

1515
from hugr import ext
16+
from hugr.ext import ExtensionRegistry
1617

1718

1819
ExtensionId = stys.ExtensionId
@@ -55,6 +56,21 @@ def to_model(self) -> model.Term | model.Splice:
5556
"""Convert the type argument to a model Term."""
5657
raise NotImplementedError(self)
5758

59+
def used_extensions(self) -> ExtensionRegistry:
60+
"""Get the set of extensions required to define this type argument.
61+
62+
Raises:
63+
UnknownTypeExtensionError: if a type argument contains is a
64+
:class:`Opaque` type that has not been resolved.
65+
66+
Example:
67+
>>> TypeTypeArg(ty=Qubit).used_extensions().ids()
68+
{'prelude'}
69+
"""
70+
from hugr.ext import ExtensionRegistry
71+
72+
return ExtensionRegistry()
73+
5874

5975
@runtime_checkable
6076
class Type(Protocol):
@@ -95,10 +111,34 @@ def to_model(self) -> model.Term | model.Splice:
95111
"""Convert the type to a model Term."""
96112
raise NotImplementedError(self)
97113

114+
def used_extensions(self) -> ExtensionRegistry:
115+
"""Get the set of extensions required to define this type.
116+
117+
Note that :class:`Opaque` types do not know their extension, so they
118+
will raise an error. Use :meth:`resolve` to get the actual type
119+
and then call this method.
120+
121+
Raises:
122+
UnknownTypeExtensionError: if the type is an :class:`Opaque` type
123+
and has not been resolved.
124+
125+
Example:
126+
>>> Qubit.used_extensions().ids()
127+
{'prelude'}
128+
"""
129+
from hugr.ext import ExtensionRegistry
130+
131+
return ExtensionRegistry()
132+
98133

99134
#: Row of types.
100135
TypeRow = list[Type]
101136

137+
138+
class UnknownTypeExtensionError(Exception):
139+
"""Exception raised when querying the extension of an :method:`Opaque` type."""
140+
141+
102142
# --------------------------------------------
103143
# --------------- TypeParam ------------------
104144
# --------------------------------------------
@@ -211,6 +251,9 @@ def __str__(self) -> str:
211251
def to_model(self) -> model.Term | model.Splice:
212252
return self.ty.to_model()
213253

254+
def used_extensions(self) -> ExtensionRegistry:
255+
return self.ty.used_extensions()
256+
214257

215258
@dataclass(frozen=True)
216259
class BoundedNatArg(TypeArg):
@@ -264,6 +307,12 @@ def to_model(self) -> model.Term:
264307
# For now we assume that this is a list.
265308
return model.List([elem.to_model() for elem in self.elems])
266309

310+
def used_extensions(self) -> ExtensionRegistry:
311+
reg = super().used_extensions()
312+
for arg in self.elems:
313+
reg.extend(arg.used_extensions())
314+
return reg
315+
267316

268317
@dataclass(frozen=True)
269318
class VariableArg(TypeArg):
@@ -324,6 +373,13 @@ def to_model(self) -> model.Term:
324373
)
325374
return model.Apply("core.adt", [variants])
326375

376+
def used_extensions(self) -> ExtensionRegistry:
377+
types = [ty for row in self.variant_rows for ty in row]
378+
reg = super().used_extensions()
379+
for ty in types:
380+
reg.extend(ty.used_extensions())
381+
return reg
382+
327383

328384
@dataclass(eq=False)
329385
class UnitSum(Sum):
@@ -457,6 +513,13 @@ def __repr__(self) -> str:
457513
def to_model(self) -> model.Term:
458514
return model.Apply("prelude.usize")
459515

516+
def used_extensions(self) -> ExtensionRegistry:
517+
from hugr.std.prelude import PRELUDE_EXTENSION
518+
519+
reg = super().used_extensions()
520+
reg.add_extension(PRELUDE_EXTENSION)
521+
return reg
522+
460523

461524
@dataclass(frozen=True)
462525
class Alias(Type):
@@ -543,6 +606,14 @@ def to_model(self) -> model.Term:
543606
outputs = model.List([output.to_model() for output in self.output])
544607
return model.Apply("core.fn", [inputs, outputs])
545608

609+
def used_extensions(self) -> ExtensionRegistry:
610+
reg = super().used_extensions()
611+
for ty in self.input:
612+
reg.extend(ty.used_extensions())
613+
for ty in self.output:
614+
reg.extend(ty.used_extensions())
615+
return reg
616+
546617

547618
@dataclass(frozen=True)
548619
class PolyFuncType(Type):
@@ -587,6 +658,9 @@ def to_model(self) -> model.Term:
587658
error = "PolyFuncType used as a Type"
588659
raise TypeError(error)
589660

661+
def used_extensions(self) -> ExtensionRegistry:
662+
return self.body.used_extensions()
663+
590664

591665
@dataclass
592666
class ExtType(Type):
@@ -632,7 +706,7 @@ def __eq__(self, value):
632706
return super().__eq__(value)
633707

634708
def to_model(self) -> model.Term:
635-
# This cast is only neccessary because `Type` can both be an
709+
# This cast is only necessary because `Type` can both be an
636710
# actual type or a row variable.
637711
args = [cast(model.Term, arg.to_model()) for arg in self.args]
638712

@@ -642,6 +716,11 @@ def to_model(self) -> model.Term:
642716

643717
return model.Apply(name, args)
644718

719+
def used_extensions(self) -> ExtensionRegistry:
720+
reg = super().used_extensions()
721+
reg.add_extension(self.type_def.get_extension())
722+
return reg
723+
645724

646725
def _type_str(name: str, args: Sequence[TypeArg]) -> str:
647726
if len(args) == 0:
@@ -693,6 +772,10 @@ def to_model(self) -> model.Term:
693772

694773
return model.Apply(self.id, args)
695774

775+
def used_extensions(self) -> ExtensionRegistry:
776+
msg = "Opaque types do not know their extension. Call `resolve` first."
777+
raise UnknownTypeExtensionError(msg)
778+
696779

697780
@dataclass
698781
class _QubitDef(Type):
@@ -708,6 +791,13 @@ def __repr__(self) -> str:
708791
def to_model(self) -> model.Term:
709792
return model.Apply("prelude.qubit", [])
710793

794+
def used_extensions(self) -> ExtensionRegistry:
795+
from hugr.std.prelude import PRELUDE_EXTENSION
796+
797+
reg = super().used_extensions()
798+
reg.add_extension(PRELUDE_EXTENSION)
799+
return reg
800+
711801

712802
#: Qubit type.
713803
Qubit = _QubitDef()

0 commit comments

Comments
 (0)