Skip to content

Commit f603610

Browse files
authored
Merge pull request #72 from kngwyu/err-refine
Refactor IntoPyErr
2 parents 430e022 + d7aabfd commit f603610

File tree

5 files changed

+45
-19
lines changed

5 files changed

+45
-19
lines changed

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ numpy = "0.3"
4545
``` rust
4646
extern crate numpy;
4747
extern crate pyo3;
48-
use numpy::{IntoPyResult, PyArray1, get_array_module};
48+
use numpy::{PyArray1, get_array_module};
4949
use pyo3::prelude::{ObjectProtocol, PyDict, PyResult, Python};
5050

5151
fn main() -> Result<(), ()> {
@@ -65,7 +65,7 @@ fn main_<'py>(py: Python<'py>) -> PyResult<()> {
6565
let pyarray: &PyArray1<i32> = py
6666
.eval("np.array([1, 2, 3], dtype='int32')", Some(&dict), None)?
6767
.extract()?;
68-
let slice = pyarray.as_slice().into_pyresult("Array Cast failed")?;
68+
let slice = pyarray.as_slice()?;
6969
assert_eq!(slice, &[1, 2, 3]);
7070
Ok(())
7171
}
@@ -118,21 +118,24 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
118118
x: &PyArrayDyn<f64>,
119119
y: &PyArrayDyn<f64>,
120120
) -> PyResult<PyArrayDyn<f64>> {
121-
let x = x.as_array().into_pyresult("x must be f64 array")?;
122-
let y = y.as_array().into_pyresult("y must be f64 array")?;
121+
// you can convert numpy error into PyErr via ?
122+
let x = x.as_array()?;
123+
// you can also specify your error context, via closure
124+
let y = y.as_array().into_pyresult_with(|| "y must be f64 array")?;
123125
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
124126
}
125127

126128
// wrapper of `mult`
127129
#[pyfn(m, "mult")]
128130
fn mult_py(_py: Python, a: f64, x: &PyArrayDyn<f64>) -> PyResult<()> {
129-
let x = x.as_array_mut().into_pyresult("x must be f64 array")?;
131+
let x = x.as_array_mut()?;
130132
mult(a, x);
131133
Ok(())
132134
}
133135

134136
Ok(())
135137
}
138+
136139
```
137140

138141
Contribution

example/extensions/src/lib.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
2626
x: &PyArrayDyn<f64>,
2727
y: &PyArrayDyn<f64>,
2828
) -> PyResult<PyArrayDyn<f64>> {
29-
let x = x.as_array().into_pyresult("x must be f64 array")?;
30-
let y = y.as_array().into_pyresult("y must be f64 array")?;
29+
// you can convert numpy error into PyErr via ?
30+
let x = x.as_array()?;
31+
// you can also specify your error context, via closure
32+
let y = y.as_array().into_pyresult_with(|| "y must be f64 array")?;
3133
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
3234
}
3335

3436
// wrapper of `mult`
3537
#[pyfn(m, "mult")]
3638
fn mult_py(_py: Python, a: f64, x: &PyArrayDyn<f64>) -> PyResult<()> {
37-
let x = x.as_array_mut().into_pyresult("x must be f64 array")?;
39+
let x = x.as_array_mut()?;
3840
mult(a, x);
3941
Ok(())
4042
}

src/array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use std::os::raw::c_int;
1010
use std::ptr;
1111

1212
use convert::{NpyIndex, ToNpyDims};
13-
use error::{ErrorKind, IntoPyErr};
13+
use error::{ErrorKind, IntoPyResult};
1414
use types::{NpyDataType, TypeNum};
1515

1616
/// A safe, static-typed interface for
@@ -130,7 +130,7 @@ impl<'a, T: TypeNum, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
130130
array
131131
.type_check()
132132
.map(|_| array)
133-
.map_err(|err| err.into_pyerr("FromPyObject::extract typecheck failed"))
133+
.into_pyresult_with(|| "FromPyObject::extract typecheck failed")
134134
}
135135
}
136136

src/error.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,27 @@ use std::error;
77
use std::fmt;
88
use types::{NpyDataType, TypeNum};
99

10-
pub trait IntoPyErr {
11-
fn into_pyerr(self, msg: &str) -> PyErr;
10+
pub trait IntoPyErr: Into<PyErr> {
11+
fn into_pyerr(self) -> PyErr;
12+
fn into_pyerr_with<D: fmt::Display>(self, impl FnOnce() -> D) -> PyErr;
1213
}
1314

1415
pub trait IntoPyResult {
1516
type ValueType;
16-
fn into_pyresult(self, msg: &str) -> PyResult<Self::ValueType>;
17+
fn into_pyresult(self) -> PyResult<Self::ValueType>;
18+
fn into_pyresult_with<D: fmt::Display>(self, impl FnOnce() -> D) -> PyResult<Self::ValueType>;
1719
}
1820

1921
impl<T, E: IntoPyErr> IntoPyResult for Result<T, E> {
2022
type ValueType = T;
21-
fn into_pyresult(self, msg: &str) -> PyResult<T> {
22-
self.map_err(|e| e.into_pyerr(msg))
23+
fn into_pyresult(self) -> PyResult<Self::ValueType> {
24+
self.map_err(|e| e.into())
25+
}
26+
fn into_pyresult_with<D: fmt::Display>(
27+
self,
28+
msg: impl FnOnce() -> D,
29+
) -> PyResult<Self::ValueType> {
30+
self.map_err(|e| e.into_pyerr_with(msg))
2331
}
2432
}
2533

@@ -117,11 +125,24 @@ impl fmt::Display for ErrorKind {
117125

118126
impl error::Error for ErrorKind {}
119127

128+
impl From<ErrorKind> for PyErr {
129+
fn from(err: ErrorKind) -> PyErr {
130+
match err {
131+
ErrorKind::PyToRust { .. } | ErrorKind::FromVec { .. } | ErrorKind::PyToPy(_) => {
132+
PyErr::new::<exc::TypeError, _>(format!("{}", err))
133+
}
134+
}
135+
}
136+
}
137+
120138
impl IntoPyErr for ErrorKind {
121-
fn into_pyerr(self, msg: &str) -> PyErr {
139+
fn into_pyerr(self) -> PyErr {
140+
Into::into(self)
141+
}
142+
fn into_pyerr_with<D: fmt::Display>(self, msg: impl FnOnce() -> D) -> PyErr {
122143
match self {
123144
ErrorKind::PyToRust { .. } | ErrorKind::FromVec { .. } | ErrorKind::PyToPy(_) => {
124-
PyErr::new::<exc::TypeError, _>(format!("{}, msg: {}", self, msg))
145+
PyErr::new::<exc::TypeError, _>(format!("{} msg: {}", self, msg()))
125146
}
126147
}
127148
}

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub use array::{
5151
PyArrayDyn,
5252
};
5353
pub use convert::{NpyIndex, ToNpyDims, ToPyArray};
54-
pub use error::{IntoPyErr, IntoPyResult, ArrayFormat, ErrorKind};
54+
pub use error::{ArrayFormat, ErrorKind, IntoPyErr, IntoPyResult};
55+
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5556
pub use npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5657
pub use types::{c32, c64, NpyDataType, TypeNum};
57-
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};

0 commit comments

Comments
 (0)