@@ -32,19 +32,12 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
32
32
}
33
33
34
34
pub fn fill ( shape : & Shape , value : T ) -> Tensor < T > {
35
- let total_size = shape. size ( ) ;
36
- let mut vec = Vec :: with_capacity ( total_size) ;
37
- for _ in 0 ..total_size { vec. push ( value) ; }
35
+ let mut vec = Vec :: with_capacity ( shape. size ( ) ) ;
36
+ for _ in 0 ..shape. size ( ) { vec. push ( value) ; }
38
37
Tensor :: new ( shape, & vec) . unwrap ( )
39
38
}
40
-
41
- pub fn zeros ( shape : & Shape ) -> Tensor < T > {
42
- Tensor :: fill ( shape, T :: zero ( ) )
43
- }
44
-
45
- pub fn ones ( shape : & Shape ) -> Tensor < T > {
46
- Tensor :: fill ( shape, T :: one ( ) )
47
- }
39
+ pub fn zeros ( shape : & Shape ) -> Tensor < T > { Tensor :: fill ( shape, T :: zero ( ) ) }
40
+ pub fn ones ( shape : & Shape ) -> Tensor < T > { Tensor :: fill ( shape, T :: one ( ) ) }
48
41
49
42
// Properties
50
43
pub fn shape ( & self ) -> & Shape { & self . shape }
@@ -64,8 +57,8 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
64
57
pub fn sum ( & self , axes : Axes ) -> Tensor < T > {
65
58
let all_axes = ( 0 ..self . shape . order ( ) ) . collect :: < Vec < _ > > ( ) ;
66
59
let remaining_axes = all_axes. clone ( ) . into_iter ( ) . filter ( |& i| !axes. contains ( & i) ) . collect :: < Vec < _ > > ( ) ;
67
- let remaining_dims = remaining_axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
68
- let removing_dims = axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
60
+ let remaining_dims = remaining_axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
61
+ let removing_dims = axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
69
62
70
63
// We resolve to a scalar value
71
64
if axes. is_empty ( ) | ( remaining_dims. len ( ) == 0 ) {
@@ -95,7 +88,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
95
88
}
96
89
97
90
pub fn mean ( & self , axes : Axes ) -> Tensor < T > {
98
- let removing_dims = axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
91
+ let removing_dims = axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
99
92
let removing_dims_t: Vec < T > = removing_dims. iter ( ) . map ( |& dim| {
100
93
let mut result = T :: zero ( ) ;
101
94
for _ in 0 ..dim {
@@ -108,7 +101,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
108
101
}
109
102
110
103
pub fn var ( & self , axes : Axes ) -> Tensor < T > {
111
- let removing_dims = axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
104
+ let removing_dims = axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
112
105
let removing_dims_t: Vec < T > = removing_dims. iter ( ) . map ( |& dim| {
113
106
let mut result = T :: zero ( ) ;
114
107
for _ in 0 ..dim {
@@ -120,8 +113,8 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
120
113
121
114
let all_axes = ( 0 ..self . shape . order ( ) ) . collect :: < Vec < _ > > ( ) ;
122
115
let remaining_axes = all_axes. clone ( ) . into_iter ( ) . filter ( |& i| !axes. contains ( & i) ) . collect :: < Vec < _ > > ( ) ;
123
- let remaining_dims = remaining_axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
124
- let removing_dims = axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
116
+ let remaining_dims = remaining_axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
117
+ let removing_dims = axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
125
118
126
119
// We resolve to a scalar value
127
120
if axes. is_empty ( ) | ( remaining_dims. len ( ) == 0 ) {
@@ -157,8 +150,8 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
157
150
pub fn max ( & self , axes : Axes ) -> Tensor < T > {
158
151
let all_axes = ( 0 ..self . shape . order ( ) ) . collect :: < Vec < _ > > ( ) ;
159
152
let remaining_axes = all_axes. clone ( ) . into_iter ( ) . filter ( |& i| !axes. contains ( & i) ) . collect :: < Vec < _ > > ( ) ;
160
- let remaining_dims = remaining_axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
161
- let removing_dims = axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
153
+ let remaining_dims = remaining_axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
154
+ let removing_dims = axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
162
155
163
156
// We resolve to a scalar value
164
157
if axes. is_empty ( ) | ( remaining_dims. len ( ) == 0 ) {
@@ -192,8 +185,8 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
192
185
pub fn min ( & self , axes : Axes ) -> Tensor < T > {
193
186
let all_axes = ( 0 ..self . shape . order ( ) ) . collect :: < Vec < _ > > ( ) ;
194
187
let remaining_axes = all_axes. clone ( ) . into_iter ( ) . filter ( |& i| !axes. contains ( & i) ) . collect :: < Vec < _ > > ( ) ;
195
- let remaining_dims = remaining_axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
196
- let removing_dims = axes. iter ( ) . map ( |& i| self . shape . dims [ i] ) . collect :: < Vec < _ > > ( ) ;
188
+ let remaining_dims = remaining_axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
189
+ let removing_dims = axes. iter ( ) . map ( |& i| self . shape [ i] ) . collect :: < Vec < _ > > ( ) ;
197
190
198
191
// We resolve to a scalar value
199
192
if axes. is_empty ( ) | ( remaining_dims. len ( ) == 0 ) {
@@ -227,9 +220,7 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
227
220
// Tensor Product
228
221
// Consistent with numpy.tensordot(a, b, axis=0)
229
222
pub fn prod ( & self , other : & Tensor < T > ) -> Tensor < T > {
230
- let mut new_dims = self . shape . dims . clone ( ) ;
231
- new_dims. extend ( & other. shape . dims ) ;
232
- let new_shape = Shape :: new ( new_dims) . unwrap ( ) ;
223
+ let new_shape = self . shape . stack ( & other. shape ) ;
233
224
234
225
let mut new_data = Vec :: with_capacity ( self . size ( ) * other. size ( ) ) ;
235
226
for & a in & self . data {
0 commit comments