Skip to content

Commit cae10b9

Browse files
authored
Allow to nest structs of arrays (#41)
1 parent 404a443 commit cae10b9

File tree

12 files changed

+873
-249
lines changed

12 files changed

+873
-249
lines changed

README.md

+35
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,41 @@ for (name, smell, color) in soa_zip!(vec, [name, mut smell, color]) {
138138
}
139139
```
140140

141+
## Nested Struct of Arrays
142+
143+
In order to nest a struct of arrays inside another struct of arrays, one can use the `#[nested_soa]` attribute.
144+
145+
For example, the following code
146+
147+
```rust
148+
#[derive(StructOfArray)]
149+
pub struct Point {
150+
x: f32,
151+
y: f32,
152+
}
153+
#[derive(StructOfArray)]
154+
pub struct Particle {
155+
#[nested_soa]
156+
point: Point,
157+
mass: f32,
158+
}
159+
```
160+
161+
will generate structs that looks like this:
162+
163+
```rust
164+
pub struct PointVec {
165+
x: Vec<f32>,
166+
y: Vec<f32>,
167+
}
168+
pub struct ParticleVec {
169+
point: PointVec, // rather than Vec<Point>
170+
mass: Vec<f32>
171+
}
172+
```
173+
174+
All helper structs will be also nested, for example `PointSlice` will be nested in `ParticleSlice`.
175+
141176
## Documentation
142177

143178
Please see http://lumol.org/soa-derive/soa_derive_example/ for a small

example/lib.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
//! the code is generated by a single file:
33
//!
44
//! ```no_run
5-
//! #[macro_use]
65
//! extern crate soa_derive;
7-
//! # fn main() {
6+
//! # mod particle {
7+
//! #[macro_use]
8+
//! use soa_derive::StructOfArray;
89
//!
910
//! /// A basic Particle type
1011
//! #[derive(Debug, PartialEq, StructOfArray)]
11-
//! #[soa_derive = "Debug, PartialEq"]
12+
//! #[soa_derive(Debug, PartialEq)]
1213
//! pub struct Particle {
1314
//! /// Mass of the particle
1415
//! pub mass: f64,

soa-derive-internal/src/index.rs

+82-39
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,78 @@
11
use proc_macro2::TokenStream;
22
use quote::quote;
33

4-
use crate::input::Input;
4+
use crate::input::{Input, TokenStreamIterator};
55

66
pub fn derive(input: &Input) -> TokenStream {
77
let vec_name = &input.vec_name();
8-
let slice_name = &input.slice_name();
9-
let slice_mut_name = &input.slice_mut_name();
10-
let ref_name = &input.ref_name();
11-
let ref_mut_name = &input.ref_mut_name();
8+
let slice_name = Input::slice_name(&input.name);
9+
let slice_mut_name = Input::slice_mut_name(&input.name);
10+
let ref_name = Input::ref_name(&input.name);
11+
let ref_mut_name = Input::ref_mut_name(&input.name);
1212
let fields_names = input.fields.iter()
1313
.map(|field| field.ident.clone().unwrap())
1414
.collect::<Vec<_>>();
15-
let fields_names_1 = &fields_names;
16-
let fields_names_2 = &fields_names;
15+
16+
let get_unchecked = input.iter_fields().map(
17+
|(field_ident, _, is_nested)| {
18+
if is_nested {
19+
quote! {
20+
#field_ident: self.clone().get_unchecked(slice.#field_ident),
21+
}
22+
}
23+
else {
24+
quote! {
25+
#field_ident: slice.#field_ident.get_unchecked(self.clone()),
26+
}
27+
}
28+
},
29+
).concat();
30+
31+
let get_unchecked_mut = input.iter_fields().map(
32+
|(field_ident, _, is_nested)| {
33+
if is_nested {
34+
quote! {
35+
#field_ident: self.clone().get_unchecked_mut(slice.#field_ident),
36+
}
37+
}
38+
else {
39+
quote! {
40+
#field_ident: slice.#field_ident.get_unchecked_mut(self.clone()),
41+
}
42+
}
43+
},
44+
).concat();
45+
46+
let index = input.iter_fields().map(
47+
|(field_ident, _, is_nested)| {
48+
if is_nested {
49+
quote! {
50+
#field_ident: self.clone().index(slice.#field_ident),
51+
}
52+
}
53+
else {
54+
quote! {
55+
#field_ident: & slice.#field_ident[self.clone()],
56+
}
57+
}
58+
},
59+
).concat();
60+
61+
let index_mut = input.iter_fields().map(
62+
|(field_ident, _, is_nested)| {
63+
if is_nested {
64+
quote! {
65+
#field_ident: self.clone().index_mut(slice.#field_ident),
66+
}
67+
}
68+
else {
69+
quote! {
70+
#field_ident: &mut slice.#field_ident[self.clone()],
71+
}
72+
}
73+
},
74+
).concat();
75+
1776
let first_field_name = &fields_names[0];
1877

1978
quote!{
@@ -32,16 +91,12 @@ pub fn derive(input: &Input) -> TokenStream {
3291

3392
#[inline]
3493
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
35-
#ref_name {
36-
#(#fields_names_1: soa.#fields_names_2.get_unchecked(self),)*
37-
}
94+
self.get_unchecked(soa.as_slice())
3895
}
3996

4097
#[inline]
4198
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
42-
#ref_name {
43-
#(#fields_names_1: & soa.#fields_names_2[self],)*
44-
}
99+
self.index(soa.as_slice())
45100
}
46101
}
47102

