Skip to content

Commit 024c197

Browse files
committed
ray query: validation, better test
1 parent 18710fe commit 024c197

File tree

5 files changed

+284
-124
lines changed

5 files changed

+284
-124
lines changed

src/valid/expression.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ pub enum ExpressionError {
3535
InvalidPointerType(Handle<crate::Expression>),
3636
#[error("Array length of {0:?} can't be done")]
3737
InvalidArrayType(Handle<crate::Expression>),
38+
#[error("Get intersection of {0:?} can't be done")]
39+
InvalidRayQueryType(Handle<crate::Expression>),
3840
#[error("Splatting {0:?} can't be done")]
3941
InvalidSplatType(Handle<crate::Expression>),
4042
#[error("Swizzling {0:?} can't be done")]
@@ -1427,7 +1429,26 @@ impl super::Validator {
14271429
return Err(ExpressionError::InvalidArrayType(expr));
14281430
}
14291431
},
1430-
E::RayQueryProceedResult | E::RayQueryGetIntersection { .. } => ShaderStages::all(),
1432+
E::RayQueryProceedResult => ShaderStages::all(),
1433+
E::RayQueryGetIntersection {
1434+
query,
1435+
committed: _,
1436+
} => match resolver[query] {
1437+
Ti::Pointer {
1438+
base,
1439+
space: crate::AddressSpace::Function,
1440+
} => match resolver.types[base].inner {
1441+
Ti::RayQuery => ShaderStages::all(),
1442+
ref other => {
1443+
log::error!("Intersection result of a pointer to {:?}", other);
1444+
return Err(ExpressionError::InvalidRayQueryType(query));
1445+
}
1446+
},
1447+
ref other => {
1448+
log::error!("Intersection result of {:?}", other);
1449+
return Err(ExpressionError::InvalidRayQueryType(query));
1450+
}
1451+
},
14311452
};
14321453
Ok(stages)
14331454
}

src/valid/function.rs

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ pub enum AtomicError {
4747
InvalidPointer(Handle<crate::Expression>),
4848
#[error("Operand {0:?} has invalid type.")]
4949
InvalidOperand(Handle<crate::Expression>),
50-
#[error("Result expression {0:?} has already been introduced earlier")]
51-
ResultAlreadyInScope(Handle<crate::Expression>),
5250
#[error("Result type for {0:?} doesn't match the statement")]
5351
ResultTypeMismatch(Handle<crate::Expression>),
5452
}
@@ -131,6 +129,14 @@ pub enum FunctionError {
131129
},
132130
#[error("Atomic operation is invalid")]
133131
InvalidAtomic(#[from] AtomicError),
132+
#[error("Ray Query {0:?} is not a local variable")]
133+
InvalidRayQueryExpression(Handle<crate::Expression>),
134+
#[error("Acceleration structure {0:?} is not a matching expression")]
135+
InvalidAccelerationStructure(Handle<crate::Expression>),
136+
#[error("Ray descriptor {0:?} is not a matching expression")]
137+
InvalidRayDescriptor(Handle<crate::Expression>),
138+
#[error("Ray Query {0:?} does not have a matching type")]
139+
InvalidRayQueryType(Handle<crate::Type>),
134140
#[error(
135141
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
136142
)]
@@ -169,8 +175,10 @@ struct BlockContext<'a> {
169175
info: &'a FunctionInfo,
170176
expressions: &'a Arena<crate::Expression>,
171177
types: &'a UniqueArena<crate::Type>,
178+
local_vars: &'a Arena<crate::LocalVariable>,
172179
global_vars: &'a Arena<crate::GlobalVariable>,
173180
functions: &'a Arena<crate::Function>,
181+
special_types: &'a crate::SpecialTypes,
174182
prev_infos: &'a [FunctionInfo],
175183
return_type: Option<Handle<crate::Type>>,
176184
}
@@ -188,8 +196,10 @@ impl<'a> BlockContext<'a> {
188196
info,
189197
expressions: &fun.expressions,
190198
types: &module.types,
199+
local_vars: &fun.local_variables,
191200
global_vars: &module.global_variables,
192201
functions: &module.functions,
202+
special_types: &module.special_types,
193203
prev_infos,
194204
return_type: fun.result.as_ref().map(|fr| fr.ty),
195205
}
@@ -299,6 +309,21 @@ impl super::Validator {
299309
Ok(callee_info.available_stages)
300310
}
301311

