@@ -14,6 +14,7 @@ use super::{apply_rewrite_rules, id};
14
14
use rspirv:: dr:: { Block , Function , Instruction , ModuleHeader , Operand } ;
15
15
use rspirv:: spirv:: { Op , Word } ;
16
16
use rustc_data_structures:: fx:: { FxHashMap , FxHashSet } ;
17
+ use rustc_middle:: bug;
17
18
use std:: collections:: hash_map;
18
19
19
20
pub fn mem2reg (
@@ -27,14 +28,21 @@ pub fn mem2reg(
27
28
let preds = compute_preds ( & func. blocks , & reachable) ;
28
29
let idom = compute_idom ( & preds, & reachable) ;
29
30
let dominance_frontier = compute_dominance_frontier ( & preds, & idom) ;
30
- while insert_phis_all (
31
- header,
32
- types_global_values,
33
- pointer_to_pointee,
34
- constants,
35
- & mut func. blocks ,
36
- & dominance_frontier,
37
- ) { }
31
+ loop {
32
+ let changed = insert_phis_all (
33
+ header,
34
+ types_global_values,
35
+ pointer_to_pointee,
36
+ constants,
37
+ & mut func. blocks ,
38
+ & dominance_frontier,
39
+ ) ;
40
+ if !changed {
41
+ break ;
42
+ }
43
+ // mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
44
+ super :: dce:: dce_phi ( func) ;
45
+ }
38
46
}
39
47
40
48
fn label_to_index ( blocks : & [ Block ] , id : Word ) -> usize {
@@ -171,6 +179,9 @@ fn insert_phis_all(
171
179
if var_maps_and_types. is_empty ( ) {
172
180
return false ;
173
181
}
182
+ for ( var_map, _) in & var_maps_and_types {
183
+ split_copy_memory ( header, blocks, var_map) ;
184
+ }
174
185
for & ( ref var_map, base_var_type) in & var_maps_and_types {
175
186
let blocks_with_phi = insert_phis ( blocks, dominance_frontier, var_map) ;
176
187
let mut renamer = Renamer {
@@ -244,7 +255,10 @@ fn collect_access_chains(
244
255
match inst. class . opcode {
245
256
// Only allow store if pointer is the lhs, not rhs
246
257
Op :: Store if index == 0 => { }
247
- Op :: Load | Op :: AccessChain | Op :: InBoundsAccessChain => { }
258
+ Op :: Load
259
+ | Op :: AccessChain
260
+ | Op :: InBoundsAccessChain
261
+ | Op :: CopyMemory => { }
248
262
_ => return None ,
249
263
}
250
264
}
@@ -271,6 +285,76 @@ fn collect_access_chains(
271
285
Some ( variables)
272
286
}
273
287
288
+ // Splits an OpCopyMemory into an OpLoad followed by an OpStore. This is because we want to be able
289
+ // to mem2reg variables used in OpCopyMemory, but analysis becomes very difficult: we only analyze
290
+ // one variable at a time, but OpCopyMemory can copy between two local variables (both of which are
291
+ // getting mem2reg'd), requiring cross-analysis shenanigans. So, if we know at least one side of
292
+ // the OpCopyMemory is getting mem2reg'd, we can safely split it into a load/store pair: at least
293
+ // one side of the pair is going to evaporate in the subsequent rewrite. Then, we can only deal
294
+ // with one side of a pair at a time, treating the other side as opaque (but possibly rewriting
295
+ // both sides).
296
+ //
297
+ // This means that an OpCopyMemory between two local variables will completely disappear, while an
298
+ // OpCopyMemory from a global to a local will turn into an OpLoad, and local to global will turn
299
+ // into an OpStore.
300
+ //
301
+ // Note that while we only look at a single var map in this function, if an OpCopyMemory contains
302
+ // variables from two var maps, the second pass won't do anything since the first pass will already
303
+ // have split it (but that's fine, it would have done the same thing anyway).
304
+ //
305
+ // Finally, an edge case to keep in mind is that an OpCopyMemory can happen between two vars in the
306
+ // same var map (e.g. `s.x = s.y;`).
307
+ fn split_copy_memory (
308
+ header : & mut ModuleHeader ,
309
+ blocks : & mut [ Block ] ,
310
+ var_map : & FxHashMap < Word , VarInfo > ,
311
+ ) {
312
+ for block in blocks {
313
+ let mut inst_index = 0 ;
314
+ while inst_index < block. instructions . len ( ) {
315
+ let inst = & block. instructions [ inst_index] ;
316
+ if inst. class . opcode == Op :: CopyMemory {
317
+ let target = inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ;
318
+ let source = inst. operands [ 1 ] . id_ref_any ( ) . unwrap ( ) ;
319
+ if inst. operands . len ( ) > 2 {
320
+ // TODO: Copy the memory operands to the load/store
321
+ bug ! ( "mem2reg OpCopyMemory doesn't support memory operands yet" ) ;
322
+ }
323
+ let ty = match ( var_map. get ( & target) , var_map. get ( & source) ) {
324
+ ( None , None ) => {
325
+ inst_index += 1 ;
326
+ continue ;
327
+ }
328
+ ( Some ( target) , None ) => target. ty ,
329
+ ( None , Some ( source) ) => source. ty ,
330
+ ( Some ( target) , Some ( source) ) => {
331
+ assert_eq ! ( target. ty, source. ty) ;
332
+ target. ty
333
+ }
334
+ } ;
335
+ let temp_id = id ( header) ;
336
+ block. instructions [ inst_index] = Instruction :: new (
337
+ Op :: Load ,
338
+ Some ( ty) ,
339
+ Some ( temp_id) ,
340
+ vec ! [ Operand :: IdRef ( source) ] ,
341
+ ) ;
342
+ inst_index += 1 ;
343
+ block. instructions . insert (
344
+ inst_index,
345
+ Instruction :: new (
346
+ Op :: Store ,
347
+ None ,
348
+ None ,
349
+ vec ! [ Operand :: IdRef ( target) , Operand :: IdRef ( temp_id) ] ,
350
+ ) ,
351
+ ) ;
352
+ }
353
+ inst_index += 1 ;
354
+ }
355
+ }
356
+ }
357
+
274
358
fn has_store ( block : & Block , var_map : & FxHashMap < Word , VarInfo > ) -> bool {
275
359
block. instructions . iter ( ) . any ( |inst| {
276
360
let ptr = match inst. class . opcode {
0 commit comments