Skip to content

Commit dc4fe8e

Browse files
committed
Make SROA expand assignments.
1 parent 0843acb commit dc4fe8e

6 files changed

+188
-32
lines changed

compiler/rustc_mir_transform/src/sroa.rs

+65-23
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,15 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
7878
rvalue: &Rvalue<'tcx>,
7979
location: Location,
8080
) {
81-
if lvalue.as_local().is_some() && let Rvalue::Aggregate(..) = rvalue {
82-
// Aggregate assignments are expanded in run_pass.
83-
self.visit_rvalue(rvalue, location);
84-
return;
81+
if lvalue.as_local().is_some() {
82+
match rvalue {
83+
// Aggregate assignments are expanded in run_pass.
84+
Rvalue::Aggregate(..) | Rvalue::Use(..) => {
85+
self.visit_rvalue(rvalue, location);
86+
return;
87+
}
88+
_ => {}
89+
}
8590
}
8691
self.super_assign(lvalue, rvalue, location)
8792
}
@@ -195,10 +200,9 @@ fn replace_flattened_locals<'tcx>(
195200
return;
196201
}
197202

198-
let mut fragments = IndexVec::new();
203+
let mut fragments = IndexVec::<_, Option<Vec<_>>>::from_elem(None, &body.local_decls);
199204
for (k, v) in &replacements.fields {
200-
fragments.ensure_contains_elem(k.local, || Vec::new());
201-
fragments[k.local].push((k.projection, *v));
205+
fragments[k.local].get_or_insert_default().push((k.projection, *v));
202206
}
203207
debug!(?fragments);
204208

@@ -235,17 +239,17 @@ struct ReplacementVisitor<'tcx, 'll> {
235239
all_dead_locals: BitSet<Local>,
236240
/// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
237241
/// and deinit statement and debuginfo.
238-
fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
242+
fragments: IndexVec<Local, Option<Vec<(&'tcx [PlaceElem<'tcx>], Local)>>>,
239243
patch: MirPatch<'tcx>,
240244
}
241245

242246
impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
243247
fn gather_debug_info_fragments(
244248
&self,
245249
place: PlaceRef<'tcx>,
246-
) -> Vec<VarDebugInfoFragment<'tcx>> {
250+
) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
247251
let mut fragments = Vec::new();
248-
let parts = &self.fragments[place.local];
252+
let Some(parts) = &self.fragments[place.local] else { return None };
249253
for (proj, replacement_local) in parts {
250254
if proj.starts_with(place.projection) {
251255
fragments.push(VarDebugInfoFragment {
@@ -254,7 +258,7 @@ impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
254258
});
255259
}
256260
}
257-
fragments
261+
Some(fragments)
258262
}
259263

