Skip to content

Commit 00f0ec9

Browse files
authored
Merge pull request #32 from kngwyu/example-fix
Fix example
2 parents d0c873e + 8e78723 commit 00f0ec9

File tree

8 files changed

+82
-29
lines changed

8 files changed

+82
-29
lines changed

example/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
build/
22
*.so
33
*.egg-info/
4+
**/dist/
45
__pycache__

example/extensions/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ version = "0.1.0"
44
authors = ["Toshiki Teramura <[email protected]>"]
55

66
[lib]
7-
name = "rust_ext"
7+
name = "numpy_rust_ext"
88
crate-type = ["cdylib"]
99

1010
[dependencies]
1111
numpy = { path = "../.." }
1212
ndarray = "0.10"
1313

1414
[dependencies.pyo3]
15-
version = "*"
15+
git = "https://github.com/PyO3/pyo3.git"
16+
rev = "4169b0317826dc62eafcdd0faab7d009f6808c06"
1617
features = ["extension-module"]

example/extensions/src/lib.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
#![feature(proc_macro, proc_macro_path_invoc, specialization)]
1+
#![feature(proc_macro, specialization)]
22

33
extern crate ndarray;
44
extern crate numpy;
55
extern crate pyo3;
66

77
use ndarray::*;
88
use numpy::*;
9-
use pyo3::{py, PyModule, PyObject, PyResult, Python};
9+
use pyo3::{py::modinit as pymodinit, PyModule, PyResult, Python};
1010

11-
#[py::modinit(rust_ext)]
11+
#[pymodinit(_rust_ext)]
1212
fn init_module(py: Python, m: &PyModule) -> PyResult<()> {
13+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
14+
// You **must** write this sentence for PyArray type checker working correctly
15+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
16+
let _np = PyArrayModule::import(py)?;
17+
1318
// immutable example
1419
fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> {
1520
a * &x + &y
@@ -22,19 +27,19 @@ fn init_module(py: Python, m: &PyModule) -> PyResult<()> {
2227

2328
// wrapper of `axpy`
2429
#[pyfn(m, "axpy")]
25-
fn axpy_py(py: Python, a: f64, x: PyArray, y: PyArray) -> PyResult<PyArray> {
30+
fn axpy_py(py: Python, a: f64, x: &PyArray, y: &PyArray) -> PyResult<PyArray> {
2631
let np = PyArrayModule::import(py)?;
27-
let x = x.as_array().into_pyresult(py, "x must be f64 array")?;
28-
let y = y.as_array().into_pyresult(py, "y must be f64 array")?;
32+
let x = x.as_array().into_pyresult("x must be f64 array")?;
33+
let y = y.as_array().into_pyresult("y must be f64 array")?;
2934
Ok(axpy(a, x, y).into_pyarray(py, &np))
3035
}
3136

3237
// wrapper of `mult`
3338
#[pyfn(m, "mult")]
34-
fn mult_py(py: Python, a: f64, x: PyArray) -> PyResult<PyObject> {
35-
let x = x.as_array_mut().into_pyresult(py, "x must be f64 array")?;
39+
fn mult_py(_py: Python, a: f64, x: &PyArray) -> PyResult<()> {
40+
let x = x.as_array_mut().into_pyresult("x must be f64 array")?;
3641
mult(a, x);
37-
Ok(py.None()) // Python function must returns
42+
Ok(())
3843
}
3944

4045
Ok(())

example/rust_ext/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from ._rust_ext import *
1+
from numpy_rust_ext._rust_ext import *

example/setup.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,38 @@
11

2-
from setuptools import setup
2+
import os
3+
import subprocess
4+
import sys
5+
from setuptools import find_packages, setup
6+
from setuptools.command.test import test as TestCommand
37
from setuptools_rust import RustExtension, Binding
48

5-
setup(name='rust_ext',
6-
version='1.0',
7-
rust_extensions=[
8-
RustExtension('rust_ext._rust_ext', 'extensions/Cargo.toml',
9-
binding=Binding.RustCPython)],
10-
packages=['rust_ext'],
11-
zip_safe=False)
9+
10+
class CmdTest(TestCommand):
11+
def run(self):
12+
self.run_command("test_rust")
13+
test_files = os.listdir('./tests')
14+
ok = 0
15+
for f in test_files:
16+
_, ext = os.path.splitext(f)
17+
if ext == '.py':
18+
res = subprocess.call([sys.executable, f], cwd='./tests')
19+
ok = ok | res
20+
sys.exit(res)
21+
22+
23+
setup_requires = ['setuptools-rust>=0.6.0']
24+
install_requires = ['numpy']
25+
test_requires = install_requires
26+
27+
setup(
28+
name='rust_ext',
29+
version='0.1.0',
30+
description='Example of python-extension using rust-numpy',
31+
rust_extensions=[RustExtension('numpy_rust_ext._rust_ext', 'extensions/Cargo.toml')],
32+
install_requires=install_requires,
33+
setup_requires=setup_requires,
34+
test_requires=test_requires,
35+
packages=find_packages(),
36+
zip_safe=False,
37+
cmdclass=dict(test=CmdTest)
38+
)

example/test.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

example/tests/test_ext.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
from rust_ext import axpy, mult
3+
import unittest
4+
5+
class TestExt(unittest.TestCase):
6+
"""Test class for rust functions
7+
"""
8+
9+
def test_axpy(self):
10+
x = np.array([1.0, 2.0, 3.0])
11+
y = np.array([3.0, 3.0, 3.0])
12+
z = axpy(3.0, x, y)
13+
np.testing.assert_array_almost_equal(z, np.array([6.0, 9.0, 12.0]))
14+
15+
def test_mult(self):
16+
x = np.array([1.0, 2.0, 3.0])
17+
mult(3.0, x)
18+
np.testing.assert_array_almost_equal(x, np.array([3.0, 6.0, 9.0]))
19+
20+
21+
if __name__ == "__main__":
22+
unittest.main()

src/array.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ use super::*;
1414
pub struct PyArray(PyObject);
1515
pyobject_native_type!(PyArray, *npyffi::PyArray_Type_Ptr, npyffi::PyArray_Check);
1616

17+
impl IntoPyObject for PyArray {
18+
fn into_object(self, _py: Python) -> PyObject {
19+
self.0
20+
}
21+
}
22+
1723
impl PyArray {
1824
pub fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
1925
self.as_ptr() as _

0 commit comments

Comments
 (0)