Skip to content

Commit f63829c

Browse files
authored
Fix subclassing builtin protocols on older Python versions (#650)
This was recently fixed in CPython repo (while adding `Protocol` there). This PR backports the fix to all the older versions including now Python 2 version in `typing`. I also removed the unnecessary conditional import in Python 2 tests since I stumbled at it while adding the test.
1 parent 537a104 commit f63829c

File tree

4 files changed

+70
-27
lines changed

4 files changed

+70
-27
lines changed

python2/test_typing.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626
from typing import Pattern, Match
2727
import typing
2828
import weakref
29-
try:
30-
import collections.abc as collections_abc
31-
except ImportError:
32-
import collections as collections_abc # Fallback for PY3.2.
29+
import collections
3330

3431

3532
class BaseTestCase(TestCase):
@@ -1055,7 +1052,7 @@ def test_supports_index(self):
10551052
self.assertIsSubclass(int, typing.SupportsIndex)
10561053
self.assertNotIsSubclass(str, typing.SupportsIndex)
10571054

1058-
def test_protocol_instance_type_error(self):
1055+
def test_protocol_instance_works(self):
10591056
self.assertIsInstance(0, typing.SupportsAbs)
10601057
self.assertNotIsInstance('no', typing.SupportsAbs)
10611058
class C1(typing.SupportsInt):
@@ -1066,6 +1063,21 @@ class C2(C1):
10661063
c = C2()
10671064
self.assertIsInstance(c, C1)
10681065

1066+
def test_collections_protocols_allowed(self):
1067+
@runtime_checkable
1068+
class Custom(collections.Iterable, Protocol):
1069+
def close(self): pass
1070+
1071+
class A(object): pass
1072+
class B(object):
1073+
def __iter__(self):
1074+
return []
1075+
def close(self):
1076+
return 0
1077+
1078+
self.assertIsSubclass(B, Custom)
1079+
self.assertNotIsSubclass(A, Custom)
1080+
10691081

10701082
class GenericTests(BaseTestCase):
10711083

@@ -1243,17 +1255,17 @@ def __len__(self):
12431255
return 0
12441256
# this should just work
12451257
MM().update()
1246-
self.assertIsInstance(MM(), collections_abc.MutableMapping)
1258+
self.assertIsInstance(MM(), collections.MutableMapping)
12471259
self.assertIsInstance(MM(), MutableMapping)
12481260
self.assertNotIsInstance(MM(), List)
12491261
self.assertNotIsInstance({}, MM)
12501262

12511263
def test_multiple_bases(self):
1252-
class MM1(MutableMapping[str, str], collections_abc.MutableMapping):
1264+
class MM1(MutableMapping[str, str], collections.MutableMapping):
12531265
pass
12541266
with self.assertRaises(TypeError):
12551267
# consistent MRO not possible
1256-
class MM2(collections_abc.MutableMapping, MutableMapping[str, str]):
1268+
class MM2(collections.MutableMapping, MutableMapping[str, str]):
12571269
pass
12581270

12591271
def test_orig_bases(self):
@@ -1426,9 +1438,9 @@ def __call__(self):
14261438

14271439
self.assertEqual(repr(C1[int]).split('.')[-1], 'C1[int]')
14281440
self.assertEqual(C2.__parameters__, ())
1429-
self.assertIsInstance(C2(), collections_abc.Callable)
1430-
self.assertIsSubclass(C2, collections_abc.Callable)
1431-
self.assertIsSubclass(C1, collections_abc.Callable)
1441+
self.assertIsInstance(C2(), collections.Callable)
1442+
self.assertIsSubclass(C2, collections.Callable)
1443+
self.assertIsSubclass(C1, collections.Callable)
14321444
self.assertIsInstance(T1(), tuple)
14331445
self.assertIsSubclass(T2, tuple)
14341446
self.assertIsSubclass(Tuple[int, ...], typing.Sequence)
@@ -2287,11 +2299,11 @@ def __len__(self):
22872299
self.assertIsSubclass(MMC, typing.Mapping)
22882300

22892301
self.assertIsInstance(MMB[KT, VT](), typing.Mapping)
2290-
self.assertIsInstance(MMB[KT, VT](), collections_abc.Mapping)
2302+
self.assertIsInstance(MMB[KT, VT](), collections.Mapping)
22912303

2292-
self.assertIsSubclass(MMA, collections_abc.Mapping)
2293-
self.assertIsSubclass(MMB, collections_abc.Mapping)
2294-
self.assertIsSubclass(MMC, collections_abc.Mapping)
2304+
self.assertIsSubclass(MMA, collections.Mapping)
2305+
self.assertIsSubclass(MMB, collections.Mapping)
2306+
self.assertIsSubclass(MMC, collections.Mapping)
22952307

22962308
self.assertIsSubclass(MMB[str, str], typing.Mapping)
22972309
self.assertIsSubclass(MMC, MMA)
@@ -2303,9 +2315,9 @@ class G(typing.Generator[int, int, int]): pass
23032315
def g(): yield 0
23042316
self.assertIsSubclass(G, typing.Generator)
23052317
self.assertIsSubclass(G, typing.Iterable)
2306-
if hasattr(collections_abc, 'Generator'):
2307-
self.assertIsSubclass(G, collections_abc.Generator)
2308-
self.assertIsSubclass(G, collections_abc.Iterable)
2318+
if hasattr(collections, 'Generator'):
2319+
self.assertIsSubclass(G, collections.Generator)
2320+
self.assertIsSubclass(G, collections.Iterable)
23092321
self.assertNotIsSubclass(type(g), G)
23102322

23112323
def test_subclassing_subclasshook(self):
@@ -2341,23 +2353,23 @@ class D: pass
23412353
self.assertIsSubclass(D, B)
23422354

23432355
class M(): pass
2344-
collections_abc.MutableMapping.register(M)
2356+
collections.MutableMapping.register(M)
23452357
self.assertIsSubclass(M, typing.Mapping)
23462358

23472359
def test_collections_as_base(self):
23482360

2349-
class M(collections_abc.Mapping): pass
2361+
class M(collections.Mapping): pass
23502362
self.assertIsSubclass(M, typing.Mapping)
23512363
self.assertIsSubclass(M, typing.Iterable)
23522364

2353-
class S(collections_abc.MutableSequence): pass
2365+
class S(collections.MutableSequence): pass
23542366
self.assertIsSubclass(S, typing.MutableSequence)
23552367
self.assertIsSubclass(S, typing.Iterable)
23562368

2357-
class I(collections_abc.Iterable): pass
2369+
class I(collections.Iterable): pass
23582370
self.assertIsSubclass(I, typing.Iterable)
23592371

2360-
class A(collections_abc.Mapping): pass
2372+
class A(collections.Mapping): pass
23612373
class B: pass
23622374
A.register(B)
23632375
self.assertIsSubclass(B, typing.Mapping)

python2/typing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,11 @@ def utf8(value):
17591759
return _overload_dummy
17601760

17611761

1762+
_PROTO_WHITELIST = ['Callable', 'Iterable', 'Iterator',
1763+
'Hashable', 'Sized', 'Container', 'Collection',
1764+
'Reversible', 'ContextManager']
1765+
1766+
17621767
class _ProtocolMeta(GenericMeta):
17631768
"""Internal metaclass for Protocol.
17641769
@@ -1774,7 +1779,8 @@ def __init__(cls, *args, **kwargs):
17741779
for b in cls.__bases__)
17751780
if cls._is_protocol:
17761781
for base in cls.__mro__[1:]:
1777-
if not (base in (object, Generic, Callable) or
1782+
if not (base in (object, Generic) or
1783+
base.__module__ == '_abcoll' and base.__name__ in _PROTO_WHITELIST or
17781784
isinstance(base, TypingMeta) and base._is_protocol or
17791785
isinstance(base, GenericMeta) and base.__origin__ is Generic):
17801786
raise TypeError('Protocols can only inherit from other protocols,'

typing_extensions/src_py3/test_typing_extensions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import no_type_check
1515
from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict
1616
try:
17-
from typing_extensions import Protocol, runtime
17+
from typing_extensions import Protocol, runtime, runtime_checkable
1818
except ImportError:
1919
pass
2020
try:
@@ -1391,6 +1391,21 @@ class E:
13911391
x = 1
13921392
self.assertIsInstance(E(), D)
13931393

1394+
def test_collections_protocols_allowed(self):
1395+
@runtime_checkable
1396+
class Custom(collections.abc.Iterable, Protocol):
1397+
def close(self): pass
1398+
1399+
class A: ...
1400+
class B:
1401+
def __iter__(self):
1402+
return []
1403+
def close(self):
1404+
return 0
1405+
1406+
self.assertIsSubclass(B, Custom)
1407+
self.assertNotIsSubclass(A, Custom)
1408+
13941409

13951410
class TypedDictTests(BaseTestCase):
13961411

typing_extensions/src_py3/typing_extensions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,12 @@ def _next_in_mro(cls):
10801080
return next_in_mro
10811081

10821082

1083+
_PROTO_WHITELIST = ['Callable', 'Awaitable',
1084+
'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator',
1085+
'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
1086+
'ContextManager', 'AsyncContextManager']
1087+
1088+
10831089
def _get_protocol_attrs(cls):
10841090
attrs = set()
10851091
for base in cls.__mro__[:-1]: # without object
@@ -1187,7 +1193,9 @@ def __init__(cls, *args, **kwargs):
11871193
for b in cls.__bases__)
11881194
if cls._is_protocol:
11891195
for base in cls.__mro__[1:]:
1190-
if not (base in (object, Generic, Callable) or
1196+
if not (base in (object, Generic) or
1197+
base.__module__ == 'collections.abc' and
1198+
base.__name__ in _PROTO_WHITELIST or
11911199
isinstance(base, TypingMeta) and base._is_protocol or
11921200
isinstance(base, GenericMeta) and
11931201
base.__origin__ is Generic):
@@ -1513,7 +1521,9 @@ def _proto_hook(other):
15131521

15141522
# Check consistency of bases.
15151523
for base in cls.__bases__:
1516-
if not (base in (object, Generic, Callable) or
1524+
if not (base in (object, Generic) or
1525+
base.__module__ == 'collections.abc' and
1526+
base.__name__ in _PROTO_WHITELIST or
15171527
isinstance(base, _ProtocolMeta) and base._is_protocol):
15181528
raise TypeError('Protocols can only inherit from other'
15191529
' protocols, got %r' % base)

0 commit comments

Comments
 (0)