Skip to content

Commit 1ee50df

Browse files
committed
[naga spv-out] Ensure loops generated by SPIRV backend are bounded
If it is undefined behaviour for loops to be infinite, then, when encountering an infinite loop, downstream compilers are able to make certain optimizations that may be unsafe. For example, omitting bounds checks. To prevent this, we must ensure that any loops emitted by our backends are provably bounded. We already do this for both the MSL and HLSL backends. This patch makes us do so for SPIRV as well. The construct used is the same as for HLSL and MSL backends: use a vec2<u32> to emulate a 64-bit counter, which is incremented every iteration and breaks after 2^64 iterations. While the implementation is fairly verbose for the SPIRV backend, the logic is simple enough. The one point of note is that SPIRV requires `OpVariable` instructions with a `Function` storage class to be located at the start of the first block of the function. We therefore remember the IDs generated for each loop counter variable in a function whilst generating the function body's code. The instructions to declare these variables are then emitted in `Function::to_words()` prior to emitting the function's body. As this may negatively impact shader performance, this workaround can be disabled using the same mechanism as for other backends: eg calling Device::create_shader_module_trusted() and setting the ShaderRuntimeChecks::force_loop_bounding flag to false.
1 parent 189c97c commit 1ee50df

20 files changed

+2942
-2220
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ Bottom level categories:
4040

4141
## Unreleased
4242

