Skip to content

Commit 74db557

Browse files
committed
Remove duplication in from_shape_ptr and split_at
1 parent 95d21b2 commit 74db557

File tree

2 files changed

+71
-78
lines changed

2 files changed

+71
-78
lines changed

src/impl_raw_views.rs

+48-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use dimension;
1+
use dimension::{self, stride_offset};
22
use imp_prelude::*;
33
use {is_aligned, StrideShape};
44

@@ -87,6 +87,33 @@ where
8787
pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> {
8888
ArrayView::new_(self.ptr, self.dim, self.strides)
8989
}
90+
91+
/// Split the array view along `axis` and return one array pointer strictly
92+
/// before the split and one array pointer after the split.
93+
///
94+
/// **Panics** if `axis` or `index` is out of bounds.
95+
pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
96+
assert!(index <= self.len_of(axis));
97+
let left_ptr = self.ptr;
98+
let right_ptr = if index == self.len_of(axis) {
99+
self.ptr
100+
} else {
101+
let offset = stride_offset(index, self.strides.axis(axis));
102+
// The `.offset()` is safe due to the guarantees of `DataRaw`.
103+
unsafe { self.ptr.offset(offset) }
104+
};
105+
106+
let mut dim_left = self.dim.clone();
107+
dim_left.set_axis(axis, index);
108+
let left = unsafe { Self::new_(left_ptr, dim_left, self.strides.clone()) };
109+
110+
let mut dim_right = self.dim;
111+
let right_len = dim_right.axis(axis) - index;
112+
dim_right.set_axis(axis, right_len);
113+
let right = unsafe { Self::new_(right_ptr, dim_right, self.strides) };
114+
115+
(left, right)
116+
}
90117
}
91118

92119
impl<A, D> RawArrayViewMut<A, D>
@@ -155,6 +182,12 @@ where
155182
RawArrayViewMut::new_(ptr, dim, strides)
156183
}
157184

185+
/// Converts to a non-mutable `RawArrayView`.
186+
#[inline]
187+
pub(crate) fn into_raw_view(self) -> RawArrayView<A, D> {
188+
unsafe { RawArrayView::new_(self.ptr, self.dim, self.strides) }
189+
}
190+
158191
/// Return a read-only view of the array
159192
///
160193
/// **Warning** from a safety standpoint, this is equivalent to
@@ -194,4 +227,18 @@ where
194227
pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> {
195228
ArrayViewMut::new_(self.ptr, self.dim, self.strides)
196229
}
230+
231+
/// Split the array view along `axis` and return one array pointer strictly
232+
/// before the split and one array pointer after the split.
233+
///
234+
/// **Panics** if `axis` or `index` is out of bounds.
235+
pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
236+
let (left, right) = self.into_raw_view().split_at(axis, index);
237+
unsafe {
238+
(
239+
Self::new_(left.ptr, left.dim, left.strides),
240+
Self::new_(right.ptr, right.dim, right.strides),
241+
)
242+
}
243+
}
197244
}

src/impl_views.rs

+23-77
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
use std::slice;
1010

1111
use imp_prelude::*;
12-
use dimension::{self, stride_offset};
12+
use dimension;
1313
use error::ShapeError;
1414
use arraytraits::array_out_of_bounds;
1515
use {is_aligned, NdIndex, StrideShape};
@@ -111,15 +111,7 @@ impl<'a, A, D> ArrayView<'a, A, D>
111111
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *const A) -> Self
112112
where Sh: Into<StrideShape<D>>
113113
{
114-
let shape = shape.into();
115-
let dim = shape.dim;
116-
let strides = shape.strides;
117-
if cfg!(debug_assertions) {
118-
assert!(!ptr.is_null(), "The pointer must be non-null.");
119-
assert!(is_aligned(ptr), "The pointer must be aligned.");
120-
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
121-
}
122-
ArrayView::new_(ptr, dim, strides)
114+
RawArrayView::from_shape_ptr(shape, ptr).deref_into_view()
123115
}
124116

