Skip to content

Apply union expansion when checking ops involving typevars #19455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4155,10 +4155,9 @@ def check_op(
"""

if allow_reverse:
left_variants = [base_type]
left_variants = self._union_items_from_typevar(base_type)
base_type = get_proper_type(base_type)
if isinstance(base_type, UnionType):
left_variants = list(flatten_nested_unions(base_type.relevant_items()))

right_type = self.accept(arg)

# Step 1: We first try leaving the right arguments alone and destructure
Expand Down Expand Up @@ -4196,13 +4195,17 @@ def check_op(
# We don't do the same for the base expression because it could lead to weird
# type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
# TODO: Can we use `type_overrides_set()` here?
right_variants = [(right_type, arg)]
right_type = get_proper_type(right_type)
if isinstance(right_type, UnionType):
right_variants: list[tuple[Type, Expression]]
p_right = get_proper_type(right_type)
if isinstance(p_right, (UnionType, TypeVarType)):
right_variants = [
(item, TempNode(item, context=context))
for item in flatten_nested_unions(right_type.relevant_items())
for item in self._union_items_from_typevar(right_type)
]
else:
# Preserve argument identity if we do not intend to modify it
right_variants = [(right_type, arg)]
right_type = p_right

all_results = []
all_inferred = []
Expand Down Expand Up @@ -4252,6 +4255,20 @@ def check_op(
context=context,
)

def _union_items_from_typevar(self, typ: Type) -> list[Type]:
variants = [typ]
typ = get_proper_type(typ)
base_type = typ
if unwrapped := (isinstance(typ, TypeVarType) and not typ.values):
typ = get_proper_type(typ.upper_bound)
if is_union := isinstance(typ, UnionType):
variants = list(flatten_nested_unions(typ.relevant_items()))
if is_union and unwrapped:
# If not a union, keep the original type
assert isinstance(base_type, TypeVarType)
variants = [base_type.copy_modified(upper_bound=item) for item in variants]
return variants

def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
"""Type check a boolean operation ('and' or 'or')."""

Expand Down
26 changes: 26 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,32 @@ if int():
class C:
def __lt__(self, o: object, x: str = "") -> int: ...

[case testReversibleOpOnTypeVarBound]
from typing import TypeVar, Union

class A:
def __lt__(self, a: A) -> bool: ...
def __gt__(self, a: A) -> bool: ...

class B(A):
def __lt__(self, b: B) -> bool: ... # type: ignore[override]
def __gt__(self, b: B) -> bool: ... # type: ignore[override]

_T = TypeVar("_T", bound=Union[A, B])

def check(x: _T, y: _T) -> bool:
return x < y

[case testReversibleOpOnTypeVarBoundPromotion]
from typing import TypeVar, Union

_T = TypeVar("_T", bound=Union[int, float])

def check(x: _T, y: _T) -> bool:
return x < y
[builtins fixtures/ops.pyi]


[case testErrorContextAndBinaryOperators]
import typing
class A:
Expand Down
6 changes: 6 additions & 0 deletions test-data/unit/fixtures/ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class float:
def __rdiv__(self, x: 'float') -> 'float': pass
def __truediv__(self, x: 'float') -> 'float': pass
def __rtruediv__(self, x: 'float') -> 'float': pass
def __eq__(self, x: object) -> bool: pass
def __ne__(self, x: object) -> bool: pass
def __lt__(self, x: 'float') -> bool: pass
def __le__(self, x: 'float') -> bool: pass
def __gt__(self, x: 'float') -> bool: pass
def __ge__(self, x: 'float') -> bool: pass

class complex:
def __add__(self, x: complex) -> complex: pass
Expand Down