Skip to content

Commit 269865c

Browse files
committed
Clarify behavior of AxisChunksIter/Mut
IMO, it's easier to understand and work with the implementation of these iterators using `partial_chunk_index` and `partial_chunk_dim` than `n_whole_chunks` and `last_dim`.
1 parent 1443df8 commit 269865c

File tree

1 file changed

+44
-49
lines changed

1 file changed

+44
-49
lines changed

src/iterators/mod.rs

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,9 +1182,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
11821182
/// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information.
11831183
pub struct AxisChunksIter<'a, A, D> {
11841184
iter: AxisIterCore<A, D>,
1185-
n_whole_chunks: usize,
1186-
/// Dimension of the last (and possibly uneven) chunk
1187-
last_dim: D,
1185+
/// Index of the partial chunk (the chunk smaller than the specified chunk
1186+
/// size due to the axis length not being evenly divisible). If the axis
1187+
/// length is evenly divisible by the chunk size, this index is larger than
1188+
/// the maximum valid index.
1189+
partial_chunk_index: usize,
1190+
/// Dimension of the partial chunk.
1191+
partial_chunk_dim: D,
11881192
life: PhantomData<&'a A>,
11891193
}
11901194

@@ -1193,10 +1197,10 @@ clone_bounds!(
11931197
AxisChunksIter['a, A, D] {
11941198
@copy {
11951199
life,
1196-
n_whole_chunks,
1200+
partial_chunk_index,
11971201
}
11981202
iter,
1199-
last_dim,
1203+
partial_chunk_dim,
12001204
}
12011205
);
12021206

@@ -1233,12 +1237,9 @@ fn chunk_iter_parts<A, D: Dimension>(
12331237
let mut inner_dim = v.dim.clone();
12341238
inner_dim[axis] = size;
12351239

1236-
let mut last_dim = v.dim;
1237-
last_dim[axis] = if chunk_remainder == 0 {
1238-
size
1239-
} else {
1240-
chunk_remainder
1241-
};
1240+
let mut partial_chunk_dim = v.dim;
1241+
partial_chunk_dim[axis] = chunk_remainder;
1242+
let partial_chunk_index = n_whole_chunks;
12421243

12431244
let iter = AxisIterCore {
12441245
index: 0,
@@ -1249,16 +1250,16 @@ fn chunk_iter_parts<A, D: Dimension>(
12491250
ptr: v.ptr,
12501251
};
12511252

1252-
(iter, n_whole_chunks, last_dim)
1253+
(iter, partial_chunk_index, partial_chunk_dim)
12531254
}
12541255

12551256
impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> {
12561257
pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self {
1257-
let (iter, n_whole_chunks, last_dim) = chunk_iter_parts(v, axis, size);
1258+
let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size);
12581259
AxisChunksIter {
12591260
iter,
1260-
n_whole_chunks,
1261-
last_dim,
1261+
partial_chunk_index,
1262+
partial_chunk_dim,
12621263
life: PhantomData,
12631264
}
12641265
}
@@ -1270,30 +1271,24 @@ macro_rules! chunk_iter_impl {
12701271
where
12711272
D: Dimension,
12721273
{
1273-
fn get_subview(
1274-
&self,
1275-
iter_item: Option<*mut A>,
1276-
is_uneven: bool,
1277-
) -> Option<$array<'a, A, D>> {
1278-
iter_item.map(|ptr| {
1279-
if !is_uneven {
1280-
unsafe {
1281-
$array::new_(
1282-
ptr,
1283-
self.iter.inner_dim.clone(),
1284-
self.iter.inner_strides.clone(),
1285-
)
1286-
}
1287-
} else {
1288-
unsafe {
1289-
$array::new_(
1290-
ptr,
1291-
self.last_dim.clone(),
1292-
self.iter.inner_strides.clone(),
1293-
)
1294-
}
1274+
fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> {
1275+
if index != self.partial_chunk_index {
1276+
unsafe {
1277+
$array::new_(
1278+
ptr,
1279+
self.iter.inner_dim.clone(),
1280+
self.iter.inner_strides.clone(),
1281+
)
12951282
}
1296-
})
1283+
} else {
1284+
unsafe {
1285+
$array::new_(
1286+
ptr,
1287+
self.partial_chunk_dim.clone(),
1288+
self.iter.inner_strides.clone(),
1289+
)
1290+
}
1291+
}
12971292
}
12981293
}
12991294

@@ -1304,9 +1299,8 @@ macro_rules! chunk_iter_impl {
13041299
type Item = $array<'a, A, D>;
13051300

13061301
fn next(&mut self) -> Option<Self::Item> {
1307-
let res = self.iter.next();
1308-
let is_uneven = self.iter.index > self.n_whole_chunks;
1309-
self.get_subview(res, is_uneven)
1302+
let index = self.iter.index;
1303+
self.iter.next().map(|ptr| self.get_subview(index, ptr))
13101304
}
13111305

13121306
fn size_hint(&self) -> (usize, Option<usize>) {
@@ -1319,9 +1313,9 @@ macro_rules! chunk_iter_impl {
13191313
D: Dimension,
13201314
{
13211315
fn next_back(&mut self) -> Option<Self::Item> {
1322-
let is_uneven = self.iter.end > self.n_whole_chunks;
1323-
let res = self.iter.next_back();
1324-
self.get_subview(res, is_uneven)
1316+
self.iter
1317+
.next_back()
1318+
.map(|ptr| self.get_subview(self.iter.end, ptr))
13251319
}
13261320
}
13271321

@@ -1342,18 +1336,19 @@ macro_rules! chunk_iter_impl {
13421336
/// for more information.
13431337
pub struct AxisChunksIterMut<'a, A, D> {
13441338
iter: AxisIterCore<A, D>,
1345-
n_whole_chunks: usize,
1346-
last_dim: D,
1339+
partial_chunk_index: usize,
1340+
partial_chunk_dim: D,
13471341
life: PhantomData<&'a mut A>,
13481342
}
13491343

13501344
impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> {
13511345
pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self {
1352-
let (iter, len, last_dim) = chunk_iter_parts(v.into_view(), axis, size);
1346+
let (iter, partial_chunk_index, partial_chunk_dim) =
1347+
chunk_iter_parts(v.into_view(), axis, size);
13531348
AxisChunksIterMut {
13541349
iter,
1355-
n_whole_chunks: len,
1356-
last_dim,
1350+
partial_chunk_index,
1351+
partial_chunk_dim,
13571352
life: PhantomData,
13581353
}
13591354
}

0 commit comments

Comments
 (0)