Skip to content

Commit cf04b51

Browse files
committed
Use IntoPyArray in example
1 parent 5a77972 commit cf04b51

File tree

3 files changed

+10
-22
lines changed

3 files changed

+10
-22
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ extern crate numpy;
9595
extern crate pyo3;
9696

9797
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
98-
use numpy::{IntoPyResult, PyArrayDyn, ToPyArray};
98+
use numpy::{IntoPyArray, IntoPyResult, PyArrayDyn};
9999
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
100100

101101
#[pymodinit]
@@ -122,7 +122,7 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
122122
let x = x.as_array()?;
123123
// you can also specify your error context, via closure
124124
let y = y.as_array().into_pyresult_with(|| "y must be f64 array")?;
125-
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
125+
Ok(axpy(a, x, y).into_pyarray(py).to_owned(py))
126126
}
127127

128128
// wrapper of `mult`
@@ -135,7 +135,6 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
135135

136136
Ok(())
137137
}
138-
139138
```
140139

141140
Contribution

example/extensions/src/lib.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ extern crate numpy;
33
extern crate pyo3;
44

55
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
6-
use numpy::{IntoPyArray, IntoPyResult, PyArray1, PyArrayDyn, ToPyArray};
6+
use numpy::{IntoPyArray, IntoPyResult, PyArrayDyn};
77
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
88

99
#[pymodinit]
@@ -30,7 +30,7 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
3030
let x = x.as_array()?;
3131
// you can also specify your error context, via closure
3232
let y = y.as_array().into_pyresult_with(|| "y must be f64 array")?;
33-
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
33+
Ok(axpy(a, x, y).into_pyarray(py).to_owned(py))
3434
}
3535

3636
// wrapper of `mult`
@@ -41,15 +41,5 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
4141
Ok(())
4242
}
4343

44-
#[pyfn(m, "get_vec")]
45-
fn get_vec(py: Python, size: usize) -> PyResult<&PyArray1<f32>> {
46-
Ok(vec![0.0; size].into_pyarray(py))
47-
}
48-
// use numpy::slice_box::SliceBox;
49-
// #[pyfn(m, "get_slice")]
50-
// fn get_slice(py: Python, size: usize) -> PyResult<SliceBox<f32>> {
51-
// let sbox = numpy::slice_box::SliceBox::new(vec![0.0; size].into_boxed_slice());
52-
// Ok(sbox)
53-
// }
5444
Ok(())
5545
}

example/tests/test_ext.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
2-
from rust_ext import axpy, mult, get_vec
2+
from rust_ext import axpy, mult
33
import unittest
44

5+
56
class TestExt(unittest.TestCase):
67
"""Test class for rust functions
78
"""
@@ -11,18 +12,16 @@ def test_axpy(self):
1112
y = np.array([3.0, 3.0, 3.0])
1213
z = axpy(3.0, x, y)
1314
np.testing.assert_array_almost_equal(z, np.array([6.0, 9.0, 12.0]))
15+
x = np.array([*x, 4.0])
16+
y = np.array([*y, 3.0])
17+
z = axpy(3.0, x, y)
18+
np.testing.assert_array_almost_equal(z, np.array([6.0, 9.0, 12.0, 15.0]))
1419

1520
def test_mult(self):
1621
x = np.array([1.0, 2.0, 3.0])
1722
mult(3.0, x)
1823
np.testing.assert_array_almost_equal(x, np.array([3.0, 6.0, 9.0]))
1924

20-
def test_into_pyarray(self):
21-
x = get_vec(1000)
22-
np.testing.assert_array_almost_equal(x, np.zeros(1000))
23-
2425

2526
if __name__ == "__main__":
2627
unittest.main()
27-
28-

0 commit comments

Comments
 (0)