Skip to content

Commit 532d1ad

Browse files
committed
msl: ray query support
1 parent 22e341b commit 532d1ad

13 files changed

+295
-30
lines changed

src/back/mod.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,26 @@ impl crate::Statement {
218218
}
219219
}
220220
}
221+
222+
bitflags::bitflags! {
223+
/// Ray flags.
224+
#[derive(Default)]
225+
pub struct RayFlag: u32 {
226+
const OPAQUE = 0x01;
227+
const NO_OPAQUE = 0x02;
228+
const TERMINATE_ON_FIRST_HIT = 0x04;
229+
const SKIP_CLOSEST_HIT_SHADER = 0x08;
230+
const CULL_FRONT_FACING = 0x10;
231+
const CULL_BACK_FACING = 0x20;
232+
const CULL_OPAQUE = 0x40;
233+
const CULL_NO_OPAQUE = 0x80;
234+
const SKIP_TRIANGLES = 0x100;
235+
const SKIP_AABBS = 0x200;
236+
}
237+
}
238+
239+
#[repr(u32)]
240+
enum RayIntersectionType {
241+
Triangle = 1,
242+
BoundingBox = 4,
243+
}

src/back/msl/mod.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,7 @@ impl Options {
314314
match slot {
315315
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
316316
buffer: Some(slot),
317-
texture: None,
318-
sampler: None,
319-
binding_array_size: None,
320-
mutable: false,
317+
..Default::default()
321318
})),
322319
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
323320
prefix: "fake",
@@ -338,10 +335,7 @@ impl Options {
338335
match slot {
339336
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
340337
buffer: Some(slot),
341-
texture: None,
342-
sampler: None,
343-
binding_array_size: None,
344-
mutable: false,
338+
..Default::default()
345339
})),
346340
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
347341
prefix: "fake",

src/back/msl/writer.rs

Lines changed: 186 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ const WRAPPED_ARRAY_FIELD: &str = "inner";
2525
// Some more general handling of pointers is needed to be implemented here.
2626
const ATOMIC_REFERENCE: &str = "&";
2727

