Skip to content

Commit a8bd42f

Browse files
committed
feat!: make dims private in Shape
1 parent da4f2f8 commit a8bd42f

File tree

3 files changed

+41
-27
lines changed

3 files changed

+41
-27
lines changed

src/lib.rs

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,12 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
3232
}
3333

3434
pub fn fill(shape: &Shape, value: T) -> Tensor<T> {
35-
let total_size = shape.size();
36-
let mut vec = Vec::with_capacity(total_size);
37-
for _ in 0..total_size { vec.push(value); }
35+
let mut vec = Vec::with_capacity(shape.size());
36+
for _ in 0..shape.size() { vec.push(value); }
3837
Tensor::new(shape, &vec).unwrap()
3938
}
40-
41-
pub fn zeros(shape: &Shape) -> Tensor<T> {
42-
Tensor::fill(shape, T::zero())
43-
}
44-
45-
pub fn ones(shape: &Shape) -> Tensor<T> {
46-
Tensor::fill(shape, T::one())
47-
}
39+
pub fn zeros(shape: &Shape) -> Tensor<T> {Tensor::fill(shape, T::zero())}
40+
pub fn ones(shape: &Shape) -> Tensor<T> {Tensor::fill(shape, T::one())}
4841

4942
// Properties
5043
pub fn shape(&self) -> &Shape { &self.shape }
@@ -64,8 +57,8 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
6457
pub fn sum(&self, axes: Axes) -> Tensor<T> {
6558
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
6659
let remaining_axes = all_axes.clone().into_iter().filter(|&i| !axes.contains(&i)).collect::<Vec<_>>();
67-
let remaining_dims = remaining_axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
68-
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
60+
let remaining_dims = remaining_axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
61+
let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
6962

7063
// We resolve to a scalar value
7164
if axes.is_empty() | (remaining_dims.len() == 0) {
@@ -95,7 +88,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
9588
}
9689

9790
pub fn mean(&self, axes: Axes) -> Tensor<T> {
98-
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
91+
let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
9992
let removing_dims_t: Vec<T> = removing_dims.iter().map(|&dim| {
10093
let mut result = T::zero();
10194
for _ in 0..dim {
@@ -108,7 +101,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
108101
}
109102

110103
pub fn var(&self, axes: Axes) -> Tensor<T> {
111-
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
104+
let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
112105
let removing_dims_t: Vec<T> = removing_dims.iter().map(|&dim| {
113106
let mut result = T::zero();
114107
for _ in 0..dim {
@@ -120,8 +113,8 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
120113

121114
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
122115
let remaining_axes = all_axes.clone().into_iter().filter(|&i| !axes.contains(&i)).collect::<Vec<_>>();
123-
let remaining_dims = remaining_axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
124-
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
116+
let remaining_dims = remaining_axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
117+
let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
125118

126119
// We resolve to a scalar value
127120
if axes.is_empty() | (remaining_dims.len() == 0) {
@@ -157,8 +150,8 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
157150
pub fn max(&self, axes: Axes) -> Tensor<T> {
158151
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
159152
let remaining_axes = all_axes.clone().into_iter().filter(|&i| !axes.contains(&i)).collect::<Vec<_>>();
160-
let remaining_dims = remaining_axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
161-
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
153+
let remaining_dims = remaining_axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
154+
let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
162155

163156
// We resolve to a scalar value
164157
if axes.is_empty() | (remaining_dims.len() == 0) {
@@ -192,8 +185,8 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
192185
pub fn min(&self, axes: Axes) -> Tensor<T> {
193186
let all_axes = (0..self.shape.order()).collect::<Vec<_>>();
194187
let remaining_axes = all_axes.clone().into_iter().filter(|&i| !axes.contains(&i)).collect::<Vec<_>>();
195-
let remaining_dims = remaining_axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
196-
let removing_dims = axes.iter().map(|&i| self.shape.dims[i]).collect::<Vec<_>>();
188+
let remaining_dims = remaining_axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
189+
let removing_dims = axes.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
197190

198191
// We resolve to a scalar value
199192
if axes.is_empty() | (remaining_dims.len() == 0) {
@@ -227,9 +220,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
227220
// Tensor Product
228221
// Consistent with numpy.tensordot(a, b, axis=0)
229222
pub fn prod(&self, other: &Tensor<T>) -> Tensor<T> {
230-
let mut new_dims = self.shape.dims.clone();
231-
new_dims.extend(&other.shape.dims);
232-
let new_shape = Shape::new(new_dims).unwrap();
223+
let new_shape = self.shape.stack(&other.shape);
233224

234225
let mut new_data = Vec::with_capacity(self.size() * other.size());
235226
for &a in &self.data {

src/shape.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::error::ShapeError;
55

66
#[derive(Debug, Clone, PartialEq)]
77
pub struct Shape {
8-
pub dims: Vec<usize>,
8+
dims: Vec<usize>,
99
}
1010

1111
impl Shape {
@@ -23,6 +23,12 @@ impl Shape {
2323
pub fn order(&self) -> usize {
2424
self.dims.len()
2525
}
26+
27+
pub fn stack(&self, rhs: &Shape) -> Shape {
28+
let mut new_dims = self.dims.clone();
29+
new_dims.extend(rhs.dims.iter());
30+
Shape { dims: new_dims }
31+
}
2632
}
2733

2834
impl Index<usize> for Shape {
@@ -33,6 +39,14 @@ impl Index<usize> for Shape {
3339
}
3440
}
3541

42+
impl Index<std::ops::RangeFrom<usize>> for Shape {
43+
type Output = [usize];
44+
45+
fn index(&self, index: std::ops::RangeFrom<usize>) -> &Self::Output {
46+
&self.dims[index]
47+
}
48+
}
49+
3650
impl fmt::Display for Shape {
3751
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3852
use itertools::Itertools;
@@ -83,4 +97,13 @@ mod tests {
8397
let shape = shape![2, 3, 4].unwrap();
8498
assert_eq!(shape.dims, vec![2, 3, 4]);
8599
}
100+
101+
102+
#[test]
103+
fn test_shape_extend() {
104+
let shape1 = Shape::new(vec![2, 3]).unwrap();
105+
let shape2 = Shape::new(vec![4, 5]).unwrap();
106+
let extended_shape = shape1.stack(&shape2);
107+
assert_eq!(extended_shape.dims, vec![2, 3, 4, 5]);
108+
}
86109
}

src/storage.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ impl<T> DynamicStorage<T> {
2222
}
2323

2424
for (i, &dim) in coord.iter().enumerate() {
25-
if dim >= shape.dims[i] {
25+
if dim >= shape[i] {
2626
return Err(ShapeError::new(format!("out of bounds for dimension {}", i).as_str()));
2727
}
2828
}
2929

3030
let mut index = 0;
3131
for k in 0..shape.order() {
32-
let stride = shape.dims[k+1..].iter().product::<usize>();
32+
let stride = shape[k+1..].iter().product::<usize>();
3333
index += coord[k] * stride;
3434
}
3535
Ok(index)

0 commit comments

Comments
 (0)