Skip to content

Commit 1f28701

Browse files
committed
Feat: OptiX device 2; hardware boogaloo
1 parent f5379d3 commit 1f28701

File tree

8 files changed

+676
-92
lines changed

8 files changed

+676
-92
lines changed

crates/optix/examples/rust/ex04_mesh_gpu/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ pub unsafe fn __raygen__renderFrame() {
113113
RayType::SurfaceRay as u32,
114114
1,
115115
RayType::SurfaceRay as u32,
116-
&mut p0,
117-
&mut p1,
116+
[&mut p0, &mut p1],
118117
);
119118

120119
let fb_index = i.x + i.y * PARAMS.frame.size.x;

crates/optix_device/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ bitflags = "1.3.2"
99
cuda_std = { version = "0.2", path = "../cuda_std" }
1010
glam = { version = "0.20", features=["cuda", "libm"], default-features=false }
1111
paste = "1.0.6"
12+
seq-macro = "0.3.0"

crates/optix_device/src/hit.rs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
use cuda_std::gpu_only;
2+
use glam::Vec3;
3+
4+
/// The type of primitive that a ray hit.
5+
#[repr(u32)]
6+
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
7+
pub enum HitKind {
8+
/// A custom primitive, the value is the custom
9+
/// 7-bit hit kind set when reporting an intersection.
10+
Custom(u8),
11+
/// B-spline curve of degree 2 with circular cross-section.
12+
RoundQuadraticBSpline,
13+
/// B-spline curve of degree 3 with circular cross-section.
14+
RoundCubicBSpline,
15+
/// Piecewise linear curve with circular cross-section.
16+
RoundLinear,
17+
/// CatmullRom curve with circular cross-section.
18+
RoundCatmullRom,
19+
/// ▲
20+
Triangle,
21+
}
22+
23+
#[repr(u32)]
24+
#[allow(dead_code)]
25+
enum OptixPrimitiveType {
26+
Custom = 0x2500,
27+
RoundQuadraticBSpline = 0x2501,
28+
RoundCubicBSpline = 0x2502,
29+
RoundLinear = 0x2503,
30+
RoundCatmullRom = 0x2504,
31+
Triangle = 0x2531,
32+
}
33+
34+
#[gpu_only]
35+
fn get_primitive_type(val: u8) -> OptixPrimitiveType {
36+
let raw: u32;
37+
unsafe {
38+
asm!("call ({}), _optix_get_primitive_type_from_hit_kind, ({});", out(reg32) raw, in(reg32) val);
39+
core::mem::transmute(raw)
40+
}
41+
}
42+
43+
impl HitKind {
44+
pub(crate) fn from_raw(val: u8) -> Self {
45+
let kind = get_primitive_type(val);
46+
match kind {
47+
OptixPrimitiveType::Custom => HitKind::Custom(val & 0b0111_1111),
48+
OptixPrimitiveType::RoundQuadraticBSpline => HitKind::RoundQuadraticBSpline,
49+
OptixPrimitiveType::RoundCubicBSpline => HitKind::RoundCubicBSpline,
50+
OptixPrimitiveType::RoundLinear => HitKind::RoundLinear,
51+
OptixPrimitiveType::RoundCatmullRom => HitKind::RoundCatmullRom,
52+
OptixPrimitiveType::Triangle => HitKind::Triangle,
53+
}
54+
}
55+
}
56+
57+
#[gpu_only]
58+
fn get_hit_kind() -> u8 {
59+
let x: u8;
60+
unsafe {
61+
asm!("call ({}), _optix_get_hit_kind, ();", out(reg32) x);
62+
}
63+
x
64+
}
65+
66+
/// Returns the kind of primitive that was hit by the ray.
67+
pub fn hit_kind() -> HitKind {
68+
HitKind::from_raw(get_hit_kind())
69+
}
70+
71+
#[gpu_only]
72+
pub fn is_back_face_hit() -> bool {
73+
let hit_kind = get_hit_kind();
74+
let x: u32;
75+
unsafe {
76+
asm!("call ({}), _optix_get_backface_from_hit_kind, ({});", out(reg32) x, in(reg32) hit_kind);
77+
}
78+
x == 1
79+
}
80+
81+
/// Whether the ray hit a front face.
82+
pub fn is_front_face_hit() -> bool {
83+
!is_back_face_hit()
84+
}
85+
86+
/// Whether the ray hit a triangle.
87+
pub fn is_triangle_hit() -> bool {
88+
is_triangle_front_face_hit() || is_triangle_back_face_hit()
89+
}
90+
91+
const OPTIX_HIT_KIND_TRIANGLE_FRONT_FACE: u8 = 0xFE;
92+
const OPTIX_HIT_KIND_TRIANGLE_BACK_FACE: u8 = 0xFF;
93+
94+
/// Whether the ray hit the front face of a triangle.
95+
pub fn is_triangle_front_face_hit() -> bool {
96+
get_hit_kind() == OPTIX_HIT_KIND_TRIANGLE_FRONT_FACE
97+
}
98+
99+
/// Whether the ray hit the back face of a triangle.
100+
pub fn is_triangle_back_face_hit() -> bool {
101+
get_hit_kind() == OPTIX_HIT_KIND_TRIANGLE_BACK_FACE
102+
}
103+
104+
/// Returns the barycentric coordinates of the hit point on the hit triangle.
105+
#[gpu_only]
106+
pub fn triangle_barycentrics() -> Vec3 {
107+
let x: f32;
108+
let y: f32;
109+
unsafe {
110+
asm!(
111+
"call ({}, {}), _optix_get_triangle_barycentrics, ();",
112+
out(reg32) x,
113+
out(reg32) y
114+
);
115+
}
116+
Vec3::new(x, y, 1.0 - x - y)
117+
}

