@@ -6,7 +6,7 @@ use rustc_middle::mir::patch::MirPatch;
6
6
use rustc_middle:: mir:: visit:: * ;
7
7
use rustc_middle:: mir:: * ;
8
8
use rustc_middle:: ty:: TyCtxt ;
9
- use rustc_mir_dataflow:: value_analysis:: iter_fields;
9
+ use rustc_mir_dataflow:: value_analysis:: { excluded_locals , iter_fields} ;
10
10
11
11
pub struct ScalarReplacementOfAggregates ;
12
12
@@ -18,26 +18,38 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
18
18
#[ instrument( level = "debug" , skip( self , tcx, body) ) ]
19
19
fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
20
20
debug ! ( def_id = ?body. source. def_id( ) ) ;
21
- let escaping = escaping_locals ( & * body) ;
22
- debug ! ( ?escaping) ;
23
- let replacements = compute_flattening ( tcx, body, escaping) ;
24
- debug ! ( ?replacements) ;
25
- replace_flattened_locals ( tcx, body, replacements) ;
21
+ let mut excluded = excluded_locals ( body) ;
22
+ loop {
23
+ debug ! ( ?excluded) ;
24
+ let escaping = escaping_locals ( & excluded, body) ;
25
+ debug ! ( ?escaping) ;
26
+ let replacements = compute_flattening ( tcx, body, escaping) ;
27
+ debug ! ( ?replacements) ;
28
+ let all_dead_locals = replace_flattened_locals ( tcx, body, replacements) ;
29
+ if !all_dead_locals. is_empty ( ) && tcx. sess . mir_opt_level ( ) >= 4 {
30
+ for local in excluded. indices ( ) {
31
+ excluded[ local] |= all_dead_locals. contains ( local) ;
32
+ }
33
+ excluded. raw . resize ( body. local_decls . len ( ) , false ) ;
34
+ } else {
35
+ break
36
+ }
37
+ }
26
38
}
27
39
}
28
40
29
41
/// Identify all locals that are not eligible for SROA.
30
42
///
31
43
/// There are 3 cases:
32
- /// - the aggegated local is used or passed to other code (function parameters and arguments);
44
+ /// - the aggregated local is used or passed to other code (function parameters and arguments);
33
45
/// - the locals is a union or an enum;
34
46
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
35
47
/// client code.
36
- fn escaping_locals ( body : & Body < ' _ > ) -> BitSet < Local > {
48
+ fn escaping_locals ( excluded : & IndexVec < Local , bool > , body : & Body < ' _ > ) -> BitSet < Local > {
37
49
let mut set = BitSet :: new_empty ( body. local_decls . len ( ) ) ;
38
50
set. insert_range ( RETURN_PLACE ..=Local :: from_usize ( body. arg_count ) ) ;
39
51
for ( local, decl) in body. local_decls ( ) . iter_enumerated ( ) {
40
- if decl. ty . is_union ( ) || decl. ty . is_enum ( ) {
52
+ if decl. ty . is_union ( ) || decl. ty . is_enum ( ) || excluded [ local ] {
41
53
set. insert ( local) ;
42
54
}
43
55
}
@@ -62,17 +74,6 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
62
74
self . super_place ( place, context, location) ;
63
75
}
64
76
65
- fn visit_rvalue ( & mut self , rvalue : & Rvalue < ' tcx > , location : Location ) {
66
- if let Rvalue :: AddressOf ( .., place) | Rvalue :: Ref ( .., place) = rvalue {
67
- if !place. is_indirect ( ) {
68
- // Raw pointers may be used to access anything inside the enclosing place.
69
- self . set . insert ( place. local ) ;
70
- return ;
71
- }
72
- }
73
- self . super_rvalue ( rvalue, location)
74
- }
75
-
76
77
fn visit_assign (
77
78
& mut self ,
78
79
lvalue : & Place < ' tcx > ,
@@ -102,21 +103,6 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
102
103
}
103
104
}
104
105
105
- fn visit_terminator ( & mut self , terminator : & Terminator < ' tcx > , location : Location ) {
106
- // Drop implicitly calls `drop_in_place`, which takes a `&mut`.
107
- // This implies that `Drop` implicitly takes the address of the place.
108
- if let TerminatorKind :: Drop { place, .. }
109
- | TerminatorKind :: DropAndReplace { place, .. } = terminator. kind
110
- {
111
- if !place. is_indirect ( ) {
112
- // Raw pointers may be used to access anything inside the enclosing place.
113
- self . set . insert ( place. local ) ;
114
- return ;
115
- }
116
- }
117
- self . super_terminator ( terminator, location) ;
118
- }
119
-
120
106
// We ignore anything that happens in debuginfo, since we expand it using
121
107
// `VarDebugInfoContents::Composite`.
122
108
fn visit_var_debug_info ( & mut self , _: & VarDebugInfo < ' tcx > ) { }
@@ -198,14 +184,14 @@ fn replace_flattened_locals<'tcx>(
198
184
tcx : TyCtxt < ' tcx > ,
199
185
body : & mut Body < ' tcx > ,
200
186
replacements : ReplacementMap < ' tcx > ,
201
- ) {
187
+ ) -> BitSet < Local > {
202
188
let mut all_dead_locals = BitSet :: new_empty ( body. local_decls . len ( ) ) ;
203
189
for p in replacements. fields . keys ( ) {
204
190
all_dead_locals. insert ( p. local ) ;
205
191
}
206
192
debug ! ( ?all_dead_locals) ;
207
193
if all_dead_locals. is_empty ( ) {
208
- return ;
194
+ return all_dead_locals ;
209
195
}
210
196
211
197
let mut visitor = ReplacementVisitor {
@@ -227,7 +213,9 @@ fn replace_flattened_locals<'tcx>(
227
213
for var_debug_info in & mut body. var_debug_info {
228
214
visitor. visit_var_debug_info ( var_debug_info) ;
229
215
}
230
- visitor. patch . apply ( body) ;
216
+ let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
217
+ patch. apply ( body) ;
218
+ all_dead_locals
231
219
}
232
220
233
221
struct ReplacementVisitor < ' tcx , ' ll > {
@@ -361,6 +349,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
361
349
}
362
350
}
363
351
352
+ #[ instrument( level = "trace" , skip( self ) ) ]
364
353
fn visit_var_debug_info ( & mut self , var_debug_info : & mut VarDebugInfo < ' tcx > ) {
365
354
match & mut var_debug_info. value {
366
355
VarDebugInfoContents :: Place ( ref mut place) => {
@@ -375,11 +364,12 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
375
364
}
376
365
VarDebugInfoContents :: Composite { ty : _, ref mut fragments } => {
377
366
let mut new_fragments = Vec :: new ( ) ;
367
+ debug ! ( ?fragments) ;
378
368
fragments
379
369
. drain_filter ( |fragment| {
380
370
if let Some ( repl) = self . replace_place ( fragment. contents . as_ref ( ) ) {
381
371
fragment. contents = repl;
382
- true
372
+ false
383
373
} else if let Some ( frg) = self
384
374
. replacements
385
375
. gather_debug_info_fragments ( fragment. contents . as_ref ( ) )
@@ -388,12 +378,14 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
388
378
f. projection . splice ( 0 ..0 , fragment. projection . iter ( ) . copied ( ) ) ;
389
379
f
390
380
} ) ) ;
391
- false
392
- } else {
393
381
true
382
+ } else {
383
+ false
394
384
}
395
385
} )
396
386
. for_each ( drop) ;
387
+ debug ! ( ?fragments) ;
388
+ debug ! ( ?new_fragments) ;
397
389
fragments. extend ( new_fragments) ;
398
390
}
399
391
VarDebugInfoContents :: Const ( _) => { }
0 commit comments