Skip to content

Commit 56f35eb

Browse files
authored
Merge pull request #2147 from stan-dev/revert-2007-feature/elementwise_checks_revert_revert
Revert "Feature/elementwise checks revert revert"
2 parents fb6ae11 + c8340b9 commit 56f35eb

32 files changed

+344
-687
lines changed

stan/math/prim/err.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <stan/math/prim/err/check_corr_matrix.hpp>
1414
#include <stan/math/prim/err/check_cov_matrix.hpp>
1515
#include <stan/math/prim/err/check_finite.hpp>
16-
#include <stan/math/prim/err/check_finite_screen.hpp>
1716
#include <stan/math/prim/err/check_flag_sundials.hpp>
1817
#include <stan/math/prim/err/check_greater.hpp>
1918
#include <stan/math/prim/err/check_greater_or_equal.hpp>
@@ -27,7 +26,6 @@
2726
#include <stan/math/prim/err/check_nonnegative.hpp>
2827
#include <stan/math/prim/err/check_nonzero_size.hpp>
2928
#include <stan/math/prim/err/check_not_nan.hpp>
30-
#include <stan/math/prim/err/check_not_nan_screen.hpp>
3129
#include <stan/math/prim/err/check_ordered.hpp>
3230
#include <stan/math/prim/err/check_sorted.hpp>
3331
#include <stan/math/prim/err/check_pos_definite.hpp>

stan/math/prim/err/check_finite.hpp

Lines changed: 126 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,143 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
33

4-
#include <stan/math/prim/err/elementwise_check.hpp>
5-
#include <stan/math/prim/err/check_finite_screen.hpp>
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err/is_scal_finite.hpp>
6+
#include <stan/math/prim/err/throw_domain_error.hpp>
7+
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
8+
#include <stan/math/prim/fun/Eigen.hpp>
9+
#include <stan/math/prim/fun/get.hpp>
10+
#include <stan/math/prim/fun/size.hpp>
11+
#include <stan/math/prim/fun/value_of.hpp>
12+
#include <stan/math/prim/fun/value_of_rec.hpp>
13+
#include <cmath>
614

