13
13
from collections .abc import Iterable , Sequence
14
14
15
15
from hugr import ext
16
+ from hugr .ext import ExtensionRegistry
16
17
17
18
18
19
ExtensionId = stys .ExtensionId
@@ -55,6 +56,21 @@ def to_model(self) -> model.Term | model.Splice:
55
56
"""Convert the type argument to a model Term."""
56
57
raise NotImplementedError (self )
57
58
59
+ def used_extensions (self ) -> ExtensionRegistry :
60
+ """Get the set of extensions required to define this type argument.
61
+
62
+ Raises:
63
+ UnknownTypeExtensionError: if a type argument contains is a
64
+ :class:`Opaque` type that has not been resolved.
65
+
66
+ Example:
67
+ >>> TypeTypeArg(ty=Qubit).used_extensions().ids()
68
+ {'prelude'}
69
+ """
70
+ from hugr .ext import ExtensionRegistry
71
+
72
+ return ExtensionRegistry ()
73
+
58
74
59
75
@runtime_checkable
60
76
class Type (Protocol ):
@@ -95,10 +111,34 @@ def to_model(self) -> model.Term | model.Splice:
95
111
"""Convert the type to a model Term."""
96
112
raise NotImplementedError (self )
97
113
114
+ def used_extensions (self ) -> ExtensionRegistry :
115
+ """Get the set of extensions required to define this type.
116
+
117
+ Note that :class:`Opaque` types do not know their extension, so they
118
+ will raise an error. Use :meth:`resolve` to get the actual type
119
+ and then call this method.
120
+
121
+ Raises:
122
+ UnknownTypeExtensionError: if the type is an :class:`Opaque` type
123
+ and has not been resolved.
124
+
125
+ Example:
126
+ >>> Qubit.used_extensions().ids()
127
+ {'prelude'}
128
+ """
129
+ from hugr .ext import ExtensionRegistry
130
+
131
+ return ExtensionRegistry ()
132
+
98
133
99
134
#: Row of types.
100
135
TypeRow = list [Type ]
101
136
137
+
138
+ class UnknownTypeExtensionError (Exception ):
139
+ """Exception raised when querying the extension of an :method:`Opaque` type."""
140
+
141
+
102
142
# --------------------------------------------
103
143
# --------------- TypeParam ------------------
104
144
# --------------------------------------------
@@ -211,6 +251,9 @@ def __str__(self) -> str:
211
251
def to_model (self ) -> model .Term | model .Splice :
212
252
return self .ty .to_model ()
213
253
254
+ def used_extensions (self ) -> ExtensionRegistry :
255
+ return self .ty .used_extensions ()
256
+
214
257
215
258
@dataclass (frozen = True )
216
259
class BoundedNatArg (TypeArg ):
@@ -264,6 +307,12 @@ def to_model(self) -> model.Term:
264
307
# For now we assume that this is a list.
265
308
return model .List ([elem .to_model () for elem in self .elems ])
266
309
310
+ def used_extensions (self ) -> ExtensionRegistry :
311
+ reg = super ().used_extensions ()
312
+ for arg in self .elems :
313
+ reg .extend (arg .used_extensions ())
314
+ return reg
315
+
267
316
268
317
@dataclass (frozen = True )
269
318
class VariableArg (TypeArg ):
@@ -324,6 +373,13 @@ def to_model(self) -> model.Term:
324
373
)
325
374
return model .Apply ("core.adt" , [variants ])
326
375
376
+ def used_extensions (self ) -> ExtensionRegistry :
377
+ types = [ty for row in self .variant_rows for ty in row ]
378
+ reg = super ().used_extensions ()
379
+ for ty in types :
380
+ reg .extend (ty .used_extensions ())
381
+ return reg
382
+
327
383
328
384
@dataclass (eq = False )
329
385
class UnitSum (Sum ):
@@ -457,6 +513,13 @@ def __repr__(self) -> str:
457
513
def to_model (self ) -> model .Term :
458
514
return model .Apply ("prelude.usize" )
459
515
516
+ def used_extensions (self ) -> ExtensionRegistry :
517
+ from hugr .std .prelude import PRELUDE_EXTENSION
518
+
519
+ reg = super ().used_extensions ()
520
+ reg .add_extension (PRELUDE_EXTENSION )
521
+ return reg
522
+
460
523
461
524
@dataclass (frozen = True )
462
525
class Alias (Type ):
@@ -543,6 +606,14 @@ def to_model(self) -> model.Term:
543
606
outputs = model .List ([output .to_model () for output in self .output ])
544
607
return model .Apply ("core.fn" , [inputs , outputs ])
545
608
609
+ def used_extensions (self ) -> ExtensionRegistry :
610
+ reg = super ().used_extensions ()
611
+ for ty in self .input :
612
+ reg .extend (ty .used_extensions ())
613
+ for ty in self .output :
614
+ reg .extend (ty .used_extensions ())
615
+ return reg
616
+
546
617
547
618
@dataclass (frozen = True )
548
619
class PolyFuncType (Type ):
@@ -587,6 +658,9 @@ def to_model(self) -> model.Term:
587
658
error = "PolyFuncType used as a Type"
588
659
raise TypeError (error )
589
660
661
+ def used_extensions (self ) -> ExtensionRegistry :
662
+ return self .body .used_extensions ()
663
+
590
664
591
665
@dataclass
592
666
class ExtType (Type ):
@@ -632,7 +706,7 @@ def __eq__(self, value):
632
706
return super ().__eq__ (value )
633
707
634
708
def to_model (self ) -> model .Term :
635
- # This cast is only neccessary because `Type` can both be an
709
+ # This cast is only necessary because `Type` can both be an
636
710
# actual type or a row variable.
637
711
args = [cast (model .Term , arg .to_model ()) for arg in self .args ]
638
712
@@ -642,6 +716,11 @@ def to_model(self) -> model.Term:
642
716
643
717
return model .Apply (name , args )
644
718
719
+ def used_extensions (self ) -> ExtensionRegistry :
720
+ reg = super ().used_extensions ()
721
+ reg .add_extension (self .type_def .get_extension ())
722
+ return reg
723
+
645
724
646
725
def _type_str (name : str , args : Sequence [TypeArg ]) -> str :
647
726
if len (args ) == 0 :
@@ -693,6 +772,10 @@ def to_model(self) -> model.Term:
693
772
694
773
return model .Apply (self .id , args )
695
774
775
+ def used_extensions (self ) -> ExtensionRegistry :
776
+ msg = "Opaque types do not know their extension. Call `resolve` first."
777
+ raise UnknownTypeExtensionError (msg )
778
+
696
779
697
780
@dataclass
698
781
class _QubitDef (Type ):
@@ -708,6 +791,13 @@ def __repr__(self) -> str:
708
791
def to_model (self ) -> model .Term :
709
792
return model .Apply ("prelude.qubit" , [])
710
793
794
+ def used_extensions (self ) -> ExtensionRegistry :
795
+ from hugr .std .prelude import PRELUDE_EXTENSION
796
+
797
+ reg = super ().used_extensions ()
798
+ reg .add_extension (PRELUDE_EXTENSION )
799
+ return reg
800
+
711
801
712
802
#: Qubit type.
713
803
Qubit = _QubitDef ()
0 commit comments