Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 388c784

Browse files
committedFeb 1, 2025
Adds a Numpy-equivalent meshgrid function
1 parent 93dfb38 commit 388c784

File tree

1 file changed

+405
-0
lines changed

1 file changed

+405
-0
lines changed
 

‎src/free_functions.rs

Lines changed: 405 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use alloc::vec;
1010
#[cfg(not(feature = "std"))]
1111
use alloc::vec::Vec;
12+
use meshgrid_impl::Meshgrid;
1213
#[allow(unused_imports)]
1314
use std::compile_error;
1415
use std::mem::{forget, size_of};
@@ -336,3 +337,407 @@ pub fn rcarr3<A: Clone, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> A
336337
{
337338
arr3(xs).into_shared()
338339
}
340+
341+
/// The indexing order for [`meshgrid`]; see there for more details.
342+
///
343+
/// Controls whether the first argument to `meshgrid` will fill the rows or columns of the outputs.
344+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
345+
pub enum MeshIndex
346+
{
347+
/// Cartesian indexing.
348+
///
349+
/// The first argument of `meshgrid` will repeat over the columns of the output.
350+
///
351+
/// Note: this is the default in `numpy`.
352+
XY,
353+
/// Matrix indexing.
354+
///
355+
/// The first argument of `meshgrid` will repeat over the rows of the output.
356+
IJ,
357+
}
358+
359+
mod meshgrid_impl
360+
{
361+
use super::MeshIndex;
362+
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
363+
use crate::{
364+
ArrayBase,
365+
ArrayRef1,
366+
ArrayView,
367+
ArrayView2,
368+
ArrayView3,
369+
ArrayView4,
370+
ArrayView5,
371+
ArrayView6,
372+
Axis,
373+
Data,
374+
Dim,
375+
IntoDimension,
376+
Ix1,
377+
LayoutRef1,
378+
};
379+
380+
/// Construct the correct strides for the `idx`-th entry into meshgrid
381+
fn construct_strides<A, const N: usize>(
382+
arr: &LayoutRef1<A>, idx: usize, indexing: MeshIndex,
383+
) -> <[usize; N] as IntoDimension>::Dim
384+
where [usize; N]: IntoDimension
385+
{
386+
let mut ret = [0; N];
387+
if idx < 2 && indexing == MeshIndex::XY {
388+
ret[1 - idx] = arr.stride_of(Axis(0)) as usize;
389+
} else {
390+
ret[idx] = arr.stride_of(Axis(0)) as usize;
391+
}
392+
Dim(ret)
393+
}
394+
395+
/// Construct the correct shape for the `idx`-th entry into meshgrid
396+
fn construct_shape<A, const N: usize>(
397+
arrays: [&LayoutRef1<A>; N], indexing: MeshIndex,
398+
) -> <[usize; N] as IntoDimension>::Dim
399+
where [usize; N]: IntoDimension
400+
{
401+
let mut ret = arrays.map(|a| a.len());
402+
if indexing == MeshIndex::XY {
403+
ret.swap(0, 1);
404+
}
405+
Dim(ret)
406+
}
407+
408+
/// A trait to encapsulate static dispatch for [`meshgrid`](super::meshgrid); see there for more details.
409+
///
410+
/// The inputs should always be some sort of 1D array.
411+
/// The outputs should always be ND arrays where N is the number of inputs.
412+
///
413+
/// Where possible, this trait tries to return array views rather than allocating additional memory.
414+
pub trait Meshgrid
415+
{
416+
type Output;
417+
418+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output;
419+
}
420+
421+
macro_rules! meshgrid_body {
422+
($count:literal, $indexing:expr, $(($arr:expr, $idx:literal)),+) => {
423+
{
424+
let shape = construct_shape([$($arr),+], $indexing);
425+
(
426+
$({
427+
let strides = construct_strides::<_, $count>($arr, $idx, $indexing);
428+
unsafe { ArrayView::new(nonnull_debug_checked_from_ptr($arr.as_ptr() as *mut A), shape, strides) }
429+
}),+
430+
)
431+
}
432+
};
433+
}
434+
435+
impl<'a, 'b, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>)
436+
{
437+
type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>);
438+
439+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
440+
{
441+
meshgrid_body!(2, indexing, (arrays.0, 0), (arrays.1, 1))
442+
}
443+
}
444+
445+
impl<'a, 'b, S1, S2, A: 'b + 'a> Meshgrid for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>)
446+
where
447+
S1: Data<Elem = A>,
448+
S2: Data<Elem = A>,
449+
{
450+
type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>);
451+
452+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
453+
{
454+
Meshgrid::meshgrid((&**arrays.0, &**arrays.1), indexing)
455+
}
456+
}
457+
458+
impl<'a, 'b, 'c, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>)
459+
{
460+
type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>);
461+
462+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
463+
{
464+
meshgrid_body!(3, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2))
465+
}
466+
}
467+
468+
impl<'a, 'b, 'c, S1, S2, S3, A: 'b + 'a + 'c> Meshgrid
469+
for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>, &'c ArrayBase<S3, Ix1>)
470+
where
471+
S1: Data<Elem = A>,
472+
S2: Data<Elem = A>,
473+
S3: Data<Elem = A>,
474+
{
475+
type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>);
476+
477+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
478+
{
479+
Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2), indexing)
480+
}
481+
}
482+
483+
impl<'a, 'b, 'c, 'd, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>, &'d ArrayRef1<A>)
484+
{
485+
type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>);
486+
487+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
488+
{
489+
meshgrid_body!(4, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3))
490+
}
491+
}
492+
493+
impl<'a, 'b, 'c, 'd, S1, S2, S3, S4, A: 'a + 'b + 'c + 'd> Meshgrid
494+
for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>, &'c ArrayBase<S3, Ix1>, &'d ArrayBase<S4, Ix1>)
495+
where
496+
S1: Data<Elem = A>,
497+
S2: Data<Elem = A>,
498+
S3: Data<Elem = A>,
499+
S4: Data<Elem = A>,
500+
{
501+
type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>);
502+
503+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
504+
{
505+
Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3), indexing)
506+
}
507+
}
508+
509+
impl<'a, 'b, 'c, 'd, 'e, A> Meshgrid
510+
for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>, &'d ArrayRef1<A>, &'e ArrayRef1<A>)
511+
{
512+
type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>);
513+
514+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
515+
{
516+
meshgrid_body!(5, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4))
517+
}
518+
}
519+
520+
impl<'a, 'b, 'c, 'd, 'e, S1, S2, S3, S4, S5, A: 'a + 'b + 'c + 'd + 'e> Meshgrid
521+
for (
522+
&'a ArrayBase<S1, Ix1>,
523+
&'b ArrayBase<S2, Ix1>,
524+
&'c ArrayBase<S3, Ix1>,
525+
&'d ArrayBase<S4, Ix1>,
526+
&'e ArrayBase<S5, Ix1>,
527+
)
528+
where
529+
S1: Data<Elem = A>,
530+
S2: Data<Elem = A>,
531+
S3: Data<Elem = A>,
532+
S4: Data<Elem = A>,
533+
S5: Data<Elem = A>,
534+
{
535+
type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>);
536+
537+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
538+
{
539+
Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4), indexing)
540+
}
541+
}
542+
543+
impl<'a, 'b, 'c, 'd, 'e, 'f, A> Meshgrid
544+
for (
545+
&'a ArrayRef1<A>,
546+
&'b ArrayRef1<A>,
547+
&'c ArrayRef1<A>,
548+
&'d ArrayRef1<A>,
549+
&'e ArrayRef1<A>,
550+
&'f ArrayRef1<A>,
551+
)
552+
{
553+
type Output = (
554+
ArrayView6<'a, A>,
555+
ArrayView6<'b, A>,
556+
ArrayView6<'c, A>,
557+
ArrayView6<'d, A>,
558+
ArrayView6<'e, A>,
559+
ArrayView6<'f, A>,
560+
);
561+
562+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
563+
{
564+
meshgrid_body!(6, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4), (arrays.5, 5))
565+
}
566+
}
567+
568+
impl<'a, 'b, 'c, 'd, 'e, 'f, S1, S2, S3, S4, S5, S6, A: 'a + 'b + 'c + 'd + 'e + 'f> Meshgrid
569+
for (
570+
&'a ArrayBase<S1, Ix1>,
571+
&'b ArrayBase<S2, Ix1>,
572+
&'c ArrayBase<S3, Ix1>,
573+
&'d ArrayBase<S4, Ix1>,
574+
&'e ArrayBase<S5, Ix1>,
575+
&'f ArrayBase<S6, Ix1>,
576+
)
577+
where
578+
S1: Data<Elem = A>,
579+
S2: Data<Elem = A>,
580+
S3: Data<Elem = A>,
581+
S4: Data<Elem = A>,
582+
S5: Data<Elem = A>,
583+
S6: Data<Elem = A>,
584+
{
585+
type Output = (
586+
ArrayView6<'a, A>,
587+
ArrayView6<'b, A>,
588+
ArrayView6<'c, A>,
589+
ArrayView6<'d, A>,
590+
ArrayView6<'e, A>,
591+
ArrayView6<'f, A>,
592+
);
593+
594+
fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output
595+
{
596+
Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4, &**arrays.5), indexing)
597+
}
598+
}
599+
}
600+
601+
/// Create coordinate matrices from coordinate vectors.
602+
///
603+
/// Given an N-tuple of 1D coordinate vectors, return an N-tuple of ND coordinate arrays.
604+
/// This is particularly useful for computing the outputs of functions with N arguments over
605+
/// regularly spaced grids.
606+
///
607+
/// The `indexing` argument can be controlled by [`MeshIndex`] to support both Cartesian and
608+
/// matrix indexing. In the two-dimensional case, inputs of length `N` and `M` will create
609+
/// output arrays of size `(M, N)` when using [`MeshIndex::XY`] and size `(N, M)` when using
610+
/// [`MeshIndex::IJ`].
611+
///
612+
/// # Example
613+
/// ```
614+
/// use ndarray::{array, meshgrid, MeshIndex};
615+
///
616+
/// let arr1 = array![1, 2];
617+
/// let arr2 = array![3, 4];
618+
/// let arr3 = array![5, 6];
619+
///
620+
/// // Cartesian indexing
621+
/// let (res1, res2) = meshgrid((&arr1, &arr2), MeshIndex::XY);
622+
/// assert_eq!(res1, array![
623+
/// [1, 2],
624+
/// [1, 2],
625+
/// ]);
626+
/// assert_eq!(res2, array![
627+
/// [3, 3],
628+
/// [4, 4],
629+
/// ]);
630+
///
631+
/// // Matrix indexing
632+
/// let (res1, res2) = meshgrid((&arr1, &arr2), MeshIndex::IJ);
633+
/// assert_eq!(res1, array![
634+
/// [1, 1],
635+
/// [2, 2],
636+
/// ]);
637+
/// assert_eq!(res2, array![
638+
/// [3, 4],
639+
/// [3, 4],
640+
/// ]);
641+
///
642+
/// let (_, _, res3) = meshgrid((&arr1, &arr2, &arr3), MeshIndex::XY);
643+
/// assert_eq!(res3, array![
644+
/// [[5, 6],
645+
/// [5, 6]],
646+
/// [[5, 6],
647+
/// [5, 6]],
648+
/// ]);
649+
/// ```
650+
pub fn meshgrid<T: Meshgrid>(arrays: T, indexing: MeshIndex) -> T::Output
651+
{
652+
Meshgrid::meshgrid(arrays, indexing)
653+
}
654+
655+
#[cfg(test)]
656+
mod tests
657+
{
658+
use crate::{meshgrid, Axis, MeshIndex};
659+
660+
use super::s;
661+
662+
#[test]
663+
fn test_meshgrid2()
664+
{
665+
let x = array![1, 2, 3];
666+
let y = array![4, 5, 6, 7];
667+
let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
668+
assert_eq!(xx, array![[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]);
669+
assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7]]);
670+
671+
let (xx, yy) = meshgrid((&x, &y), MeshIndex::IJ);
672+
assert_eq!(xx, array![[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]);
673+
assert_eq!(yy, array![[4, 5, 6, 7], [4, 5, 6, 7], [4, 5, 6, 7]]);
674+
}
675+
676+
#[test]
677+
fn test_meshgrid3()
678+
{
679+
let x = array![1, 2, 3];
680+
let y = array![4, 5, 6, 7];
681+
let z = array![-1, -2];
682+
let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::XY);
683+
assert_eq!(xx, array![
684+
[[1, 1], [2, 2], [3, 3]],
685+
[[1, 1], [2, 2], [3, 3]],
686+
[[1, 1], [2, 2], [3, 3]],
687+
[[1, 1], [2, 2], [3, 3]],
688+
]);
689+
assert_eq!(yy, array![
690+
[[4, 4], [4, 4], [4, 4]],
691+
[[5, 5], [5, 5], [5, 5]],
692+
[[6, 6], [6, 6], [6, 6]],
693+
[[7, 7], [7, 7], [7, 7]],
694+
]);
695+
assert_eq!(zz, array![
696+
[[-1, -2], [-1, -2], [-1, -2]],
697+
[[-1, -2], [-1, -2], [-1, -2]],
698+
[[-1, -2], [-1, -2], [-1, -2]],
699+
[[-1, -2], [-1, -2], [-1, -2]],
700+
]);
701+
702+
let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::IJ);
703+
assert_eq!(xx, array![
704+
[[1, 1], [1, 1], [1, 1], [1, 1]],
705+
[[2, 2], [2, 2], [2, 2], [2, 2]],
706+
[[3, 3], [3, 3], [3, 3], [3, 3]],
707+
]);
708+
assert_eq!(yy, array![
709+
[[4, 4], [5, 5], [6, 6], [7, 7]],
710+
[[4, 4], [5, 5], [6, 6], [7, 7]],
711+
[[4, 4], [5, 5], [6, 6], [7, 7]],
712+
]);
713+
assert_eq!(zz, array![
714+
[[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
715+
[[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
716+
[[-1, -2], [-1, -2], [-1, -2], [-1, -2]],
717+
]);
718+
}
719+
720+
#[test]
721+
fn test_meshgrid_from_offset()
722+
{
723+
let x = array![1, 2, 3];
724+
let x = x.slice(s![1..]);
725+
let y = array![4, 5, 6];
726+
let y = y.slice(s![1..]);
727+
let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
728+
assert_eq!(xx, array![[2, 3], [2, 3]]);
729+
assert_eq!(yy, array![[5, 5], [6, 6]]);
730+
}
731+
732+
#[test]
733+
fn test_meshgrid_neg_stride()
734+
{
735+
let x = array![1, 2, 3];
736+
let x = x.slice(s![..;-1]);
737+
assert!(x.stride_of(Axis(0)) < 0); // Setup for test
738+
let y = array![4, 5, 6];
739+
let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY);
740+
assert_eq!(xx, array![[3, 2, 1], [3, 2, 1], [3, 2, 1]]);
741+
assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6]]);
742+
}
743+
}

0 commit comments

Comments
 (0)
Please sign in to comment.