Skip to content

Commit 2baebda

Browse files
authored
Tweaks to --strict-equality based on user feedback (#6674)
Fixes #6607 Fixes #6608
1 parent 01afc3b commit 2baebda

File tree

4 files changed

+141
-15
lines changed

4 files changed

+141
-15
lines changed

docs/source/command_line.rst

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,15 +396,17 @@ of the above sections.
396396

397397
.. code-block:: python
398398
399-
from typing import Text
399+
from typing import List, Text
400400
401-
text: Text
402-
if b'some bytes' in text: # Error: non-overlapping check!
401+
items: List[int]
402+
if 'some string' in items: # Error: non-overlapping container check!
403403
...
404-
if text != b'other bytes': # Error: non-overlapping check!
404+
405+
text: Text
406+
if text != b'other bytes': # Error: non-overlapping equality check!
405407
...
406408
407-
assert text is not None # OK, this special case is allowed.
409+
assert text is not None # OK, check against None is allowed as a special case.
408410
409411
``--strict``
410412
This flag mode enables all optional error checking flags. You can see the

mypy/checkexpr.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,7 +1938,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
19381938
self.msg.unsupported_operand_types('in', left_type, right_type, e)
19391939
# Only show dangerous overlap if there are no other errors.
19401940
elif (not local_errors.is_errors() and cont_type and
1941-
self.dangerous_comparison(left_type, cont_type)):
1941+
self.dangerous_comparison(left_type, cont_type,
1942+
original_container=right_type)):
19421943
self.msg.dangerous_comparison(left_type, cont_type, 'container', e)
19431944
else:
19441945
self.msg.add_errors(local_errors)
@@ -1951,8 +1952,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
19511952
# testCustomEqCheckStrictEquality for an example.
19521953
if self.msg.errors.total_errors() == err_count and operator in ('==', '!='):
19531954
right_type = self.accept(right)
1954-
if self.dangerous_comparison(left_type, right_type):
1955-
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)
1955+
if (not custom_equality_method(left_type) and
1956+
not custom_equality_method(right_type)):
1957+
# We suppress the error if there is a custom __eq__() method on either
1958+
# side. User defined (or even standard library) classes can define this
1959+
# to return True for comparisons between non-overlapping types.
1960+
if self.dangerous_comparison(left_type, right_type):
1961+
self.msg.dangerous_comparison(left_type, right_type, 'equality', e)
19561962

19571963
elif operator == 'is' or operator == 'is not':
19581964
right_type = self.accept(right) # validate the right operand
@@ -1974,9 +1980,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
19741980
assert result is not None
19751981
return result
19761982

