Skip to content

Commit 2edf1f3

Browse files
authored
Merge pull request #143 from kngwyu/element
Refactor Element trait
2 parents f156397 + 9fa4571 commit 2edf1f3

File tree

8 files changed

+200
-198
lines changed

8 files changed

+200
-198
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
- Remove `ErrorKind` and introduce some concrete error types
44
- `PyArray::as_slice`, `PyArray::as_slice_mut`, `PyArray::as_array`, and `PyArray::as_array_mut` is now unsafe.
55
- Introduce `PyArray::as_cell_slice`, `PyArray::to_vec`, and `PyArray::to_owned_array`
6+
- Rename `TypeNum` trait `Element`, and `NpyDataType` `DataType`
67

78
- v0.9.0
89
- Update PyO3 to 0.10.0

src/array.rs

Lines changed: 87 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::{iter::ExactSizeIterator, marker::PhantomData};
1212
use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
1313
use crate::error::{FromVecError, NotContiguousError, ShapeError};
1414
use crate::slice_box::SliceBox;
15-
use crate::types::{NpyDataType, TypeNum};
15+
use crate::types::Element;
1616

1717
/// A safe, static-typed interface for
1818
/// [NumPy ndarray](https://numpy.org/doc/stable/reference/arrays.ndarray.html).
@@ -46,7 +46,7 @@ use crate::types::{NpyDataType, TypeNum};
4646
/// `PyArray` has 2 type parametes `T` and `D`. `T` represents its data type like
4747
/// [`f32`](https://doc.rust-lang.org/std/primitive.f32.html), and `D` represents its dimension.
4848
///
49-
/// All data types you can use implements [TypeNum](../types/trait.TypeNum.html).
49+
/// All data types you can use implements [Element](../types/trait.Element.html).
5050
///
5151
/// Dimensions are represented by ndarray's
5252
/// [Dimension](https://docs.rs/ndarray/latest/ndarray/trait.Dimension.html) trait.
@@ -117,7 +117,7 @@ impl<'a, T, D> std::convert::From<&'a PyArray<T, D>> for &'a PyAny {
117117
}
118118
}
119119

120-
impl<'a, T: TypeNum, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
120+
impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
121121
// here we do type-check three times
122122
// 1. Checks if the object is PyArray
123123
// 2. Checks if the data type of the array is T
@@ -292,34 +292,27 @@ impl<T, D> PyArray<T, D> {
292292
self.len() == 0
293293
}
294294

295-
fn typenum(&self) -> i32 {
295+
fn strides_usize(&self) -> &[usize] {
296+
let n = self.ndim();
297+
let ptr = self.as_array_ptr();
296298
unsafe {
297-
let descr = (*self.as_array_ptr()).descr;
298-
(*descr).type_num
299+
let p = (*ptr).strides;
300+
slice::from_raw_parts(p as *const _, n)
299301
}
300302
}
301303

302304
/// Returns the pointer to the first element of the inner array.
303305
pub(crate) unsafe fn data(&self) -> *mut T {
304306
let ptr = self.as_array_ptr();
305-
(*ptr).data as *mut T
307+
(*ptr).data as *mut _
306308
}
307309

308310
pub(crate) unsafe fn copy_ptr(&self, other: *const T, len: usize) {
309311
ptr::copy_nonoverlapping(other, self.data(), len)
310312
}
311-
312-
fn strides_usize(&self) -> &[usize] {
313-
let n = self.ndim();
314-
let ptr = self.as_array_ptr();
315-
unsafe {
316-
let p = (*ptr).strides;
317-
slice::from_raw_parts(p as *const _, n)
318-
}
319-
}
320313
}
321314

