Skip to content

Commit 9cea3c7

Browse files
committed
FEAT: Use Baseiter optimizations in some places where it's possible
1 parent d587b2e commit 9cea3c7

File tree

10 files changed

+58
-68
lines changed

10 files changed

+58
-68
lines changed

src/array_serde.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use alloc::vec::Vec;
1717
use crate::imp_prelude::*;
1818

1919
use super::arraytraits::ARRAY_FORMAT_VERSION;
20-
use super::Iter;
20+
use super::iter::Iter;
2121
use crate::IntoDimension;
2222

2323
/// Verifies that the version of the deserialized array matches the current

src/dimension/mod.rs

+4-31
Original file line numberDiff line numberDiff line change
@@ -728,36 +728,6 @@ where
728728
}
729729
}
730730

731-
/// Move the axis which has the smallest absolute stride and a length
732-
/// greater than one to be the last axis.
733-
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
734-
where
735-
D: Dimension,
736-
{
737-
debug_assert_eq!(dim.ndim(), strides.ndim());
738-
match dim.ndim() {
739-
0 | 1 => {}
740-
2 => {
741-
if dim[1] <= 1
742-
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
743-
{
744-
dim.slice_mut().swap(0, 1);
745-
strides.slice_mut().swap(0, 1);
746-
}
747-
}
748-
n => {
749-
if let Some(min_stride_axis) = (0..n)
750-
.filter(|&ax| dim[ax] > 1)
751-
.min_by_key(|&ax| (strides[ax] as isize).abs())
752-
{
753-
let last = n - 1;
754-
dim.slice_mut().swap(last, min_stride_axis);
755-
strides.slice_mut().swap(last, min_stride_axis);
756-
}
757-
}
758-
}
759-
}
760-
761731
/// Remove axes with length one, except never removing the last axis.
762732
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
763733
where
@@ -801,14 +771,17 @@ pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
801771
where
802772
D: Dimension,
803773
{
804-
debug_assert!(dim.ndim() > 1);
774+
if dim.ndim() <= 1 {
775+
return;
776+
}
805777
debug_assert_eq!(dim.ndim(), strides.ndim());
806778
// bubble sort axes
807779
let mut changed = true;
808780
while changed {
809781
changed = false;
810782
for i in 0..dim.ndim() - 1 {
811783
// make sure higher stride axes sort before.
784+
debug_assert!(strides.get_stride(Axis(i)) >= 0);
812785
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
813786
changed = true;
814787
dim.slice_mut().swap(i, i + 1);

src/impl_methods.rs

+5-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::argument_traits::AssignElem;
1919
use crate::dimension;
2020
use crate::dimension::IntoDimension;
2121
use crate::dimension::{
22-
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
22+
abs_index, axes_of, do_slice, merge_axes,
2323
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2424
};
2525
use crate::dimension::broadcast::co_broadcast;
@@ -316,7 +316,7 @@ where
316316
where
317317
S: Data,
318318
{
319-
IndexedIter::new(self.view().into_elements_base())
319+
IndexedIter::new(self.view().into_elements_base_keep_dims())
320320
}
321321

322322
/// Return an iterator of indexes and mutable references to the elements of the array.
@@ -329,7 +329,7 @@ where
329329
where
330330
S: DataMut,
331331
{
332-
IndexedIterMut::new(self.view_mut().into_elements_base())
332+
IndexedIterMut::new(self.view_mut().into_elements_base_keep_dims())
333333
}
334334

335335
/// Return a sliced view of the array.
@@ -2196,9 +2196,7 @@ where
21962196
if let Some(slc) = self.as_slice_memory_order() {
21972197
slc.iter().fold(init, f)
21982198
} else {
2199-
let mut v = self.view();
2200-
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
2201-
v.into_elements_base().fold(init, f)
2199+
self.view().into_elements_base_any_order().fold(init, f)
22022200
}
22032201
}
22042202

@@ -2312,9 +2310,7 @@ where
23122310
match self.try_as_slice_memory_order_mut() {
23132311
Ok(slc) => slc.iter_mut().for_each(f),
23142312
Err(arr) => {
2315-
let mut v = arr.view_mut();
2316-
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
2317-
v.into_elements_base().for_each(f);
2313+
arr.view_mut().into_elements_base_any_order().for_each(f);
23182314
}
23192315
}
23202316
}

src/impl_views/conversions.rs

+34-11
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use alloc::slice;
1010

1111
use crate::imp_prelude::*;
1212

13-
use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};
14-
15-
use crate::iter::{self, AxisIter, AxisIterMut};
13+
use crate::iter::{self, Iter, IterMut, AxisIter, AxisIterMut};
14+
use crate::iterators::base::{Baseiter, ElementsBase, ElementsBaseMut, OrderOption, PreserveOrder,
15+
ArbitraryOrder, NoOptimization};
1616
use crate::math_cell::MathCell;
1717
use crate::IndexLonger;
1818

