Skip to content

Commit 4c32cfe

Browse files
authored
Closes #1425 Merge pull request #1558 from andrjohns/feature/vec_gen_design
Generalised unary vector function framework
2 parents 13a1b66 + f3b3286 commit 4c32cfe

12 files changed

+419
-242
lines changed

stan/math/fwd/fun/log_softmax.hpp

+35-31
Original file line numberDiff line numberDiff line change
@@ -6,44 +6,48 @@
66
#include <stan/math/prim/fun/Eigen.hpp>
77
#include <stan/math/prim/fun/log_softmax.hpp>
88
#include <stan/math/prim/fun/softmax.hpp>
9+
#include <stan/math/prim/meta.hpp>
10+
#include <stan/math/prim/vectorize/apply_vector_unary.hpp>
911

1012
namespace stan {
1113
namespace math {
1214

13-
template <typename T>
14-
inline Eigen::Matrix<fvar<T>, Eigen::Dynamic, 1> log_softmax(
15-
const Eigen::Matrix<fvar<T>, Eigen::Dynamic, 1>& alpha) {
16-
using Eigen::Dynamic;
17-
using Eigen::Matrix;
18-
19-
Matrix<T, Dynamic, 1> alpha_t(alpha.size());
20-
for (int k = 0; k < alpha.size(); ++k) {
21-
alpha_t(k) = alpha(k).val_;
22-
}
23-
24-
Matrix<T, Dynamic, 1> softmax_alpha_t = softmax(alpha_t);
25-
Matrix<T, Dynamic, 1> log_softmax_alpha_t = log_softmax(alpha_t);
26-
27-
Matrix<fvar<T>, Dynamic, 1> log_softmax_alpha(alpha.size());
28-
for (int k = 0; k < alpha.size(); ++k) {
29-
log_softmax_alpha(k).val_ = log_softmax_alpha_t(k);
30-
log_softmax_alpha(k).d_ = 0;
31-
}
32-
33-
for (int m = 0; m < alpha.size(); ++m) {
34-
T negative_alpha_m_d_times_softmax_alpha_t_m
35-
= -alpha(m).d_ * softmax_alpha_t(m);
36-
for (int k = 0; k < alpha.size(); ++k) {
37-
if (m == k) {
38-
log_softmax_alpha(k).d_
39-
+= alpha(m).d_ + negative_alpha_m_d_times_softmax_alpha_t_m;
40-
} else {
41-
log_softmax_alpha(k).d_ += negative_alpha_m_d_times_softmax_alpha_t_m;
15+
/**
16+
* Return the log softmax of the specified vector or container of vectors.
17+
*
18+
* @tparam T Type of input vector or matrix.
19+
* @param[in] x Unconstrained input vector.
20+
* @return Softmax of the input.
21+
* @throw std::domain_error If the input vector is size 0.
22+
*/
23+
template <typename T, require_t<is_fvar<scalar_type_t<T>>>...>
24+
inline auto log_softmax(const T& x) {
25+
return apply_vector_unary<T>::apply(x, [&](const auto& alpha) {
26+
using T_fvar = value_type_t<decltype(alpha)>;
27+
using T_fvar_inner = typename T_fvar::Scalar;
28+
29+
Eigen::Matrix<T_fvar_inner, -1, 1> alpha_t = alpha.val();
30+
Eigen::Matrix<T_fvar_inner, -1, 1> softmax_alpha_t = softmax(alpha_t);
31+
32+
Eigen::Matrix<T_fvar, -1, 1> log_softmax_alpha(alpha.size());
33+
log_softmax_alpha.val() = log_softmax(alpha_t);
34+
log_softmax_alpha.d().setZero();
35+
36+
for (int m = 0; m < alpha.size(); ++m) {
37+
T_fvar_inner negative_alpha_m_d_times_softmax_alpha_t_m
38+
= -alpha(m).d_ * softmax_alpha_t(m);
39+
for (int k = 0; k < alpha.size(); ++k) {
40+
if (m == k) {
41+
log_softmax_alpha(k).d_
42+
+= alpha(m).d_ + negative_alpha_m_d_times_softmax_alpha_t_m;
43+
} else {
44+
log_softmax_alpha(k).d_ += negative_alpha_m_d_times_softmax_alpha_t_m;
45+
}
4246
}
4347
}
44-
}
4548

46-
return log_softmax_alpha;
49+
return log_softmax_alpha;
50+
});
4751
}
4852

4953
} // namespace math

stan/math/fwd/fun/log_sum_exp.hpp

+27-28
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/prim/fun/Eigen.hpp>
77
#include <stan/math/prim/fun/constants.hpp>
88
#include <stan/math/prim/fun/log_sum_exp.hpp>
9+
#include <stan/math/prim/vectorize/apply_vector_unary.hpp>
910
#include <cmath>
1011
#include <vector>
1112

@@ -31,37 +32,35 @@ inline fvar<T> log_sum_exp(double x1, const fvar<T>& x2) {
3132

3233
template <typename T>
3334
inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
34-
using std::exp;
35-
if (x2 == NEGATIVE_INFTY) {
36-
return fvar<T>(x1.val_, x1.d_);
37-
}
38-
return fvar<T>(log_sum_exp(x1.val_, x2), x1.d_ / (1 + exp(x2 - x1.val_)));
39-
}
40-
41-
template <typename T>
42-
fvar<T> log_sum_exp(const std::vector<fvar<T> >& v) {
43-
using std::exp;
44-
std::vector<T> vals(v.size());
45-
for (size_t i = 0; i < v.size(); ++i) {
46-
vals[i] = v[i].val_;
47-
}
48-
T deriv(0.0);
49-
T denominator(0.0);
50-
for (size_t i = 0; i < v.size(); ++i) {
51-
T exp_vi = exp(vals[i]);
52-
denominator += exp_vi;
53-
deriv += v[i].d_ * exp_vi;
54-
}
55-
return fvar<T>(log_sum_exp(vals), deriv / denominator);
35+
return log_sum_exp(x2, x1);
5636
}
5737

58-
template <typename T, int R, int C>
59-
fvar<T> log_sum_exp(const Eigen::Matrix<fvar<T>, R, C>& v) {
60-
Eigen::Matrix<T, R, C> vals = v.val();
61-
Eigen::Matrix<T, R, C> exp_vals = vals.array().exp();
38+
/**
39+
* Return the log of the sum of the exponentiated values of the specified
40+
* matrix of values. The matrix may be a full matrix, a vector,
41+
* a row vector, or a container of these.
42+
*
43+
* The function is defined as follows to prevent overflow in exponential
44+
* calculations.
45+
*
46+
* \f$\log \sum_{n=1}^N \exp(x_n) = \max(x) + \log \sum_{n=1}^N \exp(x_n -
47+
* \max(x))\f$.
48+
*
49+
* @tparam T Type of input vector or matrix.
50+
* @param[in] x Matrix of specified values.
51+
* @return The log of the sum of the exponentiated vector values.
52+
*/
53+
template <typename T, require_t<is_fvar<scalar_type_t<T>>>...>
54+
inline auto log_sum_exp(const T& x) {
55+
return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
56+
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
57+
using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
58+
mat_type vals = v.val();
59+
mat_type exp_vals = vals.array().exp();
6260

63-
return fvar<T>(log_sum_exp(vals),
64-
v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
61+
return fvar<T_fvar_inner>(
62+
log_sum_exp(vals), v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
63+
});
6564
}
6665

6766
} // namespace math

stan/math/prim/fun/log_softmax.hpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/err.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
66
#include <stan/math/prim/fun/log_sum_exp.hpp>
7+
#include <stan/math/prim/vectorize/apply_vector_unary.hpp>
78

89
namespace stan {
910
namespace math {
@@ -32,18 +33,17 @@ namespace math {
3233
* \right.
3334
* \f$
3435
*
35-
* @tparam T type of elements in the vector
36-
* @param[in] v Vector to transform.
37-
* @return Unit simplex result of the softmax transform of the vector.
36+
* @tparam T Type of input vector to transform.
37+
* @param[in] x Vector to transform.
38+
* @return log unit simplex result of the softmax transform of the vector.
3839
*/
39-
template <typename T>
40-
inline Eigen::Matrix<T, Eigen::Dynamic, 1> log_softmax(
41-
const Eigen::Matrix<T, Eigen::Dynamic, 1>& v) {
42-
check_nonzero_size("log_softmax", "v", v);
43-
return v.array() - log_sum_exp(v);
40+
template <typename T, require_t<std::is_arithmetic<scalar_type_t<T>>>...>
41+
inline auto log_softmax(const T& x) {
42+
return apply_vector_unary<T>::apply(x, [&](const auto& v) {
43+
check_nonzero_size("log_softmax", "v", v);
44+
return (v.array() - log_sum_exp(v)).matrix();
45+
});
4446
}
45-
4647
} // namespace math
4748
} // namespace stan
48-
4949
#endif

stan/math/prim/fun/log_sum_exp.hpp

+15-48
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/prim/fun/constants.hpp>
66
#include <stan/math/prim/fun/Eigen.hpp>
77
#include <stan/math/prim/fun/log1p_exp.hpp>
8+
#include <stan/math/prim/vectorize/apply_vector_unary.hpp>
89
#include <cmath>
910
#include <vector>
1011

@@ -62,65 +63,31 @@ inline return_type_t<T1, T2> log_sum_exp(const T2& a, const T1& b) {
6263

6364
/**
6465
* Return the log of the sum of the exponentiated values of the specified
65-
* sequence of values.
66+
* matrix of values. The matrix may be a full matrix, a vector,
67+
* a row vector, or a container of these.
6668
*
6769
* The function is defined as follows to prevent overflow in exponential
6870
* calculations.
6971
*
7072
* \f$\log \sum_{n=1}^N \exp(x_n) = \max(x) + \log \sum_{n=1}^N \exp(x_n -
7173
* \max(x))\f$.
7274
*
73-
* @param[in] x array of specified values
75+
* @tparam T Type of input vector or matrix.
76+
* @param[in] x Matrix of specified values.
7477
* @return The log of the sum of the exponentiated vector values.
7578
*/
76-
inline double log_sum_exp(const std::vector<double>& x) {
77-
using std::exp;
78-
using std::log;
79-
double max = NEGATIVE_INFTY;
80-
for (double xx : x) {
81-
if (xx > max) {
82-
max = xx;
79+
template <typename T, require_t<std::is_arithmetic<scalar_type_t<T>>>...>
80+
inline auto log_sum_exp(const T& x) {
81+
return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
82+
if (v.size() == 0) {
83+
return NEGATIVE_INFTY;
8384
}
84-
}
85-
86-
double sum = 0.0;
87-
for (size_t ii = 0; ii < x.size(); ii++) {
88-
if (x[ii] != NEGATIVE_INFTY) {
89-
sum += exp(x[ii] - max);
85+
const double max = v.maxCoeff();
86+
if (!std::isfinite(max)) {
87+
return max;
9088
}
91-
}
92-
93-
return max + log(sum);
94-
}
95-
96-
/**
97-
* Return the log of the sum of the exponentiated values of the specified
98-
* matrix of values. The matrix may be a full matrix, a vector,
99-
* or a row vector.
100-
*
101-
* The function is defined as follows to prevent overflow in exponential
102-
* calculations.
103-
*
104-
* \f$\log \sum_{n=1}^N \exp(x_n) = \max(x) + \log \sum_{n=1}^N \exp(x_n -
105-
* \max(x))\f$.
106-
*
107-
* @tparam R number of rows, can be Eigen::Dynamic
108-
* @tparam C number of columns, can be Eigen::Dynamic
109-
*
110-
* @param[in] x Matrix of specified values
111-
* @return The log of the sum of the exponentiated vector values.
112-
*/
113-
template <int R, int C>
114-
double log_sum_exp(const Eigen::Matrix<double, R, C>& x) {
115-
if (x.size() == 0) {
116-
return NEGATIVE_INFTY;
117-
}
118-
119-
const double max = x.maxCoeff();
120-
if (!std::isfinite(max)) {
121-
return max;
122-
}
123-
return max + std::log((x.array() - max).exp().sum());
89+
return max + std::log((v.array() - max).exp().sum());
90+
});
12491
}
12592

12693
} // namespace math

0 commit comments

Comments
 (0)