322-
impl<T: TypeNum, D: Dimension> PyArray<T, D> {
315+
impl<T: Element, D: Dimension> PyArray<T, D> {
323316
/// Same as [shape](#method.shape), but returns `D`
324317
#[inline(always)]
325318
pub fn dims(&self) -> D {
@@ -369,7 +362,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
369362
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
370363
dims.ndim_cint(),
371364
dims.as_dims_ptr(),
372-
T::typenum_default(),
365+
T::ffi_dtype() as i32,
373366
strides as *mut _, // strides
374367
ptr::null_mut(), // data
375368
0, // itemsize
@@ -398,7 +391,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
398391
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
399392
dims.ndim_cint(),
400393
dims.as_dims_ptr(),
401-
T::typenum_default(),
394+
T::ffi_dtype() as i32,
402395
strides as *mut _, // strides
403396
data_ptr as _, // data
404397
mem::size_of::<T>() as i32, // itemsize
@@ -430,7 +423,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
430423
{
431424
let dims = dims.into_dimension();
432425
unsafe {
433-
let descr = PY_ARRAY_API.PyArray_DescrFromType(T::typenum_default());
426+
let descr = PY_ARRAY_API.PyArray_DescrFromType(T::ffi_dtype() as i32);
434427
let ptr = PY_ARRAY_API.PyArray_Zeros(
435428
dims.ndim_cint(),
436429
dims.as_dims_ptr(),
@@ -480,26 +473,6 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
480473
}
481474
}
482475

483-
/// Construct PyArray from `ndarray::ArrayBase`.
484-
///
485-
/// This method allocates memory in Python's heap via numpy api, and then copies all elements
486-
/// of the array there.
487-
///
488-
/// # Example
489-
/// ```
490-
/// # #[macro_use] extern crate ndarray;
491-
/// use numpy::PyArray;
492-
/// let gil = pyo3::Python::acquire_gil();
493-
/// let pyarray = PyArray::from_array(gil.python(), &array![[1, 2], [3, 4]]);
494-
/// assert_eq!(pyarray.readonly().as_array(), array![[1, 2], [3, 4]]);
495-
/// ```
496-
pub fn from_array<'py, S>(py: Python<'py>, arr: &ArrayBase<S, D>) -> &'py Self
497-
where
498-
S: Data<Elem = T>,
499-
{
500-
ToPyArray::to_pyarray(arr, py)
501-
}
502-
503476
/// Construct PyArray from
504477
/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html).
505478
///
@@ -518,25 +491,6 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
518491
IntoPyArray::into_pyarray(arr, py)
519492
}
520493

521-
/// Get the immutable view of the internal data of `PyArray`, as
522-
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
523-
///
524-
/// # Safety
525-
/// If the internal array is not readonly and can be mutated from Python code,
526-
/// holding the `ArrayView` might cause undefined behavior.
527-
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
528-
ArrayView::from_shape_ptr(self.ndarray_shape(), self.data())
529-
}
530-
531-
/// Returns the internal array as `ArrayViewMut`. See also [`as_array`](#method.as_array).
532-
///
533-
/// # Safety
534-
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
535-
/// it might cause undefined behavior.
536-
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
537-
ArrayViewMut::from_shape_ptr(self.ndarray_shape(), self.data())
538-
}
539-
540494
/// Get an immutable reference of a specified element, with checking the passed index is valid.
541495
///
542496
/// See [NpyIndex](../convert/trait.NpyIndex.html) for what types you can use as index.
@@ -608,7 +562,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
608562
Idx: NpyIndex<Dim = D>,
609563
{
610564
let offset = index.get_unchecked::<T>(self.strides());
611-
&mut *(self.data().offset(offset) as *mut T)
565+
&mut *(self.data().offset(offset) as *mut _)
612566
}
613567

614568
/// Get dynamic dimensioned array from fixed dimension array.
@@ -620,35 +574,14 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
620574
}
621575

622576
fn type_check(&self) -> Result<(), ShapeError> {
623-
let truth = self.typenum();
577+
let truth = unsafe { (*(*self.as_array_ptr()).descr).type_num };
624578
let dim = self.shape().len();
625-
let dim_ok = D::NDIM.map(|n| n == dim).unwrap_or(true);
626-
if T::is_same_type(truth) && dim_ok {
579+
if T::is_same_type(truth) && D::NDIM.map(|n| n == dim).unwrap_or(true) {
627580
Ok(())
628581
} else {
629-
Err(ShapeError::new(truth, dim, T::npy_data_type(), D::NDIM))
582+
Err(ShapeError::new(truth, dim, T::DATA_TYPE, D::NDIM))
630583
}
631584
}
632-
}
633-
634-
impl<T: Clone + TypeNum, D: Dimension> PyArray<T, D> {
635-
/// Get a copy of `PyArray` as
636-
/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html).
637-
///
638-
/// # Example
639-
/// ```
640-
/// # #[macro_use] extern crate ndarray;
641-
/// use numpy::PyArray;
642-
/// let gil = pyo3::Python::acquire_gil();
643-
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
644-
/// assert_eq!(
645-
/// py_array.to_owned_array(),
646-
/// array![[0, 1], [2, 3]]
647-
/// )
648-
/// ```
649-
pub fn to_owned_array(&self) -> Array<T, D> {
650-
unsafe { self.as_array() }.to_owned()
651-
}
652585

