Skip to content

Commit 6b2fd28

Browse files
authored
Merge pull request #2190 from bstatcomp/return_expressions2
Let prim functions return expressions
2 parents 0d706ef + 8dfd5b0 commit 6b2fd28

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+303
-189
lines changed

stan/math/fwd.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
#include <stan/math/fwd/core.hpp>
77
#include <stan/math/fwd/meta.hpp>
8-
#include <stan/math/prim.hpp>
9-
108
#include <stan/math/fwd/fun.hpp>
119
#include <stan/math/fwd/functor.hpp>
1210

11+
#include <stan/math/prim.hpp>
12+
1313
#endif

stan/math/fwd/fun/sum.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#ifndef STAN_MATH_FWD_FUN_SUM_HPP
22
#define STAN_MATH_FWD_FUN_SUM_HPP
33

4-
#include <stan/math/fwd/core.hpp>
4+
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
66
#include <stan/math/prim/fun/sum.hpp>
7+
#include <stan/math/fwd/core.hpp>
78
#include <vector>
89

910
namespace stan {

stan/math/mix.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22
#define STAN_MATH_MIX_HPP
33

44
#include <stan/math/mix/meta.hpp>
5-
6-
#include <stan/math/prim.hpp>
7-
#include <stan/math/fwd.hpp>
8-
#include <stan/math/rev.hpp>
95
#include <stan/math/mix/fun.hpp>
106
#include <stan/math/mix/functor.hpp>
117

8+
#include <stan/math/rev/core.hpp>
9+
#include <stan/math/rev/meta.hpp>
10+
#include <stan/math/rev/fun.hpp>
11+
#include <stan/math/rev/functor.hpp>
12+
13+
#include <stan/math/fwd/core.hpp>
14+
#include <stan/math/fwd/meta.hpp>
15+
#include <stan/math/fwd/fun.hpp>
16+
#include <stan/math/fwd/functor.hpp>
17+
18+
#include <stan/math/prim.hpp>
19+
1220
#endif

stan/math/prim/err/elementwise_check.hpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,9 @@ inline void elementwise_check(const F& is_good, const char* function,
142142
*/
143143
template <typename F, typename T, typename... Indexings,
144144
require_eigen_t<T>* = nullptr,
145-
std::enable_if_t<static_cast<bool>(Eigen::internal::traits<T>::Flags&(
146-
Eigen::LinearAccessBit | Eigen::DirectAccessBit))>* = nullptr>
145+
std::enable_if_t<(Eigen::internal::traits<T>::Flags
146+
& Eigen::LinearAccessBit)
147+
|| T::IsVectorAtCompileTime>* = nullptr>
147148
inline void elementwise_check(const F& is_good, const char* function,
148149
const char* name, const T& x, const char* must_be,
149150
const Indexings&... indexings) {
@@ -189,13 +190,13 @@ inline void elementwise_check(const F& is_good, const char* function,
189190
* @throws `std::domain_error` if `is_good` returns `false` for the value
190191
* of any element in `x`
191192
*/
192-
template <
193-
typename F, typename T, typename... Indexings,
194-
require_eigen_t<T>* = nullptr,
195-
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
196-
& (Eigen::LinearAccessBit | Eigen::DirectAccessBit))
197-
&& !(Eigen::internal::traits<T>::Flags
198-
& Eigen::RowMajorBit)>* = nullptr>
193+
template <typename F, typename T, typename... Indexings,
194+
require_eigen_t<T>* = nullptr,
195+
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
196+
& Eigen::LinearAccessBit)
197+
&& !T::IsVectorAtCompileTime
198+
&& !(Eigen::internal::traits<T>::Flags
199+
& Eigen::RowMajorBit)>* = nullptr>
199200
inline void elementwise_check(const F& is_good, const char* function,
200201
const char* name, const T& x, const char* must_be,
201202
const Indexings&... indexings) {
@@ -230,13 +231,13 @@ inline void elementwise_check(const F& is_good, const char* function,
230231
* @throws `std::domain_error` if `is_good` returns `false` for the value
231232
* of any element in `x`
232233
*/
233-
template <
234-
typename F, typename T, typename... Indexings,
235-
require_eigen_t<T>* = nullptr,
236-
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
237-
& (Eigen::LinearAccessBit | Eigen::DirectAccessBit))
238-
&& static_cast<bool>(Eigen::internal::traits<T>::Flags
239-
& Eigen::RowMajorBit)>* = nullptr>
234+
template <typename F, typename T, typename... Indexings,
235+
require_eigen_t<T>* = nullptr,
236+
std::enable_if_t<
237+
!(Eigen::internal::traits<T>::Flags & Eigen::LinearAccessBit)
238+
&& !T::IsVectorAtCompileTime
239+
&& static_cast<bool>(Eigen::internal::traits<T>::Flags
240+
& Eigen::RowMajorBit)>* = nullptr>
240241
inline void elementwise_check(const F& is_good, const char* function,
241242
const char* name, const T& x, const char* must_be,
242243
const Indexings&... indexings) {

stan/math/prim/fun/add.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ template <typename Mat1, typename Mat2,
4242
require_all_not_st_var<Mat1, Mat2>* = nullptr>
4343
inline auto add(const Mat1& m1, const Mat2& m2) {
4444
check_matching_dims("add", "m1", m1, "m2", m2);
45-
return (m1 + m2).eval();
45+
return m1 + m2;
4646
}
4747

4848
/**
@@ -58,7 +58,7 @@ template <typename Mat, typename Scal, require_eigen_t<Mat>* = nullptr,
5858
require_stan_scalar_t<Scal>* = nullptr,
5959
require_all_not_st_var<Mat, Scal>* = nullptr>
6060
inline auto add(const Mat& m, const Scal c) {
61-
return (m.array() + c).matrix().eval();
61+
return (m.array() + c).matrix();
6262
}
6363

6464
/**
@@ -74,7 +74,7 @@ template <typename Scal, typename Mat, require_stan_scalar_t<Scal>* = nullptr,
7474
require_eigen_t<Mat>* = nullptr,
7575
require_all_not_st_var<Scal, Mat>* = nullptr>
7676
inline auto add(const Scal c, const Mat& m) {
77-
return (c + m.array()).matrix().eval();
77+
return (c + m.array()).matrix();
7878
}
7979

8080
} // namespace math

stan/math/prim/fun/block.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ namespace math {
1919
* @throw std::out_of_range if either index is out of range.
2020
*/
2121
template <typename T, require_eigen_t<T>* = nullptr>
22-
inline Eigen::Matrix<value_type_t<T>, Eigen::Dynamic, Eigen::Dynamic> block(
23-
const T& m, size_t i, size_t j, size_t nrows, size_t ncols) {
22+
inline auto block(const T& m, size_t i, size_t j, size_t nrows, size_t ncols) {
2423
check_row_index("block", "i", m, i);
2524
check_row_index("block", "i+nrows-1", m, i + nrows - 1);
2625
check_column_index("block", "j", m, j);

stan/math/prim/fun/col.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace math {
2323
template <typename T, typename = require_eigen_t<T>>
2424
inline auto col(const T& m, size_t j) {
2525
check_column_index("col", "j", m, j);
26-
return m.col(j - 1).eval();
26+
return m.col(j - 1);
2727
}
2828

2929
} // namespace math

stan/math/prim/fun/corr_matrix_free.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@ Eigen::Matrix<value_type_t<T>, Eigen::Dynamic, 1> corr_matrix_free(const T& y) {
4040

4141
Eigen::Index k = y.rows();
4242
Eigen::Index k_choose_2 = (k * (k - 1)) / 2;
43-
Array<value_type_t<T>, Dynamic, 1> x(k_choose_2);
43+
Eigen::Matrix<value_type_t<T>, Dynamic, 1> x(k_choose_2);
4444
Array<value_type_t<T>, Dynamic, 1> sds(k);
45-
bool successful = factor_cov_matrix(y, x, sds);
45+
bool successful = factor_cov_matrix(y, x.array(), sds);
4646
if (!successful) {
4747
throw_domain_error("corr_matrix_free", "factor_cov_matrix failed on y", y,
4848
"");
4949
}
5050
check_bounded("corr_matrix_free", "log(sd)", sds, -CONSTRAINT_TOLERANCE,
5151
CONSTRAINT_TOLERANCE);
52-
return x.matrix();
52+
return x;
5353
}
5454

5555
} // namespace math

stan/math/prim/fun/diag_post_multiply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ template <typename T1, typename T2, require_eigen_t<T1>* = nullptr,
1212
auto diag_post_multiply(const T1& m1, const T2& m2) {
1313
check_size_match("diag_post_multiply", "m2.size()", m2.size(), "m1.cols()",
1414
m1.cols());
15-
return (m1 * m2.asDiagonal()).eval();
15+
return m1 * m2.asDiagonal();
1616
}
1717

1818
} // namespace math

stan/math/prim/fun/diag_pre_multiply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ template <typename T1, typename T2, require_eigen_vector_t<T1>* = nullptr,
1212
auto diag_pre_multiply(const T1& m1, const T2& m2) {
1313
check_size_match("diag_pre_multiply", "m1.size()", m1.size(), "m2.rows()",
1414
m2.rows());
15-
return (m1.asDiagonal() * m2).eval();
15+
return m1.asDiagonal() * m2;
1616
}
1717

1818
} // namespace math

0 commit comments

Comments
 (0)