Skip to content

Commit e6e9b93

Browse files
committed
feat!: rename shape len to order
1 parent eebe35f commit e6e9b93

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/lib.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
7171

7272
// // Reduction operations
7373
pub fn sum(&self, axes: Axes) -> Tensor<T> {
74-
let all_axes = (0..self.shape.len()).collect::<Vec<_>>();
74+
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
7575
let remaining_axes = all_axes.clone().into_iter().filter(|&i| !axes.contains(&i)).collect::<Vec<_>>();
7676
let remaining_dims = remaining_axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
7777
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
@@ -126,7 +126,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
126126
}).collect();
127127
let n = removing_dims_t.iter().fold(T::one(), |acc, x| acc * *x);
128128

129-
let all_axes = (0..self.shape.len()).collect::<Vec<_>>();
129+
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
130130
let remaining_axes = all_axes.clone().into_iter().filter(|&i| !axes.contains(&i)).collect::<Vec<_>>();
131131
let remaining_dims = remaining_axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
132132
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
@@ -162,7 +162,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
162162
}
163163

164164
pub fn max(&self, axes: Axes) -> Tensor<T> {
165-
let all_axes = (0..self.shape.len()).collect::<Vec<_>>();
165+
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
166166
let remaining_axes = all_axes.clone().into_iter().filter(|&i| !axes.contains(&i)).collect::<Vec<_>>();
167167
let remaining_dims = remaining_axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
168168
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
@@ -196,7 +196,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
196196
}
197197

198198
pub fn min(&self, axes: Axes) -> Tensor<T> {
199-
let all_axes = (0..self.shape.len()).collect::<Vec<_>>();
199+
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
200200
let remaining_axes = all_axes.clone().into_iter().filter(|&i| !axes.contains(&i)).collect::<Vec<_>>();
201201
let remaining_dims = remaining_axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
202202
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
@@ -249,16 +249,16 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
249249
/// For the maths see: https://bit.ly/3KQjPa3
250250
fn calculate_index(&self, indices: &[usize]) -> usize {
251251
let mut index = 0;
252-
for k in 0..self.shape.len() {
252+
for k in 0..self.shape.order() {
253253
let stride = self.shape.dims[k+1..].iter().product::<usize>();
254254
index += indices[k] * stride;
255255
}
256256
index
257257
}
258258

259259
fn assert_indices(&self, indices: &[usize]) -> Result<(), ShapeError> {
260-
if indices.len() != self.shape.len() {
261-
let msg = format!("incorrect order ({} vs {}).", indices.len(), self.shape.len());
260+
if indices.len() != self.shape.order() {
261+
let msg = format!("incorrect order ({} vs {}).", indices.len(), self.shape.order());
262262
return Err(ShapeError::new(msg.as_str()));
263263
}
264264
for (i, &index) in indices.iter().enumerate() {
@@ -288,15 +288,15 @@ impl<T: Num + PartialOrd + Copy> Mul<Tensor<T>> for Tensor<T> {
288288
type Output = Tensor<T>;
289289

290290
fn mul(self, rhs: Tensor<T>) -> Tensor<T> {
291-
if self.shape.len() == 1 && rhs.shape.len() == 1 {
291+
if self.shape.order() == 1 && rhs.shape.order() == 1 {
292292
// Vector-Vector multiplication (dot product)
293293
assert!(self.shape[0] == rhs.shape[0], "Vectors must be of the same length for dot product.");
294294
let mut result = T::zero();
295295
for i in 0..self.shape[0] {
296296
result = result + self.data[i] * rhs.data[i];
297297
}
298298
Tensor::new(&shape![1].unwrap(), &vec![result]).unwrap()
299-
} else if self.shape.len() == 1 && rhs.shape.len() == 2 {
299+
} else if self.shape.order() == 1 && rhs.shape.order() == 2 {
300300
// Vector-Matrix multiplication
301301
assert!(self.shape[0] == rhs.shape[0], "The length of the vector must be equal to the number of rows in the matrix.");
302302
let mut result = Tensor::zeros(&shape![rhs.shape[1]].unwrap());
@@ -308,7 +308,7 @@ impl<T: Num + PartialOrd + Copy> Mul<Tensor<T>> for Tensor<T> {
308308
result.data[j] = sum;
309309
}
310310
result
311-
} else if self.shape.len() == 2 && rhs.shape.len() == 1 {
311+
} else if self.shape.order() == 2 && rhs.shape.order() == 1 {
312312
// Matrix-Vector multiplication
313313
assert!(self.shape[1] == rhs.shape[0], "The number of columns in the matrix must be equal to the length of the vector.");
314314
let mut result = Tensor::zeros(&shape![self.shape[0]].unwrap());
@@ -320,7 +320,7 @@ impl<T: Num + PartialOrd + Copy> Mul<Tensor<T>> for Tensor<T> {
320320
result.data[i] = sum;
321321
}
322322
result
323-
} else if self.shape.len() == 2 && rhs.shape.len() == 2 {
323+
} else if self.shape.order() == 2 && rhs.shape.order() == 2 {
324324
// Matrix-Matrix multiplication
325325
assert!(self.shape[1] == rhs.shape[0], "The number of columns in the first matrix must be equal to the number of rows in the second matrix.");
326326
let mut result = Tensor::zeros(&shape![self.shape[0], rhs.shape[1]].unwrap());

src/shape.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ impl Shape {
2020
self.dims.iter().product()
2121
}
2222

23-
pub fn len(&self) -> usize {
23+
pub fn order(&self) -> usize {
2424
self.dims.len()
2525
}
2626
}

0 commit comments

Comments
 (0)