|
2 | 2 | #define STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
|
3 | 3 |
|
4 | 4 | #include <stan/math/prim/meta.hpp>
|
5 |
| -#include <stan/math/prim/err/is_scal_finite.hpp> |
6 |
| -#include <stan/math/prim/err/throw_domain_error.hpp> |
7 |
| -#include <stan/math/prim/err/throw_domain_error_vec.hpp> |
| 5 | +#include <stan/math/prim/err/elementwise_check.hpp> |
8 | 6 | #include <stan/math/prim/fun/Eigen.hpp>
|
9 | 7 | #include <stan/math/prim/fun/get.hpp>
|
10 | 8 | #include <stan/math/prim/fun/size.hpp>
|
|
14 | 12 |
|
15 | 13 | namespace stan {
|
16 | 14 | namespace math {
|
17 |
| -namespace internal { |
18 |
| -/** |
19 |
| - * Return true if y is finite |
20 |
| - * |
21 |
| - * @tparam T_y type of y |
22 |
| - * @param y parameter to check |
23 |
| - * @return boolean |
24 |
| - */ |
25 |
| -template <typename T_y> |
26 |
| -bool is_finite(const T_y& y) { |
27 |
| - return is_scal_finite(y); |
28 |
| -} |
29 |
| - |
30 |
| -/** |
31 |
| - * Return true if every element of the matrix y is finite |
32 |
| - * |
33 |
| - * @tparam T_y type of elements y |
34 |
| - * @param y matrix to check |
35 |
| - * @return boolean |
36 |
| - */ |
37 |
| -template <typename T_y, int R, int C> |
38 |
| -bool is_finite(const Eigen::Matrix<T_y, R, C>& y) { |
39 |
| - bool all = true; |
40 |
| - for (size_t n = 0; n < y.size(); ++n) { |
41 |
| - all &= is_finite(y(n)); |
42 |
| - } |
43 |
| - return all; |
44 |
| -} |
45 |
| - |
46 |
| -/** |
47 |
| - * Return true if every element of the vector y is finite |
48 |
| - * |
49 |
| - * @tparam T_y type of elements y |
50 |
| - * @param y vector to check |
51 |
| - * @return boolean |
52 |
| - */ |
53 |
| -template <typename T_y> |
54 |
| -bool is_finite(const std::vector<T_y>& y) { |
55 |
| - bool all = true; |
56 |
| - for (size_t n = 0; n < stan::math::size(y); ++n) { |
57 |
| - all &= is_finite(y[n]); |
58 |
| - } |
59 |
| - return all; |
60 |
| -} |
61 |
| -} // namespace internal |
62 |
| - |
63 |
| -/** |
64 |
| - * Check if <code>y</code> is finite. |
65 |
| - * This function is vectorized and will check each element of |
66 |
| - * <code>y</code>. |
67 |
| - * @tparam T_y Type of y |
68 |
| - * @param function Function name (for error messages) |
69 |
| - * @param name Variable name (for error messages) |
70 |
| - * @param y Variable to check |
71 |
| - * @throw <code>domain_error</code> if y is infinity, -infinity, or NaN |
72 |
| - */ |
73 |
| -template <typename T_y, require_stan_scalar_t<T_y>* = nullptr> |
74 |
| -inline void check_finite(const char* function, const char* name, const T_y& y) { |
75 |
| - if (!internal::is_finite(y)) { |
76 |
| - throw_domain_error(function, name, y, "is ", ", but must be finite!"); |
77 |
| - } |
78 |
| -} |
79 |
| - |
80 |
| -/** |
81 |
| - * Return <code>true</code> if all values in the std::vector are finite. |
82 |
| - * |
83 |
| - * @tparam T_y type of elements in the std::vector |
84 |
| - * |
85 |
| - * @param function name of function (for error messages) |
86 |
| - * @param name variable name (for error messages) |
87 |
| - * @param y std::vector to test |
88 |
| - * @return <code>true</code> if all values are finite |
89 |
| - **/ |
90 |
| -template <typename T_y, require_stan_scalar_t<T_y>* = nullptr> |
91 |
| -inline void check_finite(const char* function, const char* name, |
92 |
| - const std::vector<T_y>& y) { |
93 |
| - for (size_t n = 0; n < stan::math::size(y); n++) { |
94 |
| - if (!internal::is_finite(stan::get(y, n))) { |
95 |
| - throw_domain_error_vec(function, name, y, n, "is ", |
96 |
| - ", but must be finite!"); |
97 |
| - } |
98 |
| - } |
99 |
| -} |
100 |
| - |
101 |
| -/** |
102 |
| - * Return <code>true</code> is the specified matrix is finite. |
103 |
| - * |
104 |
| - * @tparam Derived Eigen derived type |
105 |
| - * |
106 |
| - * @param function name of function (for error messages) |
107 |
| - * @param name variable name (for error messages) |
108 |
| - * @param y matrix to test |
109 |
| - * @return <code>true</code> if the matrix is finite |
110 |
| - **/ |
111 |
| -template <typename Mat, require_matrix_t<Mat>* = nullptr> |
112 |
| -inline void check_finite(const char* function, const char* name, const Mat& y) { |
113 |
| - if (!value_of(y).allFinite()) { |
114 |
| - for (int n = 0; n < y.size(); ++n) { |
115 |
| - if (!std::isfinite(value_of_rec(y(n)))) { |
116 |
| - throw_domain_error_vec(function, name, value_of(y), n, "is ", |
117 |
| - ", but must be finite!"); |
118 |
| - } |
119 |
| - } |
120 |
| - } |
121 |
| -} |
122 | 15 |
|
123 | 16 | /**
|
124 |
| - * Return <code>true</code> if all values in the std::vector are finite. |
| 17 | + * Return <code>true</code> if all values in `y` are finite. `y` can be a |
| 18 | + *scalar, `std::vector` or Eigen type. |
125 | 19 | *
|
126 |
| - * @tparam T_y type of elements in the std::vector |
| 20 | + * @tparam T_y type of `y` |
127 | 21 | *
|
128 | 22 | * @param function name of function (for error messages)
|
129 | 23 | * @param name variable name (for error messages)
|
130 |
| - * @param y std::vector to test |
| 24 | + * @param y scalar or container to test |
131 | 25 | * @return <code>true</code> if all values are finite
|
132 | 26 | **/
|
133 |
| -template <typename T_y, require_not_stan_scalar_t<T_y>* = nullptr> |
134 |
| -inline void check_finite(const char* function, const char* name, |
135 |
| - const std::vector<T_y>& y) { |
136 |
| - for (size_t n = 0; n < stan::math::size(y); n++) { |
137 |
| - if (!internal::is_finite(stan::get(y, n))) { |
138 |
| - throw_domain_error(function, name, "", "", "is not finite!"); |
139 |
| - } |
140 |
| - } |
| 27 | +template <typename T_y> |
| 28 | +inline void check_finite(const char* function, const char* name, const T_y& y) { |
| 29 | + elementwise_check([](double x) { return std::isfinite(x); }, function, name, |
| 30 | + y, "finite"); |
141 | 31 | }
|
142 | 32 |
|
143 | 33 | } // namespace math
|
|
0 commit comments