125117
/// Convert the view into an `ArrayView<'b, A, D>` where `'b` is a lifetime
@@ -141,35 +133,11 @@ impl<'a, A, D> ArrayView<'a, A, D>
141133
/// an array with shape 3 × 5 × 5.
142134
///
143135
/// <img src="https://rust-ndarray.github.io/ndarray/images/split_at.svg" width="300px" height="271px">
144-
pub fn split_at(self, axis: Axis, index: Ix)
145-
-> (Self, Self)
146-
{
147-
// NOTE: Keep this in sync with the ArrayViewMut version
148-
assert!(index <= self.len_of(axis));
149-
let left_ptr = self.ptr;
150-
let right_ptr = if index == self.len_of(axis) {
151-
self.ptr
152-
} else {
153-
let offset = stride_offset(index, self.strides.axis(axis));
154-
unsafe {
155-
self.ptr.offset(offset)
156-
}
157-
};
158-
159-
let mut dim_left = self.dim.clone();
160-
dim_left.set_axis(axis, index);
161-
let left = unsafe {
162-
Self::new_(left_ptr, dim_left, self.strides.clone())
163-
};
164-
165-
let mut dim_right = self.dim;
166-
let right_len = dim_right.axis(axis) - index;
167-
dim_right.set_axis(axis, right_len);
168-
let right = unsafe {
169-
Self::new_(right_ptr, dim_right, self.strides)
170-
};
171-
172-
(left, right)
136+
pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
137+
unsafe {
138+
let (left, right) = self.into_raw_view().split_at(axis, index);
139+
(left.deref_into_view(), right.deref_into_view())
140+
}
173141
}
174142

175143
/// Return the array’s data as a slice, if it is contiguous and in standard order.
@@ -183,6 +151,11 @@ impl<'a, A, D> ArrayView<'a, A, D>
183151
None
184152
}
185153
}
154+
155+
/// Converts to a raw array view.
156+
pub(crate) fn into_raw_view(self) -> RawArrayView<A, D> {
157+
unsafe { RawArrayView::new_(self.ptr, self.dim, self.strides) }
158+
}
186159
}
187160

188161

@@ -408,15 +381,7 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
408381
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *mut A) -> Self
409382
where Sh: Into<StrideShape<D>>
410383
{
411-
let shape = shape.into();
412-
let dim = shape.dim;
413-
let strides = shape.strides;
414-
if cfg!(debug_assertions) {
415-
assert!(!ptr.is_null(), "The pointer must be non-null.");
416-
assert!(is_aligned(ptr), "The pointer must be aligned.");
417-
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
418-
}
419-
ArrayViewMut::new_(ptr, dim, strides)
384+
RawArrayViewMut::from_shape_ptr(shape, ptr).deref_into_view_mut()
420385
}
421386

422387
/// Convert the view into an `ArrayViewMut<'b, A, D>` where `'b` is a lifetime
@@ -433,35 +398,11 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
433398
/// before the split and one mutable view after the split.
434399
///
435400
/// **Panics** if `axis` or `index` is out of bounds.
436-
pub fn split_at(self, axis: Axis, index: Ix)
437-
-> (Self, Self)
438-
{
439-
// NOTE: Keep this in sync with the ArrayView version
440-
assert!(index <= self.len_of(axis));
441-
let left_ptr = self.ptr;
442-
let right_ptr = if index == self.len_of(axis) {
443-
self.ptr
444-
} else {
445-
let offset = stride_offset(index, self.strides.axis(axis));
446-
unsafe {
447-
self.ptr.offset(offset)
448-
}
449-
};
450-
451-
let mut dim_left = self.dim.clone();
452-
dim_left.set_axis(axis, index);
453-
let left = unsafe {
454-
Self::new_(left_ptr, dim_left, self.strides.clone())
455-
};
456-
457-
let mut dim_right = self.dim;
458-
let right_len = dim_right.axis(axis) - index;
459-
dim_right.set_axis(axis, right_len);
460-
let right = unsafe {
461-
Self::new_(right_ptr, dim_right, self.strides)
462-
};
463-
464-
(left, right)
401+
pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
402+
unsafe {
403+
let (left, right) = self.into_raw_view_mut().split_at(axis, index);
404+
(left.deref_into_view_mut(), right.deref_into_view_mut())
405+
}
465406
}
466407

467408
/// Return the array’s data as a slice, if it is contiguous and in standard order.
@@ -605,6 +546,11 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
605546
}
606547
}
607548

549+
/// Converts to a mutable raw array view.
550+
pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut<A, D> {
551+
unsafe { RawArrayViewMut::new_(self.ptr, self.dim, self.strides) }
552+
}
553+
608554
#[inline]
609555
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
610556
unsafe {

0 commit comments

Comments
 (0)