Skip to content

Commit 2abed57

Browse files
committed
StanHeaders compatibility for both Stan 2.26 and 2.21
Introduce `USE_STANC3` macro to branch changes specific to Stan 2.26 that cannot be supported by the older version of `stanc` transpiler. Partially reverts stan-dev#2190. Signed-off-by: Hamada S. Badr <[email protected]>
1 parent 3b510ea commit 2abed57

38 files changed

+436
-2
lines changed

stan/math/prim/err/elementwise_check.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,14 @@ inline void elementwise_check(const F& is_good, const char* function,
144144
*/
145145
template <typename F, typename T, typename... Indexings,
146146
require_eigen_t<T>* = nullptr,
147+
#ifdef USE_STANC3
147148
std::enable_if_t<(Eigen::internal::traits<T>::Flags
148149
& Eigen::LinearAccessBit)
149150
|| T::IsVectorAtCompileTime>* = nullptr>
151+
#else
152+
std::enable_if_t<static_cast<bool>(Eigen::internal::traits<T>::Flags&(
153+
Eigen::LinearAccessBit | Eigen::DirectAccessBit))>* = nullptr>
154+
#endif
150155
inline void elementwise_check(const F& is_good, const char* function,
151156
const char* name, const T& x, const char* must_be,
152157
const Indexings&... indexings) {
@@ -194,13 +199,23 @@ inline void elementwise_check(const F& is_good, const char* function,
194199
* @throws `std::domain_error` if `is_good` returns `false` for the value
195200
* of any element in `x`
196201
*/
202+
#ifdef USE_STANC3
197203
template <typename F, typename T, typename... Indexings,
198204
require_eigen_t<T>* = nullptr,
199205
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
200206
& Eigen::LinearAccessBit)
201207
&& !T::IsVectorAtCompileTime
202208
&& !(Eigen::internal::traits<T>::Flags
203209
& Eigen::RowMajorBit)>* = nullptr>
210+
#else
211+
template <
212+
typename F, typename T, typename... Indexings,
213+
require_eigen_t<T>* = nullptr,
214+
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
215+
& (Eigen::LinearAccessBit | Eigen::DirectAccessBit))
216+
&& !(Eigen::internal::traits<T>::Flags
217+
& Eigen::RowMajorBit)>* = nullptr>
218+
#endif
204219
inline void elementwise_check(const F& is_good, const char* function,
205220
const char* name, const T& x, const char* must_be,
206221
const Indexings&... indexings) {
@@ -237,13 +252,23 @@ inline void elementwise_check(const F& is_good, const char* function,
237252
* @throws `std::domain_error` if `is_good` returns `false` for the value
238253
* of any element in `x`
239254
*/
255+
#ifdef USE_STANC3
240256
template <typename F, typename T, typename... Indexings,
241257
require_eigen_t<T>* = nullptr,
242258
std::enable_if_t<
243259
!(Eigen::internal::traits<T>::Flags & Eigen::LinearAccessBit)
244260
&& !T::IsVectorAtCompileTime
245261
&& static_cast<bool>(Eigen::internal::traits<T>::Flags
246262
& Eigen::RowMajorBit)>* = nullptr>
263+
#else
264+
template <
265+
typename F, typename T, typename... Indexings,
266+
require_eigen_t<T>* = nullptr,
267+
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
268+
& (Eigen::LinearAccessBit | Eigen::DirectAccessBit))
269+
&& static_cast<bool>(Eigen::internal::traits<T>::Flags
270+
& Eigen::RowMajorBit)>* = nullptr>
271+
#endif
247272
inline void elementwise_check(const F& is_good, const char* function,
248273
const char* name, const T& x, const char* must_be,
249274
const Indexings&... indexings) {

stan/math/prim/fun/add.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ template <typename Mat1, typename Mat2,
4242
require_all_not_st_var<Mat1, Mat2>* = nullptr>
4343
inline auto add(const Mat1& m1, const Mat2& m2) {
4444
check_matching_dims("add", "m1", m1, "m2", m2);
45+
#ifdef USE_STANC3
4546
return m1 + m2;
47+
#else
48+
return (m1 + m2).eval();
49+
#endif
4650
}
4751

4852
/**
@@ -74,7 +78,11 @@ template <typename Scal, typename Mat, require_stan_scalar_t<Scal>* = nullptr,
7478
require_eigen_t<Mat>* = nullptr,
7579
require_all_not_st_var<Scal, Mat>* = nullptr>
7680
inline auto add(const Scal c, const Mat& m) {
81+
#ifdef USE_STANC3
7782
return (c + m.array()).matrix();
83+
#else
84+
return (c + m.array()).matrix().eval();
85+
#endif
7886
}
7987

8088
} // namespace math

stan/math/prim/fun/assign.hpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ inline void print_mat_size(std::ostream& o) {
4141
* @param x Left-hand side.
4242
* @param y Right-hand side.
4343
*/
44+
#ifdef USE_STANC3
4445
template <typename T_lhs, typename T_rhs,
4546
require_all_stan_scalar_t<T_lhs, T_rhs>* = nullptr>
47+
#else
48+
template <typename T_lhs, typename T_rhs>
49+
#endif
4650
inline void assign(T_lhs& x, const T_rhs& y) {
4751
x = y;
4852
}
@@ -62,12 +66,59 @@ inline void assign(T_lhs& x, const T_rhs& y) {
6266
* @param y Right-hand side matrix.
6367
* @throw std::invalid_argument if sizes do not match.
6468
*/
69+
#ifdef USE_STANC3
6570
template <typename T_lhs, typename T_rhs,
6671
require_all_eigen_t<T_lhs, T_rhs>* = nullptr>
6772
inline void assign(T_lhs&& x, const T_rhs& y) {
6873
check_matching_dims("assign", "left-hand-side", x, "right-hand-side", y);
6974
x = y.template cast<value_type_t<T_lhs>>();
7075
}
76+
#else
77+
template <typename T_lhs, typename T_rhs, int R, int C>
78+
inline void assign(Eigen::Matrix<T_lhs, R, C>& x,
79+
const Eigen::Matrix<T_rhs, R, C>& y) {
80+
check_matching_dims("assign", "left-hand-side", x, "right-hand-side", y);
81+
for (int i = 0; i < x.size(); ++i) {
82+
assign(x(i), y(i));
83+
}
84+
}
85+
86+
/**
87+
* Copy the right-hand side's value to the left-hand side
88+
* variable.
89+
*
90+
* <p>The <code>assign()</code> function is overloaded. This
91+
* instance will be called for arguments that are both
92+
* <code>Eigen::Matrix</code> types and whose shapes match. The
93+
* shape of the right-hand side matrix is specified in the row and
94+
* column shape template parameters.
95+
*
96+
* <p>The left-hand side is intentionally not a reference, because
97+
* that won't match generally enough; instead, a non-reference is
98+
* used, which still holds onto a reference to the contained
99+
* matrix and thus still updates the appropriate values.
100+
*
101+
* @tparam T_lhs Type of matrix block elements.
102+
* @tparam T Type of right-hand side matrix elements.
103+
* @tparam R Row shape for right-hand side matrix.
104+
* @tparam C Column shape for right-hand side matrix.
105+
* @param x Left-hand side block view of matrix.
106+
* @param y Right-hand side matrix.
107+
* @throw std::invalid_argument if sizes do not match.
108+
*/
109+
template <typename T_lhs, typename T, int R, int C>
110+
inline void assign(Eigen::Block<T_lhs> x, const Eigen::Matrix<T, R, C>& y) {
111+
check_size_match("assign", "left-hand side rows", x.rows(),
112+
"right-hand side rows", y.rows());
113+
check_size_match("assign", "left-hand side cols", x.cols(),
114+
"right-hand side cols", y.cols());
115+
for (int n = 0; n < y.cols(); ++n) {
116+
for (int m = 0; m < y.rows(); ++m) {
117+
assign(x(m, n), y(m, n));
118+
}
119+
}
120+
}
121+
#endif
71122

72123
/**
73124
* Copy the right-hand side's value to the left-hand side

stan/math/prim/fun/col.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ namespace math {
2323
template <typename T, typename = require_eigen_t<T>>
2424
inline auto col(const T& m, size_t j) {
2525
check_column_index("col", "j", m, j);
26+
#ifdef USE_STANC3
2627
return m.col(j - 1);
28+
#else
29+
return m.col(j - 1).eval();
30+
#endif
2731
}
2832

2933
} // namespace math

stan/math/prim/fun/corr_matrix_free.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,29 @@ Eigen::Matrix<value_type_t<T>, Eigen::Dynamic, 1> corr_matrix_free(const T& y) {
3939
check_nonzero_size("corr_matrix_free", "y", y);
4040

4141
Eigen::Index k = y.rows();
42+
#ifdef USE_STANC3
4243
Eigen::Index k_choose_2 = (k * (k - 1)) / 2;
44+
#else
45+
Array<value_type_t<T>, Dynamic, 1> x(k_choose_2);
46+
#endif
4347
Eigen::Matrix<value_type_t<T>, Dynamic, 1> x(k_choose_2);
4448
Array<value_type_t<T>, Dynamic, 1> sds(k);
49+
#ifdef USE_STANC3
4550
bool successful = factor_cov_matrix(y, x.array(), sds);
51+
#else
52+
bool successful = factor_cov_matrix(y, x, sds);
53+
#endif
4654
if (!successful) {
4755
throw_domain_error("corr_matrix_free", "factor_cov_matrix failed on y", y,
4856
"");
4957
}
5058
check_bounded("corr_matrix_free", "log(sd)", sds, -CONSTRAINT_TOLERANCE,
5159
CONSTRAINT_TOLERANCE);
60+
#ifdef USE_STANC3
5261
return x;
62+
#else
63+
return x.matrix();
64+
#endif
5365
}
5466

5567
} // namespace math

stan/math/prim/fun/diag_post_multiply.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ template <typename T1, typename T2, require_eigen_t<T1>* = nullptr,
2424
auto diag_post_multiply(const T1& m1, const T2& m2) {
2525
check_size_match("diag_post_multiply", "m2.size()", m2.size(), "m1.cols()",
2626
m1.cols());
27+
#ifdef USE_STANC3
2728
return m1 * m2.asDiagonal();
29+
#else
30+
return (m1 * m2.asDiagonal()).eval();
31+
#endif
2832
}
2933

3034
} // namespace math

stan/math/prim/fun/diag_pre_multiply.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ template <typename T1, typename T2, require_eigen_vector_t<T1>* = nullptr,
2424
auto diag_pre_multiply(const T1& m1, const T2& m2) {
2525
check_size_match("diag_pre_multiply", "m1.size()", m1.size(), "m2.rows()",
2626
m2.rows());
27+
#ifdef USE_STANC3
2728
return m1.asDiagonal() * m2;
29+
#else
30+
return (m1.asDiagonal() * m2).eval();
31+
#endif
2832
}
2933

3034
} // namespace math

stan/math/prim/fun/diagonal.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ namespace math {
1717
*/
1818
template <typename T, typename = require_eigen_t<T>>
1919
inline auto diagonal(const T& m) {
20+
#ifdef USE_STANC3
2021
return m.diagonal();
22+
#else
23+
return m.diagonal().eval();
24+
#endif
2125
}
2226

2327
} // namespace math

stan/math/prim/fun/divide.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ namespace math {
1919
* @return Scalar divided by the scalar.
2020
*/
2121
template <typename Scal1, typename Scal2,
22+
#ifdef USE_STANC3
2223
require_all_stan_scalar_t<Scal1, Scal2>* = nullptr>
24+
#else
25+
typename = require_all_stan_scalar_t<Scal1, Scal2>>
26+
#endif
2327
inline return_type_t<Scal1, Scal2> divide(const Scal1& x, const Scal2& y) {
2428
return x / y;
2529
}
@@ -41,10 +45,19 @@ inline int divide(int x, int y) {
4145
* @return matrix divided by the scalar
4246
*/
4347
template <typename Mat, typename Scal, typename = require_eigen_t<Mat>,
48+
#ifdef USE_STANC3
4449
require_stan_scalar_t<Scal>* = nullptr,
4550
require_all_not_var_t<scalar_type_t<Mat>, Scal>* = nullptr>
51+
#else
52+
typename = require_stan_scalar_t<Scal>,
53+
typename = require_all_not_var_t<scalar_type_t<Mat>, Scal>>
54+
#endif
4655
inline auto divide(const Mat& m, Scal c) {
56+
#ifdef USE_STANC3
4757
return m / c;
58+
#else
59+
return (m / c).eval();
60+
#endif
4861
}
4962

5063
} // namespace math

stan/math/prim/fun/elt_divide.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ template <typename Mat1, typename Mat2,
2424
require_all_not_st_var<Mat1, Mat2>* = nullptr>
2525
auto elt_divide(const Mat1& m1, const Mat2& m2) {
2626
check_matching_dims("elt_divide", "m1", m1, "m2", m2);
27+
#ifdef USE_STANC3
2728
return (m1.array() / m2.array()).matrix();
29+
#else
30+
return (m1.array() / m2.array()).matrix().eval();
31+
#endif
2832
}
2933

3034
/**
@@ -38,11 +42,19 @@ auto elt_divide(const Mat1& m1, const Mat2& m2) {
3842
* @param s scalar
3943
* @return Elementwise division of a scalar by matrix.
4044
*/
45+
#ifdef USE_STANC3
4146
template <typename Mat, typename Scal, require_matrix_t<Mat>* = nullptr,
4247
require_stan_scalar_t<Scal>* = nullptr>
4348
auto elt_divide(const Mat& m, Scal s) {
4449
return divide(m, s);
4550
}
51+
#else
52+
template <typename Scal, typename Mat, typename = require_stan_scalar_t<Scal>,
53+
typename = require_eigen_t<Mat>>
54+
auto elt_divide(Scal s, const Mat& m) {
55+
return (s / m.array()).matrix().eval();
56+
}
57+
#endif
4658

4759
/**
4860
* Return the elementwise division of the specified scalar

stan/math/prim/fun/elt_multiply.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ template <typename Mat1, typename Mat2,
2424
require_all_not_st_var<Mat1, Mat2>* = nullptr>
2525
auto elt_multiply(const Mat1& m1, const Mat2& m2) {
2626
check_matching_dims("elt_multiply", "m1", m1, "m2", m2);
27+
#ifdef USE_STANC3
2728
return m1.cwiseProduct(m2);
29+
#else
30+
return m1.cwiseProduct(m2).eval();
31+
#endif
2832
}
2933

3034
/**

stan/math/prim/fun/head.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ inline auto head(const T& v, size_t n) {
2323
if (n != 0) {
2424
check_vector_index("head", "n", v, n);
2525
}
26+
#ifdef USE_STANC3
2627
return v.head(n);
28+
#else
29+
return v.head(n).eval();
30+
#endif
2731
}
2832

2933
/**

stan/math/prim/fun/lb_constrain.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,13 @@ inline return_type_t<T, L> lb_constrain(const T& x, const L& lb) {
5555
* @param[in,out] lp reference to log probability to increment
5656
* @return lower-bound constrained value corresponding to inputs
5757
*/
58+
#ifdef USE_STANC3
5859
template <typename T, typename L, typename S>
5960
inline return_type_t<T, L> lb_constrain(const T& x, const L& lb, S& lp) {
61+
#else
62+
template <typename T, typename L>
63+
inline return_type_t<T, L> lb_constrain(const T& x, const L& lb, T& lp) {
64+
#endif
6065
using std::exp;
6166
if (lb == NEGATIVE_INFTY) {
6267
return identity_constrain(x, lp);

stan/math/prim/fun/log_softmax.hpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/Eigen.hpp>
77
#include <stan/math/prim/fun/log_sum_exp.hpp>
8+
#ifdef USE_STANC3
89
#include <stan/math/prim/fun/to_ref.hpp>
10+
#endif
911
#include <stan/math/prim/functor/apply_vector_unary.hpp>
1012

1113
namespace stan {
@@ -39,6 +41,7 @@ namespace math {
3941
* @param[in] x vector to transform
4042
* @return log unit simplex result of the softmax transform of the vector.
4143
*/
44+
#ifdef USE_STANC3
4245
template <typename Container, require_st_arithmetic<Container>* = nullptr>
4346
inline auto log_softmax(const Container& x) {
4447
check_nonzero_size("log_softmax", "v", x);
@@ -49,7 +52,21 @@ inline auto log_softmax(const Container& x) {
4952
},
5053
to_ref(x));
5154
}
52-
55+
#else
56+
/**
57+
* Note: The return must be evaluated otherwise the Ref object falls out
58+
* of scope
59+
*/
60+
template <typename Container, require_st_arithmetic<Container>* = nullptr>
61+
inline auto log_softmax(const Container& x) {
62+
check_nonzero_size("log_softmax", "v", x);
63+
return apply_vector_unary<Container>::apply(x, [](const auto& v) {
64+
const Eigen::Ref<const plain_type_t<decltype(v)>>& v_ref = v;
65+
check_nonzero_size("log_softmax", "v", v_ref);
66+
return (v_ref.array() - log_sum_exp(v_ref)).eval();
67+
});
68+
}
69+
#endif
5370
} // namespace math
5471
} // namespace stan
5572
#endif

0 commit comments

Comments
 (0)