Skip to content

Let prim functions return expressions #2190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cfb04b8
first pass
t4c1 Nov 12, 2020
db03168
added apply_... stuff
t4c1 Nov 12, 2020
c1df65d
fixed forwarding references in unit_vector_constrain.hpp
t4c1 Nov 12, 2020
7eded17
Merge commit '1b3feb7acaaf7b8cba88e30608f48b218dcd5142' into HEAD
yashikno Nov 12, 2020
0f85ba4
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Nov 12, 2020
c92d5f7
fix cpplint
t4c1 Nov 12, 2020
26d745e
added a header back
t4c1 Nov 12, 2020
c08b000
bugfixed apply return types
t4c1 Nov 13, 2020
4b830fc
removed redundand forwarding and to_forwarding_ref
t4c1 Nov 13, 2020
3d13a2a
Merge commit '495dc91dbfbca3586cb16a3fcc44e79af48d8050' into HEAD
yashikno Nov 13, 2020
20eb3cd
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Nov 13, 2020
14388eb
kick jenkins awake
t4c1 Nov 16, 2020
eb7a24f
added missing header
t4c1 Nov 16, 2020
1765307
bugfixed apply_vector_unary and improved holder
t4c1 Nov 17, 2020
46016f4
bugfix binary_scalar_tester
t4c1 Nov 17, 2020
55bf3b2
Merge commit '34a8f8b2692d8a5c26af650d922042341b9a567d' into HEAD
yashikno Nov 17, 2020
873458c
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Nov 17, 2020
a27e08e
bugfix holder
t4c1 Nov 17, 2020
a0091a2
fix serializer
t4c1 Nov 18, 2020
c3dbb27
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Nov 18, 2020
fa88821
revert diag_matrix
t4c1 Nov 18, 2020
0b15e7b
Merge branch 'return_expressions2' of https://github.com/bstatcomp/ma…
t4c1 Nov 18, 2020
435a2b0
added include to poisson
t4c1 Nov 19, 2020
b2dbc29
fix includes
SteveBronder Nov 19, 2020
c7ab13d
more include order fixes
t4c1 Nov 20, 2020
f84d6b2
improved apply_vector_unary
t4c1 Nov 23, 2020
b4d6208
Merge branch 'develop' into return_expressions2
t4c1 Dec 1, 2020
62490e1
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 1, 2020
b316e1e
bugfix apply_vector_unary
t4c1 Dec 1, 2020
ccb1ad0
Merge branch 'return_expressions2' of https://github.com/bstatcomp/ma…
t4c1 Dec 1, 2020
2dbe8eb
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 1, 2020
4575f35
bugfix rev apply_vector_unary
t4c1 Dec 1, 2020
4524299
change how we detect if linear indexing is supported
t4c1 Dec 3, 2020
257b6fd
Merge commit '089db8e9b3ebc2ae65ef4321ebd0652205f31b82' into HEAD
yashikno Dec 3, 2020
a0523cc
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Dec 3, 2020
bd33bbe
addressed review comments
t4c1 Dec 7, 2020
8dfd5b0
make quad_form_sym use make_holder
t4c1 Dec 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stan/math/fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/meta.hpp>
#include <stan/math/prim.hpp>

#include <stan/math/fwd/fun.hpp>
#include <stan/math/fwd/functor.hpp>

#include <stan/math/prim.hpp>

#endif
3 changes: 2 additions & 1 deletion stan/math/fwd/fun/sum.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#ifndef STAN_MATH_FWD_FUN_SUM_HPP
#define STAN_MATH_FWD_FUN_SUM_HPP

#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/fwd/core.hpp>
#include <vector>

