Skip to content

Commit d2f7783

Browse files
committed
Benchmark Workaround
1 parent b233d20 commit d2f7783

File tree

4 files changed

+255
-2
lines changed

4 files changed

+255
-2
lines changed

benches/benches/loop_workaround.rs

+231
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
use std::{collections::VecDeque, time::Duration};
2+
3+
use criterion::{criterion_group, Criterion};
4+
use std::sync::LazyLock;
5+
use wgpu::{ComputePassTimestampWrites, ComputePipeline, PipelineCompilationOptions};
6+
7+
use crate::DeviceState;
8+
9+
const ITERATIONS_IN_FLIGHT: usize = 5;
10+
const WORKGROUPS_PER_DISPATCH: u32 = 1024;
11+
const INVOCATIONS_PER_DISPATCH: u32 = 64 * WORKGROUPS_PER_DISPATCH;
12+
13+
struct LoopWorkaroundState {
14+
device_state: DeviceState,
15+
pipeline: ComputePipeline,
16+
bg: wgpu::BindGroup,
17+
query_sets: Vec<wgpu::QuerySet>,
18+
resolve_buffers: Vec<wgpu::Buffer>,
19+
readback_buffers: Vec<wgpu::Buffer>,
20+
}
21+
22+
impl LoopWorkaroundState {
23+
/// Create and prepare all the resources needed for the renderpass benchmark.
24+
fn new() -> Self {
25+
let device_state = DeviceState::new();
26+
27+
let shader_module = unsafe {
28+
device_state.device.create_shader_module_trusted(
29+
wgpu::ShaderModuleDescriptor {
30+
label: Some("loop_workaround.wgsl"),
31+
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(
32+
std::fs::read_to_string(concat!(
33+
env!("CARGO_MANIFEST_DIR"),
34+
"/benches/loop_workaround.wgsl"
35+
))
36+
.unwrap(),
37+
)),
38+
},
39+
wgpu::ShaderRuntimeChecks {
40+
bounds_checks: true,
41+
force_loop_bounding: true,
42+
},
43+
)
44+
};
45+
46+
let pipeline =
47+
device_state
48+
.device
49+
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
50+
label: Some("Loop Workaround Pipeline"),
51+
layout: None,
52+
module: &shader_module,
53+
entry_point: None,
54+
compilation_options: PipelineCompilationOptions::default(),
55+
cache: None,
56+
});
57+
58+
let bind_group_layout = pipeline.get_bind_group_layout(0);
59+
60+
let buffer = device_state.device.create_buffer(&wgpu::BufferDescriptor {
61+
label: Some("Loop Workaround Buffer"),
62+
size: INVOCATIONS_PER_DISPATCH as u64 * std::mem::size_of::<u32>() as u64,
63+
usage: wgpu::BufferUsages::STORAGE,
64+
mapped_at_creation: false,
65+
});
66+
67+
let bg = device_state
68+
.device
69+
.create_bind_group(&wgpu::BindGroupDescriptor {
70+
label: Some("Loop Workaround Bind Group"),
71+
layout: &bind_group_layout,
72+
entries: &[wgpu::BindGroupEntry {
73+
binding: 0,
74+
resource: buffer.as_entire_binding(),
75+
}],
76+
});
77+
78+
let query_sets = (0..ITERATIONS_IN_FLIGHT)
79+
.map(|_| {
80+
device_state
81+
.device
82+
.create_query_set(&wgpu::QuerySetDescriptor {
83+
label: Some("Loop Workaround Query Set"),
84+
ty: wgpu::QueryType::Timestamp,
85+
count: 2,
86+
})
87+
})
88+
.collect();
89+
90+
let resolve_buffers = (0..ITERATIONS_IN_FLIGHT)
91+
.map(|_| {
92+
device_state.device.create_buffer(&wgpu::BufferDescriptor {
93+
label: Some("Loop Workaround Resolve Buffer"),
94+
size: 2 * std::mem::size_of::<u64>() as u64,
95+
usage: wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::QUERY_RESOLVE,
96+
mapped_at_creation: false,
97+
})
98+
})
99+
.collect();
100+
101+
let readback_buffers = (0..ITERATIONS_IN_FLIGHT)
102+
.map(|_| {
103+
device_state.device.create_buffer(&wgpu::BufferDescriptor {
104+
label: Some("Loop Workaround Readback Buffer"),
105+
size: 2 * std::mem::size_of::<u64>() as u64,
106+
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
107+
mapped_at_creation: false,
108+
})
109+
})
110+
.collect();
111+
112+
Self {
113+
device_state,
114+
pipeline,
115+
bg,
116+
query_sets,
117+
resolve_buffers,
118+
readback_buffers,
119+
}
120+
}
121+
}
122+
123+
fn run_bench(ctx: &mut Criterion) {
124+
let state = LazyLock::new(LoopWorkaroundState::new);
125+
126+
if !std::env::var("NEXTEST").is_ok() {
127+
LazyLock::force(&state);
128+
}
129+
130+
ctx.bench_function("Loop Workaround", |b| {
131+
b.iter_custom(|iters| {
132+
let queue_period = state.device_state.queue.get_timestamp_period() as f64;
133+
let mut in_flight_submissions = VecDeque::new();
134+
135+
let mut total_duration_spent = Duration::ZERO;
136+
137+
for iter in 0..iters {
138+
let iter_in_flight = iter % ITERATIONS_IN_FLIGHT as u64;
139+
140+
let query_set = &state.query_sets[iter_in_flight as usize];
141+
let resolve_buffer = &state.resolve_buffers[iter_in_flight as usize];
142+
let readback_buffer = &state.readback_buffers[iter_in_flight as usize];
143+
144+
let mut encoder = state
145+
.device_state
146+
.device
147+
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
148+
149+
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
150+
label: None,
151+
timestamp_writes: Some(ComputePassTimestampWrites {
152+
query_set,
153+
beginning_of_pass_write_index: Some(0),
154+
end_of_pass_write_index: Some(1),
155+
}),
156+
});
157+
158+
cpass.set_pipeline(&state.pipeline);
159+
cpass.set_bind_group(0, &state.bg, &[]);
160+
cpass.dispatch_workgroups(WORKGROUPS_PER_DISPATCH, 1, 1);
161+
162+
drop(cpass);
163+
164+
encoder.resolve_query_set(&query_set, 0..2, &resolve_buffer, 0);
165+
166+
encoder.copy_buffer_to_buffer(
167+
&resolve_buffer,
168+
0,
169+
&readback_buffer,
170+
0,
171+
2 * std::mem::size_of::<u64>() as u64,
172+
);
173+
174+
let submission_index = state.device_state.queue.submit(Some(encoder.finish()));
175+
in_flight_submissions.push_back((iter_in_flight, submission_index));
176+
177+
readback_buffer
178+
.slice(..)
179+
.map_async(wgpu::MapMode::Read, |_| {});
180+
181+
let last_iteration = iter as u64 == iters - 1;
182+
let five_iterations_in_flight = in_flight_submissions.len() == ITERATIONS_IN_FLIGHT;
183+
184+
if five_iterations_in_flight || last_iteration {
185+
let iterations_to_purge = if last_iteration {
186+
in_flight_submissions.len()
187+
} else {
188+
1
189+
};
190+
191+
for _ in 0..iterations_to_purge {
192+
let (buffer_idx, submission) = in_flight_submissions.pop_front().unwrap();
193+
194+
state
195+
.device_state
196+
.device
197+
.poll(wgpu::Maintain::WaitForSubmissionIndex(submission));
198+
199+
let readback_buffer = &state.readback_buffers[buffer_idx as usize];
200+
201+
let query_range = readback_buffer.slice(..).get_mapped_range();
202+
let query_data: &[u64] = bytemuck::cast_slice(&*query_range);
203+
204+
let diff = query_data[1] - query_data[0];
205+
let time = diff as f64 * queue_period;
206+
207+
total_duration_spent += Duration::from_secs_f64(time / 1_000_000_000.0);
208+
209+
drop(query_range);
210+
readback_buffer.unmap();
211+
}
212+
}
213+
}
214+
215+
println!(
216+
"{:?}: {} {:?} per",
217+
total_duration_spent,
218+
iters,
219+
total_duration_spent / iters as u32
220+
);
221+
222+
total_duration_spent
223+
});
224+
});
225+
}
226+
227+
criterion_group! {
228+
name = loop_workaround;
229+
config = Criterion::default().measurement_time(Duration::from_secs(20)).sample_size(10);
230+
targets = run_bench,
231+
}

