Skip to content

Commit 2976158

Browse files
committed
[naga msl-out hlsl-out] Improve workaround for infinite loops causing undefined behaviour
We must ensure that all loops emitted by the naga backends will terminate, in order to avoid undefined behaviour. This was previously implemented for the msl backend in 6545. However, the usage of `volatile` prevents the compiler from making other important optimizations. This patch improves the msl workaround and additionally implements it for hlsl. The spv implementation will be left for a follow up. Rather than using volatile, this patch increments a counter on every loop iteration, breaking from the loop after 2^64 iterations. This ensures the compiler treats the loop as finite thereby avoiding undefined behaviour, whilst at the same time allowing for other optimizations and in reality not actually affecting execution.
1 parent 04e40dd commit 2976158

20 files changed

+223
-95
lines changed

naga/src/back/hlsl/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ pub struct Options {
211211
pub zero_initialize_workgroup_memory: bool,
212212
/// Should we restrict indexing of vectors, matrices and arrays?
213213
pub restrict_indexing: bool,
214+
/// If set, loops will have code injected into them, forcing the compiler
215+
/// to think the number of iterations is bounded.
216+
pub force_loop_bounding: bool,
214217
}
215218

216219
impl Default for Options {
@@ -223,6 +226,7 @@ impl Default for Options {
223226
push_constants_target: None,
224227
zero_initialize_workgroup_memory: true,
225228
restrict_indexing: true,
229+
force_loop_bounding: true,
226230
}
227231
}
228232
}

naga/src/back/hlsl/writer.rs

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,33 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
131131
self.need_bake_expressions.clear();
132132
}
133133

134+
/// Generates statements to be inserted immediately before and at the very
135+
/// start of the body of each loop, to defeat infinite loop reasoning.
136+
/// The 0th item of the returned tuple should be inserted immediately prior
137+
/// to the loop and the 1st item should be inserted at the very start of
138+
/// the loop body.
139+
///
140+
/// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
141+
fn gen_force_bounded_loop_statements(
142+
&mut self,
143+
level: back::Level,
144+
) -> Option<(String, String)> {
145+
if !self.options.force_loop_bounding {
146+
return None;
147+
}
148+
149+
let loop_bound_name = self.namer.call("loop_bound");
150+
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u, 0u);");
151+
let level = level.next();
152+
let max = u32::MAX;
153+
let break_and_inc = format!(
154+
"{level}if (all({loop_bound_name} == uint2({max}u, {max}u))) {{ break; }}
155+
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
156+
);
157+
158+
Some((decl, break_and_inc))
159+
}
160+
134161
/// Helper method used to find which expressions of a given function require baking
135162
///
136163
/// # Notes
@@ -2048,12 +2075,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
20482075
ref continuing,
20492076
break_if,
20502077
} => {
2078+
let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2079+
let gate_name = (!continuing.is_empty() || break_if.is_some())
2080+
.then(|| self.namer.call("loop_init"));
2081+
2082+
if let Some((ref decl, _)) = force_loop_bound_statements {
2083+
writeln!(self.out, "{decl}")?;
2084+
}
2085+
if let Some(ref gate_name) = gate_name {
2086+
writeln!(self.out, "{level}bool {gate_name} = true;")?;
2087+
}
2088+
20512089
self.continue_ctx.enter_loop();
2090+
writeln!(self.out, "{level}while(true) {{")?;
2091+
if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2092+
writeln!(self.out, "{break_and_inc}")?;
2093+
}
20522094
let l2 = level.next();
2053-
if !continuing.is_empty() || break_if.is_some() {
2054-
let gate_name = self.namer.call("loop_init");
2055-
writeln!(self.out, "{level}bool {gate_name} = true;")?;
2056-
writeln!(self.out, "{level}while(true) {{")?;
2095+
if let Some(gate_name) = gate_name {
20572096
writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
20582097
let l3 = l2.next();
20592098
for sta in continuing.iter() {
@@ -2068,13 +2107,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
20682107
}
20692108
writeln!(self.out, "{l2}}}")?;
20702109
writeln!(self.out, "{l2}{gate_name} = false;")?;
2071-
} else {
2072-
writeln!(self.out, "{level}while(true) {{")?;
20732110
}
20742111

20752112
for sta in body.iter() {
20762113
self.write_stmt(module, sta, func_ctx, l2)?;
20772114
}
2115+
20782116
writeln!(self.out, "{level}}}")?;
20792117
self.continue_ctx.exit_loop();
20802118
}

