Skip to content

Commit db87d41

Browse files
committed
Add softmax function
1 parent c7391e9 commit db87d41

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

src/numeric/impl_float_maths.rs

+97
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,100 @@ where
169169
self.mapv(|a| num_traits::clamp(a, min.clone(), max.clone()))
170170
}
171171
}
172+
173+
#[cfg(feature = "std")]
174+
impl<A, S, D> ArrayBase<S, D>
175+
where
176+
A: Float + 'static,
177+
S: Data<Elem = A>,
178+
D: RemoveAxis,
179+
{
180+
/// Compute the softmax function along the specified axis.
181+
///
182+
/// The softmax function is defined as:
183+
/// ```text
184+
/// softmax(x_i) = exp(x_i) / sum(exp(x_j) for j in axis)
185+
/// ```
186+
///
187+
/// This function is usually used in machine learning to normalize the output of a neural network to a probability
188+
/// distribution.
189+
/// ```
190+
/// use ndarray::{array, Axis};
191+
///
192+
/// let a = array![[1., 2., 3.], [4., 5., 6.0_f32]];
193+
/// let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
194+
/// assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
195+
/// let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
196+
/// assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
197+
/// ```
198+
///
199+
/// # Arguments
200+
///
201+
/// * `axis`: The axis along which to compute the softmax function (so every slice along the axis will sum to 1).
202+
pub fn softmax(&self, axis: Axis) -> Array<A, D>
203+
{
204+
let mut res = Array::uninit(self.raw_dim());
205+
for (arr, mut res) in self.lanes(axis).into_iter().zip(res.lanes_mut(axis)) {
206+
let max = arr
207+
.iter()
208+
// If we have NaN and the comparison fails, the max can be arbitrary as the sum and the whole result
209+
// will be NaN anyway, so we use an arbitrary ordering.
210+
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
211+
let max = match max {
212+
Some(max) => *max,
213+
None => continue,
214+
};
215+
let mut sum = A::zero();
216+
for (i, x) in res.indexed_iter_mut() {
217+
let v = (arr[i] - max).exp();
218+
sum = sum + v;
219+
x.write(v);
220+
}
221+
for x in res.iter_mut() {
222+
// Safety: we wrote to every single element of the `res` array in the previous loop.
223+
x.write(*unsafe { x.assume_init_ref() } / sum);
224+
}
225+
}
226+
// Safety: we wrote to every single element of the array.
227+
unsafe { res.assume_init() }
228+
}
229+
}
230+
231+
#[cfg(test)]
232+
mod tests
233+
{
234+
#[cfg(feature = "std")]
235+
#[test]
236+
fn test_softmax()
237+
{
238+
use super::*;
239+
use crate::array;
240+
241+
let a = array![[1., 2., 3.], [4., 5., 6.0_f32]];
242+
let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
243+
assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
244+
let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
245+
assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
246+
247+
#[cfg(feature = "approx")]
248+
{
249+
// examples copied from scipy softmax documentation
250+
251+
use approx::assert_relative_eq;
252+
253+
let x = array![[1., 0.5, 0.2, 3.], [1., -1., 7., 3.], [2., 12., 13., 3.]];
254+
255+
let m = x.softmax(Axis(0));
256+
let y = array![[0.211942, 0.00001013, 0.00000275, 0.333333],
257+
[0.211942, 0.00000226, 0.00247262, 0.333333],
258+
[0.576117, 0.999988, 0.997525, 0.333333]];
259+
assert_relative_eq!(m, y, epsilon = 1e-5);
260+
261+
let m = x.softmax(Axis(1));
262+
let y = array![[ 1.05877e-01, 6.42177e-02, 4.75736e-02, 7.82332e-01],
263+
[ 2.42746e-03, 3.28521e-04, 9.79307e-01, 1.79366e-02],
264+
[ 1.22094e-05, 2.68929e-01, 7.31025e-01, 3.31885e-05]];
265+
assert_relative_eq!(m, y, epsilon = 1e-5);
266+
}
267+
}
268+
}

0 commit comments

Comments
 (0)