260264
fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
@@ -276,8 +280,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
276280
fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
277281
match statement.kind {
278282
StatementKind::StorageLive(l) => {
279-
if self.all_dead_locals.contains(l) {
280-
let final_locals = &self.fragments[l];
283+
if let Some(final_locals) = &self.fragments[l] {
281284
for &(_, fl) in final_locals {
282285
self.patch.add_statement(location, StatementKind::StorageLive(fl));
283286
}
@@ -286,8 +289,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
286289
return;
287290
}
288291
StatementKind::StorageDead(l) => {
289-
if self.all_dead_locals.contains(l) {
290-
let final_locals = &self.fragments[l];
292+
if let Some(final_locals) = &self.fragments[l] {
291293
for &(_, fl) in final_locals {
292294
self.patch.add_statement(location, StatementKind::StorageDead(fl));
293295
}
@@ -297,9 +299,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
297299
}
298300
StatementKind::Deinit(box ref place) => {
299301
if let Some(local) = place.as_local()
300-
&& self.all_dead_locals.contains(local)
302+
&& let Some(final_locals) = &self.fragments[local]
301303
{
302-
let final_locals = &self.fragments[local];
303304
for &(_, fl) in final_locals {
304305
self.patch.add_statement(
305306
location,
@@ -313,9 +314,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
313314

314315
StatementKind::Assign(box (ref place, Rvalue::Aggregate(_, ref operands))) => {
315316
if let Some(local) = place.as_local()
316-
&& self.all_dead_locals.contains(local)
317+
&& let Some(final_locals) = &self.fragments[local]
317318
{
318-
let final_locals = &self.fragments[local];
319319
for &(projection, fl) in final_locals {
320320
let &[PlaceElem::Field(index, _)] = projection else { bug!() };
321321
let index = index.as_usize();
@@ -330,6 +330,48 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
330330
}
331331
}
332332

333+
StatementKind::Assign(box (ref place, Rvalue::Use(Operand::Constant(_)))) => {
334+
if let Some(local) = place.as_local()
335+
&& let Some(final_locals) = &self.fragments[local]
336+
{
337+
for &(projection, fl) in final_locals {
338+
let rvalue = Rvalue::Use(Operand::Move(place.project_deeper(projection, self.tcx)));
339+
self.patch.add_statement(
340+
location,
341+
StatementKind::Assign(Box::new((fl.into(), rvalue))),
342+
);
343+
}
344+
self.all_dead_locals.remove(local);
345+
return;
346+
}
347+
}
348+
349+
StatementKind::Assign(box (ref lhs, Rvalue::Use(ref op))) => {
350+
let (rplace, copy) = match op {
351+
Operand::Copy(rplace) => (rplace, true),
352+
Operand::Move(rplace) => (rplace, false),
353+
Operand::Constant(_) => bug!(),
354+
};
355+
if let Some(local) = lhs.as_local()
356+
&& let Some(final_locals) = &self.fragments[local]
357+
{
358+
for &(projection, fl) in final_locals {
359+
let rplace = rplace.project_deeper(projection, self.tcx);
360+
let rvalue = if copy {
361+
Rvalue::Use(Operand::Copy(rplace))
362+
} else {
363+
Rvalue::Use(Operand::Move(rplace))
364+
};
365+
self.patch.add_statement(
366+
location,
367+
StatementKind::Assign(Box::new((fl.into(), rvalue))),
368+
);
369+
}
370+
statement.make_nop();
371+
return;
372+
}
373+
}
374+
333375
_ => {}
334376
}
335377
self.super_statement(statement, location)
@@ -348,9 +390,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
348390
VarDebugInfoContents::Place(ref mut place) => {
349391
if let Some(repl) = self.replace_place(place.as_ref()) {
350392
*place = repl;
351-
} else if self.all_dead_locals.contains(place.local) {
393+
} else if let Some(fragments) = self.gather_debug_info_fragments(place.as_ref()) {
352394
let ty = place.ty(self.local_decls, self.tcx).ty;
353-
let fragments = self.gather_debug_info_fragments(place.as_ref());
354395
var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
355396
}
356397
}
@@ -361,8 +402,9 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
361402
if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
362403
fragment.contents = repl;
363404
true
364-
} else if self.all_dead_locals.contains(fragment.contents.local) {
365-
let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
405+
} else if let Some(frg) =
406+
self.gather_debug_info_fragments(fragment.contents.as_ref())
407+
{
366408
new_fragments.extend(frg.into_iter().map(|mut f| {
367409
f.projection.splice(0..0, fragment.projection.iter().copied());
368410
f
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
- // MIR for `copies` before ScalarReplacementOfAggregates
2+
+ // MIR for `copies` after ScalarReplacementOfAggregates
3+
4+
fn copies(_1: Foo) -> () {
5+
debug x => _1; // in scope 0 at $DIR/sroa.rs:+0:11: +0:12
6+
let mut _0: (); // return place in scope 0 at $DIR/sroa.rs:+0:19: +0:19
7+
let _2: Foo; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
8+
+ let _5: u8; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
9+
+ let _6: &str; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
10+
scope 1 {
11+
- debug y => _2; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
12+
+ debug y => Foo{ .0 => _5, .2 => _6, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
13+
let _3: u8; // in scope 1 at $DIR/sroa.rs:+2:9: +2:10
14+
scope 2 {
15+
debug t => _3; // in scope 2 at $DIR/sroa.rs:+2:9: +2:10
16+
let _4: &str; // in scope 2 at $DIR/sroa.rs:+3:9: +3:10
17+
scope 3 {
18+
debug u => _4; // in scope 3 at $DIR/sroa.rs:+3:9: +3:10
19+
}
20+
}
21+
}
22+
23+
bb0: {
24+
- StorageLive(_2); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
25+
- _2 = _1; // scope 0 at $DIR/sroa.rs:+1:13: +1:14
26+
+ StorageLive(_5); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
27+
+ StorageLive(_6); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
28+
+ nop; // scope 0 at $DIR/sroa.rs:+1:9: +1:10
29+
+ _5 = (_1.0: u8); // scope 0 at $DIR/sroa.rs:+1:13: +1:14
30+
+ _6 = (_1.2: &str); // scope 0 at $DIR/sroa.rs:+1:13: +1:14
31+
+ nop; // scope 0 at $DIR/sroa.rs:+1:13: +1:14
32+
StorageLive(_3); // scope 1 at $DIR/sroa.rs:+2:9: +2:10
33+
- _3 = (_2.0: u8); // scope 1 at $DIR/sroa.rs:+2:13: +2:16
34+
+ _3 = _5; // scope 1 at $DIR/sroa.rs:+2:13: +2:16
35+
StorageLive(_4); // scope 2 at $DIR/sroa.rs:+3:9: +3:10
36+
- _4 = (_2.2: &str); // scope 2 at $DIR/sroa.rs:+3:13: +3:16
37+
+ _4 = _6; // scope 2 at $DIR/sroa.rs:+3:13: +3:16
38+
_0 = const (); // scope 0 at $DIR/sroa.rs:+0:19: +4:2
39+
StorageDead(_4); // scope 2 at $DIR/sroa.rs:+4:1: +4:2
40+
StorageDead(_3); // scope 1 at $DIR/sroa.rs:+4:1: +4:2
41+
- StorageDead(_2); // scope 0 at $DIR/sroa.rs:+4:1: +4:2
42+
+ StorageDead(_5); // scope 0 at $DIR/sroa.rs:+4:1: +4:2
43+
+ StorageDead(_6); // scope 0 at $DIR/sroa.rs:+4:1: +4:2
44+
+ nop; // scope 0 at $DIR/sroa.rs:+4:1: +4:2
45+
return; // scope 0 at $DIR/sroa.rs:+4:2: +4:2
46+
}
47+
}
48+

tests/mir-opt/sroa.escaping.ScalarReplacementOfAggregates.diff

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
StorageLive(_5); // scope 0 at $DIR/sroa.rs:+2:34: +2:37
1818
_5 = g() -> bb1; // scope 0 at $DIR/sroa.rs:+2:34: +2:37
1919
// mir::Constant
20-
// + span: $DIR/sroa.rs:78:34: 78:35
20+
// + span: $DIR/sroa.rs:73:34: 73:35
2121
// + literal: Const { ty: fn() -> u32 {g}, val: Value(<ZST>) }
2222
}
2323

@@ -28,7 +28,7 @@
2828
_2 = &raw const (*_3); // scope 0 at $DIR/sroa.rs:+2:7: +2:41
2929
_1 = f(move _2) -> bb2; // scope 0 at $DIR/sroa.rs:+2:5: +2:42
3030
// mir::Constant
31-
// + span: $DIR/sroa.rs:78:5: 78:6
31+
// + span: $DIR/sroa.rs:73:5: 73:6
3232
// + literal: Const { ty: fn(*const u32) {f}, val: Value(<ZST>) }
3333
}
3434

tests/mir-opt/sroa.flat.ScalarReplacementOfAggregates.diff

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
+ _9 = move _6; // scope 0 at $DIR/sroa.rs:+1:30: +1:70
4646
+ _10 = const "a"; // scope 0 at $DIR/sroa.rs:+1:30: +1:70
4747
// mir::Constant
48-
// + span: $DIR/sroa.rs:57:52: 57:55
48+
// + span: $DIR/sroa.rs:53:52: 53:55
4949
// + literal: Const { ty: &str, val: Value(Slice(..)) }
5050
+ _11 = move _7; // scope 0 at $DIR/sroa.rs:+1:30: +1:70
5151
+ nop; // scope 0 at $DIR/sroa.rs:+1:30: +1:70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
- // MIR for `ref_copies` before ScalarReplacementOfAggregates
2+
+ // MIR for `ref_copies` after ScalarReplacementOfAggregates
3+
4+
fn ref_copies(_1: &Foo) -> () {
5+
debug x => _1; // in scope 0 at $DIR/sroa.rs:+0:15: +0:16
6+
let mut _0: (); // return place in scope 0 at $DIR/sroa.rs:+0:24: +0:24
7+
let _2: Foo; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
8+
+ let _5: u8; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
9+
+ let _6: &str; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
10+
scope 1 {
11+
- debug y => _2; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
12+
+ debug y => Foo{ .0 => _5, .2 => _6, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
13+
let _3: u8; // in scope 1 at $DIR/sroa.rs:+2:9: +2:10
14+
scope 2 {
15+
debug t => _3; // in scope 2 at $DIR/sroa.rs:+2:9: +2:10
16+
let _4: &str; // in scope 2 at $DIR/sroa.rs:+3:9: +3:10
17+
scope 3 {
18+
debug u => _4; // in scope 3 at $DIR/sroa.rs:+3:9: +3:10
19+
}
20+
}
21+
}
22+
23+
bb0: {
24+
- StorageLive(_2); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
25+
- _2 = (*_1); // scope 0 at $DIR/sroa.rs:+1:13: +1:15
26+
+ StorageLive(_5); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
27+
+ StorageLive(_6); // scope 0 at $DIR/sroa.rs:+1:9: +1:10
28+
+ nop; // scope 0 at $DIR/sroa.rs:+1:9: +1:10
29+
+ _5 = ((*_1).0: u8); // scope 0 at $DIR/sroa.rs:+1:13: +1:15
30+
+ _6 = ((*_1).2: &str); // scope 0 at $DIR/sroa.rs:+1:13: +1:15
31+
+ nop; // scope 0 at $DIR/sroa.rs:+1:13: +1:15
32+
StorageLive(_3); // scope 1 at $DIR/sroa.rs:+2:9: +2:10
33+
- _3 = (_2.0: u8); // scope 1 at $DIR/sroa.rs:+2:13: +2:16
34+
+ _3 = _5; // scope 1 at $DIR/sroa.rs:+2:13: +2:16
35+
StorageLive(_4); // scope 2 at $DIR/sroa.rs:+3:9: +3:10
36+
- _4 = (_2.2: &str); // scope 2 at $DIR/sroa.rs:+3:13: +3:16
37+
+ _4 = _6; // scope 2 at $DIR/sroa.rs:+3:13: +3:16
38+
_0 = const (); // scope 0 at $DIR/sroa.rs:+0:24: +4:2
39+
StorageDead(_4); // scope 2 at $DIR/sroa.rs:+4:1: +4:2
40+
StorageDead(_3); // scope 1 at $DIR/sroa.rs:+4:1: +4:2
41+
- StorageDead(_2); // scope 0 at $DIR/sroa.rs:+4:1: +4:2
42+
+ StorageDead(_5); // scope 0 at $DIR/sroa.rs:+4:1: +4:2
43+
+ StorageDead(_6); // scope 0 at $DIR/sroa.rs:+4:1: +4:2
44+
+ nop; // scope 0 at $DIR/sroa.rs:+4:1: +4:2
45+
return; // scope 0 at $DIR/sroa.rs:+4:2: +4:2
46+
}
47+
}
48+

tests/mir-opt/sroa.rs

+24-6
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,14 @@ impl Drop for Tag {
1212
fn drop(&mut self) {}
1313
}
1414

15-
// EMIT_MIR sroa.dropping.ScalarReplacementOfAggregates.diff
1615
pub fn dropping() {
1716
S(Tag(0), Tag(1), Tag(2)).1;
1817
}
1918

20-
// EMIT_MIR sroa.enums.ScalarReplacementOfAggregates.diff
2119
pub fn enums(a: usize) -> usize {
2220
if let Some(a) = Some(a) { a } else { 0 }
2321
}
2422

25-
// EMIT_MIR sroa.structs.ScalarReplacementOfAggregates.diff
2623
pub fn structs(a: f32) -> f32 {
2724
struct U {
2825
_foo: usize,
@@ -32,7 +29,6 @@ pub fn structs(a: f32) -> f32 {
3229
U { _foo: 0, a }.a
3330
}
3431

35-
// EMIT_MIR sroa.unions.ScalarReplacementOfAggregates.diff
3632
pub fn unions(a: f32) -> u32 {
3733
union Repr {
3834
f: f32,
@@ -41,6 +37,7 @@ pub fn unions(a: f32) -> u32 {
4137
unsafe { Repr { f: a }.u }
4238
}
4339

40+
#[derive(Copy, Clone)]
4441
struct Foo {
4542
a: u8,
4643
b: (),
@@ -52,7 +49,6 @@ fn g() -> u32 {
5249
3
5350
}
5451

55-
// EMIT_MIR sroa.flat.ScalarReplacementOfAggregates.diff
5652
pub fn flat() {
5753
let Foo { a, b, c, d } = Foo { a: 5, b: (), c: "a", d: Some(-4) };
5854
let _ = a;
@@ -72,17 +68,39 @@ fn f(a: *const u32) {
7268
println!("{}", unsafe { *a.add(2) });
7369
}
7470

75-
// EMIT_MIR sroa.escaping.ScalarReplacementOfAggregates.diff
7671
pub fn escaping() {
7772
// Verify this struct is not flattened.
7873
f(&Escaping { a: 1, b: 2, c: g() }.a);
7974
}
8075

76+
fn copies(x: Foo) {
77+
let y = x;
78+
let t = y.a;
79+
let u = y.c;
80+
}
81+
82+
fn ref_copies(x: &Foo) {
83+
let y = *x;
84+
let t = y.a;
85+
let u = y.c;
86+
}
87+
8188
fn main() {
8289
dropping();
8390
enums(5);
8491
structs(5.);
8592
unions(5.);
8693
flat();
8794
escaping();
95+
copies(Foo { a: 5, b: (), c: "a", d: Some(-4) });
96+
ref_copies(&Foo { a: 5, b: (), c: "a", d: Some(-4) });
8897
}
98+
99+
// EMIT_MIR sroa.dropping.ScalarReplacementOfAggregates.diff
100+
// EMIT_MIR sroa.enums.ScalarReplacementOfAggregates.diff
101+
// EMIT_MIR sroa.structs.ScalarReplacementOfAggregates.diff
102+
// EMIT_MIR sroa.unions.ScalarReplacementOfAggregates.diff
103+
// EMIT_MIR sroa.flat.ScalarReplacementOfAggregates.diff
104+
// EMIT_MIR sroa.escaping.ScalarReplacementOfAggregates.diff
105+
// EMIT_MIR sroa.copies.ScalarReplacementOfAggregates.diff
106+
// EMIT_MIR sroa.ref_copies.ScalarReplacementOfAggregates.diff

0 commit comments

Comments
 (0)