Skip to content

[match-case] fix matching against typing.Callable and Protocol types. #19471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7994,11 +7994,13 @@ def conditional_types(
) -> tuple[Type | None, Type | None]:
"""Takes in the current type and a proposed type of an expression.

Returns a 2-tuple: The first element is the proposed type, if the expression
can be the proposed type. The second element is the type it would hold
if it was not the proposed type, if any. UninhabitedType means unreachable.
None means no new information can be inferred. If default is set it is returned
instead."""
Returns a 2-tuple:
The first element is the proposed type, if the expression can be the proposed type.
The second element is the type it would hold if it was not the proposed type, if any.
UninhabitedType means unreachable.
None means no new information can be inferred.
If default is set it is returned instead.
"""
if proposed_type_ranges:
if len(proposed_type_ranges) == 1:
target = proposed_type_ranges[0].item
Expand All @@ -8010,14 +8012,26 @@ def conditional_types(
current_type = try_expanding_sum_type_to_union(current_type, enum_name)
proposed_items = [type_range.item for type_range in proposed_type_ranges]
proposed_type = make_simplified_union(proposed_items)
if isinstance(proposed_type, AnyType):
current_type = get_proper_type(current_type)
if isinstance(current_type, AnyType):
return proposed_type, current_type
elif isinstance(proposed_type, AnyType):
# We don't really know much about the proposed type, so we shouldn't
# attempt to narrow anything. Instead, we broaden the expr to Any to
# avoid false positives
return proposed_type, default
elif not any(
type_range.is_upper_bound for type_range in proposed_type_ranges
) and is_proper_subtype(current_type, proposed_type, ignore_promotions=True):
) and ( # concrete subtypes
is_proper_subtype(current_type, proposed_type, ignore_promotions=True)
or ( # structural subtypes
is_subtype(current_type, proposed_type, ignore_promotions=True)
and (
isinstance(proposed_type, CallableType)
or (isinstance(proposed_type, Instance) and proposed_type.type.is_protocol)
)
)
):
# Expression is always of one of the types in proposed_type_ranges
return default, UninhabitedType()
elif not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
Expand Down
10 changes: 10 additions & 0 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
UninhabitedType,
UnionType,
UnpackType,
callable_with_ellipsis,
find_unpack_in_list,
get_proper_type,
split_with_prefix_and_suffix,
Expand Down Expand Up @@ -553,6 +554,15 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
and isinstance(get_proper_type(type_info.type), AnyType)
):
typ = type_info.type
elif (
isinstance(type_info, Var)
and type_info.type is not None
and type_info.fullname == "typing.Callable"
):
# Create a `Callable[..., Any]`
fallback = self.chk.named_type("builtins.function")
any_type = AnyType(TypeOfAny.unannotated)
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
else:
if isinstance(type_info, Var) and type_info.type is not None:
name = type_info.type.str_with_options(self.options)
Expand Down
32 changes: 32 additions & 0 deletions test-data/unit/check-generic-alias.test
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,38 @@ t23: collections.abc.ValuesView[str]
# reveal_type(t23) # Nx Revealed type is "collections.abc.ValuesView[builtins.str]"
[builtins fixtures/tuple.pyi]

[case testGenericAliasIsinstanceUnreachable]
# flags: --warn-unreachable --python-version 3.10
from collections.abc import Iterable

class A: ...

def test(dependencies: list[A] | None) -> None:
if dependencies is None:
dependencies = []
elif not isinstance(dependencies, Iterable):
dependencies = [dependencies] # E: Statement is unreachable

[builtins fixtures/isinstancelist.pyi]
[typing fixtures/typing-full.pyi]

[case testGenericAliasRedundantExprCompoundIfExpr]
# flags: --warn-unreachable --enable-error-code=redundant-expr --python-version 3.10

from typing import Any, reveal_type
from collections.abc import Iterable

def test_example(x: Iterable[Any]) -> None:
if isinstance(x, Iterable) and not isinstance(x, str): # E: Left operand of "and" is always true
reveal_type(x) # N: Revealed type is "typing.Iterable[Any]"

def test_counterexample(x: Any) -> None:
if isinstance(x, Iterable) and not isinstance(x, str):
reveal_type(x) # N: Revealed type is "typing.Iterable[Any]"

[builtins fixtures/isinstancelist.pyi]
[typing fixtures/typing-full.pyi]


[case testGenericBuiltinTupleTyping]
from typing import Tuple
Expand Down
5 changes: 3 additions & 2 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -1506,11 +1506,12 @@ class C: pass
def f(x: P1) -> int: ...
@overload
def f(x: P2) -> str: ...
def f(x):
def f(x: object) -> object:
if isinstance(x, P1):
return P1.attr1
if isinstance(x, P2): # E: Only @runtime_checkable protocols can be used with instance and class checks
return P1.attr2
return P2.attr2
return None

reveal_type(f(C1())) # N: Revealed type is "builtins.int"
reveal_type(f(C2())) # N: Revealed type is "builtins.str"
Expand Down
155 changes: 155 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,30 @@ match m:
-- Literal Pattern --

[case testMatchLiteralPatternNarrows]
# flags: --warn-unreachable
m: object

match m:
case 1:
reveal_type(m) # N: Revealed type is "Literal[1]"
case 2:
reveal_type(m) # N: Revealed type is "Literal[2]"
case other:
reveal_type(other) # N: Revealed type is "builtins.object"

[case testMatchLiteralPatternNarrows2]
# flags: --warn-unreachable
from typing import Any

m: Any

match m:
case 1:
reveal_type(m) # N: Revealed type is "Literal[1]"
case 2:
reveal_type(m) # N: Revealed type is "Literal[2]"
case other:
reveal_type(other) # N: Revealed type is "Any"

[case testMatchLiteralPatternAlreadyNarrower-skip]
m: bool
Expand Down Expand Up @@ -1069,6 +1088,142 @@ match m:
case Foo():
pass

[case testMatchClassPatternCallable]
# flags: --warn-unreachable
from typing import Callable, Any

class FnImpl:
def __call__(self, x: object, /) -> int: ...

def test_any(x: Any) -> None:
match x:
case Callable() as fn:
reveal_type(fn) # N: Revealed type is "def (*Any, **Any) -> Any"
case other:
reveal_type(other) # N: Revealed type is "Any"

def test_object(x: object) -> None:
match x:
case Callable() as fn:
reveal_type(fn) # N: Revealed type is "def (*Any, **Any) -> Any"
case other:
reveal_type(other) # N: Revealed type is "builtins.object"

def test_impl(x: FnImpl) -> None:
match x:
case Callable() as fn:
reveal_type(fn) # N: Revealed type is "__main__.FnImpl"
case other:
reveal_type(other) # E: Statement is unreachable

def test_callable(x: Callable[[object], int]) -> None:
match x:
case Callable() as fn:
reveal_type(fn) # N: Revealed type is "def (builtins.object) -> builtins.int"
case other:
reveal_type(other) # E: Statement is unreachable

[case testMatchClassPatternCallbackProtocol]
# flags: --warn-unreachable
from typing import Any, Callable
from typing_extensions import Protocol, runtime_checkable

@runtime_checkable
class FnProto(Protocol):
def __call__(self, x: int, /) -> object: ...

class FnImpl:
def __call__(self, x: object, /) -> int: ...

def test_any(x: Any) -> None:
match x:
case FnProto() as fn:
reveal_type(fn) # N: Revealed type is "__main__.FnProto"
case other:
reveal_type(other) # N: Revealed type is "Any"

def test_object(x: object) -> None:
match x:
case FnProto() as fn:
reveal_type(fn) # N: Revealed type is "__main__.FnProto"
case other:
reveal_type(other) # N: Revealed type is "builtins.object"

def test_impl(x: FnImpl) -> None:
match x:
case FnProto() as fn:
reveal_type(fn) # N: Revealed type is "__main__.FnImpl"
case other:
reveal_type(other) # E: Statement is unreachable

def test_callable(x: Callable[[object], int]) -> None:
match x:
case FnProto() as fn:
reveal_type(fn) # N: Revealed type is "def (builtins.object) -> builtins.int"
case other:
reveal_type(other) # E: Statement is unreachable

[builtins fixtures/dict.pyi]

[case testMatchClassPatternAnyCallableProtocol]
# flags: --warn-unreachable
from typing import Any, Callable
from typing_extensions import Protocol, runtime_checkable

@runtime_checkable
class AnyCallable(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...

class FnImpl:
def __call__(self, x: object, /) -> int: ...

def test_object(x: object) -> None:
match x:
case AnyCallable() as fn:
reveal_type(fn) # N: Revealed type is "__main__.AnyCallable"
case other:
reveal_type(other) # N: Revealed type is "builtins.object"

def test_impl(x: FnImpl) -> None:
match x:
case AnyCallable() as fn:
reveal_type(fn) # N: Revealed type is "__main__.FnImpl"
case other:
reveal_type(other) # E: Statement is unreachable

def test_callable(x: Callable[[object], int]) -> None:
match x:
case AnyCallable() as fn:
reveal_type(fn) # N: Revealed type is "def (builtins.object) -> builtins.int"
case other:
reveal_type(other) # E: Statement is unreachable

[builtins fixtures/dict.pyi]


[case testMatchClassPatternProtocol]
from typing import Any
from typing_extensions import Protocol, runtime_checkable

class Proto(Protocol):
def foo(self, x: int, /) -> object: ...

class Impl:
def foo(self, x: object, /) -> int: ...

def test_object(x: object) -> None:
match x:
case Proto() as y:
reveal_type(y) # N: Revealed type is "__main__.Proto"

def test_impl(x: Impl) -> None:
match x:
case Proto() as y:
reveal_type(y) # N: Revealed type is "__main__.Impl"

[builtins fixtures/dict.pyi]


[case testMatchClassPatternNestedGenerics]
# From cpython test_patma.py
x = [[{0: 0}]]
Expand Down