Skip to content

Add partition(similar to numpy.partition) #1498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3184,6 +3184,90 @@ impl<A, D: Dimension> ArrayRef<A, D>
f(&*prev, &mut *curr)
});
}

/// Return a partitioned copy of the array.
///
/// Creates a copy of the array and partially sorts it around the k-th element along the given axis.
/// The k-th element will be in its sorted position, with:
/// - All elements smaller than the k-th element to its left
/// - All elements equal or greater than the k-th element to its right
/// - The ordering within each partition is undefined
///
/// # Parameters
///
/// * `kth` - Index to partition by. The k-th element will be in its sorted position.
/// * `axis` - Axis along which to partition.
///
/// # Returns
///
/// A new array of the same shape and type as the input array, with elements partitioned.
///
/// # Examples
///
/// ```
/// use ndarray::prelude::*;
///
/// let a = array![7, 1, 5, 2, 6, 0, 3, 4];
/// let p = a.partition(3, Axis(0));
///
/// // The element at position 3 is now 3, with smaller elements to the left
/// // and greater elements to the right
/// assert_eq!(p[3], 3);
/// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3));
/// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
/// ```
pub fn partition(&self, kth: usize, axis: Axis) -> Array<A, D>
where A: Clone + Ord
{
// Check if axis is valid
if axis.index() >= self.ndim() {
panic!("axis {} is out of bounds for array of dimension {}", axis.index(), self.ndim());
}

// Check if kth is valid
if kth >= self.len_of(axis) {
panic!("kth {} is out of bounds for axis {} with length {}", kth, axis.index(), self.len_of(axis));
}

// If the array is empty, return a copy
if self.is_empty() {
return self.to_owned();
}

// If the array is 1D, handle as a special case
if self.ndim() == 1 {
let mut result = self.to_owned();
if let Some(slice) = result.as_slice_mut() {
slice.select_nth_unstable(kth);
}
return result;
}

// For multi-dimensional arrays, partition along the specified axis
let mut result = self.to_owned();

// Process each lane with partitioning
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
// For each lane, we need a contiguous slice to partition
if let Some(slice) = lane.as_slice_mut() {
// If the lane's memory is contiguous, use select_nth_unstable directly
slice.select_nth_unstable(kth);
} else {
// For non-contiguous memory, create a temporary vector
let mut values = lane.iter().cloned().collect::<Vec<_>>();

// Partition the vector
values.select_nth_unstable(kth);

// Copy values back to the lane
Zip::from(&mut lane).and(&values).for_each(|dest, src| {
*dest = src.clone();
});
}
});

result
}
}

/// Transmute from A to B.
Expand Down Expand Up @@ -3277,4 +3361,83 @@ mod tests
let _a2 = a.clone();
assert_first!(a);
}

#[test]
fn test_partition_1d()
{
let a = array![7, 1, 5, 2, 6, 0, 3, 4];
let kth = 3;
let p = a.partition(kth, Axis(0));

// The element at position kth is in its sorted position
assert_eq!(p[kth], 3);

// All elements to the left are less than or equal to the kth element
for i in 0..kth {
assert!(p[i] <= p[kth]);
}

// All elements to the right are greater than or equal to the kth element
for i in (kth + 1)..p.len() {
assert!(p[i] >= p[kth]);
}
}

#[test]
fn test_partition_2d()
{
let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]];

// Partition along axis 0 (rows)
let p_axis0 = a.partition(1, Axis(0));

// For each column, the middle row should be in its sorted position
for col in 0..3 {
assert!(p_axis0[[0, col]] <= p_axis0[[1, col]]);
assert!(p_axis0[[2, col]] >= p_axis0[[1, col]]);
}

// Partition along axis 1 (columns)
let p_axis1 = a.partition(1, Axis(1));

// For each row, the middle column should be in its sorted position
for row in 0..3 {
assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]]);
assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]]);
}
}

#[test]
fn test_partition_3d()
{
let a = arr3(&[[[9, 2], [3, 4]], [[5, 6], [7, 8]]]);

// Partition along the last axis
let p = a.partition(0, Axis(2));

// Check the partitioning along the last axis
for i in 0..2 {
for j in 0..2 {
assert!(p[[i, j, 0]] <= p[[i, j, 1]]);
}
}
}

#[test]
#[should_panic]
fn test_partition_invalid_kth()
{
let a = array![1, 2, 3, 4];
// This should panic because kth=4 is out of bounds
let _ = a.partition(4, Axis(0));
}

#[test]
#[should_panic]
fn test_partition_invalid_axis()
{
let a = array![1, 2, 3, 4];
// This should panic because axis=1 is out of bounds for a 1D array
let _ = a.partition(0, Axis(1));
}
}
Loading