312+
#[cfg(feature = "validate")]
313+
fn emit_expression(
314+
&mut self,
315+
handle: Handle<crate::Expression>,
316+
context: &BlockContext,
317+
) -> Result<(), WithSpan<FunctionError>> {
318+
if self.valid_expression_set.insert(handle.index()) {
319+
self.valid_expression_list.push(handle);
320+
Ok(())
321+
} else {
322+
Err(FunctionError::ExpressionAlreadyInScope(handle)
323+
.with_span_handle(handle, context.expressions))
324+
}
325+
}
326+
302327
#[cfg(feature = "validate")]
303328
fn validate_atomic(
304329
&mut self,
@@ -347,13 +372,7 @@ impl super::Validator {
347372
}
348373
}
349374

350-
if self.valid_expression_set.insert(result.index()) {
351-
self.valid_expression_list.push(result);
352-
} else {
353-
return Err(AtomicError::ResultAlreadyInScope(result)
354-
.with_span_handle(result, context.expressions)
355-
.into_other());
356-
}
375+
self.emit_expression(result, context)?;
357376
match context.expressions[result] {
358377
crate::Expression::AtomicResult { ty, comparison }
359378
if {
@@ -401,12 +420,7 @@ impl super::Validator {
401420
match *statement {
402421
S::Emit(ref range) => {
403422
for handle in range.clone() {
404-
if self.valid_expression_set.insert(handle.index()) {
405-
self.valid_expression_list.push(handle);
406-
} else {
407-
return Err(FunctionError::ExpressionAlreadyInScope(handle)
408-
.with_span_handle(handle, context.expressions));
409-
}
423+
self.emit_expression(handle, context)?;
410424
}
411425
}
412426
S::Block(ref block) => {
@@ -807,8 +821,55 @@ impl super::Validator {
807821
} => {
808822
self.validate_atomic(pointer, fun, value, result, context)?;
809823
}
810-
S::RayQuery { query: _, fun: _ } => {
811-
//TODO
824+
S::RayQuery { query, ref fun } => {
825+
let query_var = match *context.get_expression(query) {
826+
crate::Expression::LocalVariable(var) => &context.local_vars[var],
827+
ref other => {
828+
log::error!("Unexpected ray query expression {other:?}");
829+
return Err(FunctionError::InvalidRayQueryExpression(query)
830+
.with_span_static(span, "invalid query expression"));
831+
}
832+
};
833+
match context.types[query_var.ty].inner {
834+
Ti::RayQuery => {}
835+
ref other => {
836+
log::error!("Unexpected ray query type {other:?}");
837+
return Err(FunctionError::InvalidRayQueryType(query_var.ty)
838+
.with_span_static(span, "invalid query type"));
839+
}
840+
}
841+
match *fun {
842+
crate::RayQueryFunction::Initialize {
843+
acceleration_structure,
844+
descriptor,
845+
} => {
846+
match *context
847+
.resolve_type(acceleration_structure, &self.valid_expression_set)?
848+
{
849+
Ti::AccelerationStructure => {}
850+
_ => {
851+
return Err(FunctionError::InvalidAccelerationStructure(
852+
acceleration_structure,
853+
)
854+
.with_span_static(span, "invalid acceleration structure"))
855+
}
856+
}
857+
let desc_ty_given =
858+
context.resolve_type(descriptor, &self.valid_expression_set)?;
859+
let desc_ty_expected = context
860+
.special_types
861+
.ray_desc
862+
.map(|handle| &context.types[handle].inner);
863+
if Some(desc_ty_given) != desc_ty_expected {
864+
return Err(FunctionError::InvalidRayDescriptor(descriptor)
865+
.with_span_static(span, "invalid ray descriptor"));
866+
}
867+
}
868+
crate::RayQueryFunction::Proceed { result } => {
869+
self.emit_expression(result, context)?;
870+
}
871+
crate::RayQueryFunction::Terminate => {}
872+
}
812873
}
813874
}
814875
}

tests/in/ray-query.wgsl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,29 @@ struct RayIntersection {
4545

4646
struct Output {
4747
visible: u32,
48+
normal: vec3<f32>,
4849
}
4950

5051
@group(0) @binding(1)
5152
var<storage, read_write> output: Output;
5253

54+
fn get_torus_normal(world_point: vec3<f32>, intersection: RayIntersection) -> vec3<f32> {
55+
let local_point = intersection.world_to_object * vec4<f32>(world_point, 1.0);
56+
let point_on_guiding_line = normalize(local_point.xy) * 2.4;
57+
let world_point_on_guiding_line = intersection.object_to_world * vec4<f32>(point_on_guiding_line, 0.0, 1.0);
58+
return normalize(world_point - world_point_on_guiding_line);
59+
}
60+
5361
@compute @workgroup_size(1)
5462
fn main() {
5563
var rq: ray_query;
5664

57-
rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3<f32>(0.0), vec3<f32>(0.0, 1.0, 0.0)));
65+
let dir = vec3<f32>(0.0, 1.0, 0.0);
66+
rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3<f32>(0.0), dir));
5867

59-
rayQueryProceed(&rq);
68+
while (rayQueryProceed(&rq)) {}
6069

6170
let intersection = rayQueryGetCommittedIntersection(&rq);
6271
output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE);
72+
output.normal = get_torus_normal(dir * intersection.t, intersection);
6373
}

