Skip to content

Commit 7dd7b31

Browse files
authored
Merge pull request #2556 from stan-dev/fix/check-less-greater
Vectorize checks called by compiler
2 parents adb413d + 31b69b6 commit 7dd7b31

29 files changed

+2750
-912
lines changed

stan/math/prim/err.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@
7373
#include <stan/math/prim/err/out_of_range.hpp>
7474
#include <stan/math/prim/err/system_error.hpp>
7575
#include <stan/math/prim/err/throw_domain_error.hpp>
76+
#include <stan/math/prim/err/throw_domain_error_mat.hpp>
7677
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
7778
#include <stan/math/prim/err/validate_non_negative_index.hpp>
7879
#include <stan/math/prim/err/validate_positive_index.hpp>
7980
#include <stan/math/prim/err/validate_unit_vector_index.hpp>
80-
8181
#endif
Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,71 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_CHOLESKY_FACTOR_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_CHOLESKY_FACTOR_HPP
33

4-
#include <stan/math/prim/meta.hpp>
54
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/fun/to_ref.hpp>
7+
#include <stan/math/prim/fun/value_of_rec.hpp>
68
#include <stan/math/prim/err/check_positive.hpp>
79
#include <stan/math/prim/err/check_less_or_equal.hpp>
810
#include <stan/math/prim/err/check_lower_triangular.hpp>
11+
#include <stan/math/prim/err/make_iter_name.hpp>
912

1013
namespace stan {
1114
namespace math {
1215

1316
/**
14-
* Check if the specified matrix is a valid Cholesky factor.
15-
* A Cholesky factor is a lower triangular matrix whose diagonal
16-
* elements are all positive. Note that Cholesky factors need not
17-
* be square, but require at least as many rows M as columns N
18-
* (i.e., M &gt;= N).
19-
* @tparam EigMat Type of the Cholesky factor (must be derived from \c
20-
* Eigen::MatrixBase)
17+
* Throw an exception if the specified matrix is not a valid Cholesky factor. A
18+
* Cholesky factor is a lower triangular matrix whose diagonal elements are all
19+
* positive. Note that Cholesky factors need not be square, but require at
20+
* least as many rows M as columns N (i.e., `M >= N`).
21+
* @tparam Mat Type inheriting from `MatrixBase` with neither rows or columns
22+
* defined at compile time to be equal to 1 or a `var_value` with the var's
23+
* inner type inheriting from `Eigen::MatrixBase` with neither rows or columns
24+
* defined at compile time to be equal to 1
2125
* @param function Function name (for error messages)
2226
* @param name Variable name (for error messages)
2327
* @param y Matrix to test
24-
* @throw <code>std::domain_error</code> if y is not a valid Cholesky
25-
* factor, if number of rows is less than the number of columns,
26-
* if there are 0 columns, or if any element in matrix is NaN
28+
* @throw `std::domain_error` if y is not a valid Cholesky factor, if number of
29+
* rows is less than the number of columns, if there are 0 columns, or if any
30+
* element in matrix is `NaN`
2731
*/
28-
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
32+
template <typename Mat, require_matrix_t<Mat>* = nullptr>
2933
inline void check_cholesky_factor(const char* function, const char* name,
30-
const EigMat& y) {
34+
const Mat& y) {
3135
check_less_or_equal(function, "columns and rows of Cholesky factor", y.cols(),
3236
y.rows());
3337
check_positive(function, "columns of Cholesky factor", y.cols());
34-
const Eigen::Ref<const plain_type_t<EigMat>>& y_ref = y;
38+
auto&& y_ref = to_ref(value_of_rec(y));
3539
check_lower_triangular(function, name, y_ref);
3640
check_positive(function, name, y_ref.diagonal());
3741
}
3842

43+
/**
44+
* Throw an exception if the specified matrix is not a valid Cholesky factor. A
45+
* Cholesky factor is a lower triangular matrix whose diagonal elements are all
46+
* positive. Note that Cholesky factors need not be square, but require at
47+
* least as many rows M as columns N (i.e., `M >= N`).
48+
* @tparam StdVec A standard vector with inner type either inheriting from
49+
* `MatrixBase` with neither rows or columns defined at compile time to be equal
50+
* to 1 or a `var_value` with the var's inner type inheriting from
51+
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to
52+
* be equal to 1
53+
* @param function Function name (for error messages)
54+
* @param name Variable name (for error messages)
55+
* @param y Standard vector of matrices to test
56+
* @throw `std::domain_error` if y is not a valid Cholesky factor, if number of
57+
* rows is less than the number of columns, if there are 0 columns, or if any
58+
* element in matrix is `NaN`
59+
*/
60+
template <typename StdVec, require_std_vector_t<StdVec>* = nullptr>
61+
void check_cholesky_factor(const char* function, const char* name,
62+
const StdVec& y) {
63+
for (size_t i = 0; i < y.size(); ++i) {
64+
check_cholesky_factor(function, internal::make_iter_name(name, i).c_str(),
65+
y[i]);
66+
}
67+
}
68+
3969
} // namespace math
4070
} // namespace stan
4171
#endif

