@@ -1938,7 +1938,8 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
1938
1938
self .msg .unsupported_operand_types ('in' , left_type , right_type , e )
1939
1939
# Only show dangerous overlap if there are no other errors.
1940
1940
elif (not local_errors .is_errors () and cont_type and
1941
- self .dangerous_comparison (left_type , cont_type )):
1941
+ self .dangerous_comparison (left_type , cont_type ,
1942
+ original_container = right_type )):
1942
1943
self .msg .dangerous_comparison (left_type , cont_type , 'container' , e )
1943
1944
else :
1944
1945
self .msg .add_errors (local_errors )
@@ -1951,8 +1952,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
1951
1952
# testCustomEqCheckStrictEquality for an example.
1952
1953
if self .msg .errors .total_errors () == err_count and operator in ('==' , '!=' ):
1953
1954
right_type = self .accept (right )
1954
- if self .dangerous_comparison (left_type , right_type ):
1955
- self .msg .dangerous_comparison (left_type , right_type , 'equality' , e )
1955
+ if (not custom_equality_method (left_type ) and
1956
+ not custom_equality_method (right_type )):
1957
+ # We suppress the error if there is a custom __eq__() method on either
1958
+ # side. User defined (or even standard library) classes can define this
1959
+ # to return True for comparisons between non-overlapping types.
1960
+ if self .dangerous_comparison (left_type , right_type ):
1961
+ self .msg .dangerous_comparison (left_type , right_type , 'equality' , e )
1956
1962
1957
1963
elif operator == 'is' or operator == 'is not' :
1958
1964
right_type = self .accept (right ) # validate the right operand
@@ -1974,9 +1980,13 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
1974
1980
assert result is not None
1975
1981
return result
1976
1982
1977
- def dangerous_comparison (self , left : Type , right : Type ) -> bool :
1983
+ def dangerous_comparison (self , left : Type , right : Type ,
1984
+ original_container : Optional [Type ] = None ) -> bool :
1978
1985
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.
1979
1986
1987
+ The original_container is the original container type for 'in' checks
1988
+ (and None for equality checks).
1989
+
1980
1990
Rules:
1981
1991
* X and None are overlapping even in strict-optional mode. This is to allow
1982
1992
'assert x is not None' for x defined as 'x = None # type: str' in class body
@@ -1985,9 +1995,7 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool:
1985
1995
non-overlapping, although technically None is overlap, it is most
1986
1996
likely an error.
1987
1997
* Any overlaps with everything, i.e. always safe.
1988
- * Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0
1989
- are errors. This is mostly needed for bytes vs unicode, and
1990
- int vs float are added just for consistency.
1998
+ * Special case: b'abc' in b'cde' is safe.
1991
1999
"""
1992
2000
if not self .chk .options .strict_equality :
1993
2001
return False
@@ -1996,7 +2004,12 @@ def dangerous_comparison(self, left: Type, right: Type) -> bool:
1996
2004
if isinstance (left , UnionType ) and isinstance (right , UnionType ):
1997
2005
left = remove_optional (left )
1998
2006
right = remove_optional (right )
1999
- return not is_overlapping_types (left , right , ignore_promotions = True )
2007
+ if (original_container and has_bytes_component (original_container ) and
2008
+ has_bytes_component (left )):
2009
+ # We need to special case bytes, because both 97 in b'abc' and b'a' in b'abc'
2010
+ # return True (and we want to show the error only if the check can _never_ be True).
2011
+ return False
2012
+ return not is_overlapping_types (left , right , ignore_promotions = False )
2000
2013
2001
2014
def get_operator_method (self , op : str ) -> str :
2002
2015
if op == '/' and self .chk .options .python_version [0 ] == 2 :
@@ -3809,3 +3822,33 @@ def is_expr_literal_type(node: Expression) -> bool:
3809
3822
underlying = node .node
3810
3823
return isinstance (underlying , TypeAlias ) and isinstance (underlying .target , LiteralType )
3811
3824
return False
3825
+
3826
+
3827
+ def custom_equality_method (typ : Type ) -> bool :
3828
+ """Does this type have a custom __eq__() method?"""
3829
+ if isinstance (typ , Instance ):
3830
+ method = typ .type .get_method ('__eq__' )
3831
+ if method and method .info :
3832
+ return not method .info .fullname ().startswith ('builtins.' )
3833
+ return False
3834
+ if isinstance (typ , UnionType ):
3835
+ return any (custom_equality_method (t ) for t in typ .items )
3836
+ if isinstance (typ , TupleType ):
3837
+ return custom_equality_method (tuple_fallback (typ ))
3838
+ if isinstance (typ , CallableType ) and typ .is_type_obj ():
3839
+ # Look up __eq__ on the metaclass for class objects.
3840
+ return custom_equality_method (typ .fallback )
3841
+ if isinstance (typ , AnyType ):
3842
+ # Avoid false positives in uncertain cases.
3843
+ return True
3844
+ # TODO: support other types (see ExpressionChecker.has_member())?
3845
+ return False
3846
+
3847
+
3848
+ def has_bytes_component (typ : Type ) -> bool :
3849
+ """Is this the builtin bytes type, or a union that contains it?"""
3850
+ if isinstance (typ , UnionType ):
3851
+ return any (has_bytes_component (t ) for t in typ .items )
3852
+ if isinstance (typ , Instance ) and typ .type .fullname () == 'builtins.bytes' :
3853
+ return True
3854
+ return False
0 commit comments