@@ -169,3 +169,100 @@ where
169
169
self . mapv ( |a| num_traits:: clamp ( a, min. clone ( ) , max. clone ( ) ) )
170
170
}
171
171
}
172
+
173
+ #[ cfg( feature = "std" ) ]
174
+ impl < A , S , D > ArrayBase < S , D >
175
+ where
176
+ A : Float + ' static ,
177
+ S : Data < Elem = A > ,
178
+ D : RemoveAxis ,
179
+ {
180
+ /// Compute the softmax function along the specified axis.
181
+ ///
182
+ /// The softmax function is defined as:
183
+ /// ```text
184
+ /// softmax(x_i) = exp(x_i) / sum(exp(x_j) for j in axis)
185
+ /// ```
186
+ ///
187
+ /// This function is usually used in machine learning to normalize the output of a neural network to a probability
188
+ /// distribution.
189
+ /// ```
190
+ /// use ndarray::{array, Axis};
191
+ ///
192
+ /// let a = array![[1., 2., 3.], [4., 5., 6.0_f32]];
193
+ /// let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
194
+ /// assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
195
+ /// let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
196
+ /// assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
197
+ /// ```
198
+ ///
199
+ /// # Arguments
200
+ ///
201
+ /// * `axis`: The axis along which to compute the softmax function (so every slice along the axis will sum to 1).
202
+ pub fn softmax ( & self , axis : Axis ) -> Array < A , D >
203
+ {
204
+ let mut res = Array :: uninit ( self . raw_dim ( ) ) ;
205
+ for ( arr, mut res) in self . lanes ( axis) . into_iter ( ) . zip ( res. lanes_mut ( axis) ) {
206
+ let max = arr
207
+ . iter ( )
208
+ // If we have NaN and the comparison fails, the max can be arbitrary as the sum and the whole result
209
+ // will be NaN anyway, so we use an arbitrary ordering.
210
+ . max_by ( |a, b| a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
211
+ let max = match max {
212
+ Some ( max) => * max,
213
+ None => continue ,
214
+ } ;
215
+ let mut sum = A :: zero ( ) ;
216
+ for ( i, x) in res. indexed_iter_mut ( ) {
217
+ let v = ( arr[ i] - max) . exp ( ) ;
218
+ sum = sum + v;
219
+ x. write ( v) ;
220
+ }
221
+ for x in res. iter_mut ( ) {
222
+ // Safety: we wrote to every single element of the `res` array in the previous loop.
223
+ x. write ( * unsafe { x. assume_init_ref ( ) } / sum) ;
224
+ }
225
+ }
226
+ // Safety: we wrote to every single element of the array.
227
+ unsafe { res. assume_init ( ) }
228
+ }
229
+ }
230
+
231
+ #[ cfg( test) ]
232
+ mod tests
233
+ {
234
+ #[ cfg( feature = "std" ) ]
235
+ #[ test]
236
+ fn test_softmax ( )
237
+ {
238
+ use super :: * ;
239
+ use crate :: array;
240
+
241
+ let a = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6.0_f32 ] ] ;
242
+ let b = a. softmax ( Axis ( 0 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
243
+ assert_eq ! ( b, array![ [ 0.05 , 0.05 , 0.05 ] , [ 0.95 , 0.95 , 0.95 ] ] ) ;
244
+ let c = a. softmax ( Axis ( 1 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
245
+ assert_eq ! ( c, array![ [ 0.09 , 0.24 , 0.67 ] , [ 0.09 , 0.24 , 0.67 ] ] ) ;
246
+
247
+ #[ cfg( feature = "approx" ) ]
248
+ {
249
+ // examples copied from scipy softmax documentation
250
+
251
+ use approx:: assert_relative_eq;
252
+
253
+ let x = array ! [ [ 1. , 0.5 , 0.2 , 3. ] , [ 1. , -1. , 7. , 3. ] , [ 2. , 12. , 13. , 3. ] ] ;
254
+
255
+ let m = x. softmax ( Axis ( 0 ) ) ;
256
+ let y = array ! [ [ 0.211942 , 0.00001013 , 0.00000275 , 0.333333 ] ,
257
+ [ 0.211942 , 0.00000226 , 0.00247262 , 0.333333 ] ,
258
+ [ 0.576117 , 0.999988 , 0.997525 , 0.333333 ] ] ;
259
+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
260
+
261
+ let m = x. softmax ( Axis ( 1 ) ) ;
262
+ let y = array ! [ [ 1.05877e-01 , 6.42177e-02 , 4.75736e-02 , 7.82332e-01 ] ,
263
+ [ 2.42746e-03 , 3.28521e-04 , 9.79307e-01 , 1.79366e-02 ] ,
264
+ [ 1.22094e-05 , 2.68929e-01 , 7.31025e-01 , 3.31885e-05 ] ] ;
265
+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
266
+ }
267
+ }
268
+ }
0 commit comments