Skip to content

Add median contraction method #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,62 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
t
}

pub fn median(&self, axes: Axes) -> Tensor<T> {
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
let remaining_axes = all_axes
.clone()
.into_iter()
.filter(|&i| !axes.contains(&i))
.collect::<Vec<_>>();
let remaining_dims = remaining_axes
.iter()
.map(|&i| self.shape[i])
.collect::<Vec<_>>();
let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();

// We resolve to a scalar value
if axes.is_empty() || remaining_dims.is_empty() {
let mut data = self.data.iter().copied().collect::<Vec<T>>();
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = data.len() / 2;
let median = if data.len() % 2 == 0 {
let two = T::one() + T::one();
(data[mid - 1] + data[mid]) / two
} else {
data[mid]
};
return Tensor::new(&Shape::new(vec![1]).unwrap(), &[median]).unwrap();
}

// Create new tensor with right shape
let new_shape = Shape::new(remaining_dims).unwrap();
let remove_shape = Shape::new(removing_dims).unwrap();
let mut t: Tensor<T> = Tensor::zeros(&new_shape);

for target in IndexIterator::new(&new_shape) {
let mut values = Vec::new();
let median_iter = IndexIterator::new(&remove_shape);
for median_index in median_iter {
let mut indices = target.clone();
for (i, &axis) in axes.iter().enumerate() {
indices = indices.insert(axis, median_index[i]);
}
values.push(*self.get(&indices).unwrap());
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = values.len() / 2;
let median = if values.len() % 2 == 0 {
let two = T::one() + T::one();
(values[mid - 1] + values[mid]) / two
} else {
values[mid]
};
let _ = t.set(&target, median);
}

t
}

// Tensor Product
// Consistent with numpy.tensordot(a, b, axis=0)
pub fn prod(&self, other: &Tensor<T>) -> Tensor<T> {
Expand Down Expand Up @@ -985,6 +1041,56 @@ mod tests {
assert_eq!(result.data, DynamicStorage::new(vec![-10.0, -8.0, -12.0]));
}

#[test]
fn test_tensor_median_no_axis_1d_odd() {
let shape = shape![5].unwrap();
let data = vec![1.0, -2.0, 3.0, -4.0, 5.0];
let tensor = Tensor::new(&shape, &data).unwrap();

let result = tensor.median(vec![]);

assert_eq!(result.shape(), &shape![1].unwrap());
assert_eq!(result.data, DynamicStorage::new(vec![1.0]));
}

#[test]
fn test_tensor_median_no_axis_1d_even() {
let shape = shape![6].unwrap();
let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, 6.0];
let tensor = Tensor::new(&shape, &data).unwrap();

let result = tensor.median(vec![]);

assert_eq!(result.shape(), &shape![1].unwrap());
assert_eq!(result.data, DynamicStorage::new(vec![2.0]));
}

#[test]
fn test_tensor_median_one_axis_2d() {
let shape = shape![2, 3].unwrap();
let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0];
let tensor = Tensor::new(&shape, &data).unwrap();

let result = tensor.median(vec![0]);

assert_eq!(result.shape(), &shape![3].unwrap());
assert_eq!(result.data, DynamicStorage::new(vec![-1.5, 1.5, -1.5]));
}

#[test]
fn test_tensor_median_multiple_axes_3d() {
let shape = shape![2, 2, 3].unwrap();
let data = vec![
1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0, 11.0, -12.0,
];
let tensor = Tensor::new(&shape, &data).unwrap();

let result = tensor.median(vec![0, 1]);

assert_eq!(result.shape(), &shape![3].unwrap());
assert_eq!(result.data, DynamicStorage::new(vec![-1.5, 1.5, -1.5]));
}

#[test]
fn test_tensor_prod_1d_1d() {
let shape1 = shape![3].unwrap();
Expand Down