Skip to content

Commit e6fd81e

Browse files
committed
Refine how overload selection handles *args, **kwargs, and Any
This pull request implements the changes discussed in python#5124. Specifically... 1. When two overload alternatives match due to Any, we return the last matching return type if it's a supertype of all of the previous ones. If it's not a supertype, we give up and return 'Any' as before. 2. If a user calls an overload with a starred expression, we try matching alternatives with a starred arg or kwarg first, even if those alternatives do not appear first in the list. If none of the starred alternatives are a valid match, we fall back to checking the other remaining alternatives in order.
1 parent aaafd15 commit e6fd81e

File tree

2 files changed

+154
-21
lines changed

2 files changed

+154
-21
lines changed

mypy/checkexpr.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,18 +1195,27 @@ def plausible_overload_call_targets(self,
11951195
arg_kinds: List[int],
11961196
arg_names: Optional[Sequence[Optional[str]]],
11971197
overload: Overloaded) -> List[CallableType]:
1198-
"""Returns all overload call targets that having matching argument counts."""
1198+
"""Returns all overload call targets that having matching argument counts.
1199+
1200+
If the given args contains a star-arg (*arg or **kwarg argument), this method
1201+
will ensure all star-arg overloads appear at the start of the list, instead
1202+
of their usual location."""
11991203
matches = [] # type: List[CallableType]
1204+
star_matches = [] # type: List[CallableType]
1205+
args_have_star = ARG_STAR in arg_kinds or ARG_STAR2 in arg_kinds
12001206
for typ in overload.items():
12011207
formal_to_actual = map_actuals_to_formals(arg_kinds, arg_names,
12021208
typ.arg_kinds, typ.arg_names,
12031209
lambda i: arg_types[i])
12041210

12051211
if self.check_argument_count(typ, arg_types, arg_kinds, arg_names,
12061212
formal_to_actual, None, None):
1207-
matches.append(typ)
1213+
if args_have_star and (typ.is_var_arg or typ.is_kw_arg):
1214+
star_matches.append(typ)
1215+
else:
1216+
matches.append(typ)
12081217

1209-
return matches
1218+
return star_matches + matches
12101219

