Skip to content

Commit 3f22a89

Browse files
authored
Handle OpCopyMemory in mem2reg (#772)
* Handle OpCopyMemory in mem2reg * Update tests
1 parent 8d019e4 commit 3f22a89

File tree

8 files changed

+114
-58
lines changed

8 files changed

+114
-58
lines changed

crates/rustc_codegen_spirv/src/linker/mem2reg.rs

+93-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use super::{apply_rewrite_rules, id};
1414
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
1515
use rspirv::spirv::{Op, Word};
1616
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
17+
use rustc_middle::bug;
1718
use std::collections::hash_map;
1819

1920
pub fn mem2reg(
@@ -27,14 +28,21 @@ pub fn mem2reg(
2728
let preds = compute_preds(&func.blocks, &reachable);
2829
let idom = compute_idom(&preds, &reachable);
2930
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
30-
while insert_phis_all(
31-
header,
32-
types_global_values,
33-
pointer_to_pointee,
34-
constants,
35-
&mut func.blocks,
36-
&dominance_frontier,
37-
) {}
31+
loop {
32+
let changed = insert_phis_all(
33+
header,
34+
types_global_values,
35+
pointer_to_pointee,
36+
constants,
37+
&mut func.blocks,
38+
&dominance_frontier,
39+
);
40+
if !changed {
41+
break;
42+
}
43+
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
44+
super::dce::dce_phi(func);
45+
}
3846
}
3947

4048
fn label_to_index(blocks: &[Block], id: Word) -> usize {
@@ -171,6 +179,9 @@ fn insert_phis_all(
171179
if var_maps_and_types.is_empty() {
172180
return false;
173181
}
182+
for (var_map, _) in &var_maps_and_types {
183+
split_copy_memory(header, blocks, var_map);
184+
}
174185
for &(ref var_map, base_var_type) in &var_maps_and_types {
175186
let blocks_with_phi = insert_phis(blocks, dominance_frontier, var_map);
176187
let mut renamer = Renamer {
@@ -244,7 +255,10 @@ fn collect_access_chains(
244255
match inst.class.opcode {
245256
// Only allow store if pointer is the lhs, not rhs
246257
Op::Store if index == 0 => {}
247-
Op::Load | Op::AccessChain | Op::InBoundsAccessChain => {}
258+
Op::Load
259+
| Op::AccessChain
260+
| Op::InBoundsAccessChain
261+
| Op::CopyMemory => {}
248262
_ => return None,
249263
}
250264
}
@@ -271,6 +285,76 @@ fn collect_access_chains(
271285
Some(variables)
272286
}
273287

288+
// Splits an OpCopyMemory into an OpLoad followed by an OpStore. This is because we want to be able
289+
// to mem2reg variables used in OpCopyMemory, but analysis becomes very difficult: we only analyze
290+
// one variable at a time, but OpCopyMemory can copy between two local variables (both of which are
291+
// getting mem2reg'd), requiring cross-analysis shenanigans. So, if we know at least one side of
292+
// the OpCopyMemory is getting mem2reg'd, we can safely split it into a load/store pair: at least
293+
// one side of the pair is going to evaporate in the subsequent rewrite. Then, we can only deal
294+
// with one side of a pair at a time, treating the other side as opaque (but possibly rewriting
295+
// both sides).
296+
//
297+
// This means that an OpCopyMemory between two local variables will completely disappear, while an
298+
// OpCopyMemory from a global to a local will turn into an OpLoad, and local to global will turn
299+
// into an OpStore.
300+
//
301+
// Note that while we only look at a single var map in this function, if an OpCopyMemory contains
302+
// variables from two var maps, the second pass won't do anything since the first pass will already
303+
// have split it (but that's fine, it would have done the same thing anyway).
304+
//
305+
// Finally, an edge case to keep in mind is that an OpCopyMemory can happen between two vars in the
306+
// same var map (e.g. `s.x = s.y;`).
307+
fn split_copy_memory(
308+
header: &mut ModuleHeader,
309+
blocks: &mut [Block],
310+
var_map: &FxHashMap<Word, VarInfo>,
311+
) {
312+
for block in blocks {
313+
let mut inst_index = 0;
314+
while inst_index < block.instructions.len() {
315+
let inst = &block.instructions[inst_index];
316+
if inst.class.opcode == Op::CopyMemory {
317+
let target = inst.operands[0].id_ref_any().unwrap();
318+
let source = inst.operands[1].id_ref_any().unwrap();
319+
if inst.operands.len() > 2 {
320+
// TODO: Copy the memory operands to the load/store
321+
bug!("mem2reg OpCopyMemory doesn't support memory operands yet");
322+
}
323+
let ty = match (var_map.get(&target), var_map.get(&source)) {
324+
(None, None) => {
325+
inst_index += 1;
326+
continue;
327+
}
328+
(Some(target), None) => target.ty,
329+
(None, Some(source)) => source.ty,
330+
(Some(target), Some(source)) => {
331+
assert_eq!(target.ty, source.ty);
332+
target.ty
333+
}
334+
};
335+
let temp_id = id(header);
336+
block.instructions[inst_index] = Instruction::new(
337+
Op::Load,
338+
Some(ty),
339+
Some(temp_id),
340+
vec![Operand::IdRef(source)],
341+
);
342+
inst_index += 1;
343+
block.instructions.insert(
344+
inst_index,
345+
Instruction::new(
346+
Op::Store,
347+
None,
348+
None,
349+
vec![Operand::IdRef(target), Operand::IdRef(temp_id)],
350+
),
351+
);
352+
}
353+
inst_index += 1;
354+
}
355+
}
356+
}
357+
274358
fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
275359
block.instructions.iter().any(|inst| {
276360
let ptr = match inst.class.opcode {

crates/rustc_codegen_spirv/src/linker/mod.rs

-2
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,6 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
256256
&constants,
257257
func,
258258
);
259-
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
260-
dce::dce_phi(func);
261259
destructure_composites::destructure_composites(func);
262260
}
263261
}

tests/ui/dis/ptr_read.stderr

+5-10
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@
22
%4 = OpFunctionParameter %5
33
%6 = OpFunctionParameter %5
44
%7 = OpLabel
5-
%8 = OpVariable %5 Function
6-
OpLine %9 319 5
7-
OpStore %8 %10
8-
OpLine %11 699 8
9-
OpCopyMemory %8 %4
10-
OpLine %11 700 8
11-
%12 = OpLoad %13 %8
12-
OpLine %14 7 13
13-
OpStore %6 %12
14-
OpLine %14 8 1
5+
OpLine %8 699 8
6+
%9 = OpLoad %10 %4
7+
OpLine %11 7 13
8+
OpStore %6 %9
9+
OpLine %11 8 1
1510
OpReturn
1611
OpFunctionEnd

tests/ui/dis/ptr_read_method.stderr

+5-10
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@
22
%4 = OpFunctionParameter %5
33
%6 = OpFunctionParameter %5
44
%7 = OpLabel
5-
%8 = OpVariable %5 Function
6-
OpLine %9 319 5
7-
OpStore %8 %10
8-
OpLine %11 699 8
9-
OpCopyMemory %8 %4
10-
OpLine %11 700 8
11-
%12 = OpLoad %13 %8
12-
OpLine %14 7 13
13-
OpStore %6 %12
14-
OpLine %14 8 1
5+
OpLine %8 699 8
6+
%9 = OpLoad %10 %4
7+
OpLine %11 7 13
8+
OpStore %6 %9
9+
OpLine %11 8 1
1510
OpReturn
1611
OpFunctionEnd

tests/ui/dis/ptr_write.stderr

+5-8
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
%4 = OpFunctionParameter %5
33
%6 = OpFunctionParameter %5
44
%7 = OpLabel
5-
%8 = OpVariable %5 Function
6-
OpLine %9 7 35
7-
%10 = OpLoad %11 %4
8-
OpLine %9 7 13
9-
OpStore %8 %10
10-
OpLine %12 890 8
11-
OpCopyMemory %6 %8
12-
OpLine %9 8 1
5+
OpLine %8 7 35
6+
%9 = OpLoad %10 %4
7+
OpLine %11 890 8
8+
OpStore %6 %9
9+
OpLine %8 8 1
1310
OpReturn
1411
OpFunctionEnd

tests/ui/dis/ptr_write_method.stderr

+5-8
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
%4 = OpFunctionParameter %5
33
%6 = OpFunctionParameter %5
44
%7 = OpLabel
5-
%8 = OpVariable %5 Function
6-
OpLine %9 7 37
7-
%10 = OpLoad %11 %4
8-
OpLine %12 1013 17
9-
OpStore %8 %10
10-
OpLine %13 890 8
11-
OpCopyMemory %6 %8
12-
OpLine %9 8 1
5+
OpLine %8 7 37
6+
%9 = OpLoad %10 %4
7+
OpLine %11 890 8
8+
OpStore %6 %9
9+
OpLine %8 8 1
1310
OpReturn
1411
OpFunctionEnd

tests/ui/lang/core/ops/range-contains.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
// in `core` (see https://github.com/rust-lang/rust/pull/87723), cannot
33
// cause a fatal error, but at most a zombie or SPIR-V validation error.
44

5-
// build-fail
6-
7-
// HACK(eddyb) this allows CI (older?) `spirv-val` output to also work.
8-
// normalize-stderr-test " %\d+ = OpVariable %\w+ Function\n\n" -> ""
5+
// build-pass
96

107
use spirv_std as _;
118

tests/ui/lang/core/ops/range-contains.stderr

-7
This file was deleted.

0 commit comments

Comments
 (0)