653586
/// Returns the copy of the internal data of `PyArray` to `Vec`.
654587
///
@@ -672,9 +605,66 @@ impl<T: Clone + TypeNum, D: Dimension> PyArray<T, D> {
672605
pub fn to_vec(&self) -> Result<Vec<T>, NotContiguousError> {
673606
unsafe { self.as_slice() }.map(ToOwned::to_owned)
674607
}
608+
609+
/// Construct PyArray from `ndarray::ArrayBase`.
610+
///
611+
/// This method allocates memory in Python's heap via numpy api, and then copies all elements
612+
/// of the array there.
613+
///
614+
/// # Example
615+
/// ```
616+
/// # #[macro_use] extern crate ndarray;
617+
/// use numpy::PyArray;
618+
/// let gil = pyo3::Python::acquire_gil();
619+
/// let pyarray = PyArray::from_array(gil.python(), &array![[1, 2], [3, 4]]);
620+
/// assert_eq!(pyarray.readonly().as_array(), array![[1, 2], [3, 4]]);
621+
/// ```
622+
pub fn from_array<'py, S>(py: Python<'py>, arr: &ArrayBase<S, D>) -> &'py Self
623+
where
624+
S: Data<Elem = T>,
625+
{
626+
ToPyArray::to_pyarray(arr, py)
627+
}
628+
629+
/// Get the immutable view of the internal data of `PyArray`, as
630+
/// [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
631+
///
632+
/// # Safety
633+
/// If the internal array is not readonly and can be mutated from Python code,
634+
/// holding the `ArrayView` might cause undefined behavior.
635+
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
636+
ArrayView::from_shape_ptr(self.ndarray_shape(), self.data())
637+
}
638+
639+
/// Returns the internal array as `ArrayViewMut`. See also [`as_array`](#method.as_array).
640+
///
641+
/// # Safety
642+
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
643+
/// it might cause undefined behavior.
644+
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
645+
ArrayViewMut::from_shape_ptr(self.ndarray_shape(), self.data())
646+
}
647+
648+
/// Get a copy of `PyArray` as
649+
/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html).
650+
///
651+
/// # Example
652+
/// ```
653+
/// # #[macro_use] extern crate ndarray;
654+
/// use numpy::PyArray;
655+
/// let gil = pyo3::Python::acquire_gil();
656+
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
657+
/// assert_eq!(
658+
/// py_array.to_owned_array(),
659+
/// array![[0, 1], [2, 3]]
660+
/// )
661+
/// ```
662+
pub fn to_owned_array(&self) -> Array<T, D> {
663+
unsafe { self.as_array() }.to_owned()
664+
}
675665
}
676666

