Skip to content

Commit b8c84cc

Browse files
committed
Add support for operators with union operands
This pull request resolves python#2128 -- it modifies how we check operators to add support for operations like `Union[int, float] + Union[int, float]`. This approach basically iterates over all possible variations of the left and right operands when they're unions and uses the union of the resulting inferred type as the type of the overall expression. Some implementation notes: 1. I attempting "destructuring" just the left operand, which is basically the approach proposed here: python#2128 (comment) Unfortunately, I discovered it became necessary to also destructure the right operand to handle certain edge cases -- see the testOperatorDoubleUnionInterwovenUnionAdd test case. 2. This algorithm varies slightly from what we do for union math in that we don't attempt to "preserve" the union/we always destructure both operands. I'm fairly confident that this is type-safe; I plan on testing this pull request against some internal code bases to help us gain more confidence.
1 parent a5392e0 commit b8c84cc

11 files changed

+292
-73
lines changed

mypy/checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2525,7 +2525,7 @@ def visit_operator_assignment_stmt(self,
25252525
if inplace:
25262526
# There is __ifoo__, treat as x = x.__ifoo__(y)
25272527
rvalue_type, method_type = self.expr_checker.check_op(
2528-
method, lvalue_type, s.rvalue, s)
2528+
method, s.lvalue, lvalue_type, s.rvalue, s)
25292529
if not is_subtype(rvalue_type, lvalue_type):
25302530
self.msg.incompatible_operator_assignment(s.op, s)
25312531
else:

mypy/checkexpr.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
16971697

