|
1 | 1 | #ifndef STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
|
2 | 2 | #define STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
|
3 | 3 |
|
4 |
| -#include <stan/math/prim/err/elementwise_check.hpp> |
5 |
| -#include <stan/math/prim/err/check_finite_screen.hpp> |
| 4 | +#include <stan/math/prim/meta.hpp> |
| 5 | +#include <stan/math/prim/err/is_scal_finite.hpp> |
| 6 | +#include <stan/math/prim/err/throw_domain_error.hpp> |
| 7 | +#include <stan/math/prim/err/throw_domain_error_vec.hpp> |
| 8 | +#include <stan/math/prim/fun/Eigen.hpp> |
| 9 | +#include <stan/math/prim/fun/get.hpp> |
| 10 | +#include <stan/math/prim/fun/size.hpp> |
| 11 | +#include <stan/math/prim/fun/value_of.hpp> |
| 12 | +#include <stan/math/prim/fun/value_of_rec.hpp> |
| 13 | +#include <cmath> |
6 | 14 |
|
7 | 15 | namespace stan {
|
8 | 16 | namespace math {
|
| 17 | +namespace internal { |
| 18 | +/** |
| 19 | + * Return true if y is finite |
| 20 | + * |
| 21 | + * @tparam T_y type of y |
| 22 | + * @param y parameter to check |
| 23 | + * @return boolean |
| 24 | + */ |
| 25 | +template <typename T_y> |
| 26 | +bool is_finite(const T_y& y) { |
| 27 | + return is_scal_finite(y); |
| 28 | +} |
| 29 | + |
| 30 | +/** |
| 31 | + * Return true if every element of the matrix y is finite |
| 32 | + * |
| 33 | + * @tparam T_y type of elements y |
| 34 | + * @param y matrix to check |
| 35 | + * @return boolean |
| 36 | + */ |
| 37 | +template <typename T_y, int R, int C> |
| 38 | +bool is_finite(const Eigen::Matrix<T_y, R, C>& y) { |
| 39 | + bool all = true; |
| 40 | + for (size_t n = 0; n < y.size(); ++n) { |
| 41 | + all &= is_finite(y(n)); |
| 42 | + } |
| 43 | + return all; |
| 44 | +} |
| 45 | + |
| 46 | +/** |
| 47 | + * Return true if every element of the vector y is finite |
| 48 | + * |
| 49 | + * @tparam T_y type of elements y |
| 50 | + * @param y vector to check |
| 51 | + * @return boolean |
| 52 | + */ |
| 53 | +template <typename T_y> |
| 54 | +bool is_finite(const std::vector<T_y>& y) { |
| 55 | + bool all = true; |
| 56 | + for (size_t n = 0; n < stan::math::size(y); ++n) { |
| 57 | + all &= is_finite(y[n]); |
| 58 | + } |
| 59 | + return all; |
| 60 | +} |
| 61 | +} // namespace internal |
9 | 62 |
|
10 | 63 | /**
|
11 | 64 | * Check if <code>y</code> is finite.
|
12 | 65 | * This function is vectorized and will check each element of
|
13 | 66 | * <code>y</code>.
|
14 |
| - * @tparam T_y type of y |
15 |
| - * @param function function name (for error messages) |
16 |
| - * @param name variable name (for error messages) |
17 |
| - * @param y variable to check |
| 67 | + * @tparam T_y Type of y |
| 68 | + * @param function Function name (for error messages) |
| 69 | + * @param name Variable name (for error messages) |
| 70 | + * @param y Variable to check |
18 | 71 | * @throw <code>domain_error</code> if y is infinity, -infinity, or NaN
|
19 | 72 | */
|
20 |
| -template <typename T_y> |
| 73 | +template <typename T_y, require_stan_scalar_t<T_y>* = nullptr> |
21 | 74 | inline void check_finite(const char* function, const char* name, const T_y& y) {
|
22 |
| - if (check_finite_screen(y)) { |
23 |
| - auto is_good = [](const auto& y) { return std::isfinite(y); }; |
24 |
| - elementwise_check(is_good, function, name, y, ", but must be finite!"); |
| 75 | + if (!internal::is_finite(y)) { |
| 76 | + throw_domain_error(function, name, y, "is ", ", but must be finite!"); |
| 77 | + } |
| 78 | +} |
| 79 | + |
| 80 | +/** |
| 81 | + * Return <code>true</code> if all values in the std::vector are finite. |
| 82 | + * |
| 83 | + * @tparam T_y type of elements in the std::vector |
| 84 | + * |
| 85 | + * @param function name of function (for error messages) |
| 86 | + * @param name variable name (for error messages) |
| 87 | + * @param y std::vector to test |
| 88 | + * @return <code>true</code> if all values are finite |
| 89 | + **/ |
| 90 | +template <typename T_y, require_stan_scalar_t<T_y>* = nullptr> |
| 91 | +inline void check_finite(const char* function, const char* name, |
| 92 | + const std::vector<T_y>& y) { |
| 93 | + for (size_t n = 0; n < stan::math::size(y); n++) { |
| 94 | + if (!internal::is_finite(stan::get(y, n))) { |
| 95 | + throw_domain_error_vec(function, name, y, n, "is ", |
| 96 | + ", but must be finite!"); |
| 97 | + } |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +/** |
| 102 | + * Return <code>true</code> is the specified matrix is finite. |
| 103 | + * |
| 104 | + * @tparam Derived Eigen derived type |
| 105 | + * |
| 106 | + * @param function name of function (for error messages) |
| 107 | + * @param name variable name (for error messages) |
| 108 | + * @param y matrix to test |
| 109 | + * @return <code>true</code> if the matrix is finite |
| 110 | + **/ |
| 111 | +template <typename EigMat, require_eigen_t<EigMat>* = nullptr> |
| 112 | +inline void check_finite(const char* function, const char* name, |
| 113 | + const EigMat& y) { |
| 114 | + if (!value_of(y).allFinite()) { |
| 115 | + for (int n = 0; n < y.size(); ++n) { |
| 116 | + if (!std::isfinite(value_of_rec(y(n)))) { |
| 117 | + throw_domain_error_vec(function, name, y, n, "is ", |
| 118 | + ", but must be finite!"); |
| 119 | + } |
| 120 | + } |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +/** |
| 125 | + * Return <code>true</code> if all values in the std::vector are finite. |
| 126 | + * |
| 127 | + * @tparam T_y type of elements in the std::vector |
| 128 | + * |
| 129 | + * @param function name of function (for error messages) |
| 130 | + * @param name variable name (for error messages) |
| 131 | + * @param y std::vector to test |
| 132 | + * @return <code>true</code> if all values are finite |
| 133 | + **/ |
| 134 | +template <typename T_y, require_not_stan_scalar_t<T_y>* = nullptr> |
| 135 | +inline void check_finite(const char* function, const char* name, |
| 136 | + const std::vector<T_y>& y) { |
| 137 | + for (size_t n = 0; n < stan::math::size(y); n++) { |
| 138 | + if (!internal::is_finite(stan::get(y, n))) { |
| 139 | + throw_domain_error(function, name, "", "", "is not finite!"); |
| 140 | + } |
25 | 141 | }
|
26 | 142 | }
|
27 | 143 |
|
|
0 commit comments