naga/src/back/msl/writer.rs

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,6 @@ pub struct Writer<W> {
383383
/// Set of (struct type, struct field index) denoting which fields require
384384
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
385385
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
386-
387-
/// Name of the force-bounded-loop macro.
388-
///
389-
/// See `emit_force_bounded_loop_macro` for details.
390-
force_bounded_loop_macro_name: String,
391386
}
392387

393388
impl crate::Scalar {
@@ -601,7 +596,7 @@ struct ExpressionContext<'a> {
601596
/// accesses. These may need to be cached in temporary variables. See
602597
/// `index::find_checked_indexes` for details.
603598
guarded_indices: HandleSet<crate::Expression>,
604-
/// See [`Writer::emit_force_bounded_loop_macro`] for details.
599+
/// See [`Writer::gen_force_bounded_loop_statements`] for details.
605600
force_loop_bounding: bool,
606601
}
607602

@@ -685,7 +680,6 @@ impl<W: Write> Writer<W> {
685680
#[cfg(test)]
686681
put_block_stack_pointers: Default::default(),
687682
struct_member_pads: FastHashSet::default(),
688-
force_bounded_loop_macro_name: String::default(),
689683
}
690684
}
691685

@@ -696,17 +690,11 @@ impl<W: Write> Writer<W> {
696690
self.out
697691
}
698692