1977-
def dangerous_comparison(self, left: Type, right: Type) -> bool:
1983+
def dangerous_comparison(self, left: Type, right: Type,
1984+
original_container: Optional[Type] = None) -> bool:
19781985
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.
19791986
1987+
The original_container is the original container type for 'in' checks
1988+
(and None for equality checks).
1989+
19801990
Rules:
19811991
* X and None are overlapping even in strict-optional mode. This is to allow
19821992
'assert x is not None' for x defined as 'x = None # type: str' in class body
@@ -1985,9 +1995,7 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool:
19851995
non-overlapping, although technically None is overlap, it is most
19861996
likely an error.
19871997
* Any overlaps with everything, i.e. always safe.
1988-
* Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0
1989-
are errors. This is mostly needed for bytes vs unicode, and
1990-
int vs float are added just for consistency.
1998+
* Special case: b'abc' in b'cde' is safe.
19911999
"""
19922000
if not self.chk.options.strict_equality:
19932001
return False
@@ -1996,7 +2004,12 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool:
19962004
if isinstance(left, UnionType) and isinstance(right, UnionType):
19972005
left = remove_optional(left)
19982006
right = remove_optional(right)
1999-
return not is_overlapping_types(left, right, ignore_promotions=True)
2007+
if (original_container and has_bytes_component(original_container) and
2008+
has_bytes_component(left)):
2009+
# We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc'
2010+
# return True (and we want to show the error only if the check can _never_ be True).
2011+
return False
2012+
return not is_overlapping_types(left, right, ignore_promotions=False)
20002013

20012014
def get_operator_method(self, op: str) -> str:
20022015
if op == '/' and self.chk.options.python_version[0] == 2:
@@ -3809,3 +3822,33 @@ def is_expr_literal_type(node: Expression) -> bool:
38093822
underlying = node.node
38103823
return isinstance(underlying, TypeAlias) and isinstance(underlying.target, LiteralType)
38113824
return False
3825+
3826+
3827+
def custom_equality_method(typ: Type) -> bool:
3828+
"""Does this type have a custom __eq__() method?"""
3829+
if isinstance(typ, Instance):
3830+
method = typ.type.get_method('__eq__')
3831+
if method and method.info:
3832+
return not method.info.fullname().startswith('builtins.')
3833+
return False
3834+
if isinstance(typ, UnionType):
3835+
return any(custom_equality_method(t) for t in typ.items)
3836+
if isinstance(typ, TupleType):
3837+
return custom_equality_method(tuple_fallback(typ))
3838+
if isinstance(typ, CallableType) and typ.is_type_obj():
3839+
# Look up __eq__ on the metaclass for class objects.
3840+
return custom_equality_method(typ.fallback)
3841+
if isinstance(typ, AnyType):
3842+
# Avoid false positives in uncertain cases.
3843+
return True
3844+
# TODO: support other types (see ExpressionChecker.has_member())?
3845+
return False
3846+
3847+
3848+
def has_bytes_component(typ: Type) -> bool:
3849+
"""Is this the builtin bytes type, or a union that contains it?"""
3850+
if isinstance(typ, UnionType):
3851+
return any(has_bytes_component(t) for t in typ.items)
3852+
if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes':
3853+
return True
3854+
return False

test-data/unit/check-expressions.test

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2024,7 +2024,23 @@ cb: Union[Container[A], Container[B]]
20242024
[builtins fixtures/bool.pyi]
20252025
[typing fixtures/typing-full.pyi]
20262026

2027-
[case testStrictEqualityNoPromote]
2027+
[case testStrictEqualityBytesSpecial]
2028+
# flags: --strict-equality
2029+
b'abc' in b'abcde'
2030+
[builtins fixtures/primitives.pyi]
2031+
[typing fixtures/typing-full.pyi]
2032+
2033+
[case testStrictEqualityBytesSpecialUnion]
2034+
# flags: --strict-equality
2035+
from typing import Union
2036+
x: Union[bytes, str]
2037+
2038+
b'abc' in x
2039+
x in b'abc'
2040+
[builtins fixtures/primitives.pyi]
2041+
[typing fixtures/typing-full.pyi]
2042+
2043+
[case testStrictEqualityNoPromotePy3]
20282044
# flags: --strict-equality
20292045
'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes")
20302046
b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str")
@@ -2035,6 +2051,16 @@ x != y # E: Non-overlapping equality check (left operand type: "str", right ope
20352051
[builtins fixtures/primitives.pyi]
20362052
[typing fixtures/typing-full.pyi]
20372053

2054+
[case testStrictEqualityOkPromote]
2055+
# flags: --strict-equality
2056+
from typing import Container
2057+
c: Container[int]
2058+
2059+
1 == 1.0 # OK
2060+
1.0 in c # OK
2061+
[builtins fixtures/primitives.pyi]
2062+
[typing fixtures/typing-full.pyi]
2063+
20382064
[case testStrictEqualityAny]
20392065
# flags: --strict-equality
20402066
from typing import Any, Container
@@ -2086,6 +2112,58 @@ class B:
20862112
A() == B() # E: Unsupported operand types for == ("A" and "B")
20872113
[builtins fixtures/bool.pyi]
20882114

2115+
[case testCustomEqCheckStrictEqualityOKInstance]
2116+
# flags: --strict-equality
2117+
class A:
2118+
def __eq__(self, other: object) -> bool:
2119+
...
2120+
class B:
2121+
def __eq__(self, other: object) -> bool:
2122+
...
2123+
2124+
A() == int() # OK
2125+
int() != B() # OK
2126+
[builtins fixtures/bool.pyi]
2127+
2128+
[case testCustomEqCheckStrictEqualityOKUnion]
2129+
# flags: --strict-equality
2130+
from typing import Union
2131+
class A:
2132+
def __eq__(self, other: object) -> bool:
2133+
...
2134+
2135+
x: Union[A, str]
2136+
x == int()
2137+
[builtins fixtures/bool.pyi]
2138+
2139+
[case testCustomEqCheckStrictEqualityTuple]
2140+
# flags: --strict-equality
2141+
from typing import NamedTuple
2142+
2143+
class Base(NamedTuple):
2144+
attr: int
2145+
2146+
class Custom(Base):
2147+
def __eq__(self, other: object) -> bool: ...
2148+
2149+
Base(int()) == int() # E: Non-overlapping equality check (left operand type: "Base", right operand type: "int")
2150+
Base(int()) == tuple()
2151+
Custom(int()) == int()
2152+
[builtins fixtures/bool.pyi]
2153+
2154+
[case testCustomEqCheckStrictEqualityMeta]
2155+
# flags: --strict-equality
2156+
class CustomMeta(type):
2157+
def __eq__(self, other: object) -> bool: ...
2158+
2159+
class Normal: ...
2160+
class Custom(metaclass=CustomMeta): ...
2161+
2162+
Normal == int() # E: Non-overlapping equality check (left operand type: "Type[Normal]", right operand type: "int")
2163+
Normal == Normal
2164+
Custom == int()
2165+
[builtins fixtures/bool.pyi]
2166+
20892167
[case testCustomContainsCheckStrictEquality]
20902168
# flags: --strict-equality
20912169
class A:

test-data/unit/fixtures/primitives.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ class str(Sequence[str]):
2323
def __contains__(self, other: object) -> bool: pass
2424
def __getitem__(self, item: int) -> str: pass
2525
def format(self, *args) -> str: pass
26-
class bytes: pass
26+
class bytes(Sequence[int]):
27+
def __iter__(self) -> Iterator[int]: pass
28+
def __contains__(self, other: object) -> bool: pass
29+
def __getitem__(self, item: int) -> int: pass
2730
class bytearray: pass
2831
class tuple(Generic[T]): pass
2932
class function: pass

0 commit comments

Comments
 (0)