Skip to content

Commit c916203

Browse files
authored
Merge pull request #691 from jturner314/split-chunks
Add .split_at() methods for AxisChunksIter/Mut
2 parents f2fd1dc + 4bee214 commit c916203

File tree

2 files changed

+124
-49
lines changed

2 files changed

+124
-49
lines changed

src/iterators/mod.rs

+83-49
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,19 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
825825
};
826826
(left, right)
827827
}
828+
829+
/// Does the same thing as `.next()` but also returns the index of the item
830+
/// relative to the start of the axis.
831+
fn next_with_index(&mut self) -> Option<(usize, *mut A)> {
832+
let index = self.index;
833+
self.next().map(|ptr| (index, ptr))
834+
}
835+
836+
/// Does the same thing as `.next_back()` but also returns the index of the
837+
/// item relative to the start of the axis.
838+
fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> {
839+
self.next_back().map(|ptr| (self.end, ptr))
840+
}
828841
}
829842

830843
impl<A, D> Iterator for AxisIterCore<A, D>
@@ -1182,9 +1195,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
11821195
/// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information.
11831196
pub struct AxisChunksIter<'a, A, D> {
11841197
iter: AxisIterCore<A, D>,
1185-
n_whole_chunks: usize,
1186-
/// Dimension of the last (and possibly uneven) chunk
1187-
last_dim: D,
1198+
/// Index of the partial chunk (the chunk smaller than the specified chunk
1199+
/// size due to the axis length not being evenly divisible). If the axis
1200+
/// length is evenly divisible by the chunk size, this index is larger than
1201+
/// the maximum valid index.
1202+
partial_chunk_index: usize,
1203+
/// Dimension of the partial chunk.
1204+
partial_chunk_dim: D,
11881205
life: PhantomData<&'a A>,
11891206
}
11901207

@@ -1193,10 +1210,10 @@ clone_bounds!(
11931210
AxisChunksIter['a, A, D] {
11941211
@copy {
11951212
life,
1196-
n_whole_chunks,
1213+
partial_chunk_index,
11971214
}
11981215
iter,
1199-
last_dim,
1216+
partial_chunk_dim,
12001217
}
12011218
);
12021219