stan/math/prim/err/check_cholesky_factor_corr.hpp

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,39 @@
44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta.hpp>
66
#include <stan/math/prim/fun/to_ref.hpp>
7+
#include <stan/math/prim/fun/value_of_rec.hpp>
78
#include <stan/math/prim/err/check_positive.hpp>
89
#include <stan/math/prim/err/check_lower_triangular.hpp>
910
#include <stan/math/prim/err/check_square.hpp>
1011
#include <stan/math/prim/err/check_unit_vector.hpp>
12+
#include <stan/math/prim/err/make_iter_name.hpp>
1113

1214
namespace stan {
1315
namespace math {
1416

1517
/**
16-
* Check if the specified matrix is a valid Cholesky factor of a
17-
* correlation matrix.
18-
* A Cholesky factor is a lower triangular matrix whose diagonal
19-
* elements are all positive. Note that Cholesky factors need not
20-
* be square, but require at least as many rows M as columns N
21-
* (i.e., M &gt;= N).
22-
* Tolerance is specified by <code>math::CONSTRAINT_TOLERANCE</code>.
23-
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and
24-
* columns.
18+
* Throw an exception if the specified matrix is not a valid Cholesky factor of
19+
* a correlation matrix. A Cholesky factor is a lower triangular matrix whose
20+
* diagonal elements are all positive and each row has unit Euclidean length.
21+
* Note that Cholesky factors need not be square, but require at least as many
22+
* rows M as columns N (i.e., `M >= N`). Tolerance is specified by
23+
* `math::CONSTRAINT_TOLERANCE`. Tolerance is specified by
24+
* `math::CONSTRAINT_TOLERANCE`.
25+
* @tparam Mat Type inheriting from `MatrixBase` with neither rows or columns
26+
* defined at compile time to be equal to 1 or a `var_value` with the var's
27+
* inner type inheriting from `Eigen::MatrixBase` with neither rows or columns
28+
* defined at compile time to be equal to 1
2529
* @param function Function name (for error messages)
2630
* @param name Variable name (for error messages)
2731
* @param y Matrix to test
28-
* @throw <code>std::domain_error</code> if y is not a valid Cholesky
29-
* factor, if number of rows is less than the number of columns,
30-
* if there are 0 columns, or if any element in matrix is NaN
32+
* @throw `std::domain_error` if y is not a valid Cholesky factor, if number of
33+
* rows is less than the number of columns, if there are 0 columns, or if any
34+
* element in matrix is NaN
3135
*/
32-
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr>
36+
template <typename Mat, require_matrix_t<Mat>* = nullptr>
3337
void check_cholesky_factor_corr(const char* function, const char* name,
34-
const EigMat& y) {
35-
const auto& y_ref = to_ref(y);
38+
const Mat& y) {
39+
const auto& y_ref = to_ref(value_of_rec(y));
3640
check_square(function, name, y_ref);
3741
check_lower_triangular(function, name, y_ref);
3842
check_positive(function, name, y_ref.diagonal());
@@ -41,6 +45,34 @@ void check_cholesky_factor_corr(const char* function, const char* name,
4145
}
4246
}
4347

48+
/**
49+
* Throw an exception if the specified matrix is not a valid Cholesky factor of
50+
* a correlation matrix. A Cholesky factor is a lower triangular matrix whose
51+
* diagonal elements are all positive and each row has unit Euclidean length.
52+
* Note that Cholesky factors need not be square, but require at least as many
53+
* rows M as columns N (i.e., `M >= N`). Tolerance is specified by
54+
* `math::CONSTRAINT_TOLERANCE`. Tolerance is specified by
55+
* `math::CONSTRAINT_TOLERANCE`.
56+
* @tparam StdVec A standard vector with inner type either inheriting from
57+
* `MatrixBase` with neither rows or columns defined at compile time to be equal
58+
* to 1 or a `var_value` with the var's inner type inheriting from
59+
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to
60+
* be equal to 1
61+
* @param function Function name (for error messages)
62+
* @param name Variable name (for error messages)
63+
* @param y Standard vector of matrics to test
64+
* @throw `std::domain_error` if y[i] is not a valid Cholesky factor, if number
65+
* of rows is less than the number of columns, if there are 0 columns, or if any
66+
* element in matrix is NaN
67+
*/
68+
template <typename StdVec, require_std_vector_t<StdVec>* = nullptr>
69+
void check_cholesky_factor_corr(const char* function, const char* name,
70+
const StdVec& y) {
71+
for (size_t i = 0; i < y.size(); ++i) {
72+
check_cholesky_factor_corr(function,
73+
internal::make_iter_name(name, i).c_str(), y[i]);
74+
}
75+
}
4476
} // namespace math
4577
} // namespace stan
4678
#endif

