Skip to content

Commit 7b74759

Browse files
committed
When HIR auto-refs a comparison operator, clean it up by dereffing in MIR
Today, if you're comparing `&&T`s, it ends up auto-reffing in HIR. So the MIR ends up calling `PartialEq/Cmp` with `&&&T`, and the MIR inliner can only get that down to `&T`: <https://rust.godbolt.org/z/hje6jd4Yf>. So this adds an always-run pass to look at `Call`s in MIR with `from_hir_call: false` to just call the correct `Partial{Eq,Cmp}` implementation directly, even if it's debug and we're not running the inliner, to avoid needing to ever monomorphize a bunch of useless forwarding impls. This hopes to avoid ever needing something like rust-lang#108372 where we'd tell people to manually dereference the sides of their comparisons.
1 parent e4b9f86 commit 7b74759

11 files changed

+910
-213
lines changed

compiler/rustc_middle/src/mir/mod.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -1931,9 +1931,13 @@ impl<'tcx> Operand<'tcx> {
19311931
///
19321932
/// While this is unlikely in general, it's the normal case of what you'll
19331933
/// find as the `func` in a [`TerminatorKind::Call`].
1934-
pub fn const_fn_def(&self) -> Option<(DefId, SubstsRef<'tcx>)> {
1935-
let const_ty = self.constant()?.literal.ty();
1936-
if let ty::FnDef(def_id, substs) = *const_ty.kind() { Some((def_id, substs)) } else { None }
1934+
pub fn const_fn_def(&self) -> Option<(DefId, SubstsRef<'tcx>, Span)> {
1935+
let constant = self.constant()?;
1936+
if let ty::FnDef(def_id, substs) = *constant.literal.ty().kind() {
1937+
Some((def_id, substs, constant.span))
1938+
} else {
1939+
None
1940+
}
19371941
}
19381942
}
19391943

compiler/rustc_mir_transform/src/instcombine.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ impl<'tcx> InstCombineContext<'tcx, '_> {
168168
else { return };
169169

170170
// Only bother looking more if it's easy to know what we're calling
171-
let Some((fn_def_id, fn_substs)) = func.const_fn_def()
171+
let Some((fn_def_id, fn_substs, _span)) = func.const_fn_def()
172172
else { return };
173173

174174
// Clone needs one subst, so we can cheaply rule out other stuff

compiler/rustc_mir_transform/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ mod ssa;
9494
pub mod simplify;
9595
mod simplify_branches;
9696
mod simplify_comparison_integral;
97+
mod simplify_ref_comparisons;
9798
mod sroa;
9899
mod uninhabited_enum_branching;
99100
mod unreachable_prop;
@@ -497,6 +498,8 @@ fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
497498
&cleanup_post_borrowck::CleanupPostBorrowck,
498499
&remove_noop_landing_pads::RemoveNoopLandingPads,
499500
&simplify::SimplifyCfg::new("early-opt"),
501+
// Adds more `Deref`s, so needs to be before `Derefer`.
502+
&simplify_ref_comparisons::SimplifyRefComparisons,
500503
&deref_separator::Derefer,
501504
];
502505

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
use crate::MirPass;
2+
use rustc_middle::mir::*;
3+
use rustc_middle::ty::{self, Ty, TyCtxt};
4+
5+
/// This pass replaces `x OP y` with `*x OP *y` when `OP` is a comparison operator.
6+
///
7+
/// The goal is to make is so that it's never better for the user to write
8+
/// `***x == ***y` than to write the obvious `x == y` (when `x` and `y` are
9+
/// references and thus those do the same thing). This is particularly
10+
/// important because the type-checker will auto-ref any comparison that's not
11+
/// done directly on a primitive. That means that `a_ref == b_ref` doesn't
12+
/// become `PartialEq::eq(a_ref, b_ref)`, even though that would work, but rather
13+
/// ```no_run
14+
/// # fn foo(a_ref: &i32, b_ref: &i32) -> bool {
15+
/// let temp1 = &a_ref;
16+
/// let temp2 = &b_ref;
17+
/// PartialEq::eq(temp1, temp2)
18+
/// # }
19+
/// ```
20+
/// Thus this pass means it directly calls the *interesting* `impl` directly,
21+
/// rather than needing to monomorphize and/or inline it later. (And when this
22+
/// comment was written in March 2023, the MIR inliner seemed to only inline
23+
/// one level of `==`, so if the comparison is on something like `&&i32` the
24+
/// extra forwarding impls needed to be monomorphized even in an optimized build.)
25+
///
26+
/// Make sure this runs before the `Derefer`, since it might add multiple levels
27+
/// of dereferences in the `Operand`s that are arguments to the `Call`.
28+
pub struct SimplifyRefComparisons;
29+
30+
impl<'tcx> MirPass<'tcx> for SimplifyRefComparisons {
31+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
32+
// Despite the method name, this is `PartialEq`, not `Eq`.
33+
let Some(partial_eq) = tcx.lang_items().eq_trait() else { return };
34+
let Some(partial_ord) = tcx.lang_items().partial_ord_trait() else { return };
35+
36+
for block in body.basic_blocks.as_mut() {
37+
let terminator = block.terminator.as_mut().unwrap();
38+
let TerminatorKind::Call { func, args, from_hir_call: false, .. } =
39+
&mut terminator.kind
40+
else { continue };
41+
42+
// Quickly skip unary operators
43+
if args.len() != 2 {
44+
continue;
45+
}
46+
let (Some(left_place), Some(right_place)) = (args[0].place(), args[1].place())
47+
else { continue };
48+
49+
let (fn_def, fn_substs, fn_span) =
50+
func.const_fn_def().expect("HIR operators to always call the traits directly");
51+
let substs =
52+
fn_substs.try_as_type_list().expect("HIR operators only have type parameters");
53+
let [left_ty, right_ty] = *substs.as_slice() else { continue };
54+
let (depth, new_left_ty, new_right_ty) = find_ref_depth(left_ty, right_ty);
55+
if depth == 0 {
56+
// Already dereffed as far as possible.
57+
continue;
58+
}
59+
60+
// Check it's a comparison, not `+`/`&`/etc.
61+
let trait_def = tcx.trait_of_item(fn_def);
62+
if trait_def != Some(partial_eq) && trait_def != Some(partial_ord) {
63+
continue;
64+
}
65+
66+
let derefs = vec![ProjectionElem::Deref; depth];
67+
let new_substs = [new_left_ty.into(), new_right_ty.into()];
68+
69+
*func = Operand::function_handle(tcx, fn_def, new_substs, fn_span);
70+
args[0] = Operand::Copy(left_place.project_deeper(&derefs, tcx));
71+
args[1] = Operand::Copy(right_place.project_deeper(&derefs, tcx));
72+
}
73+
}
74+
}
75+
76+
fn find_ref_depth<'tcx>(mut left: Ty<'tcx>, mut right: Ty<'tcx>) -> (usize, Ty<'tcx>, Ty<'tcx>) {
77+
let mut depth = 0;
78+
while let (ty::Ref(_, new_left, Mutability::Not), ty::Ref(_, new_right, Mutability::Not)) =
79+
(left.kind(), right.kind())
80+
{
81+
depth += 1;
82+
(left, right) = (*new_left, *new_right);
83+
}
84+
85+
(depth, left, right)
86+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
- // MIR for `multi_ref_prim` before SimplifyRefComparisons
2+
+ // MIR for `multi_ref_prim` after SimplifyRefComparisons
3+
4+
fn multi_ref_prim(_1: &&&i32, _2: &&&i32) -> () {
5+
debug x => _1; // in scope 0 at $DIR/simplify_cmp.rs:+0:23: +0:24
6+
debug y => _2; // in scope 0 at $DIR/simplify_cmp.rs:+0:34: +0:35
7+
let mut _0: (); // return place in scope 0 at $DIR/simplify_cmp.rs:+0:45: +0:45
8+
let _3: bool; // in scope 0 at $DIR/simplify_cmp.rs:+1:9: +1:11
9+
let mut _4: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
10+
let mut _5: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
11+
let mut _7: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+2:14: +2:15
12+
let mut _8: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+2:19: +2:20
13+
let mut _10: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:14: +3:15
14+
let mut _11: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:18: +3:19
15+
let _12: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:18: +3:19
16+
let mut _14: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:14: +4:15
17+
let mut _15: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:19: +4:20
18+
let _16: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:19: +4:20
19+
let mut _18: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:14: +5:15
20+
let mut _19: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:18: +5:19
21+
let _20: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:18: +5:19
22+
let mut _22: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:14: +6:15
23+
let mut _23: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:19: +6:20
24+
let _24: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:19: +6:20
25+
scope 1 {
26+
debug _a => _3; // in scope 1 at $DIR/simplify_cmp.rs:+1:9: +1:11
27+
let _6: bool; // in scope 1 at $DIR/simplify_cmp.rs:+2:9: +2:11
28+
scope 2 {
29+
debug _b => _6; // in scope 2 at $DIR/simplify_cmp.rs:+2:9: +2:11
30+
let _9: bool; // in scope 2 at $DIR/simplify_cmp.rs:+3:9: +3:11
31+
scope 3 {
32+
debug _c => _9; // in scope 3 at $DIR/simplify_cmp.rs:+3:9: +3:11
33+
let _13: bool; // in scope 3 at $DIR/simplify_cmp.rs:+4:9: +4:11
34+
scope 4 {
35+
debug _d => _13; // in scope 4 at $DIR/simplify_cmp.rs:+4:9: +4:11
36+
let _17: bool; // in scope 4 at $DIR/simplify_cmp.rs:+5:9: +5:11
37+
scope 5 {
38+
debug _e => _17; // in scope 5 at $DIR/simplify_cmp.rs:+5:9: +5:11
39+
let _21: bool; // in scope 5 at $DIR/simplify_cmp.rs:+6:9: +6:11
40+
scope 6 {
41+
debug _f => _21; // in scope 6 at $DIR/simplify_cmp.rs:+6:9: +6:11
42+
}
43+
}
44+
}
45+
}
46+
}
47+
}
48+
49+
bb0: {
50+
StorageLive(_3); // scope 0 at $DIR/simplify_cmp.rs:+1:9: +1:11
51+
StorageLive(_4); // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
52+
_4 = &_1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
53+
StorageLive(_5); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
54+
_5 = &_2; // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
55+
- _3 = <&&&i32 as PartialEq>::eq(move _4, move _5) -> bb1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:20
56+
+ _3 = <i32 as PartialEq>::eq((*(*(*_4))), (*(*(*_5)))) -> bb1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:20
57+
// mir::Constant
58+
// + span: $DIR/simplify_cmp.rs:18:14: 18:20
59+
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialEq>::eq}, val: Value(<ZST>) }
60+
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialEq>::eq}, val: Value(<ZST>) }
61+
}
62+
63+
bb1: {
64+
StorageDead(_5); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
65+
StorageDead(_4); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
66+
StorageLive(_6); // scope 1 at $DIR/simplify_cmp.rs:+2:9: +2:11
67+
StorageLive(_7); // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:15
68+
_7 = &_1; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:15
69+
StorageLive(_8); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
70+
_8 = &_2; // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
71+
- _6 = <&&&i32 as PartialEq>::ne(move _7, move _8) -> bb2; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:20
72+
+ _6 = <i32 as PartialEq>::ne((*(*(*_7))), (*(*(*_8)))) -> bb2; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:20
73+
// mir::Constant
74+
// + span: $DIR/simplify_cmp.rs:19:14: 19:20
75+
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialEq>::ne}, val: Value(<ZST>) }
76+
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialEq>::ne}, val: Value(<ZST>) }
77+
}
78+
79+
bb2: {
80+
StorageDead(_8); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
81+
StorageDead(_7); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
82+
StorageLive(_9); // scope 2 at $DIR/simplify_cmp.rs:+3:9: +3:11
83+
StorageLive(_10); // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:15
84+
_10 = &_1; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:15
85+
StorageLive(_11); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
86+
StorageLive(_12); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
87+
_12 = &(*_2); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
88+
_11 = &_12; // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
89+
- _9 = <&&&i32 as PartialOrd>::lt(move _10, move _11) -> bb3; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:19
90+
+ _9 = <i32 as PartialOrd>::lt((*(*(*_10))), (*(*(*_11)))) -> bb3; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:19
91+
// mir::Constant
92+
// + span: $DIR/simplify_cmp.rs:20:14: 20:19
93+
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::lt}, val: Value(<ZST>) }
94+
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::lt}, val: Value(<ZST>) }
95+
}
96+
97+
bb3: {
98+
StorageDead(_11); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
99+
StorageDead(_10); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
100+
StorageDead(_12); // scope 2 at $DIR/simplify_cmp.rs:+3:19: +3:20
101+
StorageLive(_13); // scope 3 at $DIR/simplify_cmp.rs:+4:9: +4:11
102+
StorageLive(_14); // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:15
103+
_14 = &_1; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:15
104+
StorageLive(_15); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
105+
StorageLive(_16); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
106+
_16 = &(*_2); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
107+
_15 = &_16; // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
108+
- _13 = <&&&i32 as PartialOrd>::le(move _14, move _15) -> bb4; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:20
109+
+ _13 = <i32 as PartialOrd>::le((*(*(*_14))), (*(*(*_15)))) -> bb4; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:20
110+
// mir::Constant
111+
// + span: $DIR/simplify_cmp.rs:21:14: 21:20
112+
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::le}, val: Value(<ZST>) }
113+
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::le}, val: Value(<ZST>) }
114+
}
115+
116+
bb4: {
117+
StorageDead(_15); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
118+
StorageDead(_14); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
119+
StorageDead(_16); // scope 3 at $DIR/simplify_cmp.rs:+4:20: +4:21
120+
StorageLive(_17); // scope 4 at $DIR/simplify_cmp.rs:+5:9: +5:11
121+
StorageLive(_18); // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:15
122+
_18 = &_1; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:15
123+
StorageLive(_19); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
124+
StorageLive(_20); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
125+
_20 = &(*_2); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
126+
_19 = &_20; // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
127+
- _17 = <&&&i32 as PartialOrd>::gt(move _18, move _19) -> bb5; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:19
128+
+ _17 = <i32 as PartialOrd>::gt((*(*(*_18))), (*(*(*_19)))) -> bb5; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:19
129+
// mir::Constant
130+
// + span: $DIR/simplify_cmp.rs:22:14: 22:19
131+
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::gt}, val: Value(<ZST>) }
132+
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::gt}, val: Value(<ZST>) }
133+
}
134+
135+
bb5: {
136+
StorageDead(_19); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
137+
StorageDead(_18); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
138+
StorageDead(_20); // scope 4 at $DIR/simplify_cmp.rs:+5:19: +5:20
139+
StorageLive(_21); // scope 5 at $DIR/simplify_cmp.rs:+6:9: +6:11
140+
StorageLive(_22); // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:15
141+
_22 = &_1; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:15
142+
StorageLive(_23); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
143+
StorageLive(_24); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
144+
_24 = &(*_2); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
145+
_23 = &_24; // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
146+
- _21 = <&&&i32 as PartialOrd>::ge(move _22, move _23) -> bb6; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:20
147+
+ _21 = <i32 as PartialOrd>::ge((*(*(*_22))), (*(*(*_23)))) -> bb6; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:20
148+
// mir::Constant
149+
// + span: $DIR/simplify_cmp.rs:23:14: 23:20
150+
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::ge}, val: Value(<ZST>) }
151+
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::ge}, val: Value(<ZST>) }
152+
}
153+
154+
bb6: {
155+
StorageDead(_23); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
156+
StorageDead(_22); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
157+
StorageDead(_24); // scope 5 at $DIR/simplify_cmp.rs:+6:20: +6:21
158+
_0 = const (); // scope 0 at $DIR/simplify_cmp.rs:+0:45: +7:2
159+
StorageDead(_21); // scope 5 at $DIR/simplify_cmp.rs:+7:1: +7:2
160+
StorageDead(_17); // scope 4 at $DIR/simplify_cmp.rs:+7:1: +7:2
161+
StorageDead(_13); // scope 3 at $DIR/simplify_cmp.rs:+7:1: +7:2
162+
StorageDead(_9); // scope 2 at $DIR/simplify_cmp.rs:+7:1: +7:2
163+
StorageDead(_6); // scope 1 at $DIR/simplify_cmp.rs:+7:1: +7:2
164+
StorageDead(_3); // scope 0 at $DIR/simplify_cmp.rs:+7:1: +7:2
165+
return; // scope 0 at $DIR/simplify_cmp.rs:+7:2: +7:2
166+
}
167+
}
168+

0 commit comments

Comments
 (0)