43+
### Major changes
44+
45+
#### Naga
46+
47+
##### Ensure loops generated by SPIR-V and HLSL Naga backends are bounded
48+
49+
Make sure that all loops in shaders generated by these naga backends are bounded
50+
to avoid undefined behaviour due to infinite loops. Note that this may have a
51+
performance cost. As with the existing implementation for the MSL backend this
52+
can be disabled by using `Device::create_shader_module_trusted()`.
53+
54+
By @jamienicol in [#6929](https://github.com/gfx-rs/wgpu/pull/6929) and [#7080](https://github.com/gfx-rs/wgpu/pull/7080).
55+
4356
### New Features
4457

4558
#### Naga

naga/src/back/spv/block.rs

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,155 @@ impl Writer {
260260
}
261261

262262
impl BlockContext<'_> {
263+
/// Generates code to ensure that a loop is bounded. Should be called immediately
264+
/// after adding the OpLoopMerge instruction to `block`. This function will
265+
/// [`consume()`](crate::back::spv::Function::consume) `block` and append its
266+
/// instructions to a new [`Block`], which will be returned to the caller for it to
267+
/// consumed prior to writing the loop body.
268+
///
269+
/// Additionally this function will populate [`force_loop_bounding_vars`](crate::back::spv::Function::force_loop_bounding_vars),
270+
/// ensuring that [`Function::to_words()`](crate::back::spv::Function::to_words) will
271+
/// declare the required variables.
272+
///
273+
/// See [`crate::back::msl::Writer::gen_force_bounded_loop_statements`] for details
274+
/// of why this is required.
275+
fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
276+
let uint_type_id = self.writer.get_uint_type_id();
277+
let uint2_type_id = self.writer.get_uint2_type_id();
278+
let uint2_ptr_type_id = self
279+
.writer
280+
.get_uint2_pointer_type_id(spirv::StorageClass::Function);
281+
let bool_type_id = self.writer.get_bool_type_id();
282+
let bool2_type_id = self.writer.get_bool2_type_id();
283+
let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
284+
let zero_uint2_const_id = self.writer.get_constant_composite(
285+
LookupType::Local(LocalType::Numeric(NumericType::Vector {
286+
size: crate::VectorSize::Bi,
287+
scalar: crate::Scalar::U32,
288+
})),
289+
&[zero_uint_const_id, zero_uint_const_id],
290+
);
291+
let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
292+
let max_uint_const_id = self
293+
.writer
294+
.get_constant_scalar(crate::Literal::U32(u32::MAX));
295+
let max_uint2_const_id = self.writer.get_constant_composite(
296+
LookupType::Local(LocalType::Numeric(NumericType::Vector {
297+
size: crate::VectorSize::Bi,
298+
scalar: crate::Scalar::U32,
299+
})),
300+
&[max_uint_const_id, max_uint_const_id],
301+
);
302+
303+
let loop_counter_var_id = self.gen_id();
304+
if self.writer.flags.contains(WriterFlags::DEBUG) {
305+
self.writer
306+
.debugs
307+
.push(Instruction::name(loop_counter_var_id, "loop_bound"));
308+
}
309+
let var = super::LocalVariable {
310+
id: loop_counter_var_id,
311+
instruction: Instruction::variable(
312+
uint2_ptr_type_id,
313+
loop_counter_var_id,
314+
spirv::StorageClass::Function,
315+
Some(zero_uint2_const_id),
316+
),
317+
};
318+
self.function.force_loop_bounding_vars.push(var);
319+
320+
let break_if_block = self.gen_id();
321+
322+
self.function
323+
.consume(block, Instruction::branch(break_if_block));
324+
block = Block::new(break_if_block);
325+
326+
// Load the current loop counter value from its variable. We use a vec2<u32> to
327+
// simulate a 64-bit counter.
328+
let load_id = self.gen_id();
329+
block.body.push(Instruction::load(
330+
uint2_type_id,
331+
load_id,
332+
loop_counter_var_id,
333+
None,
334+
));
335+
336+
// If both the high and low u32s have reached u32::MAX then break. ie
337+
// if (all(eq(loop_counter, vec2(u32::MAX)))) { break; }
338+
let eq_id = self.gen_id();
339+
block.body.push(Instruction::binary(
340+
spirv::Op::IEqual,
341+
bool2_type_id,
342+
eq_id,
343+
max_uint2_const_id,
344+
load_id,
345+
));
346+
let all_eq_id = self.gen_id();
347+
block.body.push(Instruction::relational(
348+
spirv::Op::All,
349+
bool_type_id,
350+
all_eq_id,
351+
eq_id,
352+
));
353+
354+
let inc_counter_block_id = self.gen_id();
355+
block.body.push(Instruction::selection_merge(
356+
inc_counter_block_id,
357+
spirv::SelectionControl::empty(),
358+
));
359+
self.function.consume(
360+
block,
361+
Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
362+
);
363+
block = Block::new(inc_counter_block_id);
364+
365+
// To simulate a 64-bit counter we always increment the low u32, and increment
366+
// the high u32 when the low u32 overflows. ie
367+
// counter += vec2(select(0u, 1u, counter.y == u32::MAX), 1u);
368+
let low_id = self.gen_id();
369+
block.body.push(Instruction::composite_extract(
370+
uint_type_id,
371+
low_id,
372+
load_id,
373+
&[1],
374+
));
375+
let low_overflow_id = self.gen_id();
376+
block.body.push(Instruction::binary(
377+
spirv::Op::IEqual,
378+
bool_type_id,
379+
low_overflow_id,
380+
low_id,
381+
max_uint_const_id,
382+
));
383+
let carry_bit_id = self.gen_id();
384+
block.body.push(Instruction::select(
385+
uint_type_id,
386+
carry_bit_id,
387+
low_overflow_id,
388+
one_uint_const_id,
389+
zero_uint_const_id,
390+
));
391+
let increment_id = self.gen_id();
392+
block.body.push(Instruction::composite_construct(
393+
uint2_type_id,
394+
increment_id,
395+
&[carry_bit_id, one_uint_const_id],
396+
));
397+
let result_id = self.gen_id();
398+
block.body.push(Instruction::binary(
399+
spirv::Op::IAdd,
400+
uint2_type_id,
401+
result_id,
402+
load_id,
403+
increment_id,
404+
));
405+
block
406+
.body
407+
.push(Instruction::store(loop_counter_var_id, result_id, None));
408+
409+
block
410+
}
411+
263412
/// Cache an expression for a value.
264413
pub(super) fn cache_expression_value(
265414
&mut self,
@@ -2531,6 +2680,10 @@ impl BlockContext<'_> {
25312680
continuing_id,
25322681
spirv::SelectionControl::NONE,
25332682
));
2683+
2684+
if self.force_loop_bounding {
2685+
block = self.write_force_bounded_loop_instructions(block, merge_id);
2686+
}
25342687
self.function.consume(block, Instruction::branch(body_id));
25352688

25362689
// We can ignore the `BlockExitDisposition` returned here because,

