Skip to content

Commit 61c18c9

Browse files
committed
Add softmax function
1 parent c7391e9 commit 61c18c9

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

src/numeric/impl_float_maths.rs

+91
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,94 @@ where
169169
self.mapv(|a| num_traits::clamp(a, min.clone(), max.clone()))
170170
}
171171
}
172+
173+
#[cfg(feature = "std")]
174+
impl<A, S, D> ArrayBase<S, D>
175+
where
176+
A: Float + 'static,
177+
S: Data<Elem = A>,
178+
D: RemoveAxis,
179+
{
180+
/// Compute the softmax function along the specified axis.
181+
///
182+
/// The softmax function is defined as:
183+
/// ```text
184+
/// softmax(x_i) = exp(x_i) / sum(exp(x_j) for j in axis)
185+
/// ```
186+
///
187+
/// This function is usually used in machine learning to normalize the output of a neural network to a probability
188+
/// distribution.
189+
/// ```
190+
/// use ndarray::{array, Axis};
191+
///
192+
/// let a = array![[1., 2., 3.], [4., 5., 6.0_f32]];
193+
/// let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
194+
/// assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
195+
/// let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
196+
/// assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
197+
/// ```
198+
///
199+
/// # Arguments
200+
///
201+
/// * `axis`: The axis along which to compute the softmax function (so every slice along the axis will sum to 1).
202+
pub fn softmax(&self, axis: Axis) -> Array<A, D>
203+
{
204+
let mut res = Array::uninit(self.raw_dim());
205+
for (arr, mut res) in self.lanes(axis).into_iter().zip(res.lanes_mut(axis)) {
206+
let max = arr
207+
.iter()
208+
// If we have NaN and the comparison fails, the max can be arbitrary as the sum and the whole result
209+
// will be NaN anyway, so we use an arbitrary ordering.
210+
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
211+
let max = match max {
212+
Some(max) => *max,
213+
None => continue,
214+
};
215+
let sum = arr.fold(A::zero(), |sum, x| sum + (*x - max).exp());
216+
for (i, x) in res.indexed_iter_mut() {
217+
x.write((arr[i] - max).exp() / sum);
218+
}
219+
}
220+
// Safety: we wrote to every single element of the array.
221+
unsafe { res.assume_init() }
222+
}
223+
}
224+
225+
#[cfg(test)]
226+
mod tests
227+
{
228+
#[cfg(feature = "std")]
229+
#[test]
230+
fn test_softmax()
231+
{
232+
use super::*;
233+
use crate::array;
234+
235+
let a = array![[1., 2., 3.], [4., 5., 6.0_f32]];
236+
let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
237+
assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
238+
let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
239+
assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
240+
241+
#[cfg(feature = "approx")]
242+
{
243+
// examples copied from scipy softmax documentation
244+
245+
use approx::assert_relative_eq;
246+
247+
let x = array![[1., 0.5, 0.2, 3.], [1., -1., 7., 3.], [2., 12., 13., 3.]];
248+
249+
let m = x.softmax(Axis(0));
250+
let y = array![[0.211942, 0.00001013, 0.00000275, 0.333333],
251+
[0.211942, 0.00000226, 0.00247262, 0.333333],
252+
[0.576117, 0.999988, 0.997525, 0.333333]];
253+
assert_relative_eq!(m, y, epsilon = 1e-5);
254+
255+
let m = x.softmax(Axis(1));
256+
let y = array![[ 1.05877e-01, 6.42177e-02, 4.75736e-02, 7.82332e-01],
257+
[ 2.42746e-03, 3.28521e-04, 9.79307e-01, 1.79366e-02],
258+
[ 1.22094e-05, 2.68929e-01, 7.31025e-01, 3.31885e-05]];
259+
assert_relative_eq!(m, y, epsilon = 1e-5);
260+
}
261+
}
262+
}

0 commit comments

Comments
 (0)