Skip to content

Commit

Permalink
Support not writing top mip for HiZ.
Browse files Browse the repository at this point in the history
It's very unlikely that it'll matter in practice and it saves a lot of
GPU time to not have to write out full-res.
  • Loading branch information
Themaister committed Dec 17, 2024
1 parent 2a95dcb commit 3e5428a
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 37 deletions.
6 changes: 5 additions & 1 deletion assets/shaders/inc/meshlet_render.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ layout(set = MESHLET_RENDER_DESCRIPTOR_SET, binding = MESHLET_RENDER_FRUSTUM_BIN
mat4 view;
vec4 viewport_scale_bias;
ivec2 hiz_resolution;
int hiz_min_lod;
int hiz_max_lod;
} frustum;

Expand Down Expand Up @@ -114,13 +115,16 @@ bool hiz_cull(vec2 view_range_x, vec2 view_range_y, float closest_z)
// We need to sample from a LOD where where there is at most one texel delta
// between lo/hi values.
int max_delta = max(ix.y - ix.x, iy.y - iy.x);
int lod = min(findMSB(max_delta - 1) + 1, frustum.hiz_max_lod);
int lod = clamp(findMSB(max_delta - 1) + 1, frustum.hiz_min_lod, frustum.hiz_max_lod);
ivec2 lod_max_coord = max(frustum.hiz_resolution >> lod, ivec2(1)) - 1;
ix = min(ix >> lod, lod_max_coord.xx);
iy = min(iy >> lod, lod_max_coord.yy);

ivec2 hiz_coord = ivec2(ix.x, iy.x);

// We didn't write the top LOD.
lod -= frustum.hiz_min_lod;

float d = texelFetch(uHiZDepth, hiz_coord, lod).x;
bool nx = ix.y != ix.x;
bool ny = iy.y != iy.x;
Expand Down
55 changes: 42 additions & 13 deletions assets/shaders/post/hiz.comp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#version 450
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_quad : require

// A rewrite of SPD to support HiZ correctly and moar wave ops for good measure.

layout(local_size_x = 256) in;

layout(set = 0, binding = 0, r32f) uniform writeonly image2D uImageTop;
#if defined(WRITE_TOP_LEVEL) && WRITE_TOP_LEVEL
layout(set = 0, binding = 0, r32f) coherent writeonly uniform image2D uImageTop;
#endif
layout(set = 0, binding = 1, r32f) coherent uniform image2D uImages[12];
layout(set = 1, binding = 0) uniform sampler2D uTexture;
layout(set = 1, binding = 1) buffer Counter
Expand Down Expand Up @@ -66,13 +67,12 @@ void write_image(ivec2 coord, int mip, float v)
imageStore(uImages[mip - 1], coord, vec4(v));
}

void write_image4_top(ivec2 coord, int mip, vec4 v)
#if defined(WRITE_TOP_LEVEL) && WRITE_TOP_LEVEL
void write_image_top(ivec2 coord, float v)
{
imageStore(uImageTop, coord + ivec2(0, 0), v.xxxx);
imageStore(uImageTop, coord + ivec2(1, 0), v.yyyy);
imageStore(uImageTop, coord + ivec2(0, 1), v.zzzz);
imageStore(uImageTop, coord + ivec2(1, 1), v.wwww);
imageStore(uImageTop, coord, vec4(v, 0, 0, 0));
}
#endif

const int SHARED_WIDTH = 32;
const int SHARED_HEIGHT = 32;
Expand Down Expand Up @@ -120,11 +120,35 @@ float fetch_image_mip6(ivec2 coord)
return imageLoad(uImages[5], coord).x;
}

vec4 write_mip0_transformed(vec4 v, ivec2 base_coord)
vec4 write_mip0_transformed(vec4 v, ivec2 base_coord, ivec2 local_coord)
{
v = transform_z(v);

#if defined(WRITE_TOP_LEVEL) && WRITE_TOP_LEVEL
// Ensure that top-level image is written with full cache lines per warp.
// Writing in the strided 2x2 pattern is noticably bad for L2 performance.
// Taking extra time on the shader cores to reshuffle data is actually beneficial since we're fully bandwidth bound
// in these shaders, so we should give the memory system all the help it can get.
store_shared(2 * local_coord + ivec2(0, 0), v.x);
store_shared(2 * local_coord + ivec2(1, 0), v.y);
store_shared(2 * local_coord + ivec2(0, 1), v.z);
store_shared(2 * local_coord + ivec2(1, 1), v.w);

barrier();

// Write out transformed LOD 0
write_image4_top(base_coord, 0, v);
for (int y = 0; y < 2; y++)
{
for (int x = 0; x < 2; x++)
{
ivec2 tile_offset = ivec2(x, y) * 16;
write_image_top(base_coord + tile_offset + local_coord, load_shared(local_coord + tile_offset));
}
}

barrier();
#endif

return v;
}

Expand Down Expand Up @@ -241,10 +265,14 @@ void main()
// It seems like we need to be super careful about memory access patterns to get optimal bandwidth.

// LOD 0 feedback with transform.
vec4 tile00 = write_mip0_transformed(fetch_2x2_texture(base_coord_00), base_coord_00);
vec4 tile10 = write_mip0_transformed(fetch_2x2_texture(base_coord_10), base_coord_10);
vec4 tile01 = write_mip0_transformed(fetch_2x2_texture(base_coord_01), base_coord_01);
vec4 tile11 = write_mip0_transformed(fetch_2x2_texture(base_coord_11), base_coord_11);
vec4 tile00 = write_mip0_transformed(
fetch_2x2_texture(base_coord_00), ivec2(gl_WorkGroupID.xy * 64u) + ivec2(0, 0), ivec2(local_coord));
vec4 tile10 = write_mip0_transformed(
fetch_2x2_texture(base_coord_10), ivec2(gl_WorkGroupID.xy * 64u) + ivec2(32, 0), ivec2(local_coord));
vec4 tile01 = write_mip0_transformed(
fetch_2x2_texture(base_coord_01), ivec2(gl_WorkGroupID.xy * 64u) + ivec2(0, 32), ivec2(local_coord));
vec4 tile11 = write_mip0_transformed(
fetch_2x2_texture(base_coord_11), ivec2(gl_WorkGroupID.xy * 64u) + ivec2(32, 32), ivec2(local_coord));
if (registers.mips <= 1)
return;

