Skip to content

Commit 591f6ce

Browse files
authored
Merge pull request #49 from kngwyu/scalar
Add i8/u8/i16/u16 support
2 parents 2139a7e + e4cda62 commit 591f6ce

File tree

8 files changed

+106
-49
lines changed

8 files changed

+106
-49
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ keywords = ["numpy", "python", "binding"]
1010
license-file = "LICENSE"
1111

1212
[dependencies]
13+
cfg-if = "0.1.5"
1314
libc = "0.2"
1415
num-complex = "0.1"
1516
ndarray = "0.11"

appveyor.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ install:
1414
- rustc -V
1515
- cargo -V
1616
- set RUST_BACKTRACE=1
17+
- pip install numpy
1718

1819
build_script:
1920
- cargo build --verbose

src/array.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//! Untyped safe interface for NumPy ndarray
1+
//! Safe interface for NumPy ndarray
22
33
use ndarray::*;
44
use npyffi;
@@ -42,7 +42,7 @@ impl<'a, T: TypeNum> FromPyObject<'a> for &'a PyArray<T> {
4242
array
4343
.type_check()
4444
.map(|_| array)
45-
.map_err(|err| err.into_pyerr("FromPyObject::extract failed"))
45+
.map_err(|err| err.into_pyerr("FromPyObject::extract typecheck failed"))
4646
}
4747
}
4848

@@ -324,12 +324,11 @@ impl<T: TypeNum> PyArray<T> {
324324
}
325325