naga/src/back/spv/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ struct Function {
144144
signature: Option<Instruction>,
145145
parameters: Vec<FunctionArgument>,
146146
variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
147+
/// List of local variables used as a counters to ensure that all loops are bounded.
148+
force_loop_bounding_vars: Vec<LocalVariable>,
147149

148150
/// A map taking an expression that yields a composite value (array, matrix)
149151
/// to the temporary variables we have spilled it to, if any. Spilling
@@ -694,6 +696,8 @@ struct BlockContext<'w> {
694696

695697
/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
696698
expression_constness: ExpressionConstnessTracker,
699+
700+
force_loop_bounding: bool,
697701
}
698702

699703
impl BlockContext<'_> {
@@ -747,6 +751,7 @@ pub struct Writer {
747751
flags: WriterFlags,
748752
bounds_check_policies: BoundsCheckPolicies,
749753
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
754+
force_loop_bounding: bool,
750755
void_type: Word,
751756
//TODO: convert most of these into vectors, addressable by handle indices
752757
lookup_type: crate::FastHashMap<LookupType, Word>,
@@ -846,6 +851,10 @@ pub struct Options<'a> {
846851
/// Dictates the way workgroup variables should be zero initialized
847852
pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
848853

854+
/// If set, loops will have code injected into them, forcing the compiler
855+
/// to think the number of iterations is bounded.
856+
pub force_loop_bounding: bool,
857+
849858
pub debug_info: Option<DebugInfo<'a>>,
850859
}
851860

@@ -864,6 +873,7 @@ impl Default for Options<'_> {
864873
capabilities: None,
865874
bounds_check_policies: BoundsCheckPolicies::default(),
866875
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill,
876+
force_loop_bounding: true,
867877
debug_info: None,
868878
}
869879
}

