Skip to content

Commit ac4c1c7

Browse files
committed
Merge branch 'master' of github.com:wesselb/plum
2 parents 1e5de72 + 300487f commit ac4c1c7

File tree

2 files changed

+146
-3
lines changed

2 files changed

+146
-3
lines changed

plum/parametric.py

+114-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Type, TypeVar, Union
22

33
import beartype.door
44
from beartype.roar import BeartypeDoorNonpepException
@@ -13,11 +13,14 @@
1313
"CovariantMeta",
1414
"parametric",
1515
"type_parameter",
16+
"type_unparametrized",
1617
"kind",
1718
"Kind",
1819
"Val",
1920
]
2021

22+
T = TypeVar("T")
23+
2124

2225
_dispatch = Dispatcher()
2326

@@ -274,11 +277,82 @@ def class_new(cls, *args, **kw_args):
274277
cls.__new__ = class_new
275278
super(original_class, cls).__init_subclass__(**kw_args)
276279

280+
def __class_nonparametric__(cls):
281+
"""Return the non-parametric type of an object.
282+
283+
:mod:`plum.parametric` produces parametric subtypes of classes. This
284+
method can be used to get the non-parametric type of an object.
285+
286+
Examples
287+
--------
288+
>>> from plum import parametric
289+
>>> @parametric
290+
... class Obj:
291+
... @classmethod
292+
... def __infer_type_parameter__(cls, *arg):
293+
... return type(arg[0])
294+
... def __init__(self, x):
295+
... self.x = x
296+
... def __repr__(self):
297+
... return f"Obj({self.x})"
298+
299+
>>> obj = Obj(1)
300+
>>> obj
301+
Obj(1)
302+
303+
>>> type(obj).mro()
304+
[Obj[int], Obj, object]
305+
306+
>>> obj.__class_nonparametric__().mro()
307+
[Obj, object]
308+
"""
309+
return original_class
310+
311+
def __class_unparametrized__(cls):
312+
"""Return the unparametrized type of an object.
313+
314+
:mod:`plum.parametric` produces parametric subtypes of classes. This
315+
method can be used to get the un-parametrized type of an object.
316+
317+
Examples
318+
--------
319+
>>> from plum import parametric
320+
>>> @parametric
321+
... class Obj:
322+
... @classmethod
323+
... def __infer_type_parameter__(cls, *arg):
324+
... return type(arg[0])
325+
... def __init__(self, x):
326+
... self.x = x
327+
... def __repr__(self):
328+
... return f"Obj({self.x})"
329+
330+
>>> obj = Obj(1)
331+
>>> obj
332+
Obj(1)
333+
334+
>>> type(obj).__name__
335+
Obj[int]
336+
337+
>>> obj.__class_unparametrized__().mro()
338+
[Obj, Obj, object]
339+
340+
Note that this is still NOT the 'original' non-`parametric`-wrapped
341+
type. This is the type that is wrapped by :mod:`plum.parametric`, but
342+
without the inferred type parameter(s).
343+
"""
344+
return parametric_class
345+
277346
# Create parametric class.
278347
parametric_class = meta(
279348
original_class.__name__,
280349
(original_class,),
281-
{"__new__": __new__, "__init_subclass__": __init_subclass__},
350+
{
351+
"__new__": __new__,
352+
"__init_subclass__": __init_subclass__,
353+
"__class_nonparametric__": __class_nonparametric__,
354+
"__class_unparametrized__": __class_unparametrized__,
355+
},
282356
)
283357
parametric_class._parametric = True
284358
parametric_class._concrete = False
@@ -356,6 +430,44 @@ def type_parameter(x):
356430
)
357431

358432

433+
def type_unparametrized(q: T) -> Type[T]:
434+
"""Return the unparametrized type of an object.
435+
436+
:mod:`plum.parametric` produces parametric subtypes of classes. This
437+
function can be used to get the un-parametrized type of an object.
438+
This function also works for normal, :mod:`plum.parametric`-wrapped classes.
439+
440+
Examples
441+
--------
442+
>>> from plum import parametric
443+
>>> @parametric
444+
... class Obj:
445+
... @classmethod
446+
... def __infer_type_parameter__(cls, *arg):
447+
... return type(arg[0])
448+
... def __init__(self, x):
449+
... self.x = x
450+
... def __repr__(self):
451+
... return f"Obj({self.x})"
452+
453+
>>> obj = Obj(1)
454+
>>> obj
455+
Obj(1)
456+
457+
>>> type(obj).__name__
458+
Obj[int]
459+
460+
>>> type_unparametrized(obj).__name__
461+
Obj
462+
463+
Note that this is still NOT the 'original' non-`parametric`-wrapped type.
464+
This is the type that is wrapped by :mod:`plum.parametric`, but without the
465+
inferred type parameter(s).
466+
"""
467+
typ = type(q)
468+
return q.__class_unparametrized__() if isinstance(typ, ParametricTypeMeta) else typ
469+
470+
359471
def kind(SuperClass=object):
360472
"""Create a parametric wrapper type for dispatch purposes.
361473

tests/test_parametric.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
parametric,
1616
type_parameter,
1717
)
18-
from plum.parametric import CovariantMeta, is_concrete, is_type
18+
from plum.parametric import CovariantMeta, is_concrete, is_type, type_unparametrized
1919

2020

2121
def test_covariantmeta():
@@ -60,6 +60,15 @@ class A(Base1, metaclass=metaclass):
6060
assert issubclass(type(a1), Base1)
6161
assert not issubclass(type(a1), Base2)
6262

63+
assert a1.__class_unparametrized__() is A
64+
assert a2.__class_unparametrized__() is A
65+
66+
# Here we are testing that the class returned by `__class_nonparametric__`
67+
# is the 'original' class that the @parametric decorator was applied to.
68+
assert a1.__class_nonparametric__() is A.mro()[1]
69+
assert issubclass(a2.__class_nonparametric__(), Base1)
70+
assert a2.__class_nonparametric__() is not Base1
71+
6372
# Test multiple type parameters.
6473
assert A[1, 2] == A[1, 2]
6574

@@ -575,3 +584,25 @@ class Wrapper(Pytree):
575584

576585
Wrapper[int]
577586
assert Wrapper[int] in register
587+
588+
589+
def test_type_unparametrized():
590+
"""Test the `type_unparametrized` function."""
591+
592+
@parametric
593+
class Obj:
594+
@classmethod
595+
def __infer_type_parameter__(cls, *arg):
596+
return type(arg[0])
597+
598+
def __init__(self, x):
599+
self.x = x
600+
601+
def __repr__(self):
602+
return f"Obj({self.x})"
603+
604+
pobj = Obj(1)
605+
606+
assert type(pobj) is Obj[int]
607+
assert type_unparametrized(pobj) is not Obj[int]
608+
assert type_unparametrized(pobj) is Obj

0 commit comments

Comments
 (0)