Skip to content

Commit 8e05ab0

Browse files
committed
Run SROA to fixpoint.
1 parent 42c9514 commit 8e05ab0

File tree

4 files changed

+78
-67
lines changed

4 files changed

+78
-67
lines changed

compiler/rustc_mir_dataflow/src/value_analysis.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ pub fn iter_fields<'tcx>(
824824
}
825825

826826
/// Returns all locals with projections that have their reference or address taken.
827-
fn excluded_locals(body: &Body<'_>) -> IndexVec<Local, bool> {
827+
pub fn excluded_locals(body: &Body<'_>) -> IndexVec<Local, bool> {
828828
struct Collector {
829829
result: IndexVec<Local, bool>,
830830
}

compiler/rustc_mir_transform/src/sroa.rs

+33-41
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_middle::mir::patch::MirPatch;
66
use rustc_middle::mir::visit::*;
77
use rustc_middle::mir::*;
88
use rustc_middle::ty::TyCtxt;
9-
use rustc_mir_dataflow::value_analysis::iter_fields;
9+
use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
1010

1111
pub struct ScalarReplacementOfAggregates;
1212

@@ -18,26 +18,38 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
1818
#[instrument(level = "debug", skip(self, tcx, body))]
1919
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
2020
debug!(def_id = ?body.source.def_id());
21-
let escaping = escaping_locals(&*body);
22-
debug!(?escaping);
23-
let replacements = compute_flattening(tcx, body, escaping);
24-
debug!(?replacements);
25-
replace_flattened_locals(tcx, body, replacements);
21+
let mut excluded = excluded_locals(body);
22+
loop {
23+
debug!(?excluded);
24+
let escaping = escaping_locals(&excluded, body);
25+
debug!(?escaping);
26+
let replacements = compute_flattening(tcx, body, escaping);
27+
debug!(?replacements);
28+
let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
29+
if !all_dead_locals.is_empty() && tcx.sess.mir_opt_level() >= 4 {
30+
for local in excluded.indices() {
31+
excluded[local] |= all_dead_locals.contains(local) ;
32+
}
33+
excluded.raw.resize(body.local_decls.len(), false);
34+
} else {
35+
break
36+
}
37+
}
2638
}
2739
}
2840

2941
/// Identify all locals that are not eligible for SROA.
3042
///
3143
/// There are 3 cases:
32-
/// - the aggegated local is used or passed to other code (function parameters and arguments);
44+
/// - the aggregated local is used or passed to other code (function parameters and arguments);
3345
/// - the locals is a union or an enum;
3446
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
3547
/// client code.
36-
fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
48+
fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
3749
let mut set = BitSet::new_empty(body.local_decls.len());
3850
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
3951
for (local, decl) in body.local_decls().iter_enumerated() {
40-
if decl.ty.is_union() || decl.ty.is_enum() {
52+
if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
4153
set.insert(local);
4254
}
4355
}
@@ -62,17 +74,6 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
6274
self.super_place(place, context, location);
6375
}
6476

65-
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
66-
if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
67-
if !place.is_indirect() {
68-
// Raw pointers may be used to access anything inside the enclosing place.
69-
self.set.insert(place.local);
70-
return;
71-
}
72-
}
73-
self.super_rvalue(rvalue, location)
74-
}
75-
7677
fn visit_assign(
7778
&mut self,
7879
lvalue: &Place<'tcx>,
@@ -102,21 +103,6 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
102103
}
103104
}
104105

105-
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
106-
// Drop implicitly calls `drop_in_place`, which takes a `&mut`.
107-
// This implies that `Drop` implicitly takes the address of the place.
108-
if let TerminatorKind::Drop { place, .. }
109-
| TerminatorKind::DropAndReplace { place, .. } = terminator.kind
110-
{
111-
if !place.is_indirect() {
112-
// Raw pointers may be used to access anything inside the enclosing place.
113-
self.set.insert(place.local);
114-
return;
115-
}
116-
}
117-
self.super_terminator(terminator, location);
118-
}
119-
120106
// We ignore anything that happens in debuginfo, since we expand it using
121107
// `VarDebugInfoContents::Composite`.
122108
fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
@@ -198,14 +184,14 @@ fn replace_flattened_locals<'tcx>(
198184
tcx: TyCtxt<'tcx>,
199185
body: &mut Body<'tcx>,
200186
replacements: ReplacementMap<'tcx>,
201-
) {
187+
) -> BitSet<Local> {
202188
let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
203189
for p in replacements.fields.keys() {
204190
all_dead_locals.insert(p.local);
205191
}
206192
debug!(?all_dead_locals);
207193
if all_dead_locals.is_empty() {
208-
return;
194+
return all_dead_locals;
209195
}
210196

211197
let mut visitor = ReplacementVisitor {
@@ -227,7 +213,9 @@ fn replace_flattened_locals<'tcx>(
227213
for var_debug_info in &mut body.var_debug_info {
228214
visitor.visit_var_debug_info(var_debug_info);
229215
}
230-
visitor.patch.apply(body);
216+
let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
217+
patch.apply(body);
218+
all_dead_locals
231219
}
232220

233221
struct ReplacementVisitor<'tcx, 'll> {
@@ -361,6 +349,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
361349
}
362350
}
363351