677-
impl<T: TypeNum> PyArray<T, Ix1> {
667+
impl<T: Element> PyArray<T, Ix1> {
678668
/// Construct one-dimension PyArray from slice.
679669
///
680670
/// # Example
@@ -808,7 +798,7 @@ impl<T: TypeNum> PyArray<T, Ix1> {
808798
}
809799
}
810800

811-
impl<T: TypeNum> PyArray<T, Ix2> {
801+
impl<T: Element> PyArray<T, Ix2> {
812802
/// Construct a two-dimension PyArray from `Vec<Vec<T>>`.
813803
///
814804
/// This function checks all dimension of inner vec, and if there's any vec
@@ -824,10 +814,7 @@ impl<T: TypeNum> PyArray<T, Ix2> {
824814
/// assert_eq!(pyarray.readonly().as_array(), array![[1, 2, 3], [1, 2, 3]]);
825815
/// assert!(PyArray::from_vec2(gil.python(), &[vec![1], vec![2, 3]]).is_err());
826816
/// ```
827-
pub fn from_vec2<'py>(py: Python<'py>, v: &[Vec<T>]) -> Result<&'py Self, FromVecError>
828-
where
829-
T: Clone,
830-
{
817+
pub fn from_vec2<'py>(py: Python<'py>, v: &[Vec<T>]) -> Result<&'py Self, FromVecError> {
831818
let last_len = v.last().map_or(0, |v| v.len());
832819
if v.iter().any(|v| v.len() != last_len) {
833820
return Err(FromVecError::new(v.len(), last_len));
@@ -837,15 +824,15 @@ impl<T: TypeNum> PyArray<T, Ix2> {
837824
unsafe {
838825
for (y, vy) in v.iter().enumerate() {
839826
for (x, vyx) in vy.iter().enumerate() {
840-
*array.uget_mut([y, x]) = *vyx;
827+
*array.uget_mut([y, x]) = vyx.clone();
841828
}
842829
}
843830
}
844831
Ok(array)
845832
}
846833
}
847834

848-
impl<T: TypeNum> PyArray<T, Ix3> {
835+
impl<T: Element> PyArray<T, Ix3> {
849836
/// Construct a three-dimension PyArray from `Vec<Vec<Vec<T>>>`.
850837
///
851838
/// This function checks all dimension of inner vec, and if there's any vec
@@ -864,10 +851,7 @@ impl<T: TypeNum> PyArray<T, Ix3> {
864851
/// );
865852
/// assert!(PyArray::from_vec3(gil.python(), &[vec![vec![1], vec![]]]).is_err());
866853
/// ```
867-
pub fn from_vec3<'py>(py: Python<'py>, v: &[Vec<Vec<T>>]) -> Result<&'py Self, FromVecError>
868-
where
869-
T: Clone,
870-
{
854+
pub fn from_vec3<'py>(py: Python<'py>, v: &[Vec<Vec<T>>]) -> Result<&'py Self, FromVecError> {
871855
let len2 = v.last().map_or(0, |v| v.len());
872856
if v.iter().any(|v| v.len() != len2) {
873857
return Err(FromVecError::new(v.len(), len2));
@@ -882,7 +866,7 @@ impl<T: TypeNum> PyArray<T, Ix3> {
882866
for (z, vz) in v.iter().enumerate() {
883867
for (y, vzy) in vz.iter().enumerate() {
884868
for (x, vzyx) in vzy.iter().enumerate() {
885-
*array.uget_mut([z, y, x]) = *vzyx;
869+
*array.uget_mut([z, y, x]) = vzyx.clone();
886870
}
887871
}
888872
}
@@ -891,12 +875,7 @@ impl<T: TypeNum> PyArray<T, Ix3> {
891875
}
892876
}
893877

894-
impl<T: TypeNum, D> PyArray<T, D> {
895-
/// Returns the scalar type of the array.
896-
pub fn data_type(&self) -> NpyDataType {
897-
NpyDataType::from_i32(self.typenum())
898-
}
899-
878+
impl<T: Element, D> PyArray<T, D> {
900879
/// Copies self into `other`, performing a data-type conversion if necessary.
901880
/// # Example
902881
/// ```
@@ -907,7 +886,7 @@ impl<T: TypeNum, D> PyArray<T, D> {
907886
/// assert!(pyarray_f.copy_to(pyarray_i).is_ok());
908887
/// assert_eq!(pyarray_i.readonly().as_slice().unwrap(), &[2, 3, 4]);
909888
/// ```
910-
pub fn copy_to<U: TypeNum>(&self, other: &PyArray<U, D>) -> PyResult<()> {
889+
pub fn copy_to<U: Element>(&self, other: &PyArray<U, D>) -> PyResult<()> {
911890
let self_ptr = self.as_array_ptr();
912891
let other_ptr = other.as_array_ptr();
913892
let result = unsafe { PY_ARRAY_API.PyArray_CopyInto(other_ptr, self_ptr) };
@@ -926,9 +905,9 @@ impl<T: TypeNum, D> PyArray<T, D> {
926905
/// let pyarray_f = PyArray::arange(gil.python(), 2.0, 5.0, 1.0);
927906
/// let pyarray_i = pyarray_f.cast::<i32>(false).unwrap();
928907
/// assert_eq!(pyarray_i.readonly().as_slice().unwrap(), &[2, 3, 4]);
929-
pub fn cast<'py, U: TypeNum>(&'py self, is_fortran: bool) -> PyResult<&'py PyArray<U, D>> {
908+
pub fn cast<'py, U: Element>(&'py self, is_fortran: bool) -> PyResult<&'py PyArray<U, D>> {
930909
let ptr = unsafe {
931-
let descr = PY_ARRAY_API.PyArray_DescrFromType(U::typenum_default());
910+
let descr = PY_ARRAY_API.PyArray_DescrFromType(U::ffi_dtype() as i32);
932911
PY_ARRAY_API.PyArray_CastToType(
933912
self.as_array_ptr(),
934913
descr,
@@ -995,7 +974,7 @@ impl<T: TypeNum, D> PyArray<T, D> {
995974
}
996975
}
997976

998-
impl<T: TypeNum + AsPrimitive<f64>> PyArray<T, Ix1> {
977+
impl<T: Element + AsPrimitive<f64>> PyArray<T, Ix1> {
999978
/// Return evenly spaced values within a given interval.
1000979
/// Same as [numpy.arange](https://numpy.org/doc/stable/reference/generated/numpy.arange.html).
1001980
///
@@ -1015,7 +994,7 @@ impl<T: TypeNum + AsPrimitive<f64>> PyArray<T, Ix1> {
1015994
start.as_(),
1016995
stop.as_(),
1017996
step.as_(),
1018-
T::typenum_default(),
997+
T::ffi_dtype() as i32,
1019998
);
1020999
Self::from_owned_ptr(py, ptr)
10211000
}

0 commit comments

Comments
 (0)