715
namespace stan {
816
namespace math {
17+
namespace internal {
18+
/**
19+
* Return true if y is finite
20+
*
21+
* @tparam T_y type of y
22+
* @param y parameter to check
23+
* @return boolean
24+
*/
25+
template <typename T_y>
26+
bool is_finite(const T_y& y) {
27+
return is_scal_finite(y);
28+
}
29+
30+
/**
31+
* Return true if every element of the matrix y is finite
32+
*
33+
* @tparam T_y type of elements y
34+
* @param y matrix to check
35+
* @return boolean
36+
*/
37+
template <typename T_y, int R, int C>
38+
bool is_finite(const Eigen::Matrix<T_y, R, C>& y) {
39+
bool all = true;
40+
for (size_t n = 0; n < y.size(); ++n) {
41+
all &= is_finite(y(n));
42+
}
43+
return all;
44+
}
45+
46+
/**
47+
* Return true if every element of the vector y is finite
48+
*
49+
* @tparam T_y type of elements y
50+
* @param y vector to check
51+
* @return boolean
52+
*/
53+
template <typename T_y>
54+
bool is_finite(const std::vector<T_y>& y) {
55+
bool all = true;
56+
for (size_t n = 0; n < stan::math::size(y); ++n) {
57+
all &= is_finite(y[n]);
58+
}
59+
return all;
60+
}
61+
} // namespace internal
962

1063
/**
1164
* Check if <code>y</code> is finite.
1265
* This function is vectorized and will check each element of
1366
* <code>y</code>.
14-
* @tparam T_y type of y
15-
* @param function function name (for error messages)
16-
* @param name variable name (for error messages)
17-
* @param y variable to check
67+
* @tparam T_y Type of y
68+
* @param function Function name (for error messages)
69+
* @param name Variable name (for error messages)
70+
* @param y Variable to check
1871
* @throw <code>domain_error</code> if y is infinity, -infinity, or NaN
1972
*/
20-
template <typename T_y>
73+
template <typename T_y, require_stan_scalar_t<T_y>* = nullptr>
2174
inline void check_finite(const char* function, const char* name, const T_y& y) {
22-
if (check_finite_screen(y)) {
23-
auto is_good = [](const auto& y) { return std::isfinite(y); };
24-
elementwise_check(is_good, function, name, y, ", but must be finite!");
75+
if (!internal::is_finite(y)) {
76+
throw_domain_error(function, name, y, "is ", ", but must be finite!");
77+
}
78+
}
79+
80+
/**
81+
* Return <code>true</code> if all values in the std::vector are finite.
82+
*
83+
* @tparam T_y type of elements in the std::vector
84+
*
85+
* @param function name of function (for error messages)
86+
* @param name variable name (for error messages)
87+
* @param y std::vector to test
88+
* @return <code>true</code> if all values are finite
89+
**/
90+
template <typename T_y, require_stan_scalar_t<T_y>* = nullptr>
91+
inline void check_finite(const char* function, const char* name,
92+
const std::vector<T_y>& y) {
93+
for (size_t n = 0; n < stan::math::size(y); n++) {
94+
if (!internal::is_finite(stan::get(y, n))) {
95+
throw_domain_error_vec(function, name, y, n, "is ",
96+
", but must be finite!");
97+
}
98+
}
99+
}
100+
101+
/**
102+
* Return <code>true</code> is the specified matrix is finite.
103+
*
104+
* @tparam Derived Eigen derived type
105+
*
106+
* @param function name of function (for error messages)
107+
* @param name variable name (for error messages)
108+
* @param y matrix to test
109+
* @return <code>true</code> if the matrix is finite
110+
**/
111+
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
112+
inline void check_finite(const char* function, const char* name,
113+
const EigMat& y) {
114+
if (!value_of(y).allFinite()) {
115+
for (int n = 0; n < y.size(); ++n) {
116+
if (!std::isfinite(value_of_rec(y(n)))) {
117+
throw_domain_error_vec(function, name, y, n, "is ",
118+
", but must be finite!");
119+
}
120+
}
121+
}
122+
}
123+
124+
/**
125+
* Return <code>true</code> if all values in the std::vector are finite.
126+
*
127+
* @tparam T_y type of elements in the std::vector
128+
*
129+
* @param function name of function (for error messages)
130+
* @param name variable name (for error messages)
131+
* @param y std::vector to test
132+
* @return <code>true</code> if all values are finite
133+
**/
134+
template <typename T_y, require_not_stan_scalar_t<T_y>* = nullptr>
135+
inline void check_finite(const char* function, const char* name,
136+
const std::vector<T_y>& y) {
137+
for (size_t n = 0; n < stan::math::size(y); n++) {
138+
if (!internal::is_finite(stan::get(y, n))) {
139+
throw_domain_error(function, name, "", "", "is not finite!");
140+
}
25141
}
26142
}
27143

stan/math/prim/err/check_finite_screen.hpp

Lines changed: 0 additions & 43 deletions
This file was deleted.

stan/math/prim/err/check_nonnegative.hpp

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,42 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP
33

4-
#include <stan/math/prim/err/elementwise_check.hpp>
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err/throw_domain_error.hpp>
6+
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
7+
#include <stan/math/prim/fun/get.hpp>
8+
#include <stan/math/prim/fun/size.hpp>
9+
#include <type_traits>
510

611
namespace stan {
712
namespace math {
813

14+
namespace internal {
15+
template <typename T_y, bool is_vec>
16+
struct nonnegative {
17+
static void check(const char* function, const char* name, const T_y& y) {
18+
// have to use not is_unsigned. is_signed will be false for
19+
// floating point types that have no unsigned versions.
20+
if (!std::is_unsigned<T_y>::value && !(y >= 0)) {
21+
throw_domain_error(function, name, y, "is ", ", but must be >= 0!");
22+
}
23+
}
24+
};
25+
26+
template <typename T_y>
27+
struct nonnegative<T_y, true> {
28+
static void check(const char* function, const char* name, const T_y& y) {
29+
for (size_t n = 0; n < stan::math::size(y); n++) {
30+
if (!std::is_unsigned<typename value_type<T_y>::type>::value
31+
&& !(stan::get(y, n) >= 0)) {
32+
throw_domain_error_vec(function, name, y, n, "is ",
33+
", but must be >= 0!");
34+
}
35+
}
36+
}
37+
};
38+
} // namespace internal
39+
940
/**
1041
* Check if <code>y</code> is non-negative.
1142
* This function is vectorized and will check each element of <code>y</code>.
@@ -19,10 +50,9 @@ namespace math {
1950
template <typename T_y>
2051
inline void check_nonnegative(const char* function, const char* name,
2152
const T_y& y) {
22-
auto is_good = [](const auto& y) { return y >= 0; };
23-
elementwise_check(is_good, function, name, y, ", but must be >= 0!");
53+
internal::nonnegative<T_y, is_vector_like<T_y>::value>::check(function, name,
54+
y);
2455
}
25-
2656
} // namespace math
2757
} // namespace stan
2858
#endif

stan/math/prim/err/check_not_nan.hpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,40 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP
33

4-
#include <stan/math/prim/err/elementwise_check.hpp>
5-
#include <stan/math/prim/err/check_not_nan_screen.hpp>
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err/throw_domain_error.hpp>
6+
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
7+
#include <stan/math/prim/fun/get.hpp>
8+
#include <stan/math/prim/fun/is_nan.hpp>
9+
#include <stan/math/prim/fun/size.hpp>
10+
#include <stan/math/prim/fun/value_of_rec.hpp>
611

712
namespace stan {
813
namespace math {
914

15+
namespace internal {
16+
template <typename T_y, bool is_vec>
17+
struct not_nan {
18+
static void check(const char* function, const char* name, const T_y& y) {
19+
if (is_nan(value_of_rec(y))) {
20+
throw_domain_error(function, name, y, "is ", ", but must not be nan!");
21+
}
22+
}
23+
};
24+
25+
template <typename T_y>
26+
struct not_nan<T_y, true> {
27+
static void check(const char* function, const char* name, const T_y& y) {
28+
for (size_t n = 0; n < stan::math::size(y); n++) {
29+
if (is_nan(value_of_rec(stan::get(y, n)))) {
30+
throw_domain_error_vec(function, name, y, n, "is ",
31+
", but must not be nan!");
32+
}
33+
}
34+
}
35+
};
36+
} // namespace internal
37+
1038
/**
1139
* Check if <code>y</code> is not <code>NaN</code>.
1240
* This function is vectorized and will check each element of
@@ -21,10 +49,7 @@ namespace math {
2149
template <typename T_y>
2250
inline void check_not_nan(const char* function, const char* name,
2351
const T_y& y) {
24-
if (check_not_nan_screen(y)) {
25-
auto is_good = [](const auto& y) { return !std::isnan(y); };
26-
elementwise_check(is_good, function, name, y, ", but must not be nan!");
27-
}
52+
internal::not_nan<T_y, is_vector_like<T_y>::value>::check(function, name, y);
2853
}
2954

3055
} // namespace math

stan/math/prim/err/check_not_nan_screen.hpp

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)