namespace stan {
Expand Down
16 changes: 12 additions & 4 deletions stan/math/mix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@
#define STAN_MATH_MIX_HPP

#include <stan/math/mix/meta.hpp>

#include <stan/math/prim.hpp>
#include <stan/math/fwd.hpp>
#include <stan/math/rev.hpp>
#include <stan/math/mix/fun.hpp>
#include <stan/math/mix/functor.hpp>

#include <stan/math/rev/core.hpp>
#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/fun.hpp>
#include <stan/math/rev/functor.hpp>

#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/fun.hpp>
#include <stan/math/fwd/functor.hpp>

#include <stan/math/prim.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, there must be more to the story. Do tell!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for rev.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why the individual files?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@syclik do you have any memories of why prim was originally included before rev and fwd? prim is the most general, so it makes sense to have it last here (see also this comment from this pull), but there could be something else I don't want to accidentally miss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rev.hpp for example also includes prim, which we dont want included before fwd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and I see the conversation back here: #2190 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I'm also more convinced by what I see the only way this is wrong is that mix is included before rev and fwd and mix might reasonably depend on those lol.

Thqat can not really be an issue. If mix needs fwd or rev it is simply included. Same if fwd or rev needs prim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember what exactly was an issue, but it was something, that gcc allows even if it is not standard, while clang is more strict. I know I was reading this: https://clang.llvm.org/compatibility.html#dep_lookup

Copy link
Member

@bbbales2 bbbales2 Dec 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this change is right.

If everything is template overloads, then we can define things after they are used. Like in this: https://godbolt.org/z/hbT54n

(you'll see both test::myfunc functions get called)

If you add an overload of myfunc, it will only get called if the overload is already available when the compiler hits whatever. For instance, you can add this code above and below whatever and see what happens:

int myfunc(const char *) {
    std::cout << "hey" << std::endl;
    return 5;
}

Since we're depending on overloading and templates in Stan math, we gotta make sure and include the overloads for the rev and fwd types before the more general prim functions or the stuff in prim might not be able to find everything.


#endif
33 changes: 17 additions & 16 deletions stan/math/prim/err/elementwise_check.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ inline void elementwise_check(const F& is_good, const char* function,
*/
template <typename F, typename T, typename... Indexings,
require_eigen_t<T>* = nullptr,
std::enable_if_t<static_cast<bool>(Eigen::internal::traits<T>::Flags&(
Eigen::LinearAccessBit | Eigen::DirectAccessBit))>* = nullptr>
std::enable_if_t<(Eigen::internal::traits<T>::Flags
& Eigen::LinearAccessBit)
|| T::IsVectorAtCompileTime>* = nullptr>
inline void elementwise_check(const F& is_good, const char* function,
const char* name, const T& x, const char* must_be,
const Indexings&... indexings) {
Expand Down Expand Up @@ -189,13 +190,13 @@ inline void elementwise_check(const F& is_good, const char* function,
* @throws `std::domain_error` if `is_good` returns `false` for the value
* of any element in `x`
*/
template <
typename F, typename T, typename... Indexings,
require_eigen_t<T>* = nullptr,
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
& (Eigen::LinearAccessBit | Eigen::DirectAccessBit))
&& !(Eigen::internal::traits<T>::Flags
& Eigen::RowMajorBit)>* = nullptr>
template <typename F, typename T, typename... Indexings,
require_eigen_t<T>* = nullptr,
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
& Eigen::LinearAccessBit)
&& !T::IsVectorAtCompileTime
&& !(Eigen::internal::traits<T>::Flags
& Eigen::RowMajorBit)>* = nullptr>
inline void elementwise_check(const F& is_good, const char* function,
const char* name, const T& x, const char* must_be,
const Indexings&... indexings) {
Expand Down Expand Up @@ -230,13 +231,13 @@ inline void elementwise_check(const F& is_good, const char* function,
* @throws `std::domain_error` if `is_good` returns `false` for the value
* of any element in `x`
*/
template <
typename F, typename T, typename... Indexings,
require_eigen_t<T>* = nullptr,
std::enable_if_t<!(Eigen::internal::traits<T>::Flags
& (Eigen::LinearAccessBit | Eigen::DirectAccessBit))
&& static_cast<bool>(Eigen::internal::traits<T>::Flags
& Eigen::RowMajorBit)>* = nullptr>
template <typename F, typename T, typename... Indexings,
require_eigen_t<T>* = nullptr,
std::enable_if_t<
!(Eigen::internal::traits<T>::Flags & Eigen::LinearAccessBit)
&& !T::IsVectorAtCompileTime
&& static_cast<bool>(Eigen::internal::traits<T>::Flags
& Eigen::RowMajorBit)>* = nullptr>
inline void elementwise_check(const F& is_good, const char* function,
const char* name, const T& x, const char* must_be,
const Indexings&... indexings) {
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ template <typename Mat1, typename Mat2,
require_all_not_st_var<Mat1, Mat2>* = nullptr>
inline auto add(const Mat1& m1, const Mat2& m2) {
check_matching_dims("add", "m1", m1, "m2", m2);
return (m1 + m2).eval();
return m1 + m2;
}

/**
Expand All @@ -58,7 +58,7 @@ template <typename Mat, typename Scal, require_eigen_t<Mat>* = nullptr,
require_stan_scalar_t<Scal>* = nullptr,
require_all_not_st_var<Mat, Scal>* = nullptr>
inline auto add(const Mat& m, const Scal c) {
return (m.array() + c).matrix().eval();
return (m.array() + c).matrix();
}

/**
Expand All @@ -74,7 +74,7 @@ template <typename Scal, typename Mat, require_stan_scalar_t<Scal>* = nullptr,
require_eigen_t<Mat>* = nullptr,
require_all_not_st_var<Scal, Mat>* = nullptr>
inline auto add(const Scal c, const Mat& m) {
return (c + m.array()).matrix().eval();
return (c + m.array()).matrix();
}

} // namespace math
Expand Down
3 changes: 1 addition & 2 deletions stan/math/prim/fun/block.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ namespace math {
* @throw std::out_of_range if either index is out of range.
*/
template <typename T, require_eigen_t<T>* = nullptr>
inline Eigen::Matrix<value_type_t<T>, Eigen::Dynamic, Eigen::Dynamic> block(
const T& m, size_t i, size_t j, size_t nrows, size_t ncols) {
inline auto block(const T& m, size_t i, size_t j, size_t nrows, size_t ncols) {
check_row_index("block", "i", m, i);
check_row_index("block", "i+nrows-1", m, i + nrows - 1);
check_column_index("block", "j", m, j);
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/col.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace math {
template <typename T, typename = require_eigen_t<T>>
inline auto col(const T& m, size_t j) {
check_column_index("col", "j", m, j);
return m.col(j - 1).eval();
return m.col(j - 1);
}

} // namespace math
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/corr_matrix_free.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ Eigen::Matrix<value_type_t<T>, Eigen::Dynamic, 1> corr_matrix_free(const T& y) {

Eigen::Index k = y.rows();
Eigen::Index k_choose_2 = (k * (k - 1)) / 2;
Array<value_type_t<T>, Dynamic, 1> x(k_choose_2);
Eigen::Matrix<value_type_t<T>, Dynamic, 1> x(k_choose_2);
Array<value_type_t<T>, Dynamic, 1> sds(k);
bool successful = factor_cov_matrix(y, x, sds);
bool successful = factor_cov_matrix(y, x.array(), sds);
if (!successful) {
throw_domain_error("corr_matrix_free", "factor_cov_matrix failed on y", y,
"");
}
check_bounded("corr_matrix_free", "log(sd)", sds, -CONSTRAINT_TOLERANCE,
CONSTRAINT_TOLERANCE);
return x.matrix();
return x;
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/diag_post_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ template <typename T1, typename T2, require_eigen_t<T1>* = nullptr,
auto diag_post_multiply(const T1& m1, const T2& m2) {
check_size_match("diag_post_multiply", "m2.size()", m2.size(), "m1.cols()",
m1.cols());
return (m1 * m2.asDiagonal()).eval();
return m1 * m2.asDiagonal();
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/diag_pre_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ template <typename T1, typename T2, require_eigen_vector_t<T1>* = nullptr,
auto diag_pre_multiply(const T1& m1, const T2& m2) {
check_size_match("diag_pre_multiply", "m1.size()", m1.size(), "m2.rows()",
m2.rows());
return (m1.asDiagonal() * m2).eval();
return m1.asDiagonal() * m2;
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/diagonal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace math {
*/
template <typename T, typename = require_eigen_t<T>>
inline auto diagonal(const T& m) {
return m.diagonal().eval();
return m.diagonal();
}

} // namespace math
Expand Down
8 changes: 4 additions & 4 deletions stan/math/prim/fun/divide.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace math {
* @return Scalar divided by the scalar.
*/
template <typename Scal1, typename Scal2,
typename = require_all_stan_scalar_t<Scal1, Scal2>>
require_all_stan_scalar_t<Scal1, Scal2>* = nullptr>
inline return_type_t<Scal1, Scal2> divide(const Scal1& x, const Scal2& y) {
return x / y;
}
Expand All @@ -41,10 +41,10 @@ inline int divide(int x, int y) {
* @return matrix divided by the scalar
*/
template <typename Mat, typename Scal, typename = require_eigen_t<Mat>,
typename = require_stan_scalar_t<Scal>,
typename = require_all_not_var_t<scalar_type_t<Mat>, Scal>>
require_stan_scalar_t<Scal>* = nullptr,
require_all_not_var_t<scalar_type_t<Mat>, Scal>* = nullptr>
inline auto divide(const Mat& m, Scal c) {
return (m / c).eval();
return m / c;
}

} // namespace math
Expand Down
12 changes: 6 additions & 6 deletions stan/math/prim/fun/elt_divide.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <typename Mat1, typename Mat2,
require_all_not_st_var<Mat1, Mat2>* = nullptr>
auto elt_divide(const Mat1& m1, const Mat2& m2) {
check_matching_dims("elt_divide", "m1", m1, "m2", m2);
return (m1.array() / m2.array()).matrix().eval();
return (m1.array() / m2.array()).matrix();
}

/**
Expand All @@ -38,8 +38,8 @@ auto elt_divide(const Mat1& m1, const Mat2& m2) {
* @param s scalar
* @return Elementwise division of a scalar by matrix.
*/
template <typename Mat, typename Scal, typename = require_matrix_t<Mat>,
typename = require_stan_scalar_t<Scal>>
template <typename Mat, typename Scal, require_matrix_t<Mat>* = nullptr,
require_stan_scalar_t<Scal>* = nullptr>
auto elt_divide(const Mat& m, Scal s) {
return divide(m, s);
}
Expand All @@ -55,10 +55,10 @@ auto elt_divide(const Mat& m, Scal s) {
* @param m matrix or expression
* @return Elementwise division of a scalar by matrix.
*/
template <typename Scal, typename Mat, typename = require_stan_scalar_t<Scal>,
typename = require_eigen_t<Mat>>
template <typename Scal, typename Mat, require_stan_scalar_t<Scal>* = nullptr,
require_eigen_t<Mat>* = nullptr>
auto elt_divide(Scal s, const Mat& m) {
return (s / m.array()).matrix().eval();
return (s / m.array()).matrix();
}

template <typename Scal1, typename Scal2,
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/elt_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <typename Mat1, typename Mat2,
require_all_not_st_var<Mat1, Mat2>* = nullptr>
auto elt_multiply(const Mat1& m1, const Mat2& m2) {
check_matching_dims("elt_multiply", "m1", m1, "m2", m2);
return m1.cwiseProduct(m2).eval();
return m1.cwiseProduct(m2);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/head.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ inline auto head(const T& v, size_t n) {
if (n != 0) {
check_vector_index("head", "n", v, n);
}
return v.head(n).eval();
return v.head(n);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/identity_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace math {
* @return An identity matrix of size K.
* @throw std::domain_error if K is negative.
*/
inline Eigen::MatrixXd identity_matrix(int K) {
inline auto identity_matrix(int K) {
check_nonnegative("identity_matrix", "size", K);
return Eigen::MatrixXd::Identity(K, K);
}
Expand Down
2 changes: 0 additions & 2 deletions stan/math/prim/fun/inv_square.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/square.hpp>
#include <stan/math/prim/fun/inv_square.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>

namespace stan {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/linspaced_row_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace math {
* @throw std::domain_error if K is negative, if low is nan or infinite,
* if high is nan or infinite, or if high is less than low.
*/
inline Eigen::RowVectorXd linspaced_row_vector(int K, double low, double high) {
inline auto linspaced_row_vector(int K, double low, double high) {
static const char* function = "linspaced_row_vector";
check_nonnegative(function, "size", K);
check_finite(function, "low", low);
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/linspaced_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace math {
* @throw std::domain_error if K is negative, if low is nan or infinite,
* if high is nan or infinite, or if high is less than low.
*/
inline Eigen::VectorXd linspaced_vector(int K, double low, double high) {
inline auto linspaced_vector(int K, double low, double high) {
static const char* function = "linspaced_vector";
check_nonnegative(function, "size", K);
check_finite(function, "low", low);
Expand Down
1 change: 0 additions & 1 deletion stan/math/prim/fun/log1p.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/is_nan.hpp>
#include <stan/math/prim/fun/log1p.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <cmath>

Expand Down
19 changes: 9 additions & 10 deletions stan/math/prim/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>

namespace stan {
Expand Down Expand Up @@ -37,18 +38,16 @@ namespace math {
* @tparam Container type of input vector to transform
* @param[in] x vector to transform
* @return log unit simplex result of the softmax transform of the vector.
*
* Note: The return must be evaluated otherwise the Ref object falls out
* of scope
*/
template <typename Container,
require_arithmetic_t<scalar_type_t<Container>>* = nullptr>
template <typename Container, require_st_arithmetic<Container>* = nullptr>
inline auto log_softmax(const Container& x) {
return apply_vector_unary<Container>::apply(x, [](const auto& v) {
const Eigen::Ref<const plain_type_t<decltype(v)>>& v_ref = v;
check_nonzero_size("log_softmax", "v", v_ref);
return (v_ref.array() - log_sum_exp(v_ref)).eval();
});
check_nonzero_size("log_softmax", "v", x);
return make_holder(
[](const auto& a) {
return apply_vector_unary<ref_type_t<Container>>::apply(
a, [](const auto& v) { return v.array() - log_sum_exp(v); });
},
to_ref(x));
}

} // namespace math
Expand Down
12 changes: 8 additions & 4 deletions stan/math/prim/fun/logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
#include <cmath>
Expand Down Expand Up @@ -102,10 +103,13 @@ inline auto logit(const Container& x) {
template <typename Container,
require_container_st<std::is_arithmetic, Container>* = nullptr>
inline auto logit(const Container& x) {
return apply_vector_unary<Container>::apply(x, [](const auto& v) {
const Eigen::Ref<const plain_type_t<decltype(v)>>& v_ref = v;
return (v_ref.array() / (1 - v_ref.array())).log().eval();
});
return make_holder(
[](const auto& v_ref) {
return apply_vector_unary<ref_type_t<Container>>::apply(
v_ref,
[](const auto& v) { return (v.array() / (1 - v.array())).log(); });
},
to_ref(x));
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/minus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace math {
* @return Negation of subtrahend.
*/
template <typename T>
inline plain_type_t<T> minus(const T& x) {
inline auto minus(const T& x) {
return -x;
}

Expand Down
Loading