stan/math/prim/err/check_corr_matrix.hpp

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/fun/to_ref.hpp>
7+
#include <stan/math/prim/fun/value_of_rec.hpp>
68
#include <stan/math/prim/err/throw_domain_error.hpp>
79
#include <stan/math/prim/err/check_pos_definite.hpp>
810
#include <stan/math/prim/err/check_square.hpp>
9-
#include <stan/math/prim/fun/to_ref.hpp>
1011
#include <sstream>
1112
#include <string>
1213
#include <cmath>
@@ -15,44 +16,68 @@ namespace stan {
1516
namespace math {
1617

1718
/**
18-
* Check if the specified matrix is a valid correlation matrix.
19-
* A valid correlation matrix is symmetric, has a unit diagonal
20-
* (all 1 values), and has all values between -1 and 1
21-
* (inclusive).
22-
* This function throws exceptions if the variable is not a valid
23-
* correlation matrix.
24-
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and
25-
* columns.
19+
* Throw an exception if the specified matrix is not a valid correlation matrix.
20+
* A valid correlation matrix is symmetric positive definite, has a unit
21+
* diagonal (all 1 values), and has all values between -1 and 1 (inclusive).
22+
* @tparam Mat Type inheriting from `MatrixBase` with neither rows or columns
23+
* defined at compile time to be equal to 1 or a `var_value` with the var's
24+
* inner type inheriting from `Eigen::MatrixBase` with neither rows or columns
25+
* defined at compile time to be equal to 1
2626
* @param function Name of the function this was called from
2727
* @param name Name of the variable
2828
* @param y Matrix to test
29-
* @throw <code>std::invalid_argument</code> if the matrix is not square
30-
* @throw <code>std::domain_error</code> if the matrix is non-symmetric,
31-
* diagonals not near 1, not positive definite, or any of the
32-
* elements nan
29+
* @throw `std::invalid_argument` if the matrix is not square
30+
* @throw `std::domain_error` if the matrix is non-symmetric, diagonals not near
31+
* 1, not positive definite, or any of the elements are `NaN`
3332
*/
34-
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr>
33+
template <typename Mat, require_matrix_t<Mat>* = nullptr>
3534
inline void check_corr_matrix(const char* function, const char* name,
36-
const EigMat& y) {
37-
const auto& y_ref = to_ref(y);
35+
const Mat& y) {
36+
auto&& y_ref = to_ref(value_of_rec(y));
3837
check_square(function, name, y_ref);
3938
using std::fabs;
4039
if (y_ref.size() == 0) {
4140
return;
4241
}
4342

4443
for (Eigen::Index k = 0; k < y.rows(); ++k) {
45-
if (!(fabs(y_ref(k, k) - 1.0) <= CONSTRAINT_TOLERANCE)) {
46-
std::ostringstream msg;
47-
msg << "is not a valid correlation matrix. " << name << "("
48-
<< stan::error_index::value + k << "," << stan::error_index::value + k
49-
<< ") is ";
50-
std::string msg_str(msg.str());
51-
throw_domain_error(function, name, y_ref(k, k), msg_str.c_str(),
52-
", but should be near 1.0");
44+
if (!(fabs(y_ref.coeff(k, k) - 1.0) <= CONSTRAINT_TOLERANCE)) {
45+
[&y_ref, name, k, function]() STAN_COLD_PATH {
46+
std::ostringstream msg;
47+
msg << "is not a valid correlation matrix. " << name << "("
48+
<< stan::error_index::value + k << ","
49+
<< stan::error_index::value + k << ") is ";
50+
std::string msg_str(msg.str());
51+
throw_domain_error(function, name, y_ref(k, k), msg_str.c_str(),
52+
", but should be near 1.0");
53+
}();
5354
}
5455
}
55-
check_pos_definite(function, "y", y_ref);
56+
check_pos_definite(function, name, y_ref);
57+
}
58+
59+
/**
60+
* Throw an exception if the specified matrix is not a valid correlation matrix.
61+
* A valid correlation matrix is symmetric positive definite, has a unit
62+
* diagonal (all 1 values), and has all values between -1 and 1 (inclusive).
63+
* @tparam StdVec A standard vector with inner type either inheriting from
64+
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to
65+
* be equal to 1 or a `var_value` with the var's inner type inheriting from
66+
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to
67+
* be equal to 1.
68+
* @param function Name of the function this was called from
69+
* @param name Name of the variable
70+
* @param y Matrix to test
71+
* @throw `std::invalid_argument` if the matrix is not square
72+
* @throw `std::domain_error` if the matrix is non-symmetric, diagonals not near
73+
* 1, not positive definite, or any of the elements are `NaN`
74+
*/
75+
template <typename StdVec, require_std_vector_t<StdVec>* = nullptr>
76+
void check_corr_matrix(const char* function, const char* name,
77+
const StdVec& y) {
78+
for (auto&& y_i : y) {
79+
check_corr_matrix(function, name, y_i);
80+
}
5681
}
5782

5883
} // namespace math

stan/math/prim/err/check_cov_matrix.hpp

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,58 @@
33

44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/fun/to_ref.hpp>
7+
#include <stan/math/prim/fun/value_of_rec.hpp>
68
#include <stan/math/prim/err/check_pos_definite.hpp>
79

810
namespace stan {
911
namespace math {
1012
/**
11-
* Check if the specified matrix is a valid covariance matrix.
12-
* A valid covariance matrix is a square, symmetric matrix that is
13-
* positive definite.
14-
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and
15-
* columns.
13+
* Throw an exception if the specified matrix is not a valid covariance matrix.
14+
* A valid covariance matrix is a square, symmetric matrix that is positive
15+
* definite.
16+
* @tparam Mat Type inheriting from `MatrixBase` with neither rows or columns
17+
* defined at compile time to be equal to 1 or a `var_value` with the var's
18+
* inner type inheriting from `Eigen::MatrixBase` with neither rows or columns
19+
* defined at compile time to be equal to 1
1620
* @param function Function name (for error messages)
1721
* @param name Variable name (for error messages)
1822
* @param y Matrix to test
19-
* @throw <code>std::invalid_argument</code> if the matrix is not square
20-
* or if the matrix is 0x0
21-
* @throw <code>std::domain_error</code> if the matrix is not symmetric,
22-
* if the matrix is not positive definite,
23-
* or if any element of the matrix is nan
23+
* @throw `std::invalid_argument` if the matrix is not square or if the matrix
24+
* is 0x0
25+
* @throw `std::domain_error` if the matrix is not symmetric, if the matrix is
26+
* not positive definite, or if any element of the matrix is `NaN`
2427
*/
25-
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr>
28+
template <typename Mat, require_matrix_t<Mat>* = nullptr>
2629
inline void check_cov_matrix(const char* function, const char* name,
27-
const EigMat& y) {
30+
const Mat& y) {
2831
check_pos_definite(function, name, y);
2932
}
3033

34+
/**
35+
* Throw an exception if the specified matrix is not a valid covariance matrix.
36+
* A valid covariance matrix is a square, symmetric matrix that is positive
37+
* definite.
38+
* @tparam StdVec A standard vector with inner type either inheriting from
39+
* `MatrixBase` with neither rows or columns defined at compile time to be equal
40+
* to 1 or a `var_value` with the var's inner type inheriting from
41+
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to
42+
* be equal to 1
43+
* @param function Function name (for error messages)
44+
* @param name Variable name (for error messages)
45+
* @param y standard vector of matrices to test.
46+
* @throw `std::invalid_argument` if the matrix is not square
47+
* or if the matrix is 0x0
48+
* @throw `std::domain_error` if the matrix is not symmetric, if the matrix is
49+
* not positive definite, or if any element of the matrix is `NaN`
50+
*/
51+
template <typename StdVec, require_std_vector_t<StdVec>* = nullptr>
52+
void check_cov_matrix(const char* function, const char* name, const StdVec& y) {
53+
for (auto&& y_i : y) {
54+
check_cov_matrix(function, name, y_i);
55+
}
56+
}
57+
3158
} // namespace math
3259
} // namespace stan
3360
#endif

0 commit comments

Comments
 (0)