Skip to content

Commit 021676e

Browse files
authored
Merge pull request #119 from kngwyu/slice_to_pyarray
Fix to_pyarray for not-contiguous array
2 parents 312f00a + 6e8425d commit 021676e

File tree

5 files changed

+149
-23
lines changed

5 files changed

+149
-23
lines changed

src/array.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ impl<T, D> PyArray<T, D> {
296296
}
297297

298298
/// Returns the pointer to the first element of the inner array.
299-
unsafe fn data(&self) -> *mut T {
299+
pub(crate) unsafe fn data(&self) -> *mut T {
300300
let ptr = self.as_array_ptr();
301301
(*ptr).data as *mut T
302302
}
@@ -355,7 +355,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
355355
pub(crate) unsafe fn new_<'py, ID>(
356356
py: Python<'py>,
357357
dims: ID,
358-
strides: *mut npy_intp,
358+
strides: *const npy_intp,
359359
flag: c_int,
360360
) -> &'py Self
361361
where
@@ -367,7 +367,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
367367
dims.ndim_cint(),
368368
dims.as_dims_ptr(),
369369
T::typenum_default(),
370-
strides, // strides
370+
strides as *mut _, // strides
371371
ptr::null_mut(), // data
372372
0, // itemsize
373373
flag, // flag

src/convert.rs

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ where
5858
type Item = A;
5959
type Dim = D;
6060
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
61-
let strides = npy_strides(&self);
61+
let strides = self.npy_strides();
6262
let dim = self.raw_dim();
6363
let boxed = self.into_raw_vec().into_boxed_slice();
6464
unsafe { PyArray::from_boxed_slice(py, dim, strides.as_ptr(), boxed) }
@@ -71,12 +71,27 @@ where
7171
/// elements there**.
7272
/// # Example
7373
/// ```
74-
/// # fn main() {
7574
/// use numpy::{PyArray, ToPyArray};
7675
/// let gil = pyo3::Python::acquire_gil();
7776
/// let py_array = vec![1, 2, 3].to_pyarray(gil.python());
7877
/// assert_eq!(py_array.as_slice().unwrap(), &[1, 2, 3]);
79-
/// # }
78+
/// ```
79+
///
80+
/// This method converts a not-contiguous array to C-order contiguous array.
81+
/// # Example
82+
/// ```
83+
/// use numpy::{PyArray, ToPyArray};
84+
/// use ndarray::{arr3, s};
85+
/// let gil = pyo3::Python::acquire_gil();
86+
/// let py = gil.python();
87+
/// let a = arr3(&[[[ 1, 2, 3], [ 4, 5, 6]],
88+
/// [[ 7, 8, 9], [10, 11, 12]]]);
89+
/// let slice = a.slice(s![.., 0..1, ..]);
90+
/// let sliced = arr3(&[[[ 1, 2, 3]],
91+
/// [[ 7, 8, 9]]]);
92+
/// let py_slice = slice.to_pyarray(py);
93+
/// assert_eq!(py_slice.as_array(), sliced);
94+
/// pyo3::py_run!(py, py_slice, "assert py_slice.flags['C_CONTIGUOUS']");
8095
/// ```
8196
pub trait ToPyArray {
8297
type Item: TypeNum;
@@ -102,26 +117,107 @@ where
102117
type Dim = D;
103118
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
104119
let len = self.len();
105-
let mut strides = npy_strides(self);
106-
unsafe {
107-
let array = PyArray::new_(py, self.raw_dim(), strides.as_mut_ptr() as *mut npy_intp, 0);
108-
array.copy_ptr(self.as_ptr(), len);
109-
array
120+
if let Some(order) = self.order() {
121+
// if the array is contiguous, copy it by `copy_ptr`.
122+
let strides = self.npy_strides();
123+
unsafe {
124+
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
125+
array.copy_ptr(self.as_ptr(), len);
126+
array
127+
}
128+
} else {
129+
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
130+
let dim = self.raw_dim();
131+
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
132+
unsafe {
133+
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
134+
let data_ptr = array.data();
135+
for (i, item) in self.iter().enumerate() {
136+
data_ptr.offset(i as isize).write(*item);
137+
}
138+
array
139+
}
110140
}
111141
}
112142
}
113143

