Skip to content

Commit a5e0b28

Browse files
authored
Merge pull request #173 from PyO3/einsum
Add bindings to inner, dot and einsum
2 parents fdf7c0a + eb99fe7 commit a5e0b28

File tree

5 files changed

+231
-11
lines changed

5 files changed

+231
-11
lines changed

src/array.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,9 @@ use crate::slice_box::SliceBox;
6464
/// # Example
6565
/// ```
6666
/// # #[macro_use] extern crate ndarray;
67-
/// use pyo3::{GILGuard, Python};
6867
/// use numpy::PyArray;
6968
/// use ndarray::Array;
70-
/// Python::with_gil(|py| {
69+
/// pyo3::Python::with_gil(|py| {
7170
/// let pyarray = PyArray::arange(py, 0., 4., 1.).reshape([2, 2]).unwrap();
7271
/// let array = array![[3., 4.], [5., 6.]];
7372
/// assert_eq!(
@@ -78,6 +77,8 @@ use crate::slice_box::SliceBox;
7877
/// ```
7978
pub struct PyArray<T, D>(PyAny, PhantomData<T>, PhantomData<D>);
8079

80+
/// Zero-dimensional array.
81+
pub type PyArray0<T> = PyArray<T, Ix0>;
8182
/// One-dimensional array.
8283
pub type PyArray1<T> = PyArray<T, Ix1>;
8384
/// Two-dimensional array.
@@ -218,10 +219,9 @@ impl<T, D> PyArray<T, D> {
218219
///
219220
/// # Example
220221
/// ```
221-
/// use pyo3::{GILGuard, Python, Py};
222222
/// use numpy::PyArray1;
223-
/// fn return_py_array() -> Py<PyArray1<i32>> {
224-
/// Python::with_gil(|py| PyArray1::zeros(py, [5], false).to_owned())
223+
/// fn return_py_array() -> pyo3::Py<PyArray1<i32>> {
224+
/// pyo3::Python::with_gil(|py| PyArray1::zeros(py, [5], false).to_owned())
225225
/// }
226226
/// let array = return_py_array();
227227
/// pyo3::Python::with_gil(|py| {
@@ -594,8 +594,6 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
594594
}
595595

596596
/// Get dynamic dimensioned array from fixed dimension array.
597-
///
598-
/// See [get](#method.get) for usage.
599597
pub fn to_dyn(&self) -> &PyArray<T, IxDyn> {
600598
let python = self.py();
601599
unsafe { PyArray::from_borrowed_ptr(python, self.as_ptr()) }
@@ -710,6 +708,15 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
710708
}
711709
}
712710