352+
#[instrument(level = "trace", skip(self))]
364353
fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
365354
match &mut var_debug_info.value {
366355
VarDebugInfoContents::Place(ref mut place) => {
@@ -375,11 +364,12 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
375364
}
376365
VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
377366
let mut new_fragments = Vec::new();
367+
debug!(?fragments);
378368
fragments
379369
.drain_filter(|fragment| {
380370
if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
381371
fragment.contents = repl;
382-
true
372+
false
383373
} else if let Some(frg) = self
384374
.replacements
385375
.gather_debug_info_fragments(fragment.contents.as_ref())
@@ -388,12 +378,14 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
388378
f.projection.splice(0..0, fragment.projection.iter().copied());
389379
f
390380
}));
391-
false
392-
} else {
393381
true
382+
} else {
383+
false
394384
}
395385
})
396386
.for_each(drop);
387+
debug!(?fragments);
388+
debug!(?new_fragments);
397389
fragments.extend(new_fragments);
398390
}
399391
VarDebugInfoContents::Const(_) => {}

tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff

+12-15
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,27 @@
33

44
fn main() -> () {
55
let mut _0: (); // return place in scope 0 at $DIR/mutable_variable_aggregate.rs:+0:11: +0:11
6-
let mut _1: (i32, i32); // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
6+
let mut _3: i32; // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
7+
let mut _4: i32; // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
78
scope 1 {
8-
debug x => _1; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
9+
debug x => (i32, i32){ .0 => _3, .1 => _4, }; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
10+
let _1: i32; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
911
let _2: i32; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
10-
let _3: i32; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
1112
scope 2 {
12-
debug y => (i32, i32){ .0 => _2, .1 => _3, }; // in scope 2 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
13+
debug y => (i32, i32){ .0 => _3, .1 => _2, }; // in scope 2 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
1314
}
1415
}
1516

1617
bb0: {
17-
StorageLive(_1); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
18-
- _1 = (const 42_i32, const 43_i32); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25
19-
+ _1 = const (42_i32, 43_i32); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25
20-
(_1.1: i32) = const 99_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+2:5: +2:13
18+
StorageLive(_4); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
19+
_3 = const 42_i32; // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25
20+
_4 = const 43_i32; // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25
21+
_4 = const 99_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+2:5: +2:13
2122
StorageLive(_2); // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
22-
StorageLive(_3); // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
23-
- _2 = (_1.0: i32); // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
24-
- _3 = (_1.1: i32); // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
25-
+ _2 = const 42_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
26-
+ _3 = const 99_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
23+
- _2 = _4; // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
24+
+ _2 = const 99_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
2725
StorageDead(_2); // scope 1 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2
28-
StorageDead(_3); // scope 1 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2
29-
StorageDead(_1); // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2
26+
StorageDead(_4); // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2
3027
return; // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:2: +4:2
3128
}
3229
}

tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff

+32-10
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
debug x => _1; // in scope 0 at $DIR/sroa.rs:+0:11: +0:12
66
let mut _0: (); // return place in scope 0 at $DIR/sroa.rs:+0:19: +0:19
77
let _2: Foo; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
8+
+ let _11: u8; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
9+
+ let _12: (); // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
10+
+ let _13: &str; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
11+
+ let _14: std::option::Option<isize>; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
812
scope 1 {
9-
debug y => _2; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
13+
- debug y => _2; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
14+
+ debug y => Foo{ .0 => _11, .1 => _12, .2 => _13, .3 => _14, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
1015
let _3: u8; // in scope 1 at $DIR/sroa.rs:+2:9: +2:10
1116
scope 2 {
1217
debug t => _3; // in scope 2 at $DIR/sroa.rs:+2:9: +2:10
@@ -31,23 +36,35 @@
3136
}
3237

3338
bb0: {
34-
StorageLive(_2); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
35-
_2 = _1; // scope 0 at $DIR/sroa.rs:+1:13: +1:14
39+
- StorageLive(_2); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
40+
- _2 = _1; // scope 0 at $DIR/sroa.rs:+1:13: +1:14
41+
+ StorageLive(_11); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
42+
+ StorageLive(_12); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
43+
+ StorageLive(_13); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
44+
+ StorageLive(_14); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
45+
+ nop; // scope 0 at $DIR/sroa.rs:+1:9: +1:10
46+
+ _11 = (_1.0: u8); // scope 0 at $DIR/sroa.rs:+1:13: +1:14
47+
+ _12 = (_1.1: ()); // scope 0 at $DIR/sroa.rs:+1:13: +1:14
48+
+ _13 = (_1.2: &str); // scope 0 at $DIR/sroa.rs:+1:13: +1:14
49+
+ _14 = (_1.3: std::option::Option<isize>); // scope 0 at $DIR/sroa.rs:+1:13: +1:14
50+
+ nop; // scope 0 at $DIR/sroa.rs:+1:13: +1:14
3651
StorageLive(_3); // scope 1 at $DIR/sroa.rs:+2:9: +2:10
37-
_3 = (_2.0: u8); // scope 1 at $DIR/sroa.rs:+2:13: +2:16
52+
- _3 = (_2.0: u8); // scope 1 at $DIR/sroa.rs:+2:13: +2:16
53+
+ _3 = _11; // scope 1 at $DIR/sroa.rs:+2:13: +2:16
3854
StorageLive(_4); // scope 2 at $DIR/sroa.rs:+3:9: +3:10
39-
_4 = (_2.2: &str); // scope 2 at $DIR/sroa.rs:+3:13: +3:16
55+
- _4 = (_2.2: &str); // scope 2 at $DIR/sroa.rs:+3:13: +3:16
4056
- StorageLive(_5); // scope 3 at $DIR/sroa.rs:+4:9: +4:10
4157
- _5 = _2; // scope 3 at $DIR/sroa.rs:+4:13: +4:14
58+
+ _4 = _13; // scope 2 at $DIR/sroa.rs:+3:13: +3:16
4259
+ StorageLive(_7); // scope 3 at $DIR/sroa.rs:+4:9: +4:10
4360
+ StorageLive(_8); // scope 3 at $DIR/sroa.rs:+4:9: +4:10
4461
+ StorageLive(_9); // scope 3 at $DIR/sroa.rs:+4:9: +4:10
4562
+ StorageLive(_10); // scope 3 at $DIR/sroa.rs:+4:9: +4:10
4663
+ nop; // scope 3 at $DIR/sroa.rs:+4:9: +4:10
47-
+ _7 = (_2.0: u8); // scope 3 at $DIR/sroa.rs:+4:13: +4:14
48-
+ _8 = (_2.1: ()); // scope 3 at $DIR/sroa.rs:+4:13: +4:14
49-
+ _9 = (_2.2: &str); // scope 3 at $DIR/sroa.rs:+4:13: +4:14
50-
+ _10 = (_2.3: std::option::Option<isize>); // scope 3 at $DIR/sroa.rs:+4:13: +4:14
64+
+ _7 = _11; // scope 3 at $DIR/sroa.rs:+4:13: +4:14
65+
+ _8 = _12; // scope 3 at $DIR/sroa.rs:+4:13: +4:14
66+
+ _9 = _13; // scope 3 at $DIR/sroa.rs:+4:13: +4:14
67+
+ _10 = _14; // scope 3 at $DIR/sroa.rs:+4:13: +4:14
5168
+ nop; // scope 3 at $DIR/sroa.rs:+4:13: +4:14
5269
StorageLive(_6); // scope 4 at $DIR/sroa.rs:+5:9: +5:10
5370
- _6 = (_5.1: ()); // scope 4 at $DIR/sroa.rs:+5:13: +5:16
@@ -62,7 +79,12 @@
6279
+ nop; // scope 3 at $DIR/sroa.rs:+6:1: +6:2
6380
StorageDead(_4); // scope 2 at $DIR/sroa.rs:+6:1: +6:2
6481
StorageDead(_3); // scope 1 at $DIR/sroa.rs:+6:1: +6:2
65-
StorageDead(_2); // scope 0 at $DIR/sroa.rs:+6:1: +6:2
82+
- StorageDead(_2); // scope 0 at $DIR/sroa.rs:+6:1: +6:2
83+
+ StorageDead(_11); // scope 0 at $DIR/sroa.rs:+6:1: +6:2
84+
+ StorageDead(_12); // scope 0 at $DIR/sroa.rs:+6:1: +6:2
85+
+ StorageDead(_13); // scope 0 at $DIR/sroa.rs:+6:1: +6:2
86+
+ StorageDead(_14); // scope 0 at $DIR/sroa.rs:+6:1: +6:2
87+
+ nop; // scope 0 at $DIR/sroa.rs:+6:1: +6:2
6688
return; // scope 0 at $DIR/sroa.rs:+6:2: +6:2
6789
}
6890
}

0 commit comments

Comments
 (0)