|
9 | 9 | use alloc::vec;
|
10 | 10 | #[cfg(not(feature = "std"))]
|
11 | 11 | use alloc::vec::Vec;
|
| 12 | +use meshgrid_impl::Meshgrid; |
12 | 13 | #[allow(unused_imports)]
|
13 | 14 | use std::compile_error;
|
14 | 15 | use std::mem::{forget, size_of};
|
@@ -336,3 +337,407 @@ pub fn rcarr3<A: Clone, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> A
|
336 | 337 | {
|
337 | 338 | arr3(xs).into_shared()
|
338 | 339 | }
|
| 340 | + |
| 341 | +/// The indexing order for [`meshgrid`]; see there for more details. |
| 342 | +/// |
| 343 | +/// Controls whether the first argument to `meshgrid` will fill the rows or columns of the outputs. |
| 344 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 345 | +pub enum MeshIndex |
| 346 | +{ |
| 347 | + /// Cartesian indexing. |
| 348 | + /// |
| 349 | + /// The first argument of `meshgrid` will repeat over the columns of the output. |
| 350 | + /// |
| 351 | + /// Note: this is the default in `numpy`. |
| 352 | + XY, |
| 353 | + /// Matrix indexing. |
| 354 | + /// |
| 355 | + /// The first argument of `meshgrid` will repeat over the rows of the output. |
| 356 | + IJ, |
| 357 | +} |
| 358 | + |
| 359 | +mod meshgrid_impl |
| 360 | +{ |
| 361 | + use super::MeshIndex; |
| 362 | + use crate::extension::nonnull::nonnull_debug_checked_from_ptr; |
| 363 | + use crate::{ |
| 364 | + ArrayBase, |
| 365 | + ArrayRef1, |
| 366 | + ArrayView, |
| 367 | + ArrayView2, |
| 368 | + ArrayView3, |
| 369 | + ArrayView4, |
| 370 | + ArrayView5, |
| 371 | + ArrayView6, |
| 372 | + Axis, |
| 373 | + Data, |
| 374 | + Dim, |
| 375 | + IntoDimension, |
| 376 | + Ix1, |
| 377 | + LayoutRef1, |
| 378 | + }; |
| 379 | + |
| 380 | + /// Construct the correct strides for the `idx`-th entry into meshgrid |
| 381 | + fn construct_strides<A, const N: usize>( |
| 382 | + arr: &LayoutRef1<A>, idx: usize, indexing: MeshIndex, |
| 383 | + ) -> <[usize; N] as IntoDimension>::Dim |
| 384 | + where [usize; N]: IntoDimension |
| 385 | + { |
| 386 | + let mut ret = [0; N]; |
| 387 | + if idx < 2 && indexing == MeshIndex::XY { |
| 388 | + ret[1 - idx] = arr.stride_of(Axis(0)) as usize; |
| 389 | + } else { |
| 390 | + ret[idx] = arr.stride_of(Axis(0)) as usize; |
| 391 | + } |
| 392 | + Dim(ret) |
| 393 | + } |
| 394 | + |
| 395 | + /// Construct the correct shape for the `idx`-th entry into meshgrid |
| 396 | + fn construct_shape<A, const N: usize>( |
| 397 | + arrays: [&LayoutRef1<A>; N], indexing: MeshIndex, |
| 398 | + ) -> <[usize; N] as IntoDimension>::Dim |
| 399 | + where [usize; N]: IntoDimension |
| 400 | + { |
| 401 | + let mut ret = arrays.map(|a| a.len()); |
| 402 | + if indexing == MeshIndex::XY { |
| 403 | + ret.swap(0, 1); |
| 404 | + } |
| 405 | + Dim(ret) |
| 406 | + } |
| 407 | + |
| 408 | + /// A trait to encapsulate static dispatch for [`meshgrid`](super::meshgrid); see there for more details. |
| 409 | + /// |
| 410 | + /// The inputs should always be some sort of 1D array. |
| 411 | + /// The outputs should always be ND arrays where N is the number of inputs. |
| 412 | + /// |
| 413 | + /// Where possible, this trait tries to return array views rather than allocating additional memory. |
| 414 | + pub trait Meshgrid |
| 415 | + { |
| 416 | + type Output; |
| 417 | + |
| 418 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output; |
| 419 | + } |
| 420 | + |
| 421 | + macro_rules! meshgrid_body { |
| 422 | + ($count:literal, $indexing:expr, $(($arr:expr, $idx:literal)),+) => { |
| 423 | + { |
| 424 | + let shape = construct_shape([$($arr),+], $indexing); |
| 425 | + ( |
| 426 | + $({ |
| 427 | + let strides = construct_strides::<_, $count>($arr, $idx, $indexing); |
| 428 | + unsafe { ArrayView::new(nonnull_debug_checked_from_ptr($arr.as_ptr() as *mut A), shape, strides) } |
| 429 | + }),+ |
| 430 | + ) |
| 431 | + } |
| 432 | + }; |
| 433 | + } |
| 434 | + |
| 435 | + impl<'a, 'b, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>) |
| 436 | + { |
| 437 | + type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>); |
| 438 | + |
| 439 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 440 | + { |
| 441 | + meshgrid_body!(2, indexing, (arrays.0, 0), (arrays.1, 1)) |
| 442 | + } |
| 443 | + } |
| 444 | + |
| 445 | + impl<'a, 'b, S1, S2, A: 'b + 'a> Meshgrid for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>) |
| 446 | + where |
| 447 | + S1: Data<Elem = A>, |
| 448 | + S2: Data<Elem = A>, |
| 449 | + { |
| 450 | + type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>); |
| 451 | + |
| 452 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 453 | + { |
| 454 | + Meshgrid::meshgrid((&**arrays.0, &**arrays.1), indexing) |
| 455 | + } |
| 456 | + } |
| 457 | + |
| 458 | + impl<'a, 'b, 'c, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>) |
| 459 | + { |
| 460 | + type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>); |
| 461 | + |
| 462 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 463 | + { |
| 464 | + meshgrid_body!(3, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2)) |
| 465 | + } |
| 466 | + } |
| 467 | + |
| 468 | + impl<'a, 'b, 'c, S1, S2, S3, A: 'b + 'a + 'c> Meshgrid |
| 469 | + for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>, &'c ArrayBase<S3, Ix1>) |
| 470 | + where |
| 471 | + S1: Data<Elem = A>, |
| 472 | + S2: Data<Elem = A>, |
| 473 | + S3: Data<Elem = A>, |
| 474 | + { |
| 475 | + type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>); |
| 476 | + |
| 477 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 478 | + { |
| 479 | + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2), indexing) |
| 480 | + } |
| 481 | + } |
| 482 | + |
| 483 | + impl<'a, 'b, 'c, 'd, A> Meshgrid for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>, &'d ArrayRef1<A>) |
| 484 | + { |
| 485 | + type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>); |
| 486 | + |
| 487 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 488 | + { |
| 489 | + meshgrid_body!(4, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3)) |
| 490 | + } |
| 491 | + } |
| 492 | + |
| 493 | + impl<'a, 'b, 'c, 'd, S1, S2, S3, S4, A: 'a + 'b + 'c + 'd> Meshgrid |
| 494 | + for (&'a ArrayBase<S1, Ix1>, &'b ArrayBase<S2, Ix1>, &'c ArrayBase<S3, Ix1>, &'d ArrayBase<S4, Ix1>) |
| 495 | + where |
| 496 | + S1: Data<Elem = A>, |
| 497 | + S2: Data<Elem = A>, |
| 498 | + S3: Data<Elem = A>, |
| 499 | + S4: Data<Elem = A>, |
| 500 | + { |
| 501 | + type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>); |
| 502 | + |
| 503 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 504 | + { |
| 505 | + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3), indexing) |
| 506 | + } |
| 507 | + } |
| 508 | + |
| 509 | + impl<'a, 'b, 'c, 'd, 'e, A> Meshgrid |
| 510 | + for (&'a ArrayRef1<A>, &'b ArrayRef1<A>, &'c ArrayRef1<A>, &'d ArrayRef1<A>, &'e ArrayRef1<A>) |
| 511 | + { |
| 512 | + type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>); |
| 513 | + |
| 514 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 515 | + { |
| 516 | + meshgrid_body!(5, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4)) |
| 517 | + } |
| 518 | + } |
| 519 | + |
| 520 | + impl<'a, 'b, 'c, 'd, 'e, S1, S2, S3, S4, S5, A: 'a + 'b + 'c + 'd + 'e> Meshgrid |
| 521 | + for ( |
| 522 | + &'a ArrayBase<S1, Ix1>, |
| 523 | + &'b ArrayBase<S2, Ix1>, |
| 524 | + &'c ArrayBase<S3, Ix1>, |
| 525 | + &'d ArrayBase<S4, Ix1>, |
| 526 | + &'e ArrayBase<S5, Ix1>, |
| 527 | + ) |
| 528 | + where |
| 529 | + S1: Data<Elem = A>, |
| 530 | + S2: Data<Elem = A>, |
| 531 | + S3: Data<Elem = A>, |
| 532 | + S4: Data<Elem = A>, |
| 533 | + S5: Data<Elem = A>, |
| 534 | + { |
| 535 | + type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>); |
| 536 | + |
| 537 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 538 | + { |
| 539 | + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4), indexing) |
| 540 | + } |
| 541 | + } |
| 542 | + |
| 543 | + impl<'a, 'b, 'c, 'd, 'e, 'f, A> Meshgrid |
| 544 | + for ( |
| 545 | + &'a ArrayRef1<A>, |
| 546 | + &'b ArrayRef1<A>, |
| 547 | + &'c ArrayRef1<A>, |
| 548 | + &'d ArrayRef1<A>, |
| 549 | + &'e ArrayRef1<A>, |
| 550 | + &'f ArrayRef1<A>, |
| 551 | + ) |
| 552 | + { |
| 553 | + type Output = ( |
| 554 | + ArrayView6<'a, A>, |
| 555 | + ArrayView6<'b, A>, |
| 556 | + ArrayView6<'c, A>, |
| 557 | + ArrayView6<'d, A>, |
| 558 | + ArrayView6<'e, A>, |
| 559 | + ArrayView6<'f, A>, |
| 560 | + ); |
| 561 | + |
| 562 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 563 | + { |
| 564 | + meshgrid_body!(6, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4), (arrays.5, 5)) |
| 565 | + } |
| 566 | + } |
| 567 | + |
| 568 | + impl<'a, 'b, 'c, 'd, 'e, 'f, S1, S2, S3, S4, S5, S6, A: 'a + 'b + 'c + 'd + 'e + 'f> Meshgrid |
| 569 | + for ( |
| 570 | + &'a ArrayBase<S1, Ix1>, |
| 571 | + &'b ArrayBase<S2, Ix1>, |
| 572 | + &'c ArrayBase<S3, Ix1>, |
| 573 | + &'d ArrayBase<S4, Ix1>, |
| 574 | + &'e ArrayBase<S5, Ix1>, |
| 575 | + &'f ArrayBase<S6, Ix1>, |
| 576 | + ) |
| 577 | + where |
| 578 | + S1: Data<Elem = A>, |
| 579 | + S2: Data<Elem = A>, |
| 580 | + S3: Data<Elem = A>, |
| 581 | + S4: Data<Elem = A>, |
| 582 | + S5: Data<Elem = A>, |
| 583 | + S6: Data<Elem = A>, |
| 584 | + { |
| 585 | + type Output = ( |
| 586 | + ArrayView6<'a, A>, |
| 587 | + ArrayView6<'b, A>, |
| 588 | + ArrayView6<'c, A>, |
| 589 | + ArrayView6<'d, A>, |
| 590 | + ArrayView6<'e, A>, |
| 591 | + ArrayView6<'f, A>, |
| 592 | + ); |
| 593 | + |
| 594 | + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output |
| 595 | + { |
| 596 | + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4, &**arrays.5), indexing) |
| 597 | + } |
| 598 | + } |
| 599 | +} |
| 600 | + |
| 601 | +/// Create coordinate matrices from coordinate vectors. |
| 602 | +/// |
| 603 | +/// Given an N-tuple of 1D coordinate vectors, return an N-tuple of ND coordinate arrays. |
| 604 | +/// This is particularly useful for computing the outputs of functions with N arguments over |
| 605 | +/// regularly spaced grids. |
| 606 | +/// |
| 607 | +/// The `indexing` argument can be controlled by [`MeshIndex`] to support both Cartesian and |
| 608 | +/// matrix indexing. In the two-dimensional case, inputs of length `N` and `M` will create |
| 609 | +/// output arrays of size `(M, N)` when using [`MeshIndex::XY`] and size `(N, M)` when using |
| 610 | +/// [`MeshIndex::IJ`]. |
| 611 | +/// |
| 612 | +/// # Example |
| 613 | +/// ``` |
| 614 | +/// use ndarray::{array, meshgrid, MeshIndex}; |
| 615 | +/// |
| 616 | +/// let arr1 = array![1, 2]; |
| 617 | +/// let arr2 = array![3, 4]; |
| 618 | +/// let arr3 = array![5, 6]; |
| 619 | +/// |
| 620 | +/// // Cartesian indexing |
| 621 | +/// let (res1, res2) = meshgrid((&arr1, &arr2), MeshIndex::XY); |
| 622 | +/// assert_eq!(res1, array![ |
| 623 | +/// [1, 2], |
| 624 | +/// [1, 2], |
| 625 | +/// ]); |
| 626 | +/// assert_eq!(res2, array![ |
| 627 | +/// [3, 3], |
| 628 | +/// [4, 4], |
| 629 | +/// ]); |
| 630 | +/// |
| 631 | +/// // Matrix indexing |
| 632 | +/// let (res1, res2) = meshgrid((&arr1, &arr2), MeshIndex::IJ); |
| 633 | +/// assert_eq!(res1, array![ |
| 634 | +/// [1, 1], |
| 635 | +/// [2, 2], |
| 636 | +/// ]); |
| 637 | +/// assert_eq!(res2, array![ |
| 638 | +/// [3, 4], |
| 639 | +/// [3, 4], |
| 640 | +/// ]); |
| 641 | +/// |
| 642 | +/// let (_, _, res3) = meshgrid((&arr1, &arr2, &arr3), MeshIndex::XY); |
| 643 | +/// assert_eq!(res3, array![ |
| 644 | +/// [[5, 6], |
| 645 | +/// [5, 6]], |
| 646 | +/// [[5, 6], |
| 647 | +/// [5, 6]], |
| 648 | +/// ]); |
| 649 | +/// ``` |
| 650 | +pub fn meshgrid<T: Meshgrid>(arrays: T, indexing: MeshIndex) -> T::Output |
| 651 | +{ |
| 652 | + Meshgrid::meshgrid(arrays, indexing) |
| 653 | +} |
| 654 | + |
| 655 | +#[cfg(test)] |
| 656 | +mod tests |
| 657 | +{ |
| 658 | + use crate::{meshgrid, Axis, MeshIndex}; |
| 659 | + |
| 660 | + use super::s; |
| 661 | + |
| 662 | + #[test] |
| 663 | + fn test_meshgrid2() |
| 664 | + { |
| 665 | + let x = array![1, 2, 3]; |
| 666 | + let y = array![4, 5, 6, 7]; |
| 667 | + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); |
| 668 | + assert_eq!(xx, array![[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]); |
| 669 | + assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7]]); |
| 670 | + |
| 671 | + let (xx, yy) = meshgrid((&x, &y), MeshIndex::IJ); |
| 672 | + assert_eq!(xx, array![[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]); |
| 673 | + assert_eq!(yy, array![[4, 5, 6, 7], [4, 5, 6, 7], [4, 5, 6, 7]]); |
| 674 | + } |
| 675 | + |
| 676 | + #[test] |
| 677 | + fn test_meshgrid3() |
| 678 | + { |
| 679 | + let x = array![1, 2, 3]; |
| 680 | + let y = array![4, 5, 6, 7]; |
| 681 | + let z = array![-1, -2]; |
| 682 | + let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::XY); |
| 683 | + assert_eq!(xx, array![ |
| 684 | + [[1, 1], [2, 2], [3, 3]], |
| 685 | + [[1, 1], [2, 2], [3, 3]], |
| 686 | + [[1, 1], [2, 2], [3, 3]], |
| 687 | + [[1, 1], [2, 2], [3, 3]], |
| 688 | + ]); |
| 689 | + assert_eq!(yy, array![ |
| 690 | + [[4, 4], [4, 4], [4, 4]], |
| 691 | + [[5, 5], [5, 5], [5, 5]], |
| 692 | + [[6, 6], [6, 6], [6, 6]], |
| 693 | + [[7, 7], [7, 7], [7, 7]], |
| 694 | + ]); |
| 695 | + assert_eq!(zz, array![ |
| 696 | + [[-1, -2], [-1, -2], [-1, -2]], |
| 697 | + [[-1, -2], [-1, -2], [-1, -2]], |
| 698 | + [[-1, -2], [-1, -2], [-1, -2]], |
| 699 | + [[-1, -2], [-1, -2], [-1, -2]], |
| 700 | + ]); |
| 701 | + |
| 702 | + let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::IJ); |
| 703 | + assert_eq!(xx, array![ |
| 704 | + [[1, 1], [1, 1], [1, 1], [1, 1]], |
| 705 | + [[2, 2], [2, 2], [2, 2], [2, 2]], |
| 706 | + [[3, 3], [3, 3], [3, 3], [3, 3]], |
| 707 | + ]); |
| 708 | + assert_eq!(yy, array![ |
| 709 | + [[4, 4], [5, 5], [6, 6], [7, 7]], |
| 710 | + [[4, 4], [5, 5], [6, 6], [7, 7]], |
| 711 | + [[4, 4], [5, 5], [6, 6], [7, 7]], |
| 712 | + ]); |
| 713 | + assert_eq!(zz, array![ |
| 714 | + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], |
| 715 | + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], |
| 716 | + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], |
| 717 | + ]); |
| 718 | + } |
| 719 | + |
| 720 | + #[test] |
| 721 | + fn test_meshgrid_from_offset() |
| 722 | + { |
| 723 | + let x = array![1, 2, 3]; |
| 724 | + let x = x.slice(s![1..]); |
| 725 | + let y = array![4, 5, 6]; |
| 726 | + let y = y.slice(s![1..]); |
| 727 | + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); |
| 728 | + assert_eq!(xx, array![[2, 3], [2, 3]]); |
| 729 | + assert_eq!(yy, array![[5, 5], [6, 6]]); |
| 730 | + } |
| 731 | + |
| 732 | + #[test] |
| 733 | + fn test_meshgrid_neg_stride() |
| 734 | + { |
| 735 | + let x = array![1, 2, 3]; |
| 736 | + let x = x.slice(s![..;-1]); |
| 737 | + assert!(x.stride_of(Axis(0)) < 0); // Setup for test |
| 738 | + let y = array![4, 5, 6]; |
| 739 | + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); |
| 740 | + assert_eq!(xx, array![[3, 2, 1], [3, 2, 1], [3, 2, 1]]); |
| 741 | + assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6]]); |
| 742 | + } |
| 743 | +} |
0 commit comments