Skip to content

Commit c1cc96f

Browse files
authored
Merge pull request #326 from PyO3/extraction-error
Avoid the overhead of creating a PyErr for downcasting.
2 parents a8aac58 + 3795010 commit c1cc96f

File tree

4 files changed

+76
-29
lines changed

4 files changed

+76
-29
lines changed

benches/array.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,27 @@ fn extract_failure(bencher: &mut Bencher) {
3030
});
3131
}
3232

33+
#[bench]
34+
fn downcast_success(bencher: &mut Bencher) {
35+
Python::with_gil(|py| {
36+
let any: &PyAny = PyArray2::<f64>::zeros(py, (10, 10), false);
37+
38+
bencher.iter(|| {
39+
black_box(any).downcast::<PyArray2<f64>>().unwrap();
40+
});
41+
});
42+
}
43+
44+
#[bench]
45+
fn downcast_failure(bencher: &mut Bencher) {
46+
Python::with_gil(|py| {
47+
let any: &PyAny = PyArray2::<i32>::zeros(py, (10, 10), false);
48+
49+
bencher.iter(|| {
50+
black_box(any).downcast::<PyArray2<f64>>().unwrap_err();
51+
});
52+
});
53+
}
3354
struct Iter(Range<usize>);
3455

3556
impl Iterator for Iter {

src/array.rs

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::cold;
2626
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
2727
use crate::dtype::{Element, PyArrayDescr};
2828
use crate::error::{
29-
BorrowError, DimensionalityError, FromVecError, NotContiguousError, TypeError,
29+
BorrowError, DimensionalityError, FromVecError, IgnoreError, NotContiguousError, TypeError,
3030
DIMENSIONALITY_MISMATCH_ERR, MAX_DIMENSIONALITY_ERR,
3131
};
3232
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
@@ -131,7 +131,7 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
131131
}
132132

133133
fn is_type_of(ob: &PyAny) -> bool {
134-
<&Self>::extract(ob).is_ok()
134+
Self::extract::<IgnoreError>(ob).is_ok()
135135
}
136136
}
137137

@@ -145,30 +145,7 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
145145

146146
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray<T, D> {
147147
fn extract(ob: &'py PyAny) -> PyResult<Self> {
148-
// Check if the object is an array.
149-
let array = unsafe {
150-
if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 {
151-
return Err(PyDowncastError::new(ob, PyArray::<T, D>::NAME).into());
152-
}
153-
&*(ob as *const PyAny as *const PyArray<T, D>)
154-
};
155-
156-
// Check if the dimensionality matches `D`.
157-
let src_ndim = array.ndim();
158-
if let Some(dst_ndim) = D::NDIM {
159-
if src_ndim != dst_ndim {
160-
return Err(DimensionalityError::new(src_ndim, dst_ndim).into());
161-
}
162-
}
163-
164-
// Check if the element type matches `T`.
165-
let src_dtype = array.dtype();
166-
let dst_dtype = T::get_dtype(ob.py());
167-
if !src_dtype.is_equiv_to(dst_dtype) {
168-
return Err(TypeError::new(src_dtype, dst_dtype).into());
169-
}
170-
171-
Ok(array)
148+
PyArray::extract(ob)
172149
}
173150
}
174151

@@ -390,6 +367,36 @@ impl<T, D> PyArray<T, D> {
390367
}
391368

392369
impl<T: Element, D: Dimension> PyArray<T, D> {
370+
fn extract<'py, E>(ob: &'py PyAny) -> Result<&'py Self, E>
371+
where
372+
E: From<PyDowncastError<'py>> + From<DimensionalityError> + From<TypeError<'py>>,
373+
{
374+
// Check if the object is an array.
375+
let array = unsafe {
376+
if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 {
377+
return Err(PyDowncastError::new(ob, Self::NAME).into());
378+
}
379+
&*(ob as *const PyAny as *const Self)
380+
};
381+
382+
// Check if the dimensionality matches `D`.
383+
let src_ndim = array.ndim();
384+
if let Some(dst_ndim) = D::NDIM {
385+
if src_ndim != dst_ndim {
386+
return Err(DimensionalityError::new(src_ndim, dst_ndim).into());
387+
}
388+
}
389+
390+
// Check if the element type matches `T`.
391+
let src_dtype = array.dtype();
392+
let dst_dtype = T::get_dtype(ob.py());
393+
if !src_dtype.is_equiv_to(dst_dtype) {
394+
return Err(TypeError::new(src_dtype, dst_dtype).into());
395+
}
396+
397+
Ok(array)
398+
}
399+
393400
/// Same as [`shape`][Self::shape], but returns `D` insead of `&[usize]`.
394401
#[inline(always)]
395402
pub fn dims(&self) -> D {

src/dtype.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,12 @@ impl PyArrayDescr {
117117

118118
/// Returns true if two type descriptors are equivalent.
119119
pub fn is_equiv_to(&self, other: &Self) -> bool {
120+
let self_ptr = self.as_dtype_ptr();
121+
let other_ptr = other.as_dtype_ptr();
122+
120123
unsafe {
121-
PY_ARRAY_API.PyArray_EquivTypes(self.py(), self.as_dtype_ptr(), other.as_dtype_ptr())
122-
!= 0
124+
self_ptr == other_ptr
125+
|| PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
123126
}
124127
}
125128

@@ -413,7 +416,7 @@ fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
413416

414417
fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
415418
let is_unsigned = T::min_value() == T::zero();
416-
let bit_width = size_of::<T>() << 3;
419+
let bit_width = 8 * size_of::<T>();
417420

418421
match (is_unsigned, bit_width) {
419422
(false, 8) => NPY_TYPES::NPY_BYTE,
@@ -449,6 +452,7 @@ macro_rules! impl_element_scalar {
449452
$(#[$meta])*
450453
unsafe impl Element for $ty {
451454
const IS_COPY: bool = true;
455+
452456
fn get_dtype(py: Python) -> &PyArrayDescr {
453457
PyArrayDescr::from_npy_type(py, $npy_type)
454458
}

src/error.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,18 @@ impl fmt::Display for BorrowError {
162162
}
163163

164164
impl_pyerr!(BorrowError);
165+
166+
/// An internal type used to ignore certain error conditions
167+
///
168+
/// This is beneficial when those errors will never reach a public API anyway
169+
/// but dropping them will improve performance.
170+
pub(crate) struct IgnoreError;
171+
172+
impl<E> From<E> for IgnoreError
173+
where
174+
PyErr: From<E>,
175+
{
176+
fn from(_err: E) -> Self {
177+
Self
178+
}
179+
}

0 commit comments

Comments
 (0)