Skip to content

Commit 9d3cd07

Browse files
authored
Merge pull request #2205 from bstatcomp/no_linear_indexing_requirement
Remove linear indexing requirements from functions
2 parents 0c1ef70 + e3fbd53 commit 9d3cd07

18 files changed

+322
-398
lines changed

stan/math/prim/err/check_finite.hpp

Lines changed: 9 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
#define STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/is_scal_finite.hpp>
6-
#include <stan/math/prim/err/throw_domain_error.hpp>
7-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
5+
#include <stan/math/prim/err/elementwise_check.hpp>
86
#include <stan/math/prim/fun/Eigen.hpp>
97
#include <stan/math/prim/fun/get.hpp>
108
#include <stan/math/prim/fun/size.hpp>
@@ -14,130 +12,22 @@
1412

1513
namespace stan {
1614
namespace math {
17-
namespace internal {
18-
/**
19-
* Return true if y is finite
20-
*
21-
* @tparam T_y type of y
22-
* @param y parameter to check
23-
* @return boolean
24-
*/
25-
template <typename T_y>
26-
bool is_finite(const T_y& y) {
27-
return is_scal_finite(y);
28-
}
29-
30-
/**
31-
* Return true if every element of the matrix y is finite
32-
*
33-
* @tparam T_y type of elements y
34-
* @param y matrix to check
35-
* @return boolean
36-
*/
37-
template <typename T_y, int R, int C>
38-
bool is_finite(const Eigen::Matrix<T_y, R, C>& y) {
39-
bool all = true;
40-
for (size_t n = 0; n < y.size(); ++n) {
41-
all &= is_finite(y(n));
42-
}
43-
return all;
44-
}
45-
46-
/**
47-
* Return true if every element of the vector y is finite
48-
*
49-
* @tparam T_y type of elements y
50-
* @param y vector to check
51-
* @return boolean
52-
*/
53-
template <typename T_y>
54-
bool is_finite(const std::vector<T_y>& y) {
55-
bool all = true;
56-
for (size_t n = 0; n < stan::math::size(y); ++n) {
57-
all &= is_finite(y[n]);
58-
}
59-
return all;
60-
}
61-
} // namespace internal
62-
63-
/**
64-
* Check if <code>y</code> is finite.
65-
* This function is vectorized and will check each element of
66-
* <code>y</code>.
67-
* @tparam T_y Type of y
68-
* @param function Function name (for error messages)
69-
* @param name Variable name (for error messages)
70-
* @param y Variable to check
71-
* @throw <code>domain_error</code> if y is infinity, -infinity, or NaN
72-
*/
73-
template <typename T_y, require_stan_scalar_t<T_y>* = nullptr>
74-
inline void check_finite(const char* function, const char* name, const T_y& y) {
75-
if (!internal::is_finite(y)) {
76-
throw_domain_error(function, name, y, "is ", ", but must be finite!");
77-
}
78-
}
79-
80-
/**
81-
* Return <code>true</code> if all values in the std::vector are finite.
82-
*
83-
* @tparam T_y type of elements in the std::vector
84-
*
85-
* @param function name of function (for error messages)
86-
* @param name variable name (for error messages)
87-
* @param y std::vector to test
88-
* @return <code>true</code> if all values are finite
89-
**/
90-
template <typename T_y, require_stan_scalar_t<T_y>* = nullptr>
91-
inline void check_finite(const char* function, const char* name,
92-
const std::vector<T_y>& y) {
93-
for (size_t n = 0; n < stan::math::size(y); n++) {
94-
if (!internal::is_finite(stan::get(y, n))) {
95-
throw_domain_error_vec(function, name, y, n, "is ",
96-
", but must be finite!");
97-
}
98-
}
99-
}
100-
101-
/**
102-
* Return <code>true</code> is the specified matrix is finite.
103-
*
104-
* @tparam Derived Eigen derived type
105-
*
106-
* @param function name of function (for error messages)
107-
* @param name variable name (for error messages)
108-
* @param y matrix to test
109-
* @return <code>true</code> if the matrix is finite
110-
**/
111-
template <typename Mat, require_matrix_t<Mat>* = nullptr>
112-
inline void check_finite(const char* function, const char* name, const Mat& y) {
113-
if (!value_of(y).allFinite()) {
114-
for (int n = 0; n < y.size(); ++n) {
115-
if (!std::isfinite(value_of_rec(y(n)))) {
116-
throw_domain_error_vec(function, name, value_of(y), n, "is ",
117-
", but must be finite!");
118-
}
119-
}
120-
}
121-
}
12215

12316
/**
124-
* Return <code>true</code> if all values in the std::vector are finite.
17+
* Return <code>true</code> if all values in `y` are finite. `y` can be a
18+
*scalar, `std::vector` or Eigen type.
12519
*
126-
* @tparam T_y type of elements in the std::vector
20+
* @tparam T_y type of `y`
12721
*
12822
* @param function name of function (for error messages)
12923
* @param name variable name (for error messages)
130-
* @param y std::vector to test
24+
* @param y scalar or container to test
13125
* @return <code>true</code> if all values are finite
13226
**/
133-
template <typename T_y, require_not_stan_scalar_t<T_y>* = nullptr>
134-
inline void check_finite(const char* function, const char* name,
135-
const std::vector<T_y>& y) {
136-
for (size_t n = 0; n < stan::math::size(y); n++) {
137-
if (!internal::is_finite(stan::get(y, n))) {
138-
throw_domain_error(function, name, "", "", "is not finite!");
139-
}
140-
}
27+
template <typename T_y>
28+
inline void check_finite(const char* function, const char* name, const T_y& y) {
29+
elementwise_check([](double x) { return std::isfinite(x); }, function, name,
30+
y, "finite");
14131
}
14232

14333
} // namespace math

stan/math/prim/err/check_nonnegative.hpp

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,14 @@
22
#define STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
5+
#include <stan/math/prim/err/elementwise_check.hpp>
76
#include <stan/math/prim/fun/get.hpp>
87
#include <stan/math/prim/fun/size.hpp>
98
#include <type_traits>
109

1110
namespace stan {
1211
namespace math {
1312

14-
namespace internal {
15-
template <typename T_y, bool is_vec>
16-
struct nonnegative {
17-
static void check(const char* function, const char* name, const T_y& y) {
18-
// have to use not is_unsigned. is_signed will be false for
19-
// floating point types that have no unsigned versions.
20-
if (!std::is_unsigned<T_y>::value && !(y >= 0)) {
21-
throw_domain_error(function, name, y, "is ", ", but must be >= 0!");
22-
}
23-
}
24-
};
25-
26-
template <typename T_y>
27-
struct nonnegative<T_y, true> {
28-
static void check(const char* function, const char* name, const T_y& y) {
29-
for (size_t n = 0; n < stan::math::size(y); n++) {
30-
if (!std::is_unsigned<typename value_type<T_y>::type>::value
31-
&& !(stan::get(y, n) >= 0)) {
32-
throw_domain_error_vec(function, name, y, n, "is ",
33-
", but must be >= 0!");
34-
}
35-
}
36-
}
37-
};
38-
} // namespace internal
39-
4013
/**
4114
* Check if <code>y</code> is non-negative.
4215
* This function is vectorized and will check each element of <code>y</code>.
@@ -50,8 +23,8 @@ struct nonnegative<T_y, true> {
5023
template <typename T_y>
5124
inline void check_nonnegative(const char* function, const char* name,
5225
const T_y& y) {
53-
internal::nonnegative<T_y, is_vector_like<T_y>::value>::check(function, name,
54-
y);
26+
elementwise_check([](double x) { return x >= 0; }, function, name, y,
27+
"nonnegative");
5528
}
5629
} // namespace math
5730
} // namespace stan

stan/math/prim/err/check_not_nan.hpp

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
#define STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
5+
#include <stan/math/prim/err/elementwise_check.hpp>
76
#include <stan/math/prim/fun/get.hpp>
87
#include <stan/math/prim/fun/is_nan.hpp>
98
#include <stan/math/prim/fun/size.hpp>
@@ -12,29 +11,6 @@
1211
namespace stan {
1312
namespace math {
1413

15-
namespace internal {
16-
template <typename T_y, bool is_vec>
17-
struct not_nan {
18-
static void check(const char* function, const char* name, const T_y& y) {
19-
if (is_nan(value_of_rec(y))) {
20-
throw_domain_error(function, name, y, "is ", ", but must not be nan!");
21-
}
22-
}
23-
};
24-
25-
template <typename T_y>
26-
struct not_nan<T_y, true> {
27-
static void check(const char* function, const char* name, const T_y& y) {
28-
for (size_t n = 0; n < stan::math::size(y); n++) {
29-
if (is_nan(value_of_rec(stan::get(y, n)))) {
30-
throw_domain_error_vec(function, name, y, n, "is ",
31-
", but must not be nan!");
32-
}
33-
}
34-
}
35-
};
36-
} // namespace internal
37-
3814
/**
3915
* Check if <code>y</code> is not <code>NaN</code>.
4016
* This function is vectorized and will check each element of
@@ -49,7 +25,8 @@ struct not_nan<T_y, true> {
4925
template <typename T_y>
5026
inline void check_not_nan(const char* function, const char* name,
5127
const T_y& y) {
52-
internal::not_nan<T_y, is_vector_like<T_y>::value>::check(function, name, y);
28+
elementwise_check([](double x) { return !std::isnan(x); }, function, name, y,
29+
"not nan");
5330
}
5431

5532
} // namespace math

stan/math/prim/err/check_positive.hpp

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
#define STAN_MATH_PRIM_ERR_CHECK_POSITIVE_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
5+
#include <stan/math/prim/err/elementwise_check.hpp>
76
#include <stan/math/prim/err/invalid_argument.hpp>
87
#include <stan/math/prim/fun/get.hpp>
98
#include <stan/math/prim/fun/size.hpp>
@@ -13,34 +12,6 @@
1312
namespace stan {
1413
namespace math {
1514

16-
namespace {
17-
18-
template <typename T_y, bool is_vec>
19-
struct positive {
20-
static void check(const char* function, const char* name, const T_y& y) {
21-
// have to use not is_unsigned. is_signed will be false
22-
// floating point types that have no unsigned versions.
23-
if (!std::is_unsigned<T_y>::value && !(y > 0)) {
24-
throw_domain_error(function, name, y, "is ", ", but must be > 0!");
25-
}
26-
}
27-
};
28-
29-
template <typename T_y>
30-
struct positive<T_y, true> {
31-
static void check(const char* function, const char* name, const T_y& y) {
32-
for (size_t n = 0; n < stan::math::size(y); n++) {
33-
if (!std::is_unsigned<typename value_type<T_y>::type>::value
34-
&& !(stan::get(y, n) > 0)) {
35-
throw_domain_error_vec(function, name, y, n, "is ",
36-
", but must be > 0!");
37-
}
38-
}
39-
}
40-
};
41-
42-
} // namespace
43-
4415
/**
4516
* Check if <code>y</code> is positive.
4617
* This function is vectorized and will check each element of
@@ -55,7 +26,8 @@ struct positive<T_y, true> {
5526
template <typename T_y>
5627
inline void check_positive(const char* function, const char* name,
5728
const T_y& y) {
58-
positive<T_y, is_vector_like<T_y>::value>::check(function, name, y);
29+
elementwise_check([](double x) { return x > 0; }, function, name, y,
30+
"positive");
5931
}
6032

6133
/**

stan/math/prim/err/check_positive_finite.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
#define STAN_MATH_PRIM_ERR_CHECK_POSITIVE_FINITE_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/check_positive.hpp>
6-
#include <stan/math/prim/err/check_finite.hpp>
5+
#include <stan/math/prim/err/elementwise_check.hpp>
76

87
namespace stan {
98
namespace math {
@@ -22,8 +21,8 @@ namespace math {
2221
template <typename T_y>
2322
inline void check_positive_finite(const char* function, const char* name,
2423
const T_y& y) {
25-
check_positive(function, name, y);
26-
check_finite(function, name, y);
24+
elementwise_check([](double x) { return x > 0 && std::isfinite(x); },
25+
function, name, y, "positive finite");
2726
}
2827

2928
} // namespace math

0 commit comments

Comments
 (0)