tests/out/msl/ray-query.msl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,8 @@ constexpr metal::uint _map_intersection_type(const metal::raytracing::intersecti
1515

1616
struct Output {
1717
uint visible_;
18-
};
19-
struct RayDesc {
20-
uint flags;
21-
uint cull_mask;
22-
float tmin;
23-
float tmax;
24-
metal::float3 origin;
25-
metal::float3 dir;
18+
char _pad1[12];
19+
metal::float3 normal;
2620
};
2721
struct RayIntersection {
2822
uint kind;
@@ -38,21 +32,48 @@ struct RayIntersection {
3832
metal::float4x3 object_to_world;
3933
metal::float4x3 world_to_object;
4034
};
35+
struct RayDesc {
36+
uint flags;
37+
uint cull_mask;
38+
float tmin;
39+
float tmax;
40+
metal::float3 origin;
41+
metal::float3 dir;
42+
};
43+
44+
metal::float3 get_torus_normal(
45+
metal::float3 world_point,
46+
RayIntersection intersection
47+
) {
48+
metal::float3 local_point = intersection.world_to_object * metal::float4(world_point, 1.0);
49+
metal::float2 point_on_guiding_line = metal::normalize(local_point.xy) * 2.4000000953674316;
50+
metal::float3 world_point_on_guiding_line = intersection.object_to_world * metal::float4(point_on_guiding_line, 0.0, 1.0);
51+
return metal::normalize(world_point - world_point_on_guiding_line);
52+
}
4153

4254
kernel void main_(
4355
metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]]
4456
, device Output& output [[user(fake0)]]
4557
) {
4658
_RayQuery rq = {};
47-
RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), metal::float3(0.0, 1.0, 0.0)};
59+
metal::float3 dir = metal::float3(0.0, 1.0, 0.0);
60+
RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), dir};
4861
rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
4962
rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
5063
rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
5164
rq.intersector.accept_any_intersection((_e12.flags & 4) != 0);
5265
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq.ready = true;
53-
bool _e13 = rq.ready;
54-
rq.ready = false;
55-
RayIntersection intersection = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform};
56-
output.visible_ = static_cast<uint>(intersection.kind == 0u);
66+
while(true) {
67+
bool _e13 = rq.ready;
68+
rq.ready = false;
69+
if (_e13) {
70+
} else {
71+
break;
72+
}
73+
}
74+
RayIntersection intersection_1 = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform};
75+
output.visible_ = static_cast<uint>(intersection_1.kind == 0u);
76+
metal::float3 _e25 = get_torus_normal(dir * intersection_1.t, intersection_1);
77+
output.normal = _e25;
5778
return;
5879
}

0 commit comments

Comments
 (0)