28+
const RT_NAMESPACE: &str = "metal::raytracing";
29+
const RAY_QUERY_TYPE: &str = "_RayQuery";
30+
const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
31+
const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
32+
const RAY_QUERY_FIELD_READY: &str = "ready";
33+
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
34+
2835
/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
2936
///
3037
/// The `sizes` slice determines whether this function writes a
@@ -194,8 +201,11 @@ impl<'a> Display for TypeContext<'a> {
194201
crate::TypeInner::Sampler { comparison: _ } => {
195202
write!(out, "{NAMESPACE}::sampler")
196203
}
197-
crate::TypeInner::AccelerationStructure | crate::TypeInner::RayQuery => {
198-
unreachable!("Ray queries are not supported yet");
204+
crate::TypeInner::AccelerationStructure => {
205+
write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
206+
}
207+
crate::TypeInner::RayQuery => {
208+
write!(out, "{RAY_QUERY_TYPE}")
199209
}
200210
crate::TypeInner::BindingArray { base, size } => {
201211
let base_tyname = Self {
@@ -1865,8 +1875,39 @@ impl<W: Write> Writer<W> {
18651875
write!(self.out, ")")?;
18661876
}
18671877
}
1868-
// hot supported yet
1869-
crate::Expression::RayQueryGetIntersection { .. } => unreachable!(),
1878+
crate::Expression::RayQueryGetIntersection { query, committed } => {
1879+
if !committed {
1880+
unimplemented!()
1881+
}
1882+
let ty = context.module.special_types.ray_intersection.unwrap();
1883+
let type_name = &self.names[&NameKey::Type(ty)];
1884+
write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
1885+
self.put_expression(query, context, true)?;
1886+
write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
1887+
let fields = [
1888+
"distance",
1889+
"user_instance_id",
1890+
"instance_id",
1891+
"", // SBT offset
1892+
"geometry_id",
1893+
"primitive_id",
1894+
"triangle_barycentric_coord",
1895+
"triangle_front_facing",
1896+
"", // padding
1897+
"object_to_world_transform",
1898+
"world_to_object_transform",
1899+
];
1900+
for field in fields {
1901+
write!(self.out, ", ")?;
1902+
if field.is_empty() {
1903+
write!(self.out, "{{}}")?;
1904+
} else {
1905+
self.put_expression(query, context, true)?;
1906+
write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
1907+
}
1908+
}
1909+
write!(self.out, "}}")?;
1910+
}
18701911
}
18711912
Ok(())
18721913
}
@@ -2320,13 +2361,24 @@ impl<W: Write> Writer<W> {
23202361
) {
23212362
use crate::Expression;
23222363
self.need_bake_expressions.clear();
2364+
23232365
for (expr_handle, expr) in func.expressions.iter() {
23242366
// Expressions whose reference count is above the
23252367
// threshold should always be stored in temporaries.
23262368
let expr_info = &info[expr_handle];
23272369
let min_ref_count = func.expressions[expr_handle].bake_ref_count();
23282370
if min_ref_count <= expr_info.ref_count {
23292371
self.need_bake_expressions.insert(expr_handle);
2372+
} else {
2373+
match expr_info.ty {
2374+
// force ray desc to be baked: it's used multiple times internally
2375+
TypeResolution::Handle(h)
2376+
if Some(h) == context.module.special_types.ray_desc =>
2377+
{
2378+
self.need_bake_expressions.insert(expr_handle);
2379+
}
2380+
_ => {}
2381+
}
23302382
}
23312383

23322384
if let Expression::Math { fun, arg, arg1, .. } = *expr {
@@ -2338,11 +2390,11 @@ impl<W: Write> Writer<W> {
23382390
// times, once for each component (see `put_dot_product`), so to
23392391
// avoid duplicated evaluation, we must bake integer operands.
23402392

2341-
use crate::TypeInner;
23422393
// check what kind of product this is depending
23432394
// on the resolve type of the Dot function itself
2344-
let inner = context.resolve_type(expr_handle);
2345-
if let TypeInner::Scalar { kind, .. } = *inner {
2395+
if let crate::TypeInner::Scalar { kind, .. } =
2396+
*context.resolve_type(expr_handle)
2397+
{
23462398
match kind {
23472399
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
23482400
self.need_bake_expressions.insert(arg);
@@ -2763,7 +2815,100 @@ impl<W: Write> Writer<W> {
27632815
// done
27642816
writeln!(self.out, ";")?;
27652817
}
2766-
crate::Statement::RayQuery { .. } => unreachable!(),
2818+
crate::Statement::RayQuery { query, ref fun } => {
2819+
match *fun {
2820+
crate::RayQueryFunction::Initialize {
2821+
acceleration_structure,
2822+
descriptor,
2823+
} => {
2824+
//TODO: how to deal with winding?
2825+
write!(self.out, "{level}")?;
2826+
self.put_expression(query, &context.expression, true)?;
2827+
writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
2828+
{
2829+
let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
2830+
let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
2831+
write!(self.out, "{level}")?;
2832+
self.put_expression(query, &context.expression, true)?;
2833+
write!(
2834+
self.out,
2835+
".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
2836+
)?;
2837+
self.put_expression(descriptor, &context.expression, true)?;
2838+
write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
2839+
self.put_expression(descriptor, &context.expression, true)?;
2840+
write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
2841+
writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
2842+
}
2843+
{
2844+
let f_opaque = back::RayFlag::OPAQUE.bits();
2845+
let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
2846+
write!(self.out, "{level}")?;
2847+
self.put_expression(query, &context.expression, true)?;
2848+
write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
2849+
self.put_expression(descriptor, &context.expression, true)?;
2850+
write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
2851+
self.put_expression(descriptor, &context.expression, true)?;
2852+
write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
2853+
writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
2854+
}
2855+
{
2856+
let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
2857+
write!(self.out, "{level}")?;
2858+
self.put_expression(query, &context.expression, true)?;
2859+
write!(
2860+
self.out,
2861+
".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
2862+
)?;
2863+
self.put_expression(descriptor, &context.expression, true)?;
2864+
writeln!(self.out, ".flags & {flag}) != 0);")?;
2865+
}
2866+
2867+
write!(self.out, "{level}")?;
2868+
self.put_expression(query, &context.expression, true)?;
2869+
write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
2870+
self.put_expression(query, &context.expression, true)?;
2871+
write!(
2872+
self.out,
2873+
".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
2874+
)?;
2875+
self.put_expression(descriptor, &context.expression, true)?;
2876+
write!(self.out, ".origin, ")?;
2877+
self.put_expression(descriptor, &context.expression, true)?;
2878+
write!(self.out, ".dir, ")?;
2879+
self.put_expression(descriptor, &context.expression, true)?;
2880+
write!(self.out, ".tmin, ")?;
2881+
self.put_expression(descriptor, &context.expression, true)?;
2882+
write!(self.out, ".tmax), ")?;
2883+
self.put_expression(acceleration_structure, &context.expression, true)?;
2884+
write!(self.out, ", ")?;
2885+
self.put_expression(descriptor, &context.expression, true)?;
2886+
write!(self.out, ".cull_mask);")?;
2887+
2888+
write!(self.out, "{level}")?;
2889+
self.put_expression(query, &context.expression, true)?;
2890+
writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
2891+
}
2892+
crate::RayQueryFunction::Proceed { result } => {
2893+
write!(self.out, "{level}")?;
2894+
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
2895+
self.start_baking_expression(result, &context.expression, &name)?;
2896+
self.named_expressions.insert(result, name);
2897+
self.put_expression(query, &context.expression, true)?;
2898+
writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
2899+
//TODO: actually proceed?
2900+
2901+
write!(self.out, "{level}")?;
2902+
self.put_expression(query, &context.expression, true)?;
2903+
writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
2904+
}
2905+
crate::RayQueryFunction::Terminate => {
2906+
write!(self.out, "{level}")?;
2907+
self.put_expression(query, &context.expression, true)?;
2908+
writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?;
2909+
}
2910+
}
2911+
}
27672912
}
27682913
}
27692914

@@ -2875,14 +3020,41 @@ impl<W: Write> Writer<W> {
28753020
writeln!(self.out)?;
28763021
// Work around Metal bug where `uint` is not available by default
28773022
writeln!(self.out, "using {NAMESPACE}::uint;")?;
2878-
writeln!(self.out)?;
28793023

3024+
if module.types.iter().any(|(_, t)| match t.inner {
3025+
crate::TypeInner::RayQuery => true,
3026+
_ => false,
3027+
}) {
3028+
let tab = back::INDENT;
3029+
writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
3030+
let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
3031+
writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
3032+
writeln!(
3033+
self.out,
3034+
"{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
3035+
)?;
3036+
writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
3037+
writeln!(self.out, "}};")?;
3038+
writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
3039+
let v_triangle = back::RayIntersectionType::Triangle as u32;
3040+
let v_bbox = back::RayIntersectionType::BoundingBox as u32;
3041+
writeln!(
3042+
self.out,
3043+
"{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
3044+
)?;
3045+
writeln!(
3046+
self.out,
3047+
"{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
3048+
)?;
3049+
writeln!(self.out, "}}")?;
3050+
}
28803051
if options
28813052
.bounds_check_policies
28823053
.contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
28833054
{
28843055
self.put_default_constructible()?;
28853056
}
3057+
writeln!(self.out)?;
28863058

28873059
{
28883060
let mut indices = vec![];
@@ -2924,11 +3096,12 @@ impl<W: Write> Writer<W> {
29243096
///
29253097
/// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
29263098
fn put_default_constructible(&mut self) -> BackendResult {
3099+
let tab = back::INDENT;
29273100
writeln!(self.out, "struct DefaultConstructible {{")?;
2928-
writeln!(self.out, " template<typename T>")?;
2929-
writeln!(self.out, " operator T() && {{")?;
2930-
writeln!(self.out, " return T {{}};")?;
2931-
writeln!(self.out, " }}")?;
3101+
writeln!(self.out, "{tab}template<typename T>")?;
3102+
writeln!(self.out, "{tab}operator T() && {{")?;
3103+
writeln!(self.out, "{tab}{tab}return T {{}};")?;
3104+
writeln!(self.out, "{tab}}}")?;
29323105
writeln!(self.out, "}};")?;
29333106
Ok(())
29343107
}

tests/in/ray-query.param.ron

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,12 @@
33
spv: (
44
version: (1, 4),
55
),
6+
msl: (
7+
lang_version: (2, 4),
8+
spirv_cross_compatibility: false,
9+
fake_missing_bindings: true,
10+
zero_initialize_workgroup_memory: false,
11+
per_entry_point_map: {},
12+
inline_samplers: [],
13+
),
614
)

tests/in/ray-query.wgsl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@
22
var acc_struct: acceleration_structure;
33

44
/*
5-
let RAY_FLAG_NONE = 0u;
6-
let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 4u;
5+
let RAY_FLAG_NONE = 0x00u;
6+
let RAY_FLAG_OPAQUE = 0x01u;
7+
let RAY_FLAG_NO_OPAQUE = 0x02u;
8+
let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u;
9+
let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u;
10+
let RAY_FLAG_CULL_FRONT_FACING = 0x10u;
11+
let RAY_FLAG_CULL_BACK_FACING = 0x20u;
12+
let RAY_FLAG_CULL_OPAQUE = 0x40u;
13+
let RAY_FLAG_CULL_NO_OPAQUE = 0x80u;
14+
let RAY_FLAG_SKIP_TRIANGLES = 0x100u;
15+
let RAY_FLAG_SKIP_AABBS = 0x200u;
716

817
let RAY_QUERY_INTERSECTION_NONE = 0u;
918
let RAY_QUERY_INTERSECTION_TRIANGLE = 1u;

tests/out/msl/binding-arrays.msl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
#include <simd/simd.h>
44

55
using metal::uint;
6-
76
struct DefaultConstructible {
87
template<typename T>
98
operator T() && {
109
return T {};
1110
}
1211
};
12+
1313
struct UniformIndex {
1414
uint index;
1515
};

tests/out/msl/bounds-check-image-rzsw.msl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
#include <simd/simd.h>
44

55
using metal::uint;
6-
76
struct DefaultConstructible {
87
template<typename T>
98
operator T() && {
109
return T {};
1110
}
1211
};
12+
1313
constant metal::int2 const_type_4_ = {0, 0};
1414
constant metal::int3 const_type_7_ = {0, 0, 0};
1515
constant metal::float4 const_type_2_ = {0.0, 0.0, 0.0, 0.0};

0 commit comments

Comments
 (0)