@@ -4,15 +4,18 @@ use clippy_utils::ty::{implements_trait, implements_trait_with_env, is_copy};
4
4
use clippy_utils:: { is_lint_allowed, match_def_path} ;
5
5
use if_chain:: if_chain;
6
6
use rustc_errors:: Applicability ;
7
+ use rustc_hir:: def_id:: DefId ;
7
8
use rustc_hir:: intravisit:: { walk_expr, walk_fn, walk_item, FnKind , Visitor } ;
8
9
use rustc_hir:: {
9
- self as hir, BlockCheckMode , BodyId , Expr , ExprKind , FnDecl , HirId , Impl , Item , ItemKind , UnsafeSource , Unsafety ,
10
+ self as hir, BlockCheckMode , BodyId , Constness , Expr , ExprKind , FnDecl , HirId , Impl , Item , ItemKind , UnsafeSource ,
11
+ Unsafety ,
10
12
} ;
11
13
use rustc_lint:: { LateContext , LateLintPass } ;
12
14
use rustc_middle:: hir:: nested_filter;
13
- use rustc_middle:: ty :: subst :: GenericArg ;
15
+ use rustc_middle:: traits :: Reveal ;
14
16
use rustc_middle:: ty:: {
15
- self , BoundConstness , ImplPolarity , ParamEnv , PredicateKind , TraitPredicate , TraitRef , Ty , Visibility ,
17
+ self , Binder , BoundConstness , GenericParamDefKind , ImplPolarity , ParamEnv , PredicateKind , TraitPredicate , TraitRef ,
18
+ Ty , TyCtxt , Visibility ,
16
19
} ;
17
20
use rustc_session:: { declare_lint_pass, declare_tool_lint} ;
18
21
use rustc_span:: source_map:: Span ;
@@ -463,49 +466,16 @@ fn check_partial_eq_without_eq<'tcx>(cx: &LateContext<'tcx>, span: Span, trait_r
463
466
if let ty:: Adt ( adt, substs) = ty. kind( ) ;
464
467
if cx. tcx. visibility( adt. did( ) ) == Visibility :: Public ;
465
468
if let Some ( eq_trait_def_id) = cx. tcx. get_diagnostic_item( sym:: Eq ) ;
466
- if let Some ( peq_trait_def_id) = cx. tcx. get_diagnostic_item( sym:: PartialEq ) ;
467
469
if let Some ( def_id) = trait_ref. trait_def_id( ) ;
468
470
if cx. tcx. is_diagnostic_item( sym:: PartialEq , def_id) ;
469
- // New `ParamEnv` replacing `T: PartialEq` with `T: Eq`
470
- let param_env = ParamEnv :: new(
471
- cx. tcx. mk_predicates( cx. param_env. caller_bounds( ) . iter( ) . map( |p| {
472
- let kind = p. kind( ) ;
473
- match kind. skip_binder( ) {
474
- PredicateKind :: Trait ( p)
475
- if p. trait_ref. def_id == peq_trait_def_id
476
- && p. trait_ref. substs. get( 0 ) == p. trait_ref. substs. get( 1 )
477
- && matches!( p. trait_ref. self_ty( ) . kind( ) , ty:: Param ( _) )
478
- && p. constness == BoundConstness :: NotConst
479
- && p. polarity == ImplPolarity :: Positive =>
480
- {
481
- cx. tcx. mk_predicate( kind. rebind( PredicateKind :: Trait ( TraitPredicate {
482
- trait_ref: TraitRef :: new(
483
- eq_trait_def_id,
484
- cx. tcx. mk_substs( [ GenericArg :: from( p. trait_ref. self_ty( ) ) ] . into_iter( ) ) ,
485
- ) ,
486
- constness: BoundConstness :: NotConst ,
487
- polarity: ImplPolarity :: Positive ,
488
- } ) ) )
489
- } ,
490
- _ => p,
491
- }
492
- } ) ) ,
493
- cx. param_env. reveal( ) ,
494
- cx. param_env. constness( ) ,
495
- ) ;
496
- if !implements_trait_with_env( cx. tcx, param_env, ty, eq_trait_def_id, substs) ;
471
+ let param_env = param_env_for_derived_eq( cx. tcx, adt. did( ) , eq_trait_def_id) ;
472
+ if !implements_trait_with_env( cx. tcx, param_env, ty, eq_trait_def_id, & [ ] ) ;
473
+ // If all of our fields implement `Eq`, we can implement `Eq` too
474
+ if adt
475
+ . all_fields( )
476
+ . map( |f| f. ty( cx. tcx, substs) )
477
+ . all( |ty| implements_trait_with_env( cx. tcx, param_env, ty, eq_trait_def_id, & [ ] ) ) ;
497
478
then {
498
- // If all of our fields implement `Eq`, we can implement `Eq` too
499
- for variant in adt. variants( ) {
500
- for field in & variant. fields {
501
- let ty = field. ty( cx. tcx, substs) ;
502
-
503
- if !implements_trait( cx, ty, eq_trait_def_id, substs) {
504
- return ;
505
- }
506
- }
507
- }
508
-
509
479
span_lint_and_sugg(
510
480
cx,
511
481
DERIVE_PARTIAL_EQ_WITHOUT_EQ ,
@@ -518,3 +488,41 @@ fn check_partial_eq_without_eq<'tcx>(cx: &LateContext<'tcx>, span: Span, trait_r
518
488
}
519
489
}
520
490
}
491
+
492
+ /// Creates the `ParamEnv` used for the give type's derived `Eq` impl.
493
+ fn param_env_for_derived_eq ( tcx : TyCtxt < ' _ > , did : DefId , eq_trait_id : DefId ) -> ParamEnv < ' _ > {
494
+ // Initial map from generic index to param def.
495
+ // Vec<(param_def, needs_eq)>
496
+ let mut params = tcx
497
+ . generics_of ( did)
498
+ . params
499
+ . iter ( )
500
+ . map ( |p| ( p, matches ! ( p. kind, GenericParamDefKind :: Type { .. } ) ) )
501
+ . collect :: < Vec < _ > > ( ) ;
502
+
503
+ let ty_predicates = tcx. predicates_of ( did) . predicates ;
504
+ for ( p, _) in ty_predicates {
505
+ if let PredicateKind :: Trait ( p) = p. kind ( ) . skip_binder ( )
506
+ && p. trait_ref . def_id == eq_trait_id
507
+ && let ty:: Param ( self_ty) = p. trait_ref . self_ty ( ) . kind ( )
508
+ && p. constness == BoundConstness :: NotConst
509
+ {
510
+ // Flag types which already have an `Eq` bound.
511
+ params[ self_ty. index as usize ] . 1 = false ;
512
+ }
513
+ }
514
+
515
+ ParamEnv :: new (
516
+ tcx. mk_predicates ( ty_predicates. iter ( ) . map ( |& ( p, _) | p) . chain (
517
+ params. iter ( ) . filter ( |& & ( _, needs_eq) | needs_eq) . map ( |& ( param, _) | {
518
+ tcx. mk_predicate ( Binder :: dummy ( PredicateKind :: Trait ( TraitPredicate {
519
+ trait_ref : TraitRef :: new ( eq_trait_id, tcx. mk_substs ( [ tcx. mk_param_from_def ( param) ] . into_iter ( ) ) ) ,
520
+ constness : BoundConstness :: NotConst ,
521
+ polarity : ImplPolarity :: Positive ,
522
+ } ) ) )
523
+ } ) ,
524
+ ) ) ,
525
+ Reveal :: UserFacing ,
526
+ Constness :: NotConst ,
527
+ )
528
+ }
0 commit comments