Skip to content

Commit

Permalink
Experiment with vertex ID passing.
Browse files Browse the repository at this point in the history
  • Loading branch information
Themaister committed Jan 6, 2024
1 parent 3207736 commit 9ecbdae
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 30 deletions.
62 changes: 46 additions & 16 deletions tests/assets/shaders/meshlet_debug.mesh.frag
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,64 @@
#extension GL_EXT_nonuniform_qualifier : require
#extension GL_EXT_fragment_shader_barycentric : require

#if MESHLET_VERTEX_ID
layout(location = 0) pervertexEXT in uint vVertexID[];
layout(location = 1) perprimitiveEXT flat in uint vTransformIndex;
layout(location = 2) perprimitiveEXT flat in uint vDrawID;

struct TexturedAttr
{
uint n;
uint t;
vec2 uv;
};

layout(set = 0, binding = 2, std430) readonly buffer VBOATTR
{
TexturedAttr data[];
} attr;

layout(set = 0, binding = 5, std430) readonly buffer Transforms
{
mat4 data[];
} transforms;
#else
layout(location = 0) in mediump vec3 vNormal;
layout(location = 1) in mediump vec4 vTangent;
layout(location = 2) in vec2 vUV;
layout(location = 3) perprimitiveEXT flat in uint vDrawID;
#endif

layout(location = 0) out vec3 FragColor;

void main()
mediump vec3 decode_rgb10a2(uint v)
{
vec3 dd = fwidth(gl_BaryCoordEXT);
float d = max(max(dd.x, dd.y), dd.z);
float l = min(min(gl_BaryCoordEXT.x, gl_BaryCoordEXT.y), gl_BaryCoordEXT.z);
mediump ivec3 iv;
iv.x = bitfieldExtract(int(v), 0, 10);
iv.y = bitfieldExtract(int(v), 10, 10);
iv.z = bitfieldExtract(int(v), 20, 10);
return vec3(iv) / 511.0;
}

float pixels_from_edge = l / max(d, 0.0001);
float highlight = 1.0 - smoothstep(0.25, 0.75, pixels_from_edge);
void main()
{
#if MESHLET_VERTEX_ID
uint va = vVertexID[0];
uint vb = vVertexID[1];
uint vc = vVertexID[2];
uint na = attr.data[va].n;
uint nb = attr.data[vb].n;
uint nc = attr.data[vc].n;

mediump vec3 normal = gl_BaryCoordEXT.x * decode_rgb10a2(na) +
gl_BaryCoordEXT.y * decode_rgb10a2(nb) +
gl_BaryCoordEXT.z * decode_rgb10a2(nc);
normal = mat3(transforms.data[vTransformIndex]) * normal;
normal = normalize(normal);
#else
vec3 normal = normalize(vNormal);
vec3 tangent = normalize(vTangent.xyz);

FragColor = 0.3 * (0.5 * (normal * tangent * vTangent.w) + 0.5);
FragColor.rg += 0.05 * highlight;
FragColor.rg += vUV * 0.02;
#endif

FragColor = clamp(0.5 * normal + 0.5, vec3(0.0), vec3(1.0));
FragColor = pow(FragColor, vec3(4.0));

//uint hashed = vDrawID ^ (vDrawID * 23423465);
//FragColor.r += 0.1 * float(hashed % 19) / 19.0;
//FragColor.g += 0.1 * float(hashed % 29) / 29.0;
//FragColor.b += 0.1 * float(hashed % 131) / 131.0;
}
29 changes: 18 additions & 11 deletions tests/assets/shaders/meshlet_debug_plain.mesh
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
#error "Must define MESHLET_SIZE"
#endif

#if MESHLET_SIZE > 32
#if MESHLET_SIZE > 32 && !MESHLET_VERTEX_ID
shared uint shared_attr_index[MESHLET_SIZE];
shared vec4 shared_clip_pos[MESHLET_SIZE];
//#define MESHLET_PRIMITIVE_CULL_SHARED_INDEX shared_primitive
//shared u8vec3 shared_primitive[MESHLET_SIZE];
#endif

layout(max_primitives = MESHLET_SIZE, max_vertices = MESHLET_SIZE, triangles) out;
Expand All @@ -28,10 +25,16 @@ layout(local_size_x = 32, local_size_y_id = 0) in;
#include "meshlet_render_types.h"
#include "meshlet_primitive_cull.h"

#if MESHLET_VERTEX_ID
layout(location = 0) flat out uint vVertexID[];
layout(location = 1) perprimitiveEXT out uint vTransformIndex[];
layout(location = 2) perprimitiveEXT out uint vDrawID[];
#else
layout(location = 0) out mediump vec3 vNormal[];
layout(location = 1) out mediump vec4 vTangent[];
layout(location = 2) out vec2 vUV[];
layout(location = 3) perprimitiveEXT out uint vDrawID[];
#endif