114-
fn npy_strides<S, D, A>(array: &ArrayBase<S, D>) -> Vec<npyffi::npy_intp>
144+
enum Order {
145+
Standard,
146+
Fortran,
147+
}
148+
149+
impl Order {
150+
fn to_flag(&self) -> c_int {
151+
match self {
152+
Order::Standard => 0,
153+
Order::Fortran => 1,
154+
}
155+
}
156+
}
157+
158+
trait ArrayExt {
159+
fn npy_strides(&self) -> NpyStrides;
160+
fn order(&self) -> Option<Order>;
161+
}
162+
163+
impl<A, S, D> ArrayExt for ArrayBase<S, D>
115164
where
116165
S: Data<Elem = A>,
117166
D: Dimension,
118-
A: TypeNum,
119167
{
120-
array
121-
.strides()
122-
.into_iter()
123-
.map(|n| n * mem::size_of::<A>() as npyffi::npy_intp)
124-
.collect()
168+
fn npy_strides(&self) -> NpyStrides {
169+
NpyStrides::new(
170+
self.strides().into_iter().map(|&x| x as npyffi::npy_intp),
171+
mem::size_of::<A>(),
172+
)
173+
}
174+
175+
fn order(&self) -> Option<Order> {
176+
if self.is_standard_layout() {
177+
Some(Order::Standard)
178+
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
179+
Some(Order::Fortran)
180+
} else {
181+
None
182+
}
183+
}
184+
}
185+
186+
/// Numpy strides with short array optimization
187+
enum NpyStrides {
188+
Short([npyffi::npy_intp; 8]),
189+
Long(Vec<npyffi::npy_intp>),
190+
}
191+
192+
impl NpyStrides {
193+
fn as_ptr(&self) -> *const npy_intp {
194+
match self {
195+
NpyStrides::Short(inner) => inner.as_ptr(),
196+
NpyStrides::Long(inner) => inner.as_ptr(),
197+
}
198+
}
199+
fn from_dim<D: Dimension>(dim: &D, type_size: usize) -> Self {
200+
Self::new(
201+
dim.default_strides()
202+
.slice()
203+
.into_iter()
204+
.map(|&x| x as npyffi::npy_intp),
205+
type_size,
206+
)
207+
}
208+
fn new(strides: impl ExactSizeIterator<Item = npyffi::npy_intp>, type_size: usize) -> Self {
209+
let len = strides.len();
210+
let type_size = type_size as npyffi::npy_intp;
211+
if len <= 8 {
212+
let mut res = [0; 8];
213+
for (i, s) in strides.enumerate() {
214+
res[i] = s * type_size;
215+
}
216+
NpyStrides::Short(res)
217+
} else {
218+
NpyStrides::Long(strides.map(|n| n as npyffi::npy_intp * type_size).collect())
219+
}
220+
}
125221
}
126222

127223
/// Utility trait to specify the dimention of array

src/error.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
//! Defines error types.
2-
3-
use crate::array::PyArray;
4-
use crate::convert::ToNpyDims;
5-
use crate::types::{NpyDataType, TypeNum};
2+
use crate::types::NpyDataType;
63
use pyo3::{exceptions as exc, PyErr, PyResult, Python};
74
use std::error;
85
use std::fmt;

src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl NpyDataType {
6868
}
6969
}
7070

71-
pub trait TypeNum {
71+
pub trait TypeNum: std::fmt::Debug + Copy {
7272
fn is_same_type(other: i32) -> bool;
7373
fn npy_data_type() -> NpyDataType;
7474
fn typenum_default() -> i32;

tests/to_py.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,36 @@ fn into_pyarray_cant_resize() {
108108
let arr = a.into_pyarray(gil.python());
109109
assert!(arr.resize(100).is_err())
110110
}
111+
112+
#[test]
113+
fn forder_to_pyarray() {
114+
let gil = pyo3::Python::acquire_gil();
115+
let py = gil.python();
116+
let matrix = Array2::from_shape_vec([4, 2], vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
117+
let fortran_matrix = matrix.reversed_axes();
118+
let fmat_py = fortran_matrix.to_pyarray(py);
119+
assert_eq!(fmat_py.as_array(), array![[0, 2, 4, 6], [1, 3, 5, 7]],);
120+
pyo3::py_run!(py, fmat_py, "assert fmat_py.flags['F_CONTIGUOUS']")
121+
}
122+
123+
#[test]
124+
fn slice_to_pyarray() {
125+
let gil = pyo3::Python::acquire_gil();
126+
let py = gil.python();
127+
let matrix = Array2::from_shape_vec([4, 2], vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
128+
let slice = matrix.slice(s![1..4; -1, ..]);
129+
let slice_py = slice.to_pyarray(py);
130+
assert_eq!(slice_py.as_array(), array![[6, 7], [4, 5], [2, 3]],);
131+
pyo3::py_run!(py, slice_py, "assert slice_py.flags['C_CONTIGUOUS']")
132+
}
133+
134+
#[test]
135+
fn forder_into_pyarray() {
136+
let gil = pyo3::Python::acquire_gil();
137+
let py = gil.python();
138+
let matrix = Array2::from_shape_vec([4, 2], vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
139+
let fortran_matrix = matrix.reversed_axes();
140+
let fmat_py = fortran_matrix.into_pyarray(py);
141+
assert_eq!(fmat_py.as_array(), array![[0, 2, 4, 6], [1, 3, 5, 7]],);
142+
pyo3::py_run!(py, fmat_py, "assert fmat_py.flags['F_CONTIGUOUS']")
143+
}

0 commit comments

Comments
 (0)