@@ -825,6 +825,19 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
825
825
} ;
826
826
( left, right)
827
827
}
828
+
829
+ /// Does the same thing as `.next()` but also returns the index of the item
830
+ /// relative to the start of the axis.
831
+ fn next_with_index ( & mut self ) -> Option < ( usize , * mut A ) > {
832
+ let index = self . index ;
833
+ self . next ( ) . map ( |ptr| ( index, ptr) )
834
+ }
835
+
836
+ /// Does the same thing as `.next_back()` but also returns the index of the
837
+ /// item relative to the start of the axis.
838
+ fn next_back_with_index ( & mut self ) -> Option < ( usize , * mut A ) > {
839
+ self . next_back ( ) . map ( |ptr| ( self . end , ptr) )
840
+ }
828
841
}
829
842
830
843
impl < A , D > Iterator for AxisIterCore < A , D >
@@ -1182,9 +1195,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
1182
1195
/// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information.
1183
1196
pub struct AxisChunksIter < ' a , A , D > {
1184
1197
iter : AxisIterCore < A , D > ,
1185
- n_whole_chunks : usize ,
1186
- /// Dimension of the last (and possibly uneven) chunk
1187
- last_dim : D ,
1198
+ /// Index of the partial chunk (the chunk smaller than the specified chunk
1199
+ /// size due to the axis length not being evenly divisible). If the axis
1200
+ /// length is evenly divisible by the chunk size, this index is larger than
1201
+ /// the maximum valid index.
1202
+ partial_chunk_index : usize ,
1203
+ /// Dimension of the partial chunk.
1204
+ partial_chunk_dim : D ,
1188
1205
life : PhantomData < & ' a A > ,
1189
1206
}
1190
1207
@@ -1193,10 +1210,10 @@ clone_bounds!(
1193
1210
AxisChunksIter [ ' a, A , D ] {
1194
1211
@copy {
1195
1212
life,
1196
- n_whole_chunks ,
1213
+ partial_chunk_index ,
1197
1214
}
1198
1215
iter,
1199
- last_dim ,
1216
+ partial_chunk_dim ,
1200
1217
}
1201
1218
) ;
1202
1219
@@ -1233,12 +1250,9 @@ fn chunk_iter_parts<A, D: Dimension>(
1233
1250
let mut inner_dim = v. dim . clone ( ) ;
1234
1251
inner_dim[ axis] = size;
1235
1252
1236
- let mut last_dim = v. dim ;
1237
- last_dim[ axis] = if chunk_remainder == 0 {
1238
- size
1239
- } else {
1240
- chunk_remainder
1241
- } ;
1253
+ let mut partial_chunk_dim = v. dim ;
1254
+ partial_chunk_dim[ axis] = chunk_remainder;
1255
+ let partial_chunk_index = n_whole_chunks;
1242
1256
1243
1257
let iter = AxisIterCore {
1244
1258
index : 0 ,
@@ -1249,16 +1263,16 @@ fn chunk_iter_parts<A, D: Dimension>(
1249
1263
ptr : v. ptr ,
1250
1264
} ;
1251
1265
1252
- ( iter, n_whole_chunks , last_dim )
1266
+ ( iter, partial_chunk_index , partial_chunk_dim )
1253
1267
}
1254
1268
1255
1269
impl < ' a , A , D : Dimension > AxisChunksIter < ' a , A , D > {
1256
1270
pub ( crate ) fn new ( v : ArrayView < ' a , A , D > , axis : Axis , size : usize ) -> Self {
1257
- let ( iter, n_whole_chunks , last_dim ) = chunk_iter_parts ( v, axis, size) ;
1271
+ let ( iter, partial_chunk_index , partial_chunk_dim ) = chunk_iter_parts ( v, axis, size) ;
1258
1272
AxisChunksIter {
1259
1273
iter,
1260
- n_whole_chunks ,
1261
- last_dim ,
1274
+ partial_chunk_index ,
1275
+ partial_chunk_dim ,
1262
1276
life : PhantomData ,
1263
1277
}
1264
1278
}
@@ -1270,30 +1284,49 @@ macro_rules! chunk_iter_impl {
1270
1284
where
1271
1285
D : Dimension ,
1272
1286
{
1273
- fn get_subview(
1274
- & self ,
1275
- iter_item: Option <* mut A >,
1276
- is_uneven: bool ,
1277
- ) -> Option <$array<' a, A , D >> {
1278
- iter_item. map( |ptr| {
1279
- if !is_uneven {
1280
- unsafe {
1281
- $array:: new_(
1282
- ptr,
1283
- self . iter. inner_dim. clone( ) ,
1284
- self . iter. inner_strides. clone( ) ,
1285
- )
1286
- }
1287
- } else {
1288
- unsafe {
1289
- $array:: new_(
1290
- ptr,
1291
- self . last_dim. clone( ) ,
1292
- self . iter. inner_strides. clone( ) ,
1293
- )
1294
- }
1287
+ fn get_subview( & self , index: usize , ptr: * mut A ) -> $array<' a, A , D > {
1288
+ if index != self . partial_chunk_index {
1289
+ unsafe {
1290
+ $array:: new_(
1291
+ ptr,
1292
+ self . iter. inner_dim. clone( ) ,
1293
+ self . iter. inner_strides. clone( ) ,
1294
+ )
1295
+ }
1296
+ } else {
1297
+ unsafe {
1298
+ $array:: new_(
1299
+ ptr,
1300
+ self . partial_chunk_dim. clone( ) ,
1301
+ self . iter. inner_strides. clone( ) ,
1302
+ )
1295
1303
}
1296
- } )
1304
+ }
1305
+ }
1306
+
1307
+ /// Splits the iterator at index, yielding two disjoint iterators.
1308
+ ///
1309
+ /// `index` is relative to the current state of the iterator (which is not
1310
+ /// necessarily the start of the axis).
1311
+ ///
1312
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
1313
+ /// length.
1314
+ pub fn split_at( self , index: usize ) -> ( Self , Self ) {
1315
+ let ( left, right) = self . iter. split_at( index) ;
1316
+ (
1317
+ Self {
1318
+ iter: left,
1319
+ partial_chunk_index: self . partial_chunk_index,
1320
+ partial_chunk_dim: self . partial_chunk_dim. clone( ) ,
1321
+ life: self . life,
1322
+ } ,
1323
+ Self {
1324
+ iter: right,
1325
+ partial_chunk_index: self . partial_chunk_index,
1326
+ partial_chunk_dim: self . partial_chunk_dim,
1327
+ life: self . life,
1328
+ } ,
1329
+ )
1297
1330
}
1298
1331
}
1299
1332
@@ -1304,9 +1337,9 @@ macro_rules! chunk_iter_impl {
1304
1337
type Item = $array<' a, A , D >;
1305
1338
1306
1339
fn next( & mut self ) -> Option <Self :: Item > {
1307
- let res = self . iter. next ( ) ;
1308
- let is_uneven = self . iter . index > self . n_whole_chunks ;
1309
- self . get_subview( res , is_uneven )
1340
+ self . iter
1341
+ . next_with_index ( )
1342
+ . map ( | ( index , ptr ) | self . get_subview( index , ptr ) )
1310
1343
}
1311
1344
1312
1345
fn size_hint( & self ) -> ( usize , Option <usize >) {
@@ -1319,9 +1352,9 @@ macro_rules! chunk_iter_impl {
1319
1352
D : Dimension ,
1320
1353
{
1321
1354
fn next_back( & mut self ) -> Option <Self :: Item > {
1322
- let is_uneven = self . iter. end > self . n_whole_chunks ;
1323
- let res = self . iter . next_back ( ) ;
1324
- self . get_subview( res , is_uneven )
1355
+ self . iter
1356
+ . next_back_with_index ( )
1357
+ . map ( | ( index , ptr ) | self . get_subview( index , ptr ) )
1325
1358
}
1326
1359
}
1327
1360
@@ -1342,18 +1375,19 @@ macro_rules! chunk_iter_impl {
1342
1375
/// for more information.
1343
1376
pub struct AxisChunksIterMut < ' a , A , D > {
1344
1377
iter : AxisIterCore < A , D > ,
1345
- n_whole_chunks : usize ,
1346
- last_dim : D ,
1378
+ partial_chunk_index : usize ,
1379
+ partial_chunk_dim : D ,
1347
1380
life : PhantomData < & ' a mut A > ,
1348
1381
}
1349
1382
1350
1383
impl < ' a , A , D : Dimension > AxisChunksIterMut < ' a , A , D > {
1351
1384
pub ( crate ) fn new ( v : ArrayViewMut < ' a , A , D > , axis : Axis , size : usize ) -> Self {
1352
- let ( iter, len, last_dim) = chunk_iter_parts ( v. into_view ( ) , axis, size) ;
1385
+ let ( iter, partial_chunk_index, partial_chunk_dim) =
1386
+ chunk_iter_parts ( v. into_view ( ) , axis, size) ;
1353
1387
AxisChunksIterMut {
1354
1388
iter,
1355
- n_whole_chunks : len ,
1356
- last_dim ,
1389
+ partial_chunk_index ,
1390
+ partial_chunk_dim ,
1357
1391
life : PhantomData ,
1358
1392
}
1359
1393
}
0 commit comments