@@ -1233,12 +1250,9 @@ fn chunk_iter_parts<A, D: Dimension>(
12331250
let mut inner_dim = v.dim.clone();
12341251
inner_dim[axis] = size;
12351252

1236-
let mut last_dim = v.dim;
1237-
last_dim[axis] = if chunk_remainder == 0 {
1238-
size
1239-
} else {
1240-
chunk_remainder
1241-
};
1253+
let mut partial_chunk_dim = v.dim;
1254+
partial_chunk_dim[axis] = chunk_remainder;
1255+
let partial_chunk_index = n_whole_chunks;
12421256

12431257
let iter = AxisIterCore {
12441258
index: 0,
@@ -1249,16 +1263,16 @@ fn chunk_iter_parts<A, D: Dimension>(
12491263
ptr: v.ptr,
12501264
};
12511265

1252-
(iter, n_whole_chunks, last_dim)
1266+
(iter, partial_chunk_index, partial_chunk_dim)
12531267
}
12541268

12551269
impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> {
12561270
pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self {
1257-
let (iter, n_whole_chunks, last_dim) = chunk_iter_parts(v, axis, size);
1271+
let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size);
12581272
AxisChunksIter {
12591273
iter,
1260-
n_whole_chunks,
1261-
last_dim,
1274+
partial_chunk_index,
1275+
partial_chunk_dim,
12621276
life: PhantomData,
12631277
}
12641278
}
@@ -1270,30 +1284,49 @@ macro_rules! chunk_iter_impl {
12701284
where
12711285
D: Dimension,
12721286
{
1273-
fn get_subview(
1274-
&self,
1275-
iter_item: Option<*mut A>,
1276-
is_uneven: bool,
1277-
) -> Option<$array<'a, A, D>> {
1278-
iter_item.map(|ptr| {
1279-
if !is_uneven {
1280-
unsafe {
1281-
$array::new_(
1282-
ptr,
1283-
self.iter.inner_dim.clone(),
1284-
self.iter.inner_strides.clone(),
1285-
)
1286-
}
1287-
} else {
1288-
unsafe {
1289-
$array::new_(
1290-
ptr,
1291-
self.last_dim.clone(),
1292-
self.iter.inner_strides.clone(),
1293-
)
1294-
}
1287+
fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> {
1288+
if index != self.partial_chunk_index {
1289+
unsafe {
1290+
$array::new_(
1291+
ptr,
1292+
self.iter.inner_dim.clone(),
1293+
self.iter.inner_strides.clone(),
1294+
)
1295+
}
1296+
} else {
1297+
unsafe {
1298+
$array::new_(
1299+
ptr,
1300+
self.partial_chunk_dim.clone(),
1301+
self.iter.inner_strides.clone(),
1302+
)
12951303
}
1296-
})
1304+
}
1305+
}
1306+
1307+
/// Splits the iterator at index, yielding two disjoint iterators.
1308+
///
1309+
/// `index` is relative to the current state of the iterator (which is not
1310+
/// necessarily the start of the axis).
1311+
///
1312+
/// **Panics** if `index` is strictly greater than the iterator's remaining
1313+
/// length.
1314+
pub fn split_at(self, index: usize) -> (Self, Self) {
1315+
let (left, right) = self.iter.split_at(index);
1316+
(
1317+
Self {
1318+
iter: left,
1319+
partial_chunk_index: self.partial_chunk_index,
1320+
partial_chunk_dim: self.partial_chunk_dim.clone(),
1321+
life: self.life,
1322+
},
1323+
Self {
1324+
iter: right,
1325+
partial_chunk_index: self.partial_chunk_index,
1326+
partial_chunk_dim: self.partial_chunk_dim,
1327+
life: self.life,
1328+
},
1329+
)
12971330
}
12981331
}
12991332

@@ -1304,9 +1337,9 @@ macro_rules! chunk_iter_impl {
13041337
type Item = $array<'a, A, D>;
13051338

13061339
fn next(&mut self) -> Option<Self::Item> {
1307-
let res = self.iter.next();
1308-
let is_uneven = self.iter.index > self.n_whole_chunks;
1309-
self.get_subview(res, is_uneven)
1340+
self.iter
1341+
.next_with_index()
1342+
.map(|(index, ptr)| self.get_subview(index, ptr))
13101343
}
13111344

13121345
fn size_hint(&self) -> (usize, Option<usize>) {
@@ -1319,9 +1352,9 @@ macro_rules! chunk_iter_impl {
13191352
D: Dimension,
13201353
{
13211354
fn next_back(&mut self) -> Option<Self::Item> {
1322-
let is_uneven = self.iter.end > self.n_whole_chunks;
1323-
let res = self.iter.next_back();
1324-
self.get_subview(res, is_uneven)
1355+
self.iter
1356+
.next_back_with_index()
1357+
.map(|(index, ptr)| self.get_subview(index, ptr))
13251358
}
13261359
}
13271360

@@ -1342,18 +1375,19 @@ macro_rules! chunk_iter_impl {
13421375
/// for more information.
13431376
pub struct AxisChunksIterMut<'a, A, D> {
13441377
iter: AxisIterCore<A, D>,
1345-
n_whole_chunks: usize,
1346-
last_dim: D,
1378+
partial_chunk_index: usize,
1379+
partial_chunk_dim: D,
13471380
life: PhantomData<&'a mut A>,
13481381
}
13491382

13501383
impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> {
13511384
pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self {
1352-
let (iter, len, last_dim) = chunk_iter_parts(v.into_view(), axis, size);
1385+
let (iter, partial_chunk_index, partial_chunk_dim) =
1386+
chunk_iter_parts(v.into_view(), axis, size);
13531387
AxisChunksIterMut {
13541388
iter,
1355-
n_whole_chunks: len,
1356-
last_dim,
1389+
partial_chunk_index,
1390+
partial_chunk_dim,
13571391
life: PhantomData,
13581392
}
13591393
}

tests/iterators.rs

+41
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@ use itertools::assert_equal;
1313
use itertools::{enumerate, rev};
1414
use std::iter::FromIterator;
1515

16+
macro_rules! assert_panics {
17+
($body:expr) => {
18+
if let Ok(v) = ::std::panic::catch_unwind(|| $body) {
19+
panic!("assertion failed: should_panic; \
20+
non-panicking result: {:?}", v);
21+
}
22+
};
23+
($body:expr, $($arg:tt)*) => {
24+
if let Ok(_) = ::std::panic::catch_unwind(|| $body) {
25+
panic!($($arg)*);
26+
}
27+
};
28+
}
29+
1630
#[test]
1731
fn double_ended() {
1832
let a = ArcArray::linspace(0., 7., 8);
@@ -585,6 +599,33 @@ fn axis_chunks_iter_zero_axis_len() {
585599
assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none());
586600
}
587601

602+
#[test]
603+
fn axis_chunks_iter_split_at() {
604+
let mut a = Array2::<usize>::zeros((11, 3));
605+
a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i);
606+
for source in &[
607+
a.slice(s![..0, ..]),
608+
a.slice(s![..1, ..]),
609+
a.slice(s![..5, ..]),
610+
a.slice(s![..10, ..]),
611+
a.slice(s![..11, ..]),
612+
a.slice(s![.., ..0]),
613+
] {
614+
let chunks_iter = source.axis_chunks_iter(Axis(0), 5);
615+
let all_chunks: Vec<_> = chunks_iter.clone().collect();
616+
let n_chunks = chunks_iter.len();
617+
assert_eq!(n_chunks, all_chunks.len());
618+
for index in 0..=n_chunks {
619+
let (left, right) = chunks_iter.clone().split_at(index);
620+
assert_eq!(&all_chunks[..index], &left.collect::<Vec<_>>()[..]);
621+
assert_eq!(&all_chunks[index..], &right.collect::<Vec<_>>()[..]);
622+
}
623+
assert_panics!({
624+
chunks_iter.split_at(n_chunks + 1);
625+
});
626+
}
627+
}
628+
588629
#[test]
589630
fn axis_chunks_iter_mut() {
590631
let a = ArcArray::from_iter(0..24);

0 commit comments

Comments
 (0)