@@ -169,3 +169,94 @@ where
169
169
self . mapv ( |a| num_traits:: clamp ( a, min. clone ( ) , max. clone ( ) ) )
170
170
}
171
171
}
172
+
173
+ #[ cfg( feature = "std" ) ]
174
+ impl < A , S , D > ArrayBase < S , D >
175
+ where
176
+ A : Float + ' static ,
177
+ S : Data < Elem = A > ,
178
+ D : RemoveAxis ,
179
+ {
180
+ /// Compute the softmax function along the specified axis.
181
+ ///
182
+ /// The softmax function is defined as:
183
+ /// ```text
184
+ /// softmax(x_i) = exp(x_i) / sum(exp(x_j) for j in axis)
185
+ /// ```
186
+ ///
187
+ /// This function is usually used in machine learning to normalize the output of a neural network to a probability
188
+ /// distribution.
189
+ /// ```
190
+ /// use ndarray::{array, Axis};
191
+ ///
192
+ /// let a = array![[1., 2., 3.], [4., 5., 6.0_f32]];
193
+ /// let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
194
+ /// assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
195
+ /// let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
196
+ /// assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
197
+ /// ```
198
+ ///
199
+ /// # Arguments
200
+ ///
201
+ /// * `axis`: The axis along which to compute the softmax function (so every slice along the axis will sum to 1).
202
+ pub fn softmax ( & self , axis : Axis ) -> Array < A , D >
203
+ {
204
+ let mut res = Array :: uninit ( self . raw_dim ( ) ) ;
205
+ for ( arr, mut res) in self . lanes ( axis) . into_iter ( ) . zip ( res. lanes_mut ( axis) ) {
206
+ let max = arr
207
+ . iter ( )
208
+ // If we have NaN and the comparison fails, the max can be arbitrary as the sum and the whole result
209
+ // will be NaN anyway, so we use an arbitrary ordering.
210
+ . max_by ( |a, b| a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
211
+ let max = match max {
212
+ Some ( max) => * max,
213
+ None => continue ,
214
+ } ;
215
+ let sum = arr. fold ( A :: zero ( ) , |sum, x| sum + ( * x - max) . exp ( ) ) ;
216
+ for ( i, x) in res. indexed_iter_mut ( ) {
217
+ x. write ( ( arr[ i] - max) . exp ( ) / sum) ;
218
+ }
219
+ }
220
+ // Safety: we wrote to every single element of the array.
221
+ unsafe { res. assume_init ( ) }
222
+ }
223
+ }
224
+
225
+ #[ cfg( test) ]
226
+ mod tests
227
+ {
228
+ #[ cfg( feature = "std" ) ]
229
+ #[ test]
230
+ fn test_softmax ( )
231
+ {
232
+ use super :: * ;
233
+ use crate :: array;
234
+
235
+ let a = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6.0_f32 ] ] ;
236
+ let b = a. softmax ( Axis ( 0 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
237
+ assert_eq ! ( b, array![ [ 0.05 , 0.05 , 0.05 ] , [ 0.95 , 0.95 , 0.95 ] ] ) ;
238
+ let c = a. softmax ( Axis ( 1 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
239
+ assert_eq ! ( c, array![ [ 0.09 , 0.24 , 0.67 ] , [ 0.09 , 0.24 , 0.67 ] ] ) ;
240
+
241
+ #[ cfg( feature = "approx" ) ]
242
+ {
243
+ // examples copied from scipy softmax documentation
244
+
245
+ use approx:: assert_relative_eq;
246
+
247
+ let x = array ! [ [ 1. , 0.5 , 0.2 , 3. ] , [ 1. , -1. , 7. , 3. ] , [ 2. , 12. , 13. , 3. ] ] ;
248
+
249
+ let m = x. softmax ( Axis ( 0 ) ) ;
250
+ let y = array ! [ [ 0.211942 , 0.00001013 , 0.00000275 , 0.333333 ] ,
251
+ [ 0.211942 , 0.00000226 , 0.00247262 , 0.333333 ] ,
252
+ [ 0.576117 , 0.999988 , 0.997525 , 0.333333 ] ] ;
253
+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
254
+
255
+ let m = x. softmax ( Axis ( 1 ) ) ;
256
+ let y = array ! [ [ 1.05877e-01 , 6.42177e-02 , 4.75736e-02 , 7.82332e-01 ] ,
257
+ [ 2.42746e-03 , 3.28521e-04 , 9.79307e-01 , 1.79366e-02 ] ,
258
+ [ 1.22094e-05 , 2.68929e-01 , 7.31025e-01 , 3.31885e-05 ] ] ;
259
+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
260
+ }
261
+ }
262
+ }
0 commit comments