@@ -59,16 +114,12 @@ pub fn derive(input: &Input) -> TokenStream {
59114

60115
#[inline]
61116
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
62-
#ref_mut_name {
63-
#(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self),)*
64-
}
117+
self.get_unchecked_mut(soa.as_mut_slice())
65118
}
66119

67120
#[inline]
68121
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
69-
#ref_mut_name {
70-
#(#fields_names_1: &mut soa.#fields_names_2[self],)*
71-
}
122+
self.index_mut(soa.as_mut_slice())
72123
}
73124
}
74125

@@ -89,16 +140,12 @@ pub fn derive(input: &Input) -> TokenStream {
89140

90141
#[inline]
91142
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
92-
#slice_name {
93-
#(#fields_names_1: soa.#fields_names_2.get_unchecked(self.clone()),)*
94-
}
143+
self.get_unchecked(soa.as_slice())
95144
}
96145

97146
#[inline]
98147
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
99-
#slice_name {
100-
#(#fields_names_1: & soa.#fields_names_2[self.clone()],)*
101-
}
148+
self.index(soa.as_slice())
102149
}
103150
}
104151

@@ -116,16 +163,12 @@ pub fn derive(input: &Input) -> TokenStream {
116163

117164
#[inline]
118165
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
119-
#slice_mut_name {
120-
#(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self.clone()),)*
121-
}
166+
self.get_unchecked_mut(soa.as_mut_slice())
122167
}
123168

124169
#[inline]
125170
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
126-
#slice_mut_name {
127-
#(#fields_names_1: &mut soa.#fields_names_2[self.clone()],)*
128-
}
171+
self.index_mut(soa.as_mut_slice())
129172
}
130173
}
131174

@@ -354,14 +397,14 @@ pub fn derive(input: &Input) -> TokenStream {
354397
#[inline]
355398
unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput {
356399
#ref_name {
357-
#(#fields_names_1: slice.#fields_names_2.get_unchecked(self),)*
400+
#get_unchecked
358401
}
359402
}
360403

361404
#[inline]
362405
fn index(self, slice: #slice_name<'a>) -> Self::RefOutput {
363406
#ref_name {
364-
#(#fields_names_1: & slice.#fields_names_2[self],)*
407+
#index
365408
}
366409
}
367410
}
@@ -381,14 +424,14 @@ pub fn derive(input: &Input) -> TokenStream {
381424
#[inline]
382425
unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput {
383426
#ref_mut_name {
384-
#(#fields_names_1: slice.#fields_names_2.get_unchecked_mut(self),)*
427+
#get_unchecked_mut
385428
}
386429
}
387430

388431
#[inline]
389432
fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput {
390433
#ref_mut_name {
391-
#(#fields_names_1: &mut slice.#fields_names_2[self],)*
434+
#index_mut
392435
}
393436
}
394437
}
@@ -411,14 +454,14 @@ pub fn derive(input: &Input) -> TokenStream {
411454
#[inline]
412455
unsafe fn get_unchecked(self, slice: #slice_name<'a>) -> Self::RefOutput {
413456
#slice_name {
414-
#(#fields_names_1: slice.#fields_names_2.get_unchecked(self.clone()),)*
457+
#get_unchecked
415458
}
416459
}
417460

418461
#[inline]
419462
fn index(self, slice: #slice_name<'a>) -> Self::RefOutput {
420463
#slice_name {
421-
#(#fields_names_1: & slice.#fields_names_2[self.clone()],)*
464+
#index
422465
}
423466
}
424467
}
@@ -438,14 +481,14 @@ pub fn derive(input: &Input) -> TokenStream {
438481
#[inline]
439482
unsafe fn get_unchecked_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput {
440483
#slice_mut_name {
441-
#(#fields_names_1: slice.#fields_names_2.get_unchecked_mut(self.clone()),)*
484+
#get_unchecked_mut
442485
}
443486
}
444487

445488
#[inline]
446489
fn index_mut(self, slice: #slice_mut_name<'a>) -> Self::MutOutput {
447490
#slice_mut_name {
448-
#(#fields_names_1: &mut slice.#fields_names_2[self.clone()],)*
491+
#index_mut
449492
}
450493
}
451494
}

0 commit comments

Comments
 (0)