326326
fn type_check(&self) -> Result<(), ArrayCastError> {
327-
let test = T::typenum();
328327
let truth = self.typenum();
329-
if test == truth {
328+
if T::is_same_type(truth) {
330329
Ok(())
331330
} else {
332-
Err(ArrayCastError::to_rust(truth, test))
331+
Err(ArrayCastError::to_rust(truth, T::npy_data_type()))
333332
}
334333
}
335334

@@ -377,7 +376,7 @@ impl<T: TypeNum> PyArray<T> {
377376
np.get_type_object(npyffi::ArrayType::PyArray_Type),
378377
dims.len() as i32,
379378
dims.as_ptr() as *mut npy_intp,
380-
T::typenum(),
379+
T::typenum_default(),
381380
strides,
382381
data,
383382
0, // itemsize
@@ -397,8 +396,8 @@ impl<T: TypeNum> PyArray<T> {
397396
/// use numpy::{PyArray, PyArrayModule};
398397
/// let gil = pyo3::Python::acquire_gil();
399398
/// let np = PyArrayModule::import(gil.python()).unwrap();
400-
/// let pyarray = PyArray::new(gil.python(), &np, &[2, 2]);
401-
/// assert_eq!(pyarray.as_array().unwrap(), array![[0, 0], [0, 0]].into_dyn());
399+
/// let pyarray = PyArray::<i32>::new(gil.python(), &np, &[4, 5, 6]);
400+
/// assert_eq!(pyarray.shape(), &[4, 5, 6]);
402401
/// # }
403402
/// ```
404403
pub fn new(py: Python, np: &PyArrayModule, dims: &[usize]) -> Self {
@@ -423,7 +422,7 @@ impl<T: TypeNum> PyArray<T> {
423422
pub fn zeros(py: Python, np: &PyArrayModule, dims: &[usize], is_fortran: bool) -> Self {
424423
let dims: Vec<npy_intp> = dims.iter().map(|d| *d as npy_intp).collect();
425424
unsafe {
426-
let descr = np.PyArray_DescrFromType(T::typenum());
425+
let descr = np.PyArray_DescrFromType(T::typenum_default());
427426
let ptr = np.PyArray_Zeros(
428427
dims.len() as i32,
429428
dims.as_ptr() as *mut npy_intp,
@@ -452,7 +451,7 @@ impl<T: TypeNum> PyArray<T> {
452451
/// # }
453452
pub fn arange(py: Python, np: &PyArrayModule, start: f64, stop: f64, step: f64) -> Self {
454453
unsafe {
455-
let ptr = np.PyArray_Arange(start, stop, step, T::typenum());
454+
let ptr = np.PyArray_Arange(start, stop, step, T::typenum_default());
456455
Self::from_owned_ptr(py, ptr)
457456
}
458457
}

src/convert.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ macro_rules! array_impls {
6363
$(
6464
impl<T: TypeNum> IntoPyArray for [T; $N] {
6565
type Item = T;
66-
fn into_pyarray(mut self, py: Python, np: &PyArrayModule) -> PyArray<T> {
66+
fn into_pyarray(self, py: Python, np: &PyArrayModule) -> PyArray<T> {
6767
let dims = [$N];
68-
let ptr = &mut self as *mut [T; $N];
68+
let ptr = Box::into_raw(Box::new(self));
6969
unsafe {
7070
PyArray::new_(py, np, &dims, null_mut(), ptr as *mut c_void)
7171
}

src/error.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,16 @@ impl<T, E: IntoPyErr> IntoPyResult for Result<T, E> {
2525
#[derive(Debug)]
2626
pub enum ArrayCastError {
2727
/// Error for casting `PyArray` into `ArrayView` or `ArrayViewMut`
28-
ToRust {
29-
from: NpyDataType,
30-
to: NpyDataType,
31-
},
28+
ToRust { from: NpyDataType, to: NpyDataType },
3229
/// Error for casting rust's `Vec` into numpy array.
3330
FromVec,
3431
}
3532

3633
impl ArrayCastError {
37-
pub(crate) fn to_rust(from: i32, to: i32) -> Self {
34+
pub(crate) fn to_rust(from: i32, to: NpyDataType) -> Self {
3835
ArrayCastError::ToRust {
3936
from: NpyDataType::from_i32(from),
40-
to: NpyDataType::from_i32(to),
37+
to,
4138
}
4239
}
4340
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#![feature(specialization)]
22

3+
#[macro_use]
4+
extern crate cfg_if;
35
extern crate libc;
46
extern crate ndarray;
57
extern crate num_complex;

src/types.rs

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@ use super::npyffi::NPY_TYPES;
1313
#[derive(Clone, Debug, Eq, PartialEq)]
1414
pub enum NpyDataType {
1515
Bool,
16+
Int8,
17+
Int16,
1618
Int32,
1719
Int64,
20+
Uint8,
21+
Uint16,
1822
Uint32,
1923
Uint64,
2024
Float32,
@@ -28,46 +32,86 @@ impl NpyDataType {
2832
pub(crate) fn from_i32(npy_t: i32) -> Self {
2933
match npy_t {
3034
x if x == NPY_TYPES::NPY_BOOL as i32 => NpyDataType::Bool,
35+
x if x == NPY_TYPES::NPY_BYTE as i32 => NpyDataType::Int8,
36+
x if x == NPY_TYPES::NPY_SHORT as i32 => NpyDataType::Int16,
3137
x if x == NPY_TYPES::NPY_INT as i32 => NpyDataType::Int32,
32-
x if x == NPY_TYPES::NPY_LONG as i32 => NpyDataType::Int64,
38+
x if x == NPY_TYPES::NPY_LONG as i32 => NpyDataType::from_clong(false),
39+
x if x == NPY_TYPES::NPY_LONGLONG as i32 => NpyDataType::Int64,
40+
x if x == NPY_TYPES::NPY_UBYTE as i32 => NpyDataType::Uint8,
41+
x if x == NPY_TYPES::NPY_USHORT as i32 => NpyDataType::Uint16,
3342
x if x == NPY_TYPES::NPY_UINT as i32 => NpyDataType::Uint32,
34-
x if x == NPY_TYPES::NPY_ULONG as i32 => NpyDataType::Uint64,
43+
x if x == NPY_TYPES::NPY_ULONG as i32 => NpyDataType::from_clong(true),
44+
x if x == NPY_TYPES::NPY_ULONGLONG as i32 => NpyDataType::Uint64,
3545
x if x == NPY_TYPES::NPY_FLOAT as i32 => NpyDataType::Float32,
3646
x if x == NPY_TYPES::NPY_DOUBLE as i32 => NpyDataType::Float64,
3747
x if x == NPY_TYPES::NPY_CFLOAT as i32 => NpyDataType::Complex32,
3848
x if x == NPY_TYPES::NPY_CDOUBLE as i32 => NpyDataType::Complex64,
3949
_ => NpyDataType::Unsupported,
4050
}
4151
}
52+
#[inline(always)]
53+
fn from_clong(is_usize: bool) -> NpyDataType {
54+
if cfg!(any(target_pointer_width = "32", windows)) {
55+
if is_usize {
56+
NpyDataType::Uint32
57+
} else {
58+
NpyDataType::Int32
59+
}
60+
} else if cfg!(all(target_pointer_width = "64", not(windows))) {
61+
if is_usize {
62+
NpyDataType::Uint64
63+
} else {
64+
NpyDataType::Int64
65+
}
66+
} else {
67+
NpyDataType::Unsupported
68+
}
69+
}
4270
}
4371

4472
pub trait TypeNum: Clone {
45-
fn typenum_enum() -> NPY_TYPES;
46-
fn typenum() -> i32 {
47-
Self::typenum_enum() as i32
48-
}
49-
fn to_npy_data_type(self) -> NpyDataType;
73+
fn is_same_type(other: i32) -> bool;
74+
fn npy_data_type() -> NpyDataType;
75+
fn typenum_default() -> i32;
5076
}
5177

5278
macro_rules! impl_type_num {
53-
($t:ty, $npy_t:ident, $npy_dat_t:ident) => {
79+
($t:ty, $npy_dat_t:ident $(,$npy_types: ident)+) => {
5480
impl TypeNum for $t {
55-
fn typenum_enum() -> NPY_TYPES {
56-
NPY_TYPES::$npy_t
81+
fn is_same_type(other: i32) -> bool {
82+
$(other == NPY_TYPES::$npy_types as i32 ||)+ false
5783
}
58-
fn to_npy_data_type(self) -> NpyDataType {
84+
fn npy_data_type() -> NpyDataType {
5985
NpyDataType::$npy_dat_t
6086
}
87+
fn typenum_default() -> i32 {
88+
let t = ($(NPY_TYPES::$npy_types, )+);
89+
t.0 as i32
90+
}
6191
}
6292
};
63-
} // impl_type_num!
93+
}
94+
95+
impl_type_num!(bool, Bool, NPY_BOOL);
96+
impl_type_num!(i8, Int8, NPY_BYTE);
97+
impl_type_num!(i16, Int16, NPY_SHORT);
98+
impl_type_num!(u8, Uint8, NPY_UBYTE);
99+
impl_type_num!(u16, Uint16, NPY_USHORT);
100+
impl_type_num!(f32, Float32, NPY_FLOAT);
101+
impl_type_num!(f64, Float64, NPY_DOUBLE);
102+
impl_type_num!(c32, Complex32, NPY_CFLOAT);
103+
impl_type_num!(c64, Complex64, NPY_CDOUBLE);
64104

65-
impl_type_num!(bool, NPY_BOOL, Bool);
66-
impl_type_num!(i32, NPY_INT, Int32);
67-
impl_type_num!(i64, NPY_LONG, Int64);
68-
impl_type_num!(u32, NPY_UINT, Uint32);
69-
impl_type_num!(u64, NPY_ULONG, Uint64);
70-
impl_type_num!(f32, NPY_FLOAT, Float32);
71-
impl_type_num!(f64, NPY_DOUBLE, Float64);
72-
impl_type_num!(c32, NPY_CFLOAT, Complex32);
73-
impl_type_num!(c64, NPY_CDOUBLE, Complex64);
105+
cfg_if! {
106+
if #[cfg(any(target_pointer_width = "32", windows))] {
107+
impl_type_num!(i32, Int32, NPY_INT, NPY_LONG);
108+
impl_type_num!(u32, Uint32, NPY_UINT, NPY_ULONG);
109+
impl_type_num!(i64, Int64, NPY_LONGLONG);
110+
impl_type_num!(u64, Uint64, NPY_ULONGLONG);
111+
} else if #[cfg(all(target_pointer_width = "64", not(windows)))] {
112+
impl_type_num!(i32, Int32, NPY_INT);
113+
impl_type_num!(u32, Uint32, NPY_UINT);
114+
impl_type_num!(i64, Int64, NPY_LONG, NPY_LONGLONG);
115+
impl_type_num!(u64, Uint64, NPY_LONG, NPY_ULONGLONG);
116+
}
117+
}

tests/array.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ fn iter_to_pyarray() {
9292
let gil = pyo3::Python::acquire_gil();
9393
let np = PyArrayModule::import(gil.python()).unwrap();
9494
let arr = PyArray::from_iter(gil.python(), &np, (0..10).map(|x| x * x));
95-
assert_eq!(arr.as_slice().unwrap(), &[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]);
95+
assert_eq!(
96+
arr.as_slice().unwrap(),
97+
&[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
98+
);
9699
}
97100

98101
#[test]
@@ -130,15 +133,6 @@ fn from_vec3() {
130133
);
131134
}
132135

133-
#[test]
134-
fn from_small_array() {
135-
let gil = pyo3::Python::acquire_gil();
136-
let np = PyArrayModule::import(gil.python()).unwrap();
137-
let array: [i32; 5] = [1, 2, 3, 4, 5];
138-
let pyarray = array.into_pyarray(gil.python(), &np);
139-
assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
140-
}
141-
142136
#[test]
143137
fn from_eval() {
144138
let gil = pyo3::Python::acquire_gil();
@@ -168,3 +162,22 @@ fn from_eval_fail() {
168162
assert!(converted.is_err());
169163
}
170164

165+
macro_rules! small_array_test {
166+
($($t: ident)+) => {
167+
#[test]
168+
fn from_small_array() {
169+
let gil = pyo3::Python::acquire_gil();
170+
let np = PyArrayModule::import(gil.python()).unwrap();
171+
$({
172+
let array: [$t; 2] = [$t::min_value(), $t::max_value()];
173+
let pyarray = array.into_pyarray(gil.python(), &np);
174+
assert_eq!(
175+
pyarray.as_slice().unwrap(),
176+
&[$t::min_value(), $t::max_value()]
177+
);
178+
})+
179+
}
180+
};
181+
}
182+
183+
small_array_test!(i8 u8 i16 u16 i32 u32 i64 u64);

0 commit comments

Comments
 (0)