Skip to content

Commit 2ad51aa

Browse files
committed
feat(overloads): Expand finite sum types into a union of possible types
1 parent a794ae3 commit 2ad51aa

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

mypy/checkexpr.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2709,6 +2709,16 @@ def check_overload_call(
27092709
# Normalize unpacked kwargs before checking the call.
27102710
callee = callee.with_unpacked_kwargs()
27112711
arg_types = self.infer_arg_types_in_empty_context(args)
2712+
2713+
# Expand finite sum types into unions
2714+
# See https://github.com/python/mypy/issues/14764#issuecomment-3054510950
2715+
# And https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
2716+
arg_types = [
2717+
try_expanding_sum_type_to_union(arg_type, arg_type.type.fullname)
2718+
if isinstance(arg_type, Instance) else arg_type
2719+
for arg_type in arg_types
2720+
]
2721+
27122722
# Step 1: Filter call targets to remove ones where the argument counts don't match
27132723
plausible_targets = self.plausible_overload_call_targets(
27142724
arg_types, arg_kinds, arg_names, callee

mypy/typeops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,12 @@ class Status(Enum):
10501050
]
10511051
return make_simplified_union(items, contract_literals=False)
10521052

1053+
10531054
if isinstance(typ, Instance) and typ.type.fullname == target_fullname:
1055+
if isinstance(typ.last_known_value, LiteralType):
1056+
# fallback for Literal[True] and Literal[False]
1057+
return typ
1058+
10541059
if typ.type.fullname == "builtins.bool":
10551060
items = [LiteralType(True, typ), LiteralType(False, typ)]
10561061
return make_simplified_union(items, contract_literals=False)

test-data/unit/check-overloading.test

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5266,6 +5266,47 @@ tmp/lib.pyi:3: error: Name "func" already defined on line 1
52665266
tmp/lib.pyi:3: error: Name "overload" is not defined
52675267
main:3: note: Revealed type is "Any"
52685268

5269+
[case testOverloadCheckExpandsBools]
5270+
from typing import Literal, overload, Union
5271+
5272+
@overload
5273+
def foo(x: Literal[False]) -> None: ...
5274+
@overload
5275+
def foo(x: Literal[True]) -> int: ...
5276+
5277+
def foo(x: bool) -> Union[None, int]: ...
5278+
5279+
reveal_type(foo(True)) # N: Revealed type is "builtins.int"
5280+
reveal_type(foo(False)) # N: Revealed type is "None"
5281+
x: bool
5282+
reveal_type(foo(x)) # N: Revealed type is "Union[builtins.int, None]"
5283+
5284+
[case testOverloadCheckExpandsEnums]
5285+
from typing import Literal, overload, Union
5286+
import enum
5287+
5288+
class Color(enum.Enum):
5289+
RED = 1
5290+
BLUE = 2
5291+
YELLOW = 3
5292+
5293+
@overload
5294+
def foo(x: Literal[Color.RED]) -> None: ...
5295+
@overload
5296+
def foo(x: Literal[Color.BLUE]) -> int: ...
5297+
@overload
5298+
def foo(x: Literal[Color.YELLOW]) -> str: ...
5299+
5300+
def foo(x: Color) -> Union[None, int, str]: ...
5301+
5302+
reveal_type(foo(Color.RED)) # N: Revealed type is "None"
5303+
reveal_type(foo(Color.BLUE)) # N: Revealed type is "builtins.int"
5304+
reveal_type(foo(Color.YELLOW)) # N: Revealed type is "builtins.str"
5305+
5306+
x: Color
5307+
reveal_type(foo(x)) # N: Revealed type is "Union[None, builtins.int, builtins.str]"
5308+
[builtins fixtures/tuple.pyi]
5309+
52695310
[case testLiteralSubtypeOverlap]
52705311
from typing import Literal, overload
52715312

0 commit comments

Comments
 (0)