Skip to content

Commit d6be597

Browse files
committed
Use SpanlessEq for in trait_bounds lints
1 parent db1bda3 commit d6be597

File tree

5 files changed

+239
-81
lines changed

5 files changed

+239
-81
lines changed

clippy_lints/src/trait_bounds.rs

+49-58
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@ use clippy_utils::source::{SpanRangeExt, snippet, snippet_with_applicability};
55
use clippy_utils::{SpanlessEq, SpanlessHash, is_from_proc_macro};
66
use core::hash::{Hash, Hasher};
77
use itertools::Itertools;
8-
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
8+
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, IndexEntry};
99
use rustc_data_structures::unhash::UnhashMap;
1010
use rustc_errors::Applicability;
1111
use rustc_hir::def::Res;
1212
use rustc_hir::{
13-
GenericArg, GenericBound, Generics, Item, ItemKind, LangItem, Node, Path, PathSegment, PredicateOrigin, QPath,
13+
GenericBound, Generics, Item, ItemKind, LangItem, Node, Path, PathSegment, PredicateOrigin, QPath,
1414
TraitBoundModifier, TraitItem, TraitRef, Ty, TyKind, WherePredicate,
1515
};
1616
use rustc_lint::{LateContext, LateLintPass};
1717
use rustc_session::impl_lint_pass;
1818
use rustc_span::{BytePos, Span};
19-
use std::collections::hash_map::Entry;
2019