crates/optix_device/src/intersect.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
use cuda_std::gpu_only;
2+
use paste::paste;
3+
use seq_macro::seq;
4+
5+
pub trait IntersectionPayload {
6+
fn report_intersection(hit_t: f32, hit_kind: u8, payload: Self) -> bool;
7+
}
8+
9+
macro_rules! impl_intersection_payload {
10+
($num:tt) => {
11+
paste! {
12+
impl IntersectionPayload for [u32; $num] {
13+
#[gpu_only]
14+
fn report_intersection(
15+
hit_t: f32,
16+
hit_kind: u8,
17+
payload: Self,
18+
) -> bool {
19+
seq!(
20+
P in 0..$num {{
21+
let [#(p~P,)*] = payload;
22+
let out: u32;
23+
unsafe {
24+
asm!(
25+
concat!("call ({}), _optix_report_intersection_", stringify!($num)),
26+
concat!(", ({}, {}, ", #(concat!("{", stringify!(p~P), "},"),)* ");"),
27+
out(reg32) out,
28+
in(reg32) hit_t,
29+
in(reg32) hit_kind,
30+
#(
31+
p~P = in(reg32) p~P,
32+
)*
33+
);
34+
}
35+
out != 0
36+
}}
37+
)
38+
}
39+
}
40+
}
41+
};
42+
() => {
43+
seq! {
44+
N in 0..=7 {
45+
impl_intersection_payload! { N }
46+
}
47+
}
48+
}
49+
}
50+
51+
impl_intersection_payload! {}
52+
53+
/// Reports an intersection and passes custom attributes to further programs.
54+
///
55+
/// If `tmin <= hit_t <= tmax` then the anyhit program associated with this intersection will be invoked,
56+
/// then the program will do one of three things:
57+
/// - Ignore the intersection; no hit is recorded and this function returns `false`.
58+
/// - Terminate the ray; a hit is recorded and this function does not return. No further traversal occurs and the associated
59+
/// closesthit program is invoked.
60+
/// - Neither; A hit is recorded and this function returns `true`.
61+
///
62+
/// **Only the lower 7 bits of the `hit_kind` should be written, the top 127 values are reserved for hardware primitives.**
63+
pub fn report_intersection<P: IntersectionPayload>(hit_t: f32, hit_kind: u8, payload: P) -> bool {
64+
P::report_intersection(hit_t, hit_kind, payload)
65+
}
66+
67+
/// Records the hit, stops traversal, then proceeds to the closesthit program.
68+
#[gpu_only]
69+
pub fn terminate_ray() {
70+
unsafe {
71+
asm!("call _optix_terminate_ray, ();");
72+
}
73+
}
74+
75+
/// Discards the hit and returns control to the calling intersection program or the built-in intersection hardware.
76+
#[gpu_only]
77+
pub fn ignore_intersection() {
78+
unsafe {
79+
asm!("call _optix_ignore_intersection, ();");
80+
}
81+
}
82+
83+
macro_rules! get_attribute_fns {
84+
($num:tt) => {
85+
paste! {
86+
#[gpu_only]
87+
#[allow(clippy::missing_safety_doc)]
88+
unsafe fn [<get_attribute_ $num>]() -> u32 {
89+
let out: u32;
90+
asm!(
91+
concat!("call ({}), _optix_get_attribute_", stringify!($num), ", ();"),
92+
out(reg32) out
93+
);
94+
out
95+
}
96+
}
97+
};
98+
() => {
99+
seq! {
100+
N in 0..=7 {
101+
get_attribute_fns! { N }
102+
}
103+
}
104+
};
105+
}
106+
107+
get_attribute_fns! {}
108+
109+
/// Retrieves an attribute set by the intersection program when reporting an intersection.
110+
///
111+
/// # Safety
112+
///
113+
/// The attribute must have been set by the intersection program.
114+
///
115+
/// # Panics
116+
///
117+
/// Panics if the idx is over `7`.
118+
pub unsafe fn get_attribute(idx: u8) -> u32 {
119+
match idx {
120+
0 => get_attribute_0(),
121+
1 => get_attribute_1(),
122+
2 => get_attribute_2(),
123+
3 => get_attribute_3(),
124+
4 => get_attribute_4(),
125+
5 => get_attribute_5(),
126+
6 => get_attribute_6(),
127+
7 => get_attribute_7(),
128+
_ => panic!("Invalid attribute index"),
129+
}
130+
}

crates/optix_device/src/lib.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77

88
extern crate alloc;
99

10+
mod hit;
11+
mod intersect;
1012
pub mod misc;
1113
pub mod payload;
14+
mod ray;
1215
pub mod sys;
1316
pub mod trace;
1417
pub mod util;
@@ -53,6 +56,48 @@ pub fn get_launch_dimensions() -> UVec3 {
5356
UVec3::new(x, y, z)
5457
}
5558

59+
/// Functions/items only available in raygen programs (`__raygen__`).
5660
pub mod raygen {
61+
#[doc(inline)]
5762
pub use crate::trace::*;
5863
}
64+
65+
/// Functions/items only available in miss programs (`__miss__`).
66+
pub mod intersection {
67+
#[doc(inline)]
68+
pub use crate::intersect::{get_attribute, report_intersection};
69+
#[doc(inline)]
70+
pub use crate::ray::*;
71+
}
72+
73+
/// Functions/items only available in anyhit programs (`__anyhit__`).
74+
pub mod anyhit {
75+
#[doc(inline)]
76+
pub use crate::hit::*;
77+
#[doc(inline)]
78+
pub use crate::intersect::{get_attribute, ignore_intersection, terminate_ray};
79+
#[doc(inline)]
80+
pub use crate::ray::*;
81+
}
82+
83+
/// Functions/items only available in closesthit programs (`__closesthit__`).
84+
pub mod closesthit {
85+
#[doc(inline)]
86+
pub use crate::hit::*;
87+
#[doc(inline)]
88+
pub use crate::intersect::get_attribute;
89+
#[doc(inline)]
90+
pub use crate::ray::{
91+
ray_flags, ray_time, ray_tmax, ray_tmin, ray_visibility_mask, ray_world_direction,
92+
ray_world_origin,
93+
};
94+
}
95+
96+
/// Functions/items only available in miss programs (`__miss__`).
97+
pub mod miss {
98+
#[doc(inline)]
99+
pub use crate::ray::{
100+
ray_flags, ray_time, ray_tmax, ray_tmin, ray_visibility_mask, ray_world_direction,
101+
ray_world_origin,
102+
};
103+
}

0 commit comments

Comments
 (0)