Skip to content

Commit 123bc40

Browse files
committed
Make 'check_overload_call' try union math first, not second
This commit rearranges the logic so we try performing union math first, not second. See the discussion in python#4063 for details/justification about this change.
1 parent 718b9de commit 123bc40

File tree

2 files changed

+153
-22
lines changed

2 files changed

+153
-22
lines changed

mypy/checkexpr.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,35 +1121,65 @@ def check_overload_call(self,
11211121
plausible_targets = self.plausible_overload_call_targets(arg_types, arg_kinds,
11221122
arg_names, callee)
11231123

1124-
# Step 2: Attempt to find a matching overload
1124+
# Step 2: If the arguments contain a union, we try performing union math first.
1125+
erased_targets = None # type: Optional[List[CallableType]]
1126+
unioned_result = None # type: Optional[Tuple[Type, Type]]
1127+
unioned_errors = None # type: Optional[MessageBuilder]
1128+
if any(isinstance(arg, UnionType) for arg in arg_types):
1129+
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
1130+
arg_kinds, arg_names, context)
1131+
unioned_callable = self.union_overload_matches(erased_targets)
1132+
1133+
if unioned_callable is not None:
1134+
unioned_errors = arg_messages.clean_copy()
1135+
unioned_result = self.check_call(unioned_callable, args, arg_kinds,
1136+
context, arg_names,
1137+
arg_messages=unioned_errors,
1138+
callable_name=callable_name,
1139+
object_type=object_type)
1140+
if not unioned_errors.is_errors():
1141+
# Success! Stop early.
1142+
return unioned_result
1143+
1144+
# Step 3: If the union math fails, or if there was no union in the argument types,
1145+
# we fall back to checking each branch one-by-one.
11251146
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
11261147
arg_kinds, arg_names, callable_name,
11271148
object_type, context, arg_messages)
11281149
if inferred_result is not None:
11291150
# Success! Stop early.
11301151
return inferred_result
11311152

1132-
# Step 3: At this point, we know none of the overload alternatives exactly match.
1133-
# We fall back to using the erased types to help with union math/help us
1134-
# produce a better error message.
1135-
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
1136-
arg_kinds, arg_names, context)
1137-
1138-
# Step 4: Try and infer a second-best alternative.
1139-
if len(erased_targets) == 0:
1140-
# Step 4a: There are no viable targets, even if we relax our constraints. Give up.
1153+
# Step 4: Failure. At this point, we know there is no match. We fall back to trying
1154+
# to find a somewhat plausible overload target using the erased types
1155+
# so we can produce a nice error message.
1156+
#
1157+
# For example, suppose the user passes a value of type 'List[str]' into an
1158+
# overload with signatures f(x: int) -> int and f(x: List[int]) -> List[int].
1159+
#
1160+
# Neither alternative matches, but we can guess the user probably wants the
1161+
# second one.
1162+
if erased_targets is None:
1163+
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
1164+
arg_kinds, arg_names, context)
1165+
1166+
# Step 5: We try and infer a second-best alternative if possible. If not, fall back
1167+
# to using 'Any'.
1168+
if unioned_result is not None:
1169+
# When possible, return the error messages generated from the union-math attempt:
1170+
# they tend to be a little nicer.
1171+
assert unioned_errors is not None
1172+
arg_messages.add_errors(unioned_errors)
1173+
return unioned_result
1174+
elif len(erased_targets) > 0:
1175+
# Pick the first plausible erased target as the fallback
1176+
# TODO: Adjust the error message here to make it clear there was no match.
1177+
target = erased_targets[0] # type: Type
1178+
else:
1179+
# There was no plausible match: give up
11411180
if not self.chk.should_suppress_optional_error(arg_types):
11421181
arg_messages.no_variant_matches_arguments(callee, arg_types, context)
1143-
target = AnyType(TypeOfAny.from_error) # type: Type
1144-
elif any(isinstance(arg, UnionType) for arg in arg_types):
1145-
# Step 4b: Try performing union math
1146-
unioned_callable = self.union_overload_matches(erased_targets)
1147-
target = unioned_callable if unioned_callable is not None else erased_targets[0]
1148-
else:
1149-
# Step 4c: Use the first matching erased target: it won't match, but at
1150-
# least we can have a nicer error message.
1151-
# TODO: Adjust the error message here to make it clear there was no match.
1152-
target = erased_targets[0]
1182+
target = AnyType(TypeOfAny.from_error)
11531183