16981698
if e.op in nodes.op_methods:
16991699
method = self.get_operator_method(e.op)
1700-
result, method_type = self.check_op(method, left_type, e.right, e,
1700+
result, method_type = self.check_op(method, e.left, left_type, e.right, e,
17011701
allow_reverse=True)
17021702
e.method_type = method_type
17031703
return result
@@ -1749,7 +1749,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
17491749
sub_result = self.bool_type()
17501750
elif operator in nodes.op_methods:
17511751
method = self.get_operator_method(operator)
1752-
sub_result, method_type = self.check_op(method, left_type, right, e,
1752+
sub_result, method_type = self.check_op(method, left, left_type, right, e,
17531753
allow_reverse=True)
17541754

17551755
elif operator == 'is' or operator == 'is not':
@@ -1820,19 +1820,12 @@ def check_op_reversible(self,
18201820
left_expr: Expression,
18211821
right_type: Type,
18221822
right_expr: Expression,
1823-
context: Context) -> Tuple[Type, Type]:
1824-
# Note: this kludge exists mostly to maintain compatibility with
1825-
# existing error messages. Apparently, if the left-hand-side is a
1826-
# union and we have a type mismatch, we print out a special,
1827-
# abbreviated error message. (See messages.unsupported_operand_types).
1828-
unions_present = isinstance(left_type, UnionType)
1829-
1823+
context: Context,
1824+
msg: MessageBuilder) -> Tuple[Type, Type]:
18301825
def make_local_errors() -> MessageBuilder:
18311826
"""Creates a new MessageBuilder object."""
1832-
local_errors = self.msg.clean_copy()
1827+
local_errors = msg.clean_copy()
18331828
local_errors.disable_count = 0
1834-
if unions_present:
1835-
local_errors.disable_type_names += 1
18361829
return local_errors
18371830

18381831
def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
@@ -2009,9 +2002,9 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
20092002
# TODO: Remove this extra case
20102003
return result
20112004

2012-
self.msg.add_errors(errors[0])
2005+
msg.add_errors(errors[0])
20132006
if warn_about_uncalled_reverse_operator:
2014-
self.msg.reverse_operator_method_never_called(
2007+
msg.reverse_operator_method_never_called(
20152008
nodes.op_methods_to_symbols[op_name],
20162009
op_name,
20172010
right_type,
@@ -2025,22 +2018,57 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
20252018
result = error_any, error_any
20262019
return result
20272020

2028-
def check_op(self, method: str, base_type: Type, arg: Expression,
2029-
context: Context,
2021+
def check_op(self, method: str, base_expr: Expression, base_type: Type,
2022+
arg: Expression, context: Context,
20302023
allow_reverse: bool = False) -> Tuple[Type, Type]:
20312024
"""Type check a binary operation which maps to a method call.
20322025
20332026
Return tuple (result type, inferred operator method type).
20342027
"""
20352028

20362029
if allow_reverse:
2037-
return self.check_op_reversible(
2038-
op_name=method,
2039-
left_type=base_type,
2040-
left_expr=TempNode(base_type),
2041-
right_type=self.accept(arg),
2042-
right_expr=arg,
2043-
context=context)
2030+
# Note: We want to pass in the original 'base_expr' and 'arg' for
2031+
# 'left_expr' and 'right_expr' whenever possible so that plugins
2032+
# and similar things can introspect on the original node if possible.
2033+
left_variants = [(base_type, base_expr)]
2034+
if isinstance(base_type, UnionType):
2035+
left_variants = [(item, TempNode(item)) for item in base_type.relevant_items()]
2036+
2037+
right_type = self.accept(arg)
2038+
right_variants = [(right_type, arg)]
2039+
if isinstance(right_type, UnionType):
2040+
right_variants = [(item, TempNode(item)) for item in right_type.relevant_items()]
2041+
2042+
msg = self.msg.clean_copy()
2043+
msg.disable_count = 0
2044+
all_results = []
2045+
all_inferred = []
2046+
2047+
for left_possible_type, left_expr in left_variants:
2048+
for right_possible_type, right_expr in right_variants:
2049+
result, inferred = self.check_op_reversible(
2050+
op_name=method,
2051+
left_type=left_possible_type,
2052+
left_expr=left_expr,
2053+
right_type=right_possible_type,
2054+
right_expr=right_expr,
2055+
context=context,
2056+
msg=msg)
2057+
all_results.append(result)
2058+
all_inferred.append(inferred)
2059+
2060+
if msg.is_errors():
2061+
self.msg.add_errors(msg)
2062+
if len(left_variants) >= 2 and len(right_variants) >= 2:
2063+
self.msg.warn_both_operands_are_from_unions(context)
2064+
elif len(left_variants) >= 2:
2065+
self.msg.warn_operand_was_from_union("Left", base_type, context)
2066+
elif len(right_variants) >= 2:
2067+
self.msg.warn_operand_was_from_union("Right", right_type, context)
2068+
2069+
results_final = UnionType.make_simplified_union(all_results)
2070+
inferred_final = UnionType.make_simplified_union(all_inferred)
2071+
return results_final, inferred_final
20442072
else:
20452073
return self.check_op_local_by_name(
20462074
method=method,
@@ -2125,7 +2153,7 @@ def check_list_multiply(self, e: OpExpr) -> Type:
21252153
left_type = self.accept(e.left, type_context=self.type_context[-1])
21262154
else:
21272155
left_type = self.accept(e.left)
2128-
result, method_type = self.check_op('__mul__', left_type, e.right, e)
2156+
result, method_type = self.check_op('__mul__', e.left, left_type, e.right, e)
21292157
e.method_type = method_type
21302158
return result
21312159

@@ -2179,7 +2207,7 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type:
21792207
and left_type.is_type_obj() and left_type.type_object().is_enum):
21802208
return self.visit_enum_index_expr(left_type.type_object(), e.index, e)
21812209
else:
2182-
result, method_type = self.check_op('__getitem__', left_type, e.index, e)
2210+
result, method_type = self.check_op('__getitem__', e.base, left_type, e.index, e)
21832211
e.method_type = method_type
21842212
return result
21852213

mypy/messages.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,12 @@ def reverse_operator_method_never_called(self,
10191019
),
10201020
context=context)
10211021

1022+
def warn_both_operands_are_from_unions(self, context: Context) -> None:
1023+
self.note('Both left and right operands are unions', context)
1024+
1025+
def warn_operand_was_from_union(self, side: str, original: Type, context: Context) -> None:
1026+
self.note('{} operand is of type {}'.format(side, self.format(original)), context)
1027+
10221028
def operator_method_signatures_overlap(
10231029
self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type,
10241030
forward_method: str, context: Context) -> None:

test-data/unit/check-callable.test

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ from typing import Callable, Union
4646
x = 5 # type: Union[int, Callable[[], str], Callable[[], int]]
4747

4848
if callable(x):
49-
y = x() + 2 # E: Unsupported operand types for + (likely involving Union)
49+
y = x() + 2 # E: Unsupported operand types for + ("str" and "int") \
50+
# N: Left operand is of type "Union[str, int]"
5051
else:
5152
z = x + 6
5253

@@ -60,7 +61,8 @@ x = 5 # type: Union[int, str, Callable[[], str]]
6061
if callable(x):
6162
y = x() + 'test'
6263
else:
63-
z = x + 6 # E: Unsupported operand types for + (likely involving Union)
64+
z = x + 6 # E: Unsupported operand types for + ("str" and "int") \
65+
# N: Left operand is of type "Union[int, str]"
6466

6567
[builtins fixtures/callable.pyi]
6668

@@ -153,7 +155,8 @@ x = 5 # type: Union[int, Callable[[], str]]
153155
if callable(x) and x() == 'test':
154156
x()
155157
else:
156-
x + 5 # E: Unsupported left operand type for + (some union)
158+
x + 5 # E: Unsupported left operand type for + ("Callable[[], str]") \
159+
# N: Left operand is of type "Union[int, Callable[[], str]]"
157160

158161
[builtins fixtures/callable.pyi]
159162

test-data/unit/check-classes.test

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2019,6 +2019,136 @@ class FractionChild(Fraction): pass
20192019

20202020
class A(metaclass=Real): pass
20212021

2022+
[case testOperatorDoubleUnionIntFloat]
2023+
from typing import Union
2024+
2025+
a: Union[int, float]
2026+
b: int
2027+
c: float
2028+
2029+
reveal_type(a + a) # E: Revealed type is 'builtins.float'
2030+
reveal_type(a + b) # E: Revealed type is 'builtins.float'
2031+
reveal_type(b + a) # E: Revealed type is 'builtins.float'
2032+
reveal_type(a + c) # E: Revealed type is 'builtins.float'
2033+
reveal_type(c + a) # E: Revealed type is 'builtins.float'
2034+
[builtins fixtures/ops.pyi]
2035+
2036+
[case testOperatorDoubleUnionStandardSubtyping]
2037+
from typing import Union
2038+
2039+
class Parent:
2040+
def __add__(self, x: Parent) -> Parent: pass
2041+
def __radd__(self, x: Parent) -> Parent: pass
2042+
2043+
class Child(Parent):
2044+
def __add__(self, x: Parent) -> Child: pass
2045+
def __radd__(self, x: Parent) -> Child: pass
2046+
2047+
a: Union[Parent, Child]
2048+
b: Parent
2049+
c: Child
2050+
2051+
reveal_type(a + a) # E: Revealed type is '__main__.Parent'
2052+
reveal_type(a + b) # E: Revealed type is '__main__.Parent'
2053+
reveal_type(b + a) # E: Revealed type is '__main__.Parent'
2054+
reveal_type(a + c) # E: Revealed type is '__main__.Child'
2055+
reveal_type(c + a) # E: Revealed type is '__main__.Child'
2056+
2057+
[case testOperatorDoubleUnionNoRelationship1]
2058+
from typing import Union
2059+
2060+
class Foo:
2061+
def __add__(self, x: Foo) -> Foo: pass
2062+
def __radd__(self, x: Foo) -> Foo: pass
2063+
2064+
class Bar:
2065+
def __add__(self, x: Bar) -> Bar: pass
2066+
def __radd__(self, x: Bar) -> Bar: pass
2067+
2068+
a: Union[Foo, Bar]
2069+
b: Foo
2070+
c: Bar
2071+
2072+
a + a # E: Unsupported operand types for + ("Foo" and "Bar") \
2073+
# E: Unsupported operand types for + ("Bar" and "Foo") \
2074+
# N: Both left and right operands are unions
2075+
2076+
a + b # E: Unsupported operand types for + ("Bar" and "Foo") \
2077+
# N: Left operand is of type "Union[Foo, Bar]"
2078+
2079+
b + a # E: Unsupported operand types for + ("Foo" and "Bar") \
2080+
# N: Right operand is of type "Union[Foo, Bar]"
2081+
2082+
a + c # E: Unsupported operand types for + ("Foo" and "Bar") \
2083+
# N: Left operand is of type "Union[Foo, Bar]"
2084+
2085+
c + a # E: Unsupported operand types for + ("Bar" and "Foo") \
2086+
# N: Right operand is of type "Union[Foo, Bar]"
2087+
2088+
[case testOperatorDoubleUnionNoRelationship2]
2089+
from typing import Union
2090+
2091+
class Foo:
2092+
def __add__(self, x: Foo) -> Foo: pass
2093+
def __radd__(self, x: Foo) -> Foo: pass
2094+
2095+
class Bar:
2096+
def __add__(self, x: Union[Foo, Bar]) -> Bar: pass
2097+
def __radd__(self, x: Union[Foo, Bar]) -> Bar: pass
2098+
2099+
a: Union[Foo, Bar]
2100+
b: Foo
2101+
c: Bar
2102+
2103+
reveal_type(a + a) # E: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
2104+
reveal_type(a + b) # E: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
2105+
reveal_type(b + a) # E: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
2106+
reveal_type(a + c) # E: Revealed type is '__main__.Bar'
2107+
reveal_type(c + a) # E: Revealed type is '__main__.Bar'
2108+
2109+
[case testOperatorDoubleUnionNaiveAdd]
2110+
from typing import Union
2111+
2112+
class A: pass
2113+
class B: pass
2114+
class C:
2115+
def __radd__(self, x: A) -> int: pass
2116+
class D:
2117+
def __radd__(self, x: B) -> str: pass
2118+
2119+
x: Union[A, B]
2120+
y: Union[C, D]
2121+
2122+
x + y # E: Unsupported operand types for + ("A" and "D") \
2123+
# E: Unsupported operand types for + ("B" and "C") \
2124+
# N: Both left and right operands are unions
2125+
2126+
[case testOperatorDoubleUnionInterwovenUnionAdd]
2127+
from typing import Union
2128+
2129+
class Out1: pass
2130+
class Out2: pass
2131+
class Out3: pass
2132+
class Out4: pass
2133+
2134+
class A:
2135+
def __add__(self, x: D) -> Out1: pass
2136+
class B:
2137+
def __add__(self, x: C) -> Out2: pass
2138+
class C:
2139+
def __radd__(self, x: A) -> Out3: pass
2140+
class D:
2141+
def __radd__(self, x: B) -> Out4: pass
2142+
2143+
x: Union[A, B]
2144+
y: Union[C, D]
2145+
2146+
reveal_type(x + y) # E: Revealed type is 'Union[__main__.Out3, __main__.Out1, __main__.Out2, __main__.Out4]'
2147+
reveal_type(A() + y) # E: Revealed type is 'Union[__main__.Out3, __main__.Out1]'
2148+
reveal_type(B() + y) # E: Revealed type is 'Union[__main__.Out2, __main__.Out4]'
2149+
reveal_type(x + C()) # E: Revealed type is 'Union[__main__.Out3, __main__.Out2]'
2150+
reveal_type(x + D()) # E: Revealed type is 'Union[__main__.Out1, __main__.Out4]'
2151+
20222152
[case testAbstractReverseOperatorMethod]
20232153
import typing
20242154
from abc import abstractmethod

test-data/unit/check-generics.test

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,8 @@ class Node(Generic[T]):
755755
UNode = Union[int, Node[T]]
756756
x = 1 # type: UNode[int]
757757

758-
x + 1 # E: Unsupported left operand type for + (some union)
758+
x + 1 # E: Unsupported left operand type for + ("Node[int]") \
759+
# N: Left operand is of type "Union[int, Node[int]]"
759760
if not isinstance(x, Node):
760761
x + 1
761762

test-data/unit/check-incremental.test

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3515,7 +3515,8 @@ from typing import Optional
35153515
def foo() -> Optional[int]: return 0
35163516
[out1]
35173517
[out2]
3518-
main:3: error: Unsupported operand types for + ("int" and "Optional[int]")
3518+
main:3: error: Unsupported operand types for + ("int" and "None")
3519+
main:3: note: Right operand is of type "Optional[int]"
35193520

35203521
[case testAttrsIncrementalSubclassingCached]
35213522
from a import A
@@ -4082,7 +4083,8 @@ class Baz:
40824083
return 1
40834084
[out]
40844085
[out2]
4085-
tmp/a.py:3: error: Unsupported operand types for + ("int" and "Optional[int]")
4086+
tmp/a.py:3: error: Unsupported operand types for + ("int" and "None")
4087+
tmp/a.py:3: note: Right operand is of type "Optional[int]"
40864088

40874089
[case testIncrementalMetaclassUpdate]
40884090
import a

0 commit comments

Comments
 (0)