@@ -140,14 +140,25 @@ impl<'a, A, D> ArrayView<'a, A, D>
140140
where
141141
D: Dimension,
142142
{
143+
/// Create a base iter fromt the view with the given order option
144+
#[inline]
145+
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
146+
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
147+
}
148+
149+
#[inline]
150+
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBase<'a, A, D> {
151+
ElementsBase::new::<NoOptimization>(self)
152+
}
153+
143154
#[inline]
144-
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
145-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
155+
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBase<'a, A, D> {
156+
ElementsBase::new::<PreserveOrder>(self)
146157
}
147158

148159
#[inline]
149-
pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> {
150-
ElementsBase::new(self)
160+
pub(crate) fn into_elements_base_any_order(self) -> ElementsBase<'a, A, D> {
161+
ElementsBase::new::<ArbitraryOrder>(self)
151162
}
152163

153164
pub(crate) fn into_iter_(self) -> Iter<'a, A, D> {
@@ -179,16 +190,28 @@ where
179190
unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) }
180191
}
181192

193+
/// Create a base iter fromt the view with the given order option
182194
#[inline]
183-
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
184-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
195+
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
196+
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
185197
}
186198

187199
#[inline]
188-
pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> {
189-
ElementsBaseMut::new(self)
200+
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBaseMut<'a, A, D> {
201+
ElementsBaseMut::new::<NoOptimization>(self)
190202
}
191203

204+
#[inline]
205+
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBaseMut<'a, A, D> {
206+
ElementsBaseMut::new::<PreserveOrder>(self)
207+
}
208+
209+
#[inline]
210+
pub(crate) fn into_elements_base_any_order(self) -> ElementsBaseMut<'a, A, D> {
211+
ElementsBaseMut::new::<ArbitraryOrder>(self)
212+
}
213+
214+
192215
/// Return the array’s data as a slice, if it is contiguous and in standard order.
193216
/// Otherwise return self in the Err branch of the result.
194217
pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> {

src/iterators/base.rs

+6-8
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ impl<A, D: Dimension> Baseiter<A, D> {
7171

7272
/// Return the iter strides
7373
pub(crate) fn raw_strides(&self) -> D { self.strides.clone() }
74-
}
7574

76-
impl<A, D: Dimension> Baseiter<A, D> {
7775
/// Creating a Baseiter is unsafe because shape and stride parameters need
7876
/// to be correct to avoid performing an unsafe pointer offset while
7977
/// iterating.
@@ -252,9 +250,9 @@ clone_bounds!(
252250
);
253251

254252
impl<'a, A, D: Dimension> ElementsBase<'a, A, D> {
255-
pub fn new(v: ArrayView<'a, A, D>) -> Self {
253+
pub fn new<F: OrderOption>(v: ArrayView<'a, A, D>) -> Self {
256254
ElementsBase {
257-
inner: v.into_base_iter(),
255+
inner: v.into_base_iter::<F>(),
258256
life: PhantomData,
259257
}
260258
}
@@ -338,7 +336,7 @@ where
338336
inner: if let Some(slc) = self_.to_slice() {
339337
ElementsRepr::Slice(slc.iter())
340338
} else {
341-
ElementsRepr::Counted(self_.into_elements_base())
339+
ElementsRepr::Counted(self_.into_elements_base_preserve_order())
342340
},
343341
}
344342
}
@@ -352,7 +350,7 @@ where
352350
IterMut {
353351
inner: match self_.try_into_slice() {
354352
Ok(x) => ElementsRepr::Slice(x.iter_mut()),
355-
Err(self_) => ElementsRepr::Counted(self_.into_elements_base()),
353+
Err(self_) => ElementsRepr::Counted(self_.into_elements_base_preserve_order()),
356354
},
357355
}
358356
}
@@ -397,9 +395,9 @@ pub(crate) struct ElementsBaseMut<'a, A, D> {
397395
}
398396