2120
declare_clippy_lint! {
2221
/// ### What it does
@@ -153,7 +152,10 @@ impl<'tcx> LateLintPass<'tcx> for TraitBounds {
153152
.filter_map(get_trait_info_from_bound)
154153
.for_each(|(trait_item_res, trait_item_segments, span)| {
155154
if let Some(self_segments) = self_bounds_map.get(&trait_item_res) {
156-
if SpanlessEq::new(cx).eq_path_segments(self_segments, trait_item_segments) {
155+
if SpanlessEq::new(cx)
156+
.paths_by_resolution()
157+
.eq_path_segments(self_segments, trait_item_segments)
158+
{
157159
span_lint_and_help(
158160
cx,
159161
TRAIT_DUPLICATION_IN_BOUNDS,
@@ -302,7 +304,7 @@ impl TraitBounds {
302304
}
303305
}
304306

305-
fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_>) {
307+
fn check_trait_bound_duplication<'tcx>(cx: &LateContext<'tcx>, generics: &'_ Generics<'tcx>) {
306308
if generics.span.from_expansion() {
307309
return;
308310
}
@@ -314,6 +316,7 @@ fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_
314316
// |
315317
// collects each of these where clauses into a set keyed by generic name and comparable trait
316318
// eg. (T, Clone)
319+
#[expect(clippy::mutable_key_type)]
317320
let where_predicates = generics
318321
.predicates
319322
.iter()
@@ -367,11 +370,27 @@ fn check_trait_bound_duplication(cx: &LateContext<'_>, generics: &'_ Generics<'_
367370
}
368371
}
369372

370-
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
371-
struct ComparableTraitRef(Res, Vec<Res>);
372-
impl Default for ComparableTraitRef {
373-
fn default() -> Self {
374-
Self(Res::Err, Vec::new())
373+
struct ComparableTraitRef<'a, 'tcx> {
374+
cx: &'a LateContext<'tcx>,
375+
trait_ref: &'tcx TraitRef<'tcx>,
376+
modifier: TraitBoundModifier,
377+
}
378+
379+
impl PartialEq for ComparableTraitRef<'_, '_> {
380+
fn eq(&self, other: &Self) -> bool {
381+
self.modifier == other.modifier
382+
&& SpanlessEq::new(self.cx)
383+
.paths_by_resolution()
384+
.eq_path(self.trait_ref.path, other.trait_ref.path)
385+
}
386+
}
387+
impl Eq for ComparableTraitRef<'_, '_> {}
388+
impl Hash for ComparableTraitRef<'_, '_> {
389+
fn hash<H: Hasher>(&self, state: &mut H) {
390+
let mut s = SpanlessHash::new(self.cx).paths_by_resolution();
391+
s.hash_path(self.trait_ref.path);
392+
state.write_u64(s.finish());
393+
self.modifier.hash(state);
375394
}
376395
}
377396

@@ -392,69 +411,41 @@ fn get_trait_info_from_bound<'a>(bound: &'a GenericBound<'_>) -> Option<(Res, &'
392411
}
393412
}
394413

395-
fn get_ty_res(ty: Ty<'_>) -> Option<Res> {
396-
match ty.kind {
397-
TyKind::Path(QPath::Resolved(_, path)) => Some(path.res),
398-
TyKind::Path(QPath::TypeRelative(ty, _)) => get_ty_res(*ty),
399-
_ => None,
400-
}
401-
}
402-
403-
// FIXME: ComparableTraitRef does not support nested bounds needed for associated_type_bounds
404-
fn into_comparable_trait_ref(trait_ref: &TraitRef<'_>) -> ComparableTraitRef {
405-
ComparableTraitRef(
406-
trait_ref.path.res,
407-
trait_ref
408-
.path
409-
.segments
410-
.iter()
411-
.filter_map(|segment| {
412-
// get trait bound type arguments
413-
Some(segment.args?.args.iter().filter_map(|arg| {
414-
if let GenericArg::Type(ty) = arg {
415-
return get_ty_res(**ty);
416-
}
417-
None
418-
}))
419-
})
420-
.flatten()
421-
.collect(),
422-
)
423-
}
424-
425-
fn rollup_traits(
426-
cx: &LateContext<'_>,
427-
bounds: &[GenericBound<'_>],
414+
fn rollup_traits<'cx, 'tcx>(
415+
cx: &'cx LateContext<'tcx>,
416+
bounds: &'tcx [GenericBound<'tcx>],
428417
msg: &'static str,
429-
) -> Vec<(ComparableTraitRef, Span)> {
430-
let mut map = FxHashMap::default();
418+
) -> Vec<(ComparableTraitRef<'cx, 'tcx>, Span)> {
419+
// Source order is needed for joining spans
420+
let mut map = FxIndexMap::default();
431421
let mut repeated_res = false;
432422

433-
let only_comparable_trait_refs = |bound: &GenericBound<'_>| {
434-
if let GenericBound::Trait(t, _) = bound {
435-
Some((into_comparable_trait_ref(&t.trait_ref), t.span))
423+
let only_comparable_trait_refs = |bound: &'tcx GenericBound<'tcx>| {
424+
if let GenericBound::Trait(t, modifier) = bound {
425+
Some((
426+
ComparableTraitRef {
427+
cx,
428+
trait_ref: &t.trait_ref,
429+
modifier: *modifier,
430+
},
431+
t.span,
432+
))
436433
} else {
437434
None
438435
}
439436
};
440437

441-
let mut i = 0usize;
442438
for bound in bounds.iter().filter_map(only_comparable_trait_refs) {
443439
let (comparable_bound, span_direct) = bound;
444440
match map.entry(comparable_bound) {
445-
Entry::Occupied(_) => repeated_res = true,
446-
Entry::Vacant(e) => {
447-
e.insert((span_direct, i));
448-
i += 1;
441+
IndexEntry::Occupied(_) => repeated_res = true,
442+
IndexEntry::Vacant(e) => {
443+
e.insert(span_direct);
449444
},
450445
}
451446
}
452447

453-
// Put bounds in source order
454-
let mut comparable_bounds = vec![Default::default(); map.len()];
455-
for (k, (v, i)) in map {
456-
comparable_bounds[i] = (k, v);
457-
}
448+
let comparable_bounds: Vec<_> = map.into_iter().collect();
458449

459450
if repeated_res && let [first_trait, .., last_trait] = bounds {
460451
let all_trait_span = first_trait.span().to(last_trait.span());

clippy_utils/src/hir_utils.rs

+106-13
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::tokenize_with_text;
55
use rustc_ast::ast::InlineAsmTemplatePiece;
66
use rustc_data_structures::fx::FxHasher;
77
use rustc_hir::MatchSource::TryDesugar;
8-
use rustc_hir::def::Res;
8+
use rustc_hir::def::{DefKind, Res};
99
use rustc_hir::{
1010
ArrayLen, AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, Closure, ConstArg, ConstArgKind, Expr,
1111
ExprField, ExprKind, FnRetTy, GenericArg, GenericArgs, HirId, HirIdMap, InlineAsmOperand, LetExpr, Lifetime,
@@ -17,11 +17,33 @@ use rustc_middle::ty::TypeckResults;
1717
use rustc_span::{BytePos, ExpnKind, MacroKind, Symbol, SyntaxContext, sym};
1818
use std::hash::{Hash, Hasher};
1919
use std::ops::Range;
20+
use std::slice;
2021

2122
/// Callback that is called when two expressions are not equal in the sense of `SpanlessEq`, but
2223
/// other conditions would make them equal.
2324
type SpanlessEqCallback<'a> = dyn FnMut(&Expr<'_>, &Expr<'_>) -> bool + 'a;
2425

26+
/// Determines how paths are hashed and compared for equality.
27+
#[derive(Copy, Clone, Debug, Default)]
28+
pub enum PathCheck {
29+
/// Paths must match exactly and are hashed by their exact HIR tree.
30+
///
31+
/// Thus, `std::iter::Iterator` and `Iterator` are not considered equal even though they refer
32+
/// to the same item.
33+
#[default]
34+
Exact,
35+
/// Paths are compared and hashed based on their resolution.
36+
///
37+
/// They can appear different in the HIR tree but are still considered equal
38+
/// and have equal hashes as long as they refer to the same item.
39+
///
40+
/// Note that this is currently only partially implemented specifically for paths that are
41+
/// resolved before type-checking, i.e. the final segment must have a non-error resolution.
42+
/// If a path with an error resolution is encountered, it falls back to the default exact
43+
/// matching behavior.
44+
Resolution,
45+
}
46+
2547
/// Type used to check whether two ast are the same. This is different from the
2648
/// operator `==` on ast types as this operator would compare true equality with
2749
/// ID and span.
@@ -33,6 +55,7 @@ pub struct SpanlessEq<'a, 'tcx> {
3355
maybe_typeck_results: Option<(&'tcx TypeckResults<'tcx>, &'tcx TypeckResults<'tcx>)>,
3456
allow_side_effects: bool,
3557
expr_fallback: Option<Box<SpanlessEqCallback<'a>>>,
58+
path_check: PathCheck,
3659
}
3760

3861
impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
@@ -42,6 +65,7 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
4265
maybe_typeck_results: cx.maybe_typeck_results().map(|x| (x, x)),
4366
allow_side_effects: true,
4467
expr_fallback: None,
68+
path_check: PathCheck::default(),
4569
}
4670
}
4771

@@ -54,6 +78,16 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
5478
}
5579
}
5680

81+
/// Check paths by their resolution instead of exact equality. See [`PathCheck`] for more
82+
/// details.
83+
#[must_use]
84+
pub fn paths_by_resolution(self) -> Self {
85+
Self {
86+
path_check: PathCheck::Resolution,
87+
..self
88+
}
89+
}
90+
5791
#[must_use]
5892
pub fn expr_fallback(self, expr_fallback: impl FnMut(&Expr<'_>, &Expr<'_>) -> bool + 'a) -> Self {
5993
Self {
@@ -498,7 +532,7 @@ impl HirEqInterExpr<'_, '_, '_> {
498532
match (left.res, right.res) {
499533
(Res::Local(l), Res::Local(r)) => l == r || self.locals.get(&l) == Some(&r),
500534
(Res::Local(_), _) | (_, Res::Local(_)) => false,
501-
_ => over(left.segments, right.segments, |l, r| self.eq_path_segment(l, r)),
535+
_ => self.eq_path_segments(left.segments, right.segments),
502536
}
503537
}
504538

@@ -511,17 +545,39 @@ impl HirEqInterExpr<'_, '_, '_> {
511545
}
512546
}
513547

514-
pub fn eq_path_segments(&mut self, left: &[PathSegment<'_>], right: &[PathSegment<'_>]) -> bool {
515-
left.len() == right.len() && left.iter().zip(right).all(|(l, r)| self.eq_path_segment(l, r))
548+
pub fn eq_path_segments<'tcx>(
549+
&mut self,
550+
mut left: &'tcx [PathSegment<'tcx>],
551+
mut right: &'tcx [PathSegment<'tcx>],
552+
) -> bool {
553+
if let PathCheck::Resolution = self.inner.path_check
554+
&& let Some(left_seg) = generic_path_segments(left)
555+
&& let Some(right_seg) = generic_path_segments(right)
556+
{
557+
// If we compare by resolution, then only check the last segments that could possibly have generic
558+
// arguments
559+
left = left_seg;
560+
right = right_seg;
561+
}
562+
563+
over(left, right, |l, r| self.eq_path_segment(l, r))
516564
}
517565

518566
pub fn eq_path_segment(&mut self, left: &PathSegment<'_>, right: &PathSegment<'_>) -> bool {
519-
// The == of idents doesn't work with different contexts,
520-
// we have to be explicit about hygiene
521-
left.ident.name == right.ident.name
522-
&& both(left.args.as_ref(), right.args.as_ref(), |l, r| {
523-
self.eq_path_parameters(l, r)
524-
})
567+
if !self.eq_path_parameters(left.args(), right.args()) {
568+
return false;
569+
}
570+
571+
if let PathCheck::Resolution = self.inner.path_check
572+
&& left.res != Res::Err
573+
&& right.res != Res::Err
574+
{
575+
left.res == right.res
576+
} else {
577+
// The == of idents doesn't work with different contexts,
578+
// we have to be explicit about hygiene
579+
left.ident.name == right.ident.name
580+
}
525581
}
526582

527583
pub fn eq_ty(&mut self, left: &Ty<'_>, right: &Ty<'_>) -> bool {
@@ -684,6 +740,21 @@ pub fn eq_expr_value(cx: &LateContext<'_>, left: &Expr<'_>, right: &Expr<'_>) ->
684740
SpanlessEq::new(cx).deny_side_effects().eq_expr(left, right)
685741
}
686742

743+
/// Returns the segments of a path that might have generic parameters.
744+
/// Usually just the last segment for free items, except for when the path resolves to an associated
745+
/// item, in which case it is the last two
746+
fn generic_path_segments<'tcx>(segments: &'tcx [PathSegment<'tcx>]) -> Option<&'tcx [PathSegment<'tcx>]> {
747+
match segments.last()?.res {
748+
Res::Def(DefKind::AssocConst | DefKind::AssocFn | DefKind::AssocTy, _) => {
749+
// <Ty as module::Trait<T>>::assoc::<U>
750+
// ^^^^^^^^^^^^^^^^ ^^^^^^^^^^ segments: [module, Trait<T>, assoc<U>]
751+
Some(&segments[segments.len().checked_sub(2)?..])
752+
},
753+
Res::Err => None,
754+
_ => Some(slice::from_ref(segments.last()?)),
755+
}
756+
}
757+
687758
/// Type used to hash an ast element. This is different from the `Hash` trait
688759
/// on ast types as this
689760
/// trait would consider IDs and spans.
@@ -694,17 +765,29 @@ pub struct SpanlessHash<'a, 'tcx> {
694765
cx: &'a LateContext<'tcx>,
695766
maybe_typeck_results: Option<&'tcx TypeckResults<'tcx>>,
696767
s: FxHasher,
768+
path_check: PathCheck,
697769
}
698770

699771
impl<'a, 'tcx> SpanlessHash<'a, 'tcx> {
700772
pub fn new(cx: &'a LateContext<'tcx>) -> Self {
701773
Self {
702774
cx,
703775
maybe_typeck_results: cx.maybe_typeck_results(),
776+
path_check: PathCheck::default(),
704777
s: FxHasher::default(),
705778
}
706779
}
707780

781+
/// Check paths by their resolution instead of exact equality. See [`PathCheck`] for more
782+
/// details.
783+
#[must_use]
784+
pub fn paths_by_resolution(self) -> Self {
785+
Self {
786+
path_check: PathCheck::Resolution,
787+
..self
788+
}
789+
}
790+
708791
pub fn finish(self) -> u64 {
709792
self.s.finish()
710793
}
@@ -1042,9 +1125,19 @@ impl<'a, 'tcx> SpanlessHash<'a, 'tcx> {
10421125
// even though the binding names are different and they have different `HirId`s.
10431126
Res::Local(_) => 1_usize.hash(&mut self.s),
10441127
_ => {
1045-
for seg in path.segments {
1046-
self.hash_name(seg.ident.name);
1047-
self.hash_generic_args(seg.args().args);
1128+
if let PathCheck::Resolution = self.path_check
1129+
&& let [.., last] = path.segments
1130+
&& let Some(segments) = generic_path_segments(path.segments)
1131+
{
1132+
for seg in segments {
1133+
self.hash_generic_args(seg.args().args);
1134+
}
1135+
last.res.hash(&mut self.s);
1136+
} else {
1137+
for seg in path.segments {
1138+
self.hash_name(seg.ident.name);
1139+
self.hash_generic_args(seg.args().args);
1140+
}
10481141
}
10491142
},
10501143
}

0 commit comments

Comments
 (0)