Expand Down Expand Up @@ -275,6 +303,7 @@ void main()
store_shared(local_coord_shared + ivec2(0, 8), reduced01);
store_shared(local_coord_shared + ivec2(8, 8), reduced11);
}

barrier();

// Write LOD 3
Expand Down
4 changes: 2 additions & 2 deletions tests/hiz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ int main()
dev.begin_renderdoc_capture();

auto cmd = dev.request_command_buffer();
cmd->set_program("builtin://shaders/post/hiz.comp");
cmd->set_program("builtin://shaders/post/hiz.comp", {{ "WRITE_TOP_LEVEL", 1 }});
for (unsigned i = 0; i < 13; i++)
cmd->set_storage_texture(0, i, *views[i < push.mips ? i : (push.mips - 1)]);
cmd->set_texture(1, 0, img->get_view(), StockSampler::NearestClamp);
cmd->set_storage_buffer(1, 1, *counter_buffer);
cmd->push_constants(&push, 0, sizeof(push));
cmd->enable_subgroup_size_control(true);
cmd->set_subgroup_size_log2(true, 4, 7);
cmd->set_subgroup_size_log2(true, 2, 7);
cmd->dispatch(wg_x, wg_y, 1);
dev.submit(cmd);

Expand Down
35 changes: 14 additions & 21 deletions tests/meshlet_viewer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ struct MeshletViewerApplication : Granite::Application, Granite::EventHandler //
mat4 view;
vec4 viewport_scale_bias;
uvec2 hiz_resolution;
uint hiz_min_lod;
uint hiz_max_lod;
};

Expand All @@ -463,9 +464,10 @@ struct MeshletViewerApplication : Granite::Application, Granite::EventHandler //

ubo->view = render_context.get_render_parameters().view;
ubo->viewport_scale_bias = viewport_scale_bias;
ubo->hiz_resolution.x = hiz->get_view_width();
ubo->hiz_resolution.y = hiz->get_view_height();
ubo->hiz_max_lod = hiz->get_create_info().levels - 1;
ubo->hiz_resolution.x = hiz->get_view_width() * 2;
ubo->hiz_resolution.y = hiz->get_view_height() * 2;
ubo->hiz_min_lod = 1;
ubo->hiz_max_lod = hiz->get_create_info().levels;
}
};

Expand Down Expand Up @@ -805,24 +807,15 @@ struct MeshletViewerApplication : Granite::Application, Granite::EventHandler //
(depth_view.get_view_width() + 63u) & ~63u,
(depth_view.get_view_height() + 63u) & ~63u,
VK_FORMAT_R32_SFLOAT);
info.width /= 2;
info.height /= 2;
info.usage = VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT;
info.initial_layout = VK_IMAGE_LAYOUT_UNDEFINED;
info.levels = Util::floor_log2(max(depth_view.get_view_width(), depth_view.get_view_height()));
info.levels = Util::floor_log2(max(depth_view.get_view_width(), depth_view.get_view_height())) - 1;
info.misc |= IMAGE_MISC_CREATE_PER_MIP_LEVEL_VIEWS_BIT;

auto hiz = device.create_image(info);

ImageViewHandle views[13];
for (unsigned i = 0; i < info.levels; i++)
{
ImageViewCreateInfo view = {};
view.base_level = i;
view.levels = 1;
view.image = hiz.get();
view.view_type = VK_IMAGE_VIEW_TYPE_2D;
view.aspect = VK_IMAGE_ASPECT_COLOR_BIT;
views[i] = device.create_image_view(view);
}

struct Push
{
mat2 z_transform;
Expand All @@ -846,22 +839,22 @@ struct MeshletViewerApplication : Granite::Application, Granite::EventHandler //

Push push = {};
push.z_transform = inv_z;
push.resolution = uvec2(info.width, info.height);
push.resolution = uvec2(info.width * 2, info.height * 2);
push.inv_resolution = vec2(1.0f / float(depth_view.get_view_width()), 1.0f / float(depth_view.get_view_height()));
push.mips = info.levels;
push.mips = info.levels + 1;

uint32_t wg_x = (push.resolution.x + 63) / 64;
uint32_t wg_y = (push.resolution.y + 63) / 64;
push.target_counter = wg_x * wg_y;

cmd->set_program("builtin://shaders/post/hiz.comp");
for (unsigned i = 0; i < 13; i++)
cmd->set_storage_texture(0, i, *views[i < push.mips ? i : (push.mips - 1)]);
for (unsigned i = 0; i < 12; i++)
cmd->set_storage_texture_level(0, i + 1, hiz->get_view(), i < info.levels ? i : (info.levels - 1));
cmd->set_texture(1, 0, depth_view, StockSampler::NearestClamp);
cmd->set_storage_buffer(1, 1, *counter);
cmd->push_constants(&push, 0, sizeof(push));
cmd->enable_subgroup_size_control(true);
cmd->set_subgroup_size_log2(true, 4, 7);
cmd->set_subgroup_size_log2(true, 2, 7);

auto start_ts = cmd->write_timestamp(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);

Expand Down

0 comments on commit 3e5428a

Please sign in to comment.