Skip to content

Commit 70acd08

Browse files
committed
Feat: add OptiX (hardware rt) support to toy path tracer
1 parent afcc557 commit 70acd08

File tree

23 files changed

+956
-46
lines changed

23 files changed

+956
-46
lines changed

crates/cuda_builder/src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ pub struct CudaBuilder {
138138
pub override_libm: bool,
139139
/// Whether to generate any debug info and what level of info to generate.
140140
pub debug: DebugInfo,
141+
/// Additional arguments passed to cargo during `cargo build`.
142+
pub build_args: Vec<String>,
141143
}
142144

143145
impl CudaBuilder {
@@ -158,9 +160,17 @@ impl CudaBuilder {
158160
optix: false,
159161
override_libm: true,
160162
debug: DebugInfo::None,
163+
build_args: vec![],
161164
}
162165
}
163166

167+
/// Additional arguments passed to cargo during `cargo build`.
168+
pub fn build_args(mut self, args: &[impl AsRef<str>]) -> Self {
169+
self.build_args
170+
.extend(args.iter().map(|s| s.as_ref().to_owned()));
171+
self
172+
}
173+
164174
/// Whether to generate any debug info and what level of info to generate.
165175
pub fn debug(mut self, debug: DebugInfo) -> Self {
166176
self.debug = debug;
@@ -433,6 +443,8 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result<PathBuf, CudaBuilderError> {
433443
target,
434444
]);
435445

446+
cargo.args(&builder.build_args);
447+
436448
cargo.env(dylib_path_envvar(), new_path);
437449