layout(set = 1, binding = 0) uniform UBO
{
Expand Down Expand Up @@ -134,14 +137,16 @@ void main()

if (meshlet_lane_has_active_vert())
{
uint vert_id = meshlet.vertex_offset + linear_index + 32u * base_chunk_index;
uint out_vert_index = meshlet_compacted_vertex_output();
#if MESHLET_SIZE > 32
shared_attr_index[out_vert_index] = meshlet.vertex_offset + linear_index + 32u * base_chunk_index;
shared_clip_pos[out_vert_index] = clip_pos;
#else
gl_MeshVerticesEXT[out_vert_index].gl_Position = clip_pos;
TexturedAttr a = attr.data[meshlet.vertex_offset + linear_index + 32u * base_chunk_index];

#if MESHLET_VERTEX_ID
vVertexID[out_vert_index] = vert_id;
#elif MESHLET_SIZE > 32
shared_attr_index[out_vert_index] = vert_id;
#else
TexturedAttr a = attr.data[vert_id];
mediump vec3 n = unpack_bgr10a2(a.n).xyz;
mediump vec4 t = unpack_bgr10a2(a.t);
vUV[out_vert_index] = a.uv;
Expand All @@ -150,12 +155,11 @@ void main()
#endif
}

#if MESHLET_SIZE > 32
#if MESHLET_SIZE > 32 && !MESHLET_VERTEX_ID
barrier();

if (gl_LocalInvocationIndex < shared_active_vert_count_total)
{
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = shared_clip_pos[gl_LocalInvocationIndex];
TexturedAttr a = attr.data[shared_attr_index[gl_LocalInvocationIndex]];
mediump vec3 n = unpack_bgr10a2(a.n).xyz;
mediump vec4 t = unpack_bgr10a2(a.t);
Expand All @@ -169,6 +173,9 @@ void main()
{
#ifdef MESHLET_PRIMITIVE_CULL_SHARED_INDEX
gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(shared_primitive[gl_LocalInvocationIndex]);
#endif
#if MESHLET_VERTEX_ID
vTransformIndex[gl_LocalInvocationIndex] = task.node_offset;
#endif
vDrawID[gl_LocalInvocationIndex] = task.meshlet_index;
}
Expand Down
12 changes: 9 additions & 3 deletions tests/meshlet_viewer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ struct MeshletViewerApplication : Granite::Application, Granite::EventHandler //
bool use_hierarchical;
bool use_preculling;
bool use_occlusion_cull;
bool use_vertex_id;
} ui = {};

void render(CommandBuffer *cmd, const RenderPassInfo &rp, const ImageView *hiz)
Expand Down Expand Up @@ -561,13 +562,15 @@ struct MeshletViewerApplication : Granite::Application, Granite::EventHandler //

ui.supports_wave32 = Util::get_environment_bool("WAVE32", ui.supports_wave32);
ui.use_hierarchical = Util::get_environment_bool("HIER_TASK", ui.use_hierarchical);
ui.use_vertex_id = !use_encoded && Util::get_environment_int("VERTEX_ID", 0) != 0;

bool supports_wg32 = ui.supports_wave32 && ui.target_meshlet_workgroup_size == 32;

if (ui.use_preculling)
{
cmd->set_program("", mesh_path, "assets://shaders/meshlet_debug.mesh.frag",
{ { "MESHLET_SIZE", int(ui.target_meshlet_workgroup_size) } });
{ { "MESHLET_SIZE", int(ui.target_meshlet_workgroup_size) },
{ "MESHLET_VERTEX_ID", int(ui.use_vertex_id) } });
}
else
{
Expand All @@ -577,6 +580,7 @@ struct MeshletViewerApplication : Granite::Application, Granite::EventHandler //
{ "MESHLET_RENDER_TASK_HIERARCHICAL", int(ui.use_hierarchical) },
{ "MESHLET_RENDER_PHASE", render_phase },
{ "MESHLET_PRIMITIVE_CULL_WG32", int(supports_wg32) },
{ "MESHLET_VERTEX_ID", int(ui.use_vertex_id) },
{ "MESHLET_PRIMITIVE_CULL_WAVE32", int(ui.supports_wave32) } });

cmd->set_storage_buffer(0, 6, *aabb_buffer);
Expand Down Expand Up @@ -999,9 +1003,10 @@ struct MeshletViewerApplication : Granite::Application, Granite::EventHandler //
if (start_timestamps[readback_index] && start_timestamps[readback_index]->is_signalled() &&
end_timestamps[readback_index] && end_timestamps[readback_index]->is_signalled())
{
last_frame_time = device.convert_device_timestamp_delta(
auto next_frame_time = device.convert_device_timestamp_delta(
start_timestamps[readback_index]->get_timestamp_ticks(),
end_timestamps[readback_index]->get_timestamp_ticks());
last_frame_time = 0.999 * last_frame_time + 0.001 * next_frame_time;
}

auto encoding = device.get_resource_manager().get_mesh_encoding();
Expand Down Expand Up @@ -1100,12 +1105,13 @@ Application *application_create(int argc, char **argv)
cbs.add("--hier-task", [](Util::CLIParser &parser) { Util::set_environment("HIER_TASK", parser.next_string()); });
cbs.add("--wave32", [](Util::CLIParser &parser) { Util::set_environment("WAVE32", parser.next_string()); });
cbs.add("--precull", [](Util::CLIParser &parser) { Util::set_environment("PRECULL", parser.next_string()); });
cbs.add("--vertex-id", [](Util::CLIParser &parser) { Util::set_environment("VERTEX_ID", parser.next_string()); });
cbs.default_handler = [&](const char *arg) { path = arg; };

Util::CLIParser parser(std::move(cbs), argc - 1, argv + 1);
if (!parser.parse() || parser.is_ended_state() || !path)
{
LOGE("Usage: meshlet-viewer path.msh1\n");
LOGE("Usage: meshlet-viewer path.msh2\n");
return nullptr;
}

Expand Down

0 comments on commit 9ecbdae

Please sign in to comment.