711+
impl<T: Copy + Element> PyArray<T, Ix0> {
712+
/// Get the element of zero-dimensional PyArray.
713+
///
714+
/// See [inner](../fn.inner.html) for example.
715+
pub fn item(&self) -> T {
716+
unsafe { *self.data() }
717+
}
718+
}
719+
713720
impl<T: Element> PyArray<T, Ix1> {
714721
/// Construct one-dimension PyArray from slice.
715722
///

src/lib.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ pub mod npyffi;
4343
pub mod npyiter;
4444
mod readonly;
4545
mod slice_box;
46+
mod sum_products;
4647

4748
pub use crate::array::{
48-
get_array_module, PyArray, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5, PyArray6,
49-
PyArrayDyn,
49+
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
50+
PyArray6, PyArrayDyn,
5051
};
5152
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
5253
pub use crate::dtype::{c32, c64, DataType, Element, PyArrayDescr};
@@ -59,7 +60,8 @@ pub use crate::readonly::{
5960
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
6061
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn,
6162
};
62-
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
63+
pub use crate::sum_products::{dot, einsum_impl, inner};
64+
pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
6365

6466
/// Test readme
6567
#[doc(hidden)]
@@ -72,3 +74,29 @@ pub mod doc_test {
7274
}
7375
doc_comment!(include_str!("../README.md"), readme);
7476
}
77+
78+
/// Create a [PyArray](./array/struct.PyArray.html) with one, two or three dimensions.
79+
/// This macro is backed by
80+
/// [`ndarray::array`](https://docs.rs/ndarray/latest/ndarray/macro.array.html).
81+
///
82+
/// # Example
83+
/// ```
84+
/// pyo3::Python::with_gil(|py| {
85+
/// let array = numpy::pyarray![py, [1, 2], [3, 4]];
86+
/// assert_eq!(
87+
/// array.readonly().as_array(),
88+
/// ndarray::array![[1, 2], [3, 4]]
89+
/// );
90+
/// });
91+
#[macro_export]
92+
macro_rules! pyarray {
93+
($py: ident, $([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*) => {{
94+
$crate::IntoPyArray::into_pyarray($crate::array![$([$([$($x,)*],)*],)*], $py)
95+
}};
96+
($py: ident, $([$($x:expr),* $(,)*]),+ $(,)*) => {{
97+
$crate::IntoPyArray::into_pyarray($crate::array![$([$($x,)*],)*], $py)
98+
}};
99+
($py: ident, $($x:expr),* $(,)*) => {{
100+
$crate::IntoPyArray::into_pyarray($crate::array![$($x,)*], $py)
101+
}};
102+
}

src/npyffi/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ impl PyArrayAPI {
286286
impl_api![273; PyArray_ResultType(narrs: npy_intp, arr: *mut *mut PyArrayObject, ndtypes: npy_intp, dtypes: *mut *mut PyArray_Descr) -> *mut PyArray_Descr];
287287
impl_api![274; PyArray_CanCastArrayTo(arr: *mut PyArrayObject, to: *mut PyArray_Descr, casting: NPY_CASTING) -> npy_bool];
288288
impl_api![275; PyArray_CanCastTypeTo(from: *mut PyArray_Descr, to: *mut PyArray_Descr, casting: NPY_CASTING) -> npy_bool];
289-
impl_api![276; PyArray_EinsteinSum(subscripts: *mut c_char, nop: npy_intp, op_in: *mut *mut PyArrayObject, dtype: *mut PyArray_Descr, order: NPY_ORDER, casting: NPY_CASTING, out: *mut PyArrayObject) -> *mut PyArrayObject];
289+
impl_api![276; PyArray_EinsteinSum(subscripts: *mut c_char, nop: npy_intp, op_in: *mut *mut PyArrayObject, dtype: *mut PyArray_Descr, order: NPY_ORDER, casting: NPY_CASTING, out: *mut PyArrayObject) -> *mut PyObject];
290290
impl_api![277; PyArray_NewLikeArray(prototype: *mut PyArrayObject, order: NPY_ORDER, dtype: *mut PyArray_Descr, subok: c_int) -> *mut PyObject];
291291
impl_api![278; PyArray_GetArrayParamsFromObject(op: *mut PyObject, requested_dtype: *mut PyArray_Descr, writeable: npy_bool, out_dtype: *mut *mut PyArray_Descr, out_ndim: *mut c_int, out_dims: *mut npy_intp, out_arr: *mut *mut PyArrayObject, context: *mut PyObject) -> c_int];
292292
impl_api![279; PyArray_ConvertClipmodeSequence(object: *mut PyObject, modes: *mut NPY_CLIPMODE, n: c_int) -> c_int];

src/sum_products.rs

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
use crate::npyffi::{NPY_CASTING, NPY_ORDER};
2+
use crate::{Element, PyArray, PY_ARRAY_API};
3+
use ndarray::{Dimension, IxDyn};
4+
use pyo3::{AsPyPointer, FromPyPointer, PyAny, PyNativeType, PyResult};
5+
use std::ffi::CStr;
6+
7+
/// Return the inner product of two arrays.
8+
///
9+
/// # Example
10+
/// ```
11+
/// pyo3::Python::with_gil(|py| {
12+
/// let array = numpy::pyarray![py, 1, 2, 3];
13+
/// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
14+
/// assert_eq!(inner.item(), 14);
15+
/// });
16+
/// ```
17+
pub fn inner<'py, T, DIN1, DIN2, DOUT>(
18+
array1: &'py PyArray<T, DIN1>,
19+
array2: &'py PyArray<T, DIN2>,
20+
) -> PyResult<&'py PyArray<T, DOUT>>
21+
where
22+
DIN1: Dimension,
23+
DIN2: Dimension,
24+
DOUT: Dimension,
25+
T: Element,
26+
{
27+
let obj = unsafe {
28+
let result = PY_ARRAY_API.PyArray_InnerProduct(array1.as_ptr(), array2.as_ptr());
29+
PyAny::from_owned_ptr_or_err(array1.py(), result)?
30+
};
31+
obj.extract()
32+
}
33+
34+
/// Return the dot product of two arrays.
35+
///
36+
/// # Example
37+
/// ```
38+
/// pyo3::Python::with_gil(|py| {
39+
/// let a = numpy::pyarray![py, [1, 0], [0, 1]];
40+
/// let b = numpy::pyarray![py, [4, 1], [2, 2]];
41+
/// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
42+
/// assert_eq!(
43+
/// dot.readonly().as_array(),
44+
/// ndarray::array![[4, 1], [2, 2]]
45+
/// );
46+
/// });
47+
/// ```
48+
pub fn dot<'py, T, DIN1, DIN2, DOUT>(
49+
array1: &'py PyArray<T, DIN1>,
50+
array2: &'py PyArray<T, DIN2>,
51+
) -> PyResult<&'py PyArray<T, DOUT>>
52+
where
53+
DIN1: Dimension,
54+
DIN2: Dimension,
55+
DOUT: Dimension,
56+
T: Element,
57+
{
58+
let obj = unsafe {
59+
let result = PY_ARRAY_API.PyArray_MatrixProduct(array1.as_ptr(), array2.as_ptr());
60+
PyAny::from_owned_ptr_or_err(array1.py(), result)?
61+
};
62+
obj.extract()
63+
}
64+
65+
/// Return the Einstein summation convention of given tensors.
66+
///
67+
/// We also provide the [einsum macro](./macro.einsum.html).
68+
pub fn einsum_impl<'py, T, DOUT>(
69+
subscripts: &str,
70+
arrays: &[&'py PyArray<T, IxDyn>],
71+
) -> PyResult<&'py PyArray<T, DOUT>>
72+
where
73+
DOUT: Dimension,
74+
T: Element,
75+
{
76+
let subscripts: std::borrow::Cow<CStr> = if subscripts.ends_with("\0") {
77+
CStr::from_bytes_with_nul(subscripts.as_bytes())
78+
.unwrap()
79+
.into()
80+
} else {
81+
std::ffi::CString::new(subscripts).unwrap().into()
82+
};
83+
let obj = unsafe {
84+
let result = PY_ARRAY_API.PyArray_EinsteinSum(
85+
subscripts.as_ptr() as _,
86+
arrays.len() as _,
87+
arrays.as_ptr() as _,
88+
std::ptr::null_mut(),
89+
NPY_ORDER::NPY_KEEPORDER,
90+
NPY_CASTING::NPY_NO_CASTING,
91+
std::ptr::null_mut(),
92+
);
93+
PyAny::from_owned_ptr_or_err(arrays[0].py(), result)?
94+
};
95+
obj.extract()
96+
}
97+
98+
/// Return the Einstein summation convention of given tensors.
99+
///
100+
/// For more about the Einstein summation convention, you may reffer to
101+
/// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
102+
///
103+
/// # Example
104+
/// ```
105+
/// pyo3::Python::with_gil(|py| {
106+
/// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
107+
/// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
108+
/// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
109+
/// assert_eq!(
110+
/// einsum.readonly().as_array(),
111+
/// ndarray::array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
112+
/// );
113+
/// });
114+
/// ```
115+
#[macro_export]
116+
macro_rules! einsum {
117+
($subscripts: literal $(,$array: ident)+ $(,)*) => {{
118+
let arrays = [$($array.to_dyn(),)+];
119+
unsafe { $crate::einsum_impl(concat!($subscripts, "\0"), &arrays) }
120+
}};
121+
}

tests/sum_products.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use numpy::{array, dot, einsum, inner, pyarray, PyArray1, PyArray2};
2+
3+
#[test]
4+
fn test_dot() {
5+
pyo3::Python::with_gil(|py| {
6+
let a = pyarray![py, [1, 0], [0, 1]];
7+
let b = pyarray![py, [4, 1], [2, 2]];
8+
let c = dot(a, b).unwrap();
9+
assert_eq!(c.readonly().as_array(), array![[4, 1], [2, 2]]);
10+
let a = pyarray![py, 1, 2, 3];
11+
let err: pyo3::PyResult<&PyArray2<_>> = dot(a, b);
12+
let err = err.unwrap_err();
13+
assert!(err.to_string().contains("not aligned"), "{}", err);
14+
})
15+
}
16+
17+
#[test]
18+
fn test_inner() {
19+
pyo3::Python::with_gil(|py| {
20+
let a = pyarray![py, 1, 2, 3];
21+
let b = pyarray![py, 0, 1, 0];
22+
let c = inner(a, b).unwrap();
23+
assert_eq!(c.readonly().as_array(), ndarray::arr0(2));
24+
let a = pyarray![py, [1, 0], [0, 1]];
25+
let b = pyarray![py, [4, 1], [2, 2]];
26+
let c = inner(a, b).unwrap();
27+
assert_eq!(c.readonly().as_array(), array![[4, 2], [1, 2]]);
28+
let a = pyarray![py, 1, 2, 3];
29+
let err: pyo3::PyResult<&PyArray2<_>> = inner(a, b);
30+
let err = err.unwrap_err();
31+
assert!(err.to_string().contains("not aligned"), "{}", err);
32+
})
33+
}
34+
35+
#[test]
36+
fn test_einsum() {
37+
pyo3::Python::with_gil(|py| {
38+
let a = PyArray1::<i32>::arange(py, 0, 25, 1)
39+
.reshape([5, 5])
40+
.unwrap();
41+
let b = pyarray![py, 0, 1, 2, 3, 4];
42+
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];
43+
assert_eq!(
44+
einsum!("ii", a).unwrap().readonly().as_array(),
45+
ndarray::arr0(60)
46+
);
47+
assert_eq!(
48+
einsum!("ii->i", a).unwrap().readonly().as_array(),
49+
array![0, 6, 12, 18, 24],
50+
);
51+
assert_eq!(
52+
einsum!("ij->i", a).unwrap().readonly().as_array(),
53+
array![10, 35, 60, 85, 110],
54+
);
55+
assert_eq!(
56+
einsum!("ji", c).unwrap().readonly().as_array(),
57+
array![[0, 3], [1, 4], [2, 5]],
58+
);
59+
assert_eq!(
60+
einsum!("ij,j", a, b).unwrap().readonly().as_array(),
61+
array![30, 80, 130, 180, 230],
62+
);
63+
})
64+
}

0 commit comments

Comments
 (0)