399397
impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> {
400-
pub fn new(v: ArrayViewMut<'a, A, D>) -> Self {
398+
pub fn new<F: OrderOption>(v: ArrayViewMut<'a, A, D>) -> Self {
401399
ElementsBaseMut {
402-
inner: v.into_base_iter(),
400+
inner: v.into_base_iter::<F>(),
403401
life: PhantomData,
404402
}
405403
}

src/iterators/chunks.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ where
7979
type IntoIter = ExactChunksIter<'a, A, D>;
8080
fn into_iter(self) -> Self::IntoIter {
8181
ExactChunksIter {
82-
iter: self.base.into_elements_base(),
82+
iter: self.base.into_elements_base_any_order(),
8383
chunk: self.chunk,
8484
inner_strides: self.inner_strides,
8585
}
@@ -169,7 +169,7 @@ where
169169
type IntoIter = ExactChunksIterMut<'a, A, D>;
170170
fn into_iter(self) -> Self::IntoIter {
171171
ExactChunksIterMut {
172-
iter: self.base.into_elements_base(),
172+
iter: self.base.into_elements_base_any_order(),
173173
chunk: self.chunk,
174174
inner_strides: self.inner_strides,
175175
}

src/iterators/lanes.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::marker::PhantomData;
33
use crate::imp_prelude::*;
44
use crate::{Layout, NdProducer};
55
use crate::iterators::Baseiter;
6+
use crate::iterators::base::NoOptimization;
67

78
impl_ndproducer! {
89
['a, A, D: Dimension]
@@ -83,7 +84,7 @@ where
8384
type IntoIter = LanesIter<'a, A, D>;
8485
fn into_iter(self) -> Self::IntoIter {
8586
LanesIter {
86-
iter: self.base.into_base_iter(),
87+
iter: self.base.into_base_iter::<NoOptimization>(),
8788
inner_len: self.inner_len,
8889
inner_stride: self.inner_stride,
8990
life: PhantomData,
@@ -134,7 +135,7 @@ where
134135
type IntoIter = LanesIterMut<'a, A, D>;
135136
fn into_iter(self) -> Self::IntoIter {
136137
LanesIterMut {
137-
iter: self.base.into_base_iter(),
138+
iter: self.base.into_base_iter::<NoOptimization>(),
138139
inner_len: self.inner_len,
139140
inner_stride: self.inner_stride,
140141
life: PhantomData,

src/iterators/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
mod macros;
1111

1212
mod axis;
13-
mod base;
13+
pub(crate) mod base;
1414
mod chunks;
1515
mod into_iter;
1616
pub mod iter;

src/iterators/windows.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ where
7777
type IntoIter = WindowsIter<'a, A, D>;
7878
fn into_iter(self) -> Self::IntoIter {
7979
WindowsIter {
80-
iter: self.base.into_elements_base(),
80+
iter: self.base.into_elements_base_preserve_order(),
8181
window: self.window,
8282
strides: self.strides,
8383
}

src/lib.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ pub use crate::slice::{
150150
MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim,
151151
};
152152

153-
use crate::iterators::Baseiter;
154-
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut};
153+
use crate::iterators::{ElementsBase, ElementsBaseMut};
155154

156155
pub use crate::arraytraits::AsArray;
157156
#[cfg(feature = "std")]

0 commit comments

Comments
 (0)