699-
/// Define a macro to invoke at the bottom of each loop body, to
700-
/// defeat MSL infinite loop reasoning.
701-
///
702-
/// If we haven't done so already, emit the definition of a preprocessor
703-
/// macro to be invoked at the end of each loop body in the generated MSL,
704-
/// to ensure that the MSL compiler's optimizations do not remove bounds
705-
/// checks.
706-
///
707-
/// Only the first call to this function for a given module actually causes
708-
/// the macro definition to be written. Subsequent loops can simply use the
709-
/// prior macro definition, since macros aren't block-scoped.
693+
/// Generates statements to be inserted immediately before and at the very
694+
/// start of the body of each loop, to defeat MSL infinite loop reasoning.
695+
/// The 0th item of the returned tuple should be inserted immediately prior
696+
/// to the loop and the 1st item should be inserted at the very start of
697+
/// the loop body.
710698
///
711699
/// # What is this trying to solve?
712700
///
@@ -774,7 +762,8 @@ impl<W: Write> Writer<W> {
774762
/// but which in fact generates no instructions. Unfortunately, inline
775763
/// assembly is not handled correctly by some Metal device drivers.
776764
///
777-
/// Instead, we add the following code to the bottom of every loop:
765+
/// A previously used approach was to add the following code to the bottom
766+
/// of every loop:
778767
///
779768
/// ```ignore
780769
/// if (volatile bool unpredictable = false; unpredictable)
@@ -785,37 +774,47 @@ impl<W: Write> Writer<W> {
785774
/// the `volatile` qualifier prevents the compiler from assuming this. Thus,
786775
/// it must assume that the `break` might be reached, and hence that the
787776
/// loop is not unbounded. This prevents the range analysis impact described
788-
/// above.
777+
/// above. Unfortunately this prevented the compiler from making important,
778+
/// and safe, optimizations such as loop unrolling and was observed to
779+
/// significantly hurt performance.
789780
///
790-
/// Unfortunately, what makes this a kludge, not a hack, is that this
791-
/// solution leaves the GPU executing a pointless conditional branch, at
792-
/// runtime, in every iteration of the loop. There's no part of the system
793-
/// that has a global enough view to be sure that `unpredictable` is true,
794-
/// and remove it from the code. Adding the branch also affects
795-
/// optimization: for example, it's impossible to unroll this loop. This
796-
/// transformation has been observed to significantly hurt performance.
781+
/// Our current approach declares a counter before every loop and
782+
/// increments it every iteration, breaking after 2^64 iterations:
783+
///
784+
/// ```ignore
785+
/// uint2 loop_bound = uint2(0);
786+
/// while (true) {
787+
/// if (metal::all(loop_bound == uint2(4294967295))) { break; }
788+
/// loop_bound += uint2(loop_bound.y == 4294967295, 1);
789+
/// }
790+
/// ```
797791
///
798-
/// To make our output a bit more legible, we pull the condition out into a
799-
/// preprocessor macro defined at the top of the module.
792+
/// This convinces the compiler that the loop is finite and therefore may
793+
/// execute, whilst at the same time allowing optimizations such as loop
794+
/// unrolling. Furthermore the 64-bit counter is large enough it seems
795+
/// implausible that it would affect the execution of any shader.
800796
///
801797
/// This approach is also used by Chromium WebGPU's Dawn shader compiler:
802-
/// <https://dawn.googlesource.com/dawn/+/a37557db581c2b60fb1cd2c01abdb232927dd961/src/tint/lang/msl/writer/printer/printer.cc#222>
803-
fn emit_force_bounded_loop_macro(&mut self) -> BackendResult {
804-
if !self.force_bounded_loop_macro_name.is_empty() {
805-
return Ok(());
798+
/// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
799+
fn gen_force_bounded_loop_statements(
800+
&mut self,
801+
level: back::Level,
802+
context: &StatementContext,
803+
) -> Option<(String, String)> {
804+
if !context.expression.force_loop_bounding {
805+
return None;
806806
}
807807

808-
self.force_bounded_loop_macro_name = self.namer.call("LOOP_IS_BOUNDED");
809-
let loop_bounded_volatile_name = self.namer.call("unpredictable_break_from_loop");
810-
writeln!(
811-
self.out,
812-
"#define {} {{ volatile bool {} = false; if ({}) break; }}",
813-
self.force_bounded_loop_macro_name,
814-
loop_bounded_volatile_name,
815-
loop_bounded_volatile_name,
816-
)?;
808+
let loop_bound_name = self.namer.call("loop_bound");
809+
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u);");
810+
let level = level.next();
811+
let max = u32::MAX;
812+
let break_and_inc = format!(
813+
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2({max}u))) {{ break; }}
814+
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
815+
);
817816

818-
Ok(())
817+
Some((decl, break_and_inc))
819818
}
820819

821820
fn put_call_parameters(
@@ -3201,10 +3200,23 @@ impl<W: Write> Writer<W> {
32013200
ref continuing,
32023201
break_if,
32033202
} => {
3204-
if !continuing.is_empty() || break_if.is_some() {
3205-
let gate_name = self.namer.call("loop_init");
3203+
let force_loop_bound_statements =
3204+
self.gen_force_bounded_loop_statements(level, context);
3205+
let gate_name = (!continuing.is_empty() || break_if.is_some())
3206+
.then(|| self.namer.call("loop_init"));
3207+
3208+
if let Some((ref decl, _)) = force_loop_bound_statements {
3209+
writeln!(self.out, "{decl}")?;
3210+
}
3211+
if let Some(ref gate_name) = gate_name {
32063212
writeln!(self.out, "{level}bool {gate_name} = true;")?;
3207-
writeln!(self.out, "{level}while(true) {{",)?;
3213+
}
3214+
3215+
writeln!(self.out, "{level}while(true) {{",)?;
3216+
if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3217+
writeln!(self.out, "{break_and_inc}")?;
3218+
}
3219+
if let Some(ref gate_name) = gate_name {
32083220
let lif = level.next();
32093221
let lcontinuing = lif.next();
32103222
writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
@@ -3218,19 +3230,9 @@ impl<W: Write> Writer<W> {
32183230
}
32193231
writeln!(self.out, "{lif}}}")?;
32203232
writeln!(self.out, "{lif}{gate_name} = false;")?;
3221-
} else {
3222-
writeln!(self.out, "{level}while(true) {{",)?;
32233233
}
32243234
self.put_block(level.next(), body, context)?;
3225-
if context.expression.force_loop_bounding {
3226-
self.emit_force_bounded_loop_macro()?;
3227-
writeln!(
3228-
self.out,
3229-
"{}{}",
3230-
level.next(),
3231-
self.force_bounded_loop_macro_name
3232-
)?;
3233-
}
3235+
32343236
writeln!(self.out, "{level}}}")?;
32353237
}
32363238
crate::Statement::Break => {
@@ -3724,7 +3726,6 @@ impl<W: Write> Writer<W> {
37243726
&[CLAMPED_LOD_LOAD_PREFIX],
37253727
&mut self.names,
37263728
);
3727-
self.force_bounded_loop_macro_name.clear();
37283729
self.struct_member_pads.clear();
37293730

37303731
writeln!(

naga/tests/out/hlsl/boids.hlsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@ void main(uint3 global_invocation_id : SV_DispatchThreadID)
4141
vPos = _e8;
4242
float2 _e14 = asfloat(particlesSrc.Load2(8+index*16+0));
4343
vVel = _e14;
44+
uint2 loop_bound = uint2(0u, 0u);
4445
bool loop_init = true;
4546
while(true) {
47+
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
48+
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
4649
if (!loop_init) {
4750
uint _e91 = i;
4851
i = (_e91 + 1u);

naga/tests/out/hlsl/break-if.hlsl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
void breakIfEmpty()
22
{
3+
uint2 loop_bound = uint2(0u, 0u);
34
bool loop_init = true;
45
while(true) {
6+
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
7+
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
58
if (!loop_init) {
69
if (true) {
710
break;
@@ -17,8 +20,11 @@ void breakIfEmptyBody(bool a)
1720
bool b = (bool)0;
1821
bool c = (bool)0;
1922

23+
uint2 loop_bound_1 = uint2(0u, 0u);
2024
bool loop_init_1 = true;
2125
while(true) {
26+
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
27+
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
2228
if (!loop_init_1) {
2329
b = a;
2430
bool _e2 = b;
@@ -38,8 +44,11 @@ void breakIf(bool a_1)
3844
bool d = (bool)0;
3945
bool e = (bool)0;
4046

47+
uint2 loop_bound_2 = uint2(0u, 0u);
4148
bool loop_init_2 = true;
4249
while(true) {
50+
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
51+
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
4352
if (!loop_init_2) {
4453
bool _e5 = e;
4554
if ((a_1 == _e5)) {
@@ -58,8 +67,11 @@ void breakIfSeparateVariable()
5867
{
5968
uint counter = 0u;
6069

70+
uint2 loop_bound_3 = uint2(0u, 0u);
6171
bool loop_init_3 = true;
6272
while(true) {
73+
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
74+
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
6375
if (!loop_init_3) {
6476
uint _e5 = counter;
6577
if ((_e5 == 5u)) {

naga/tests/out/hlsl/collatz.hlsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ uint collatz_iterations(uint n_base)
66
uint i = 0u;
77

88
n = n_base;
9+
uint2 loop_bound = uint2(0u, 0u);
910
while(true) {
11+
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
12+
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
1013
uint _e4 = n;
1114
if ((_e4 > 1u)) {
1215
} else {

0 commit comments

Comments
 (0)