11541184
return self.check_call(target, args, arg_kinds, context, arg_names,
11551185
arg_messages=arg_messages,
@@ -1230,7 +1260,7 @@ def infer_overload_return_type(self,
12301260
matches.append(typ)
12311261
return_types.append(ret_type)
12321262
inferred_types.append(infer_type)
1233-
1263+
12341264
if len(matches) == 0:
12351265
# No match was found
12361266
return None

test-data/unit/check-overloading.test

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2057,8 +2057,109 @@ f3: Optional[Callable[[int], str]]
20572057

20582058
reveal_type(mymap(f1, seq)) # E: Revealed type is 'typing.Iterable[builtins.str*]'
20592059
reveal_type(mymap(f2, seq)) # E: Revealed type is 'typing.Iterable[builtins.int*]'
2060-
reveal_type(mymap(f3, seq)) # E: Revealed type is 'Union[typing.Iterable[builtins.int], typing.Iterable[builtins.str]]'
2060+
reveal_type(mymap(f3, seq)) # E: Revealed type is 'Union[typing.Iterable[builtins.int], typing.Iterable[builtins.str*]]'
20612061

20622062
[builtins fixtures/list.pyi]
20632063
[typing fixtures/typing-full.pyi]
20642064

2065+
[case testOverloadsAndNoReturnNarrowTypeNoStrictOptional]
2066+
# flags: --no-strict-optional
2067+
from typing import overload, Union, TypeVar, NoReturn, Optional
2068+
2069+
@overload
2070+
def narrow_int(x: str) -> NoReturn: ...
2071+
@overload
2072+
def narrow_int(x: int) -> int: ...
2073+
def narrow_int(x):
2074+
assert isinstance(x, int)
2075+
return x
2076+
2077+
T = TypeVar('T')
2078+
@overload
2079+
def narrow_none(x: None) -> NoReturn: ...
2080+
@overload
2081+
def narrow_none(x: T) -> T: ...
2082+
def narrow_none(x):
2083+
assert x is not None
2084+
return x
2085+
2086+
def test_narrow_int() -> None:
2087+
a: Union[int, str]
2088+
a = narrow_int(a)
2089+
reveal_type(a) # E: Revealed type is 'builtins.int'
2090+
2091+
b: int
2092+
b = narrow_int(b)
2093+
reveal_type(b) # E: Revealed type is 'builtins.int'
2094+
2095+
c: str
2096+
c = narrow_int(c)
2097+
reveal_type(c) # Note: branch is now dead, so no type is revealed
2098+
# TODO: maybe we should make mypy report a warning instead?
2099+
2100+
def test_narrow_none() -> None:
2101+
a: Optional[int]
2102+
a = narrow_none(a)
2103+
reveal_type(a) # E: Revealed type is 'Union[builtins.int, None]'
2104+
2105+
b: int
2106+
b = narrow_none(b)
2107+
reveal_type(b) # E: Revealed type is 'builtins.int'
2108+
2109+
c: None
2110+
c = narrow_none(c)
2111+
reveal_type(c) # Note: branch is now dead, so no type is revealed
2112+
2113+
[builtins fixtures/isinstance.pyi]
2114+
[typing fixtures/typing-full.pyi]
2115+
2116+
[case testOverloadsAndNoReturnNarrowTypeWithStrictOptional]
2117+
# flags: --strict-optional
2118+
from typing import overload, Union, TypeVar, NoReturn, Optional
2119+
2120+
@overload
2121+
def narrow_int(x: str) -> NoReturn: ...
2122+
@overload
2123+
def narrow_int(x: int) -> int: ...
2124+
def narrow_int(x):
2125+
assert isinstance(x, int)
2126+
return x
2127+
2128+
T = TypeVar('T')
2129+
@overload
2130+
def narrow_none(x: None) -> NoReturn: ...
2131+
@overload
2132+
def narrow_none(x: T) -> T: ...
2133+
def narrow_none(x):
2134+
assert x is not None
2135+
return x
2136+
2137+
def test_narrow_int() -> None:
2138+
a: Union[int, str]
2139+
a = narrow_int(a)
2140+
reveal_type(a) # E: Revealed type is 'builtins.int'
2141+
2142+
b: int
2143+
b = narrow_int(b)
2144+
reveal_type(b) # E: Revealed type is 'builtins.int'
2145+
2146+
c: str
2147+
c = narrow_int(c)
2148+
reveal_type(c) # Note: branch is now dead, so no type is revealed
2149+
# TODO: maybe we should make mypy report a warning instead?
2150+
2151+
def test_narrow_none() -> None:
2152+
a: Optional[int]
2153+
a = narrow_none(a)
2154+
reveal_type(a) # E: Revealed type is 'builtins.int'
2155+
2156+
b: int
2157+
b = narrow_none(b)
2158+
reveal_type(b) # E: Revealed type is 'builtins.int'
2159+
2160+
c: None
2161+
c = narrow_none(c)
2162+
reveal_type(c) # Branch is now dead
2163+
2164+
[builtins fixtures/isinstance.pyi]
2165+
[typing fixtures/typing-full.pyi]

0 commit comments

Comments
 (0)