12111220
def infer_overload_return_type(self,
12121221
plausible_targets: List[CallableType],
@@ -1270,15 +1279,20 @@ def infer_overload_return_type(self,
12701279
return None
12711280
elif any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names):
12721281
# An argument of type or containing the type 'Any' caused ambiguity.
1273-
# We infer a type of 'Any'
1274-
return self.check_call(callee=AnyType(TypeOfAny.special_form),
1275-
args=args,
1276-
arg_kinds=arg_kinds,
1277-
arg_names=arg_names,
1278-
context=context,
1279-
arg_messages=arg_messages,
1280-
callable_name=callable_name,
1281-
object_type=object_type)
1282+
if all(is_subtype(ret_type, return_types[-1]) for ret_type in return_types[:-1]):
1283+
# The last match is a supertype of all the previous ones, so it's safe
1284+
# to return that inferred type.
1285+
return return_types[-1], inferred_types[-1]
1286+
else:
1287+
# We give up and return 'Any'.
1288+
return self.check_call(callee=AnyType(TypeOfAny.special_form),
1289+
args=args,
1290+
arg_kinds=arg_kinds,
1291+
arg_names=arg_names,
1292+
context=context,
1293+
arg_messages=arg_messages,
1294+
callable_name=callable_name,
1295+
object_type=object_type)
12821296
else:
12831297
# Success! No ambiguity; return the first match.
12841298
return return_types[0], inferred_types[0]

test-data/unit/check-overloading.test

Lines changed: 128 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,7 @@ def f(x: object) -> object: ...
12761276
def f(x): pass
12771277

12781278
a: Any
1279-
reveal_type(f(a)) # E: Revealed type is 'Any'
1279+
reveal_type(f(a)) # E: Revealed type is 'builtins.object'
12801280

12811281
[case testOverloadWithOverlappingItemsAndAnyArgument2]
12821282
from typing import overload, Any
@@ -1288,7 +1288,7 @@ def f(x: float) -> float: ...
12881288
def f(x): pass
12891289

12901290
a: Any
1291-
reveal_type(f(a)) # E: Revealed type is 'Any'
1291+
reveal_type(f(a)) # E: Revealed type is 'builtins.float'
12921292

12931293
[case testOverloadWithOverlappingItemsAndAnyArgument3]
12941294
from typing import overload, Any
@@ -1313,15 +1313,15 @@ def f(x): pass
13131313

13141314
a: Any
13151315
# Any causes ambiguity
1316-
reveal_type(f(a, 1, '')) # E: Revealed type is 'Any'
1316+
reveal_type(f(a, 1, '')) # E: Revealed type is 'builtins.object'
13171317
# Any causes no ambiguity
13181318
reveal_type(f(1, a, a)) # E: Revealed type is 'builtins.int'
13191319
reveal_type(f('', a, a)) # E: Revealed type is 'builtins.object'
13201320
# Like above, but use keyword arguments.
1321-
reveal_type(f(y=1, z='', x=a)) # E: Revealed type is 'Any'
1321+
reveal_type(f(y=1, z='', x=a)) # E: Revealed type is 'builtins.object'
13221322
reveal_type(f(y=a, z='', x=1)) # E: Revealed type is 'builtins.int'
13231323
reveal_type(f(z='', x=1, y=a)) # E: Revealed type is 'builtins.int'
1324-
reveal_type(f(z='', x=a, y=1)) # E: Revealed type is 'Any'
1324+
reveal_type(f(z='', x=a, y=1)) # E: Revealed type is 'builtins.object'
13251325

13261326
[case testOverloadWithOverlappingItemsAndAnyArgument5]
13271327
from typing import overload, Any, Union
@@ -1333,7 +1333,7 @@ def f(x: Union[int, float]) -> float: ...
13331333
def f(x): pass
13341334

13351335
a: Any
1336-
reveal_type(f(a)) # E: Revealed type is 'Any'
1336+
reveal_type(f(a)) # E: Revealed type is 'builtins.float'
13371337

13381338
[case testOverloadWithOverlappingItemsAndAnyArgument6]
13391339
from typing import overload, Any
@@ -1343,7 +1343,7 @@ def f(x: int, y: int) -> int: ...
13431343
@overload
13441344
def f(x: float, y: int, z: str) -> float: ...
13451345
@overload
1346-
def f(x: object, y: int, z: str, a: None) -> object: ...
1346+
def f(x: object, y: int, z: str, a: None) -> str: ...
13471347
def f(x): pass
13481348

13491349
a: Any
@@ -1352,7 +1352,7 @@ reveal_type(f(*a)) # E: Revealed type is 'Any'
13521352
reveal_type(f(a, *a)) # E: Revealed type is 'Any'
13531353
reveal_type(f(1, *a)) # E: Revealed type is 'Any'
13541354
reveal_type(f(1.1, *a)) # E: Revealed type is 'Any'
1355-
reveal_type(f('', *a)) # E: Revealed type is 'builtins.object'
1355+
reveal_type(f('', *a)) # E: Revealed type is 'builtins.str'
13561356

13571357
[case testOverloadWithOverlappingItemsAndAnyArgument7]
13581358
from typing import overload, Any
@@ -1365,7 +1365,7 @@ def f(x): pass
13651365

13661366
a: Any
13671367
# TODO: We could infer 'int' here
1368-
reveal_type(f(1, *a)) # E: Revealed type is 'Any'
1368+
reveal_type(f(1, *a)) # E: Revealed type is 'builtins.object'
13691369

13701370
[case testOverloadWithOverlappingItemsAndAnyArgument8]
13711371
from typing import overload, Any
@@ -1381,6 +1381,26 @@ a: Any
13811381
reveal_type(f(a, 1, 1)) # E: Revealed type is 'builtins.str'
13821382
reveal_type(f(1, *a)) # E: Revealed type is 'builtins.str'
13831383

1384+
[case testOverloadWithOverlappingItemsAndAnyArgument9]
1385+
from typing import overload, Any, List
1386+
1387+
@overload
1388+
def f(x: List[int]) -> List[int]: ...
1389+
@overload
1390+
def f(x: List[Any]) -> List[Any]: ...
1391+
def f(x): pass
1392+
1393+
a: Any
1394+
b: List[Any]
1395+
c: List[str]
1396+
d: List[int]
1397+
reveal_type(f(a)) # E: Revealed type is 'builtins.list[Any]'
1398+
reveal_type(f(b)) # E: Revealed type is 'builtins.list[Any]'
1399+
reveal_type(f(c)) # E: Revealed type is 'builtins.list[Any]'
1400+
reveal_type(f(d)) # E: Revealed type is 'builtins.list[builtins.int]'
1401+
1402+
[builtins fixtures/list.pyi]
1403+
13841404
[case testOverloadOnOverloadWithType]
13851405
from typing import Any, Type, TypeVar, overload
13861406
from mod import MyInt
@@ -1723,6 +1743,105 @@ def foo2(**kwargs: int) -> str: ...
17231743
def foo2(*args: int) -> int: ... # E: Overloaded function signature 2 will never be matched: function 1's parameter type(s) are the same or broader
17241744
[builtins fixtures/dict.pyi]
17251745

1746+
[case testOverloadVarargInputAndVarargDefinition]
1747+
from typing import overload, List
1748+
1749+
class A: ...
1750+
class B: ...
1751+
class C: ...
1752+
1753+
@overload
1754+
def foo(x: int) -> A: ...
1755+
@overload
1756+
def foo(x: int, y: int) -> B: ...
1757+
@overload
1758+
def foo(x: int, y: int, z: int, *args: int) -> C: ...
1759+
def foo(*args): pass
1760+
1761+
reveal_type(foo(1)) # E: Revealed type is '__main__.A'
1762+
reveal_type(foo(1, 2)) # E: Revealed type is '__main__.B'
1763+
reveal_type(foo(1, 2, 3)) # E: Revealed type is '__main__.C'
1764+
1765+
reveal_type(foo(*[1])) # E: Revealed type is '__main__.C'
1766+
reveal_type(foo(*[1, 2])) # E: Revealed type is '__main__.C'
1767+
reveal_type(foo(*[1, 2, 3])) # E: Revealed type is '__main__.C'
1768+
1769+
x: List[int]
1770+
reveal_type(foo(*x)) # E: Revealed type is '__main__.C'
1771+
1772+
y: List[str]
1773+
foo(*y) # E: No overload variant of "foo" matches argument type "List[str]"
1774+
[builtins fixtures/list.pyi]
1775+
1776+
[case testOverloadMultipleVarargDefinition]
1777+
from typing import overload, List, Any
1778+
1779+
class A: ...
1780+
class B: ...
1781+
class C: ...
1782+
class D: ...
1783+
1784+
@overload
1785+
def foo(x: int) -> A: ...
1786+
@overload
1787+
def foo(x: int, y: int) -> B: ...
1788+
@overload
1789+
def foo(x: int, y: int, z: int, *args: int) -> C: ...
1790+
@overload
1791+
def foo(*x: str) -> D: ...
1792+
def foo(*args): pass
1793+
1794+
reveal_type(foo(*[1, 2])) # E: Revealed type is '__main__.C'
1795+
reveal_type(foo(*["a", "b"])) # E: Revealed type is '__main__.D'
1796+
1797+
x: List[Any]
1798+
reveal_type(foo(*x)) # E: Revealed type is 'Any'
1799+
[builtins fixtures/list.pyi]
1800+
1801+
[case testOverloadMultipleVarargDefinitionComplex]
1802+
from typing import TypeVar, overload, Any, Callable
1803+
1804+
T1 = TypeVar('T1')
1805+
T2 = TypeVar('T2')
1806+
T3 = TypeVar('T3')
1807+
1808+
@overload
1809+
def chain_call(input_value: T1,
1810+
f1: Callable[[T1], T2]) -> T2: ...
1811+
@overload
1812+
def chain_call(input_value: T1,
1813+
f1: Callable[[T1], T2],
1814+
f2: Callable[[T2], T3]) -> T3: ...
1815+
@overload
1816+
def chain_call(input_value: T1,
1817+
*f_rest: Callable[[T1], T1]) -> T1: ...
1818+
@overload
1819+
def chain_call(input_value: T1,
1820+
f1: Callable[[T1], T2],
1821+
f2: Callable[[T2], T3],
1822+
f3: Callable[[T3], Any],
1823+
*f_rest: Callable[[Any], Any]) -> Any: ...
1824+
def chain_call(input_value, *f_rest):
1825+
for function in f_rest:
1826+
input_value = function(input_value)
1827+
return input_value
1828+
1829+
1830+
class A: ...
1831+
class B: ...
1832+
class C: ...
1833+
class D: ...
1834+
1835+
def f(x: A) -> A: ...
1836+
def f1(x: A) -> B: ...
1837+
def f2(x: B) -> C: ...
1838+
def f3(x: C) -> D: ...
1839+
1840+
reveal_type(chain_call(A(), f1, f2)) # E: Revealed type is '__main__.C*'
1841+
reveal_type(chain_call(A(), f1, f2, f3)) # E: Revealed type is 'Any'
1842+
reveal_type(chain_call(A(), f, f, f, f)) # E: Revealed type is '__main__.A'
1843+
[builtins fixtures/list.pyi]
1844+
17261845
[case testOverloadWithPartiallyOverlappingUnions]
17271846
from typing import overload, Union
17281847

0 commit comments

Comments
 (0)