benches/benches/loop_workaround.wgsl

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@group(0) @binding(0) var<storage, read_write> data: array<u32>;
2+
3+
@compute @workgroup_size(64)
4+
fn addABunch(@builtin(global_invocation_id) global_id: vec3<u32>) {
5+
var x: u32 = data[global_id.x];
6+
for (var i = 1u; i <= 100000u; i++) {
7+
x = u32(sin(f32(x * 120u)));
8+
}
9+
data[global_id.x] = x;
10+
}

benches/benches/root.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use pollster::block_on;
33

44
mod bind_groups;
55
mod computepass;
6+
mod loop_workaround;
67
mod renderpass;
78
mod resource_creation;
89
mod shader;
@@ -62,6 +63,7 @@ criterion_main!(
6263
bind_groups::bind_groups,
6364
renderpass::renderpass,
6465
computepass::computepass,
66+
loop_workaround::loop_workaround,
6567
resource_creation::resource_creation,
6668
shader::shader
6769
);

wgpu-hal/src/vulkan/command.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,12 @@ impl crate::CommandEncoder for super::CommandEncoder {
792792
if let Some(timestamp_writes) = desc.timestamp_writes.as_ref() {
793793
if let Some(index) = timestamp_writes.beginning_of_pass_write_index {
794794
unsafe {
795-
self.write_timestamp(timestamp_writes.query_set, index);
795+
self.device.raw.cmd_write_timestamp(
796+
self.active,
797+
vk::PipelineStageFlags::TOP_OF_PIPE,
798+
timestamp_writes.query_set.raw,
799+
index,
800+
)
796801
}
797802
}
798803
self.end_of_pass_timer_query = timestamp_writes
@@ -1111,7 +1116,12 @@ impl crate::CommandEncoder for super::CommandEncoder {
11111116
if let Some(timestamp_writes) = desc.timestamp_writes.as_ref() {
11121117
if let Some(index) = timestamp_writes.beginning_of_pass_write_index {
11131118
unsafe {
1114-
self.write_timestamp(timestamp_writes.query_set, index);
1119+
self.device.raw.cmd_write_timestamp(
1120+
self.active,
1121+
vk::PipelineStageFlags::TOP_OF_PIPE,
1122+
timestamp_writes.query_set.raw,
1123+
index,
1124+
)
11151125
}
11161126
}
11171127
self.end_of_pass_timer_query = timestamp_writes

0 commit comments

Comments
 (0)