438450
if builder.release {

crates/cust/src/memory/device/device_slice.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,15 @@ fn slice_end_index_overflow_fail() -> ! {
464464
panic!("attempted to index slice up to maximum usize");
465465
}
466466

467+
impl<T: DeviceCopy> DeviceSliceIndex<T> for usize {
468+
unsafe fn get_unchecked(self, slice: &DeviceSlice<T>) -> DeviceSlice<T> {
469+
(self..self + 1).get_unchecked(slice)
470+
}
471+
fn index(self, slice: &DeviceSlice<T>) -> DeviceSlice<T> {
472+
slice.index(self..self + 1)
473+
}
474+
}
475+
467476
impl<T: DeviceCopy> DeviceSliceIndex<T> for Range<usize> {
468477
unsafe fn get_unchecked(self, slice: &DeviceSlice<T>) -> DeviceSlice<T> {
469478
DeviceSlice::from_raw_parts(slice.as_device_ptr().add(self.start), self.end - self.start)

crates/optix/examples/rust/ex04_mesh_gpu/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
use cuda_std::kernel;
1010
use optix_device::{
11-
get_launch_index,
11+
closesthit, get_launch_index,
1212
glam::*,
1313
misc::*,
1414
payload,
@@ -65,7 +65,7 @@ fn random_color(i: u32) -> Vec4 {
6565

6666
#[kernel]
6767
pub unsafe fn __closesthit__radiance() {
68-
let prim_id = primitive_index();
68+
let prim_id = closesthit::primitive_index();
6969
let buf = get_color_buf();
7070
*buf = random_color(prim_id);
7171
}

crates/optix/src/acceleration.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ pub enum AccelEmitDesc {
877877
/// Used to communicate bounds info to and from OptiX for bounding custom primitives
878878
/// and instances
879879
#[repr(C)]
880-
#[derive(DeviceCopy, Copy, Clone)]
880+
#[derive(Debug, DeviceCopy, Copy, Clone)]
881881
pub struct Aabb {
882882
min: Vector3<f32>,
883883
max: Vector3<f32>,

crates/optix_device/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ cuda_std = { version = "0.2", path = "../cuda_std" }
1010
glam = { version = "0.20", features=["cuda", "libm"], default-features=false }
1111
paste = "1.0.6"
1212
seq-macro = "0.3.0"
13+
cust_core = { version = "0.1", path = "../cust_core" }

crates/optix_device/src/intersect.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@ use cuda_std::gpu_only;
22
use paste::paste;
33
use seq_macro::seq;
44

5+
#[gpu_only]
6+
pub fn primitive_index() -> u32 {
7+
let mut idx: u32;
8+
unsafe {
9+
asm!("call ({}), _optix_read_primitive_idx, ();", out(reg32) idx);
10+
}
11+
idx
12+
}
13+
514
pub trait IntersectionPayload {
615
fn report_intersection(hit_t: f32, hit_kind: u8, payload: Self) -> bool;
716
}
@@ -22,11 +31,15 @@ macro_rules! impl_intersection_payload {
2231
let out: u32;
2332
unsafe {
2433
asm!(
34+
"{{",
35+
".reg .f32 %f0;",
36+
"mov.f32 %f0, {hit_t};",
2537
concat!("call ({}), _optix_report_intersection_", stringify!($num)),
26-
concat!(", ({}, {}, ", #(concat!("{", stringify!(p~P), "},"),)* ");"),
38+
concat!(", (%f0, {}", #(concat!(", {", stringify!(p~P), "}"),)* ");"),
39+
"}}",
2740
out(reg32) out,
28-
in(reg32) hit_t,
2941
in(reg32) hit_kind,
42+
hit_t = in(reg32) hit_t,
3043
#(
3144
p~P = in(reg32) p~P,
3245
)*

crates/optix_device/src/lib.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@ pub mod payload;
1414
mod ray;
1515
pub mod sys;
1616
pub mod trace;
17+
pub mod transform;
1718
pub mod util;
1819

1920
use cuda_std::*;
2021
pub use glam;
2122
use glam::UVec3;
2223

24+
pub use misc::*;
25+
2326
extern "C" {
2427
pub fn vprintf(format: *const u8, valist: *const core::ffi::c_void) -> i32;
2528
}
@@ -65,7 +68,7 @@ pub mod raygen {
6568
/// Functions/items only available in miss programs (`__miss__`).
6669
pub mod intersection {
6770
#[doc(inline)]
68-
pub use crate::intersect::{get_attribute, report_intersection};
71+
pub use crate::intersect::{get_attribute, primitive_index, report_intersection};
6972
#[doc(inline)]
7073
pub use crate::ray::*;
7174
}
@@ -75,7 +78,9 @@ pub mod anyhit {
7578
#[doc(inline)]
7679
pub use crate::hit::*;
7780
#[doc(inline)]
78-
pub use crate::intersect::{get_attribute, ignore_intersection, terminate_ray};
81+
pub use crate::intersect::{
82+
get_attribute, ignore_intersection, primitive_index, terminate_ray,
83+
};
7984
#[doc(inline)]
8085
pub use crate::ray::*;
8186
}
@@ -85,7 +90,7 @@ pub mod closesthit {
8590
#[doc(inline)]
8691
pub use crate::hit::*;
8792
#[doc(inline)]
88-
pub use crate::intersect::get_attribute;
93+
pub use crate::intersect::{get_attribute, primitive_index};
8994
#[doc(inline)]
9095
pub use crate::ray::{
9196
ray_flags, ray_time, ray_tmax, ray_tmin, ray_visibility_mask, ray_world_direction,

crates/optix_device/src/misc.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
use cuda_std::gpu_only;
22

3+
/// Retrieves the data past the SBT header for this particular program
4+
///
5+
/// # Safety
6+
///
7+
/// The type requested must match with what is stored in the SBT.
38
#[gpu_only]
4-
pub fn primitive_index() -> u32 {
5-
let mut idx: u32;
6-
unsafe {
7-
asm!("call ({}), _optix_read_primitive_idx, ();", out(reg32) idx);
8-
}
9-
idx
9+
pub unsafe fn sbt_data<T>() -> &'static T {
10+
let ptr: *const T;
11+
asm!("call ({}), _optix_get_sbt_data_ptr_64, ();", out(reg64) ptr);
12+
&*ptr
1013
}

crates/optix_device/src/ray.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,14 @@ pub fn ray_tmin() -> f32 {
4646
let x: f32;
4747

4848
unsafe {
49-
asm!("call ({}), _optix_get_ray_tmin, ();", out(reg32) x);
49+
asm!(
50+
"{{",
51+
".reg .f32 %f<1>;",
52+
"call (%f0), _optix_get_ray_tmin, ();",
53+
"mov.f32 {}, %f0;",
54+
"}}",
55+
out(reg32) x
56+
);
5057
}
5158

5259
x
@@ -58,7 +65,14 @@ pub fn ray_tmax() -> f32 {
5865
let x: f32;
5966

6067
unsafe {
61-
asm!("call ({}), _optix_get_ray_tmax, ();", out(reg32) x);
68+
asm!(
69+
"{{",
70+
".reg .f32 %f<1>;",
71+
"call (%f0), _optix_get_ray_tmax, ();",
72+
"mov.f32 {}, %f0;",
73+
"}}",
74+
out(reg32) x
75+
);
6276
}
6377

6478
x

crates/optix_device/src/trace.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use crate::sys::*;
2+
use cust_core::DeviceCopy;
23
use glam::Vec3;
34
use paste::paste;
45
use seq_macro::seq;
56

67
/// An opaque handle to a traversable BVH.
78
#[repr(transparent)]
8-
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
9+
#[derive(Clone, Copy, Debug, PartialEq, Eq, DeviceCopy)]
910
pub struct TraversableHandle(pub(crate) u64);
1011

1112
impl TraversableHandle {

0 commit comments

Comments
 (0)