naga/src/back/spv/writer.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ impl Function {
3232
for local_var in self.variables.values() {
3333
local_var.instruction.to_words(sink);
3434
}
35+
for local_var in self.force_loop_bounding_vars.iter() {
36+
local_var.instruction.to_words(sink);
37+
}
3538
for internal_var in self.spilled_composites.values() {
3639
internal_var.instruction.to_words(sink);
3740
}
@@ -70,6 +73,7 @@ impl Writer {
7073
flags: options.flags,
7174
bounds_check_policies: options.bounds_check_policies,
7275
zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
76+
force_loop_bounding: options.force_loop_bounding,
7377
void_type,
7478
lookup_type: crate::FastHashMap::default(),
7579
lookup_function: crate::FastHashMap::default(),
@@ -109,6 +113,7 @@ impl Writer {
109113
flags: self.flags,
110114
bounds_check_policies: self.bounds_check_policies,
111115
zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
116+
force_loop_bounding: self.force_loop_bounding,
112117
capabilities_available: take(&mut self.capabilities_available),
113118
binding_map: take(&mut self.binding_map),
114119

@@ -260,6 +265,14 @@ impl Writer {
260265
self.get_type_id(local_type.into())
261266
}
262267

268+
pub(super) fn get_uint2_type_id(&mut self) -> Word {
269+
let local_type = LocalType::Numeric(NumericType::Vector {
270+
size: crate::VectorSize::Bi,
271+
scalar: crate::Scalar::U32,
272+
});
273+
self.get_type_id(local_type.into())
274+
}
275+
263276
pub(super) fn get_uint3_type_id(&mut self) -> Word {
264277
let local_type = LocalType::Numeric(NumericType::Vector {
265278
size: crate::VectorSize::Tri,
@@ -276,6 +289,17 @@ impl Writer {
276289
self.get_type_id(local_type.into())
277290
}
278291

292+
pub(super) fn get_uint2_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
293+
let local_type = LocalType::LocalPointer {
294+
base: NumericType::Vector {
295+
size: crate::VectorSize::Bi,
296+
scalar: crate::Scalar::U32,
297+
},
298+
class,
299+
};
300+
self.get_type_id(local_type.into())
301+
}
302+
279303
pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
280304
let local_type = LocalType::LocalPointer {
281305
base: NumericType::Vector {
@@ -292,6 +316,14 @@ impl Writer {
292316
self.get_type_id(local_type.into())
293317
}
294318

319+
pub(super) fn get_bool2_type_id(&mut self) -> Word {
320+
let local_type = LocalType::Numeric(NumericType::Vector {
321+
size: crate::VectorSize::Bi,
322+
scalar: crate::Scalar::BOOL,
323+
});
324+
self.get_type_id(local_type.into())
325+
}
326+
295327
pub(super) fn get_bool3_type_id(&mut self) -> Word {
296328
let local_type = LocalType::Numeric(NumericType::Vector {
297329
size: crate::VectorSize::Tri,
@@ -594,6 +626,7 @@ impl Writer {
594626

595627
// Steal the Writer's temp list for a bit.
596628
temp_list: std::mem::take(&mut self.temp_list),
629+
force_loop_bounding: self.force_loop_bounding,
597630
writer: self,
598631
expression_constness: super::ExpressionConstnessTracker::from_arena(
599632
&ir_function.expressions,

naga/tests/out/spv/6220-break-from-loop.spvasm

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; SPIR-V
22
; Version: 1.1
33
; Generator: rspirv
4-
; Bound: 26
4+
; Bound: 46
55
OpCapability Shader
66
OpCapability Linkage
77
%1 = OpExtInstImport "GLSL.std.450"
@@ -13,31 +13,55 @@ OpMemoryModel Logical GLSL450
1313
%8 = OpConstant %3 4
1414
%9 = OpConstant %3 1
1515
%11 = OpTypePointer Function %3
16-
%18 = OpTypeBool
16+
%17 = OpTypeInt 32 0
17+
%18 = OpTypeVector %17 2
18+
%19 = OpTypePointer Function %18
19+
%20 = OpTypeBool
20+
%21 = OpTypeVector %20 2
21+
%22 = OpConstant %17 0
22+
%23 = OpConstantComposite %18 %22 %22
23+
%24 = OpConstant %17 1
24+
%25 = OpConstant %17 4294967295
25+
%26 = OpConstantComposite %18 %25 %25
1726
%5 = OpFunction %2 None %6
1827
%4 = OpLabel
1928
%10 = OpVariable %11 Function %7
29+
%27 = OpVariable %19 Function %23
2030
OpBranch %12
2131
%12 = OpLabel
2232
OpBranch %13
2333
%13 = OpLabel
2434
OpLoopMerge %14 %16 None
35+
OpBranch %28
36+
%28 = OpLabel
37+
%29 = OpLoad %18 %27
38+
%30 = OpIEqual %21 %26 %29
39+
%31 = OpAll %20 %30
40+
OpSelectionMerge %32 None
41+
OpBranchConditional %31 %14 %32
42+
%32 = OpLabel
43+
%33 = OpCompositeExtract %17 %29 1
44+
%34 = OpIEqual %20 %33 %25
45+
%35 = OpSelect %17 %34 %24 %22
46+
%36 = OpCompositeConstruct %18 %35 %24
47+
%37 = OpIAdd %18 %29 %36
48+
OpStore %27 %37
2549
OpBranch %15
2650
%15 = OpLabel
27-
%17 = OpLoad %3 %10
28-
%19 = OpSLessThan %18 %17 %8
29-
OpSelectionMerge %20 None
30-
OpBranchConditional %19 %20 %21
31-
%21 = OpLabel
51+
%38 = OpLoad %3 %10
52+
%39 = OpSLessThan %20 %38 %8
53+
OpSelectionMerge %40 None
54+
OpBranchConditional %39 %40 %41
55+
%41 = OpLabel
3256
OpBranch %14
33-
%20 = OpLabel
34-
OpBranch %22
35-
%22 = OpLabel
57+
%40 = OpLabel
58+
OpBranch %42
59+
%42 = OpLabel
3660
OpBranch %14
3761
%16 = OpLabel
38-
%24 = OpLoad %3 %10
39-
%25 = OpIAdd %3 %24 %9
40-
OpStore %10 %25
62+
%44 = OpLoad %3 %10
63+
%45 = OpIAdd %3 %44 %9
64+
OpStore %10 %45
4165
OpBranch %13
4266
%14 = OpLabel
4367
OpReturn

0 commit comments

Comments
 (0)