-
-
Notifications
You must be signed in to change notification settings - Fork 193
Vectorize checks called by compiler #2556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b7dd98c
46ce2e8
11d11b4
df9255c
83224ea
a5ffbd5
e82fd29
a678bf4
f892c31
e318a49
e3f196c
c0fd490
6cba218
64c6019
f4ed5c1
87b15e5
383f6e9
3ff2da0
2d1029c
2b1b480
6bf2271
71e8b6b
82dc731
9994174
6b14ee7
07f31cb
ea2a9eb
3a36e2b
f96a182
4273c1e
f26b24c
4a02109
840075f
aa4d183
64c804f
31b69b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,71 @@ | ||
#ifndef STAN_MATH_PRIM_ERR_CHECK_CHOLESKY_FACTOR_HPP | ||
#define STAN_MATH_PRIM_ERR_CHECK_CHOLESKY_FACTOR_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/fun/to_ref.hpp> | ||
#include <stan/math/prim/fun/value_of_rec.hpp> | ||
#include <stan/math/prim/err/check_positive.hpp> | ||
#include <stan/math/prim/err/check_less_or_equal.hpp> | ||
#include <stan/math/prim/err/check_lower_triangular.hpp> | ||
#include <stan/math/prim/err/make_iter_name.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Check if the specified matrix is a valid Cholesky factor. | ||
* A Cholesky factor is a lower triangular matrix whose diagonal | ||
* elements are all positive. Note that Cholesky factors need not | ||
* be square, but require at least as many rows M as columns N | ||
* (i.e., M >= N). | ||
* @tparam EigMat Type of the Cholesky factor (must be derived from \c | ||
* Eigen::MatrixBase) | ||
* Throw an exception if the specified matrix is not a valid Cholesky factor. A | ||
* Cholesky factor is a lower triangular matrix whose diagonal elements are all | ||
* positive. Note that Cholesky factors need not be square, but require at | ||
* least as many rows M as columns N (i.e., `M >= N`). | ||
* @tparam Mat Type inheriting from `MatrixBase` with neither rows or columns | ||
* defined at compile time to be equal to 1 or a `var_value` with the var's | ||
* inner type inheriting from `Eigen::MatrixBase` with neither rows or columns | ||
* defined at compile time to be equal to 1 | ||
* @param function Function name (for error messages) | ||
* @param name Variable name (for error messages) | ||
* @param y Matrix to test | ||
* @throw <code>std::domain_error</code> if y is not a valid Cholesky | ||
* factor, if number of rows is less than the number of columns, | ||
* if there are 0 columns, or if any element in matrix is NaN | ||
* @throw `std::domain_error` if y is not a valid Cholesky factor, if number of | ||
* rows is less than the number of columns, if there are 0 columns, or if any | ||
* element in matrix is `NaN` | ||
*/ | ||
template <typename EigMat, require_eigen_t<EigMat>* = nullptr> | ||
template <typename Mat, require_matrix_t<Mat>* = nullptr> | ||
inline void check_cholesky_factor(const char* function, const char* name, | ||
const EigMat& y) { | ||
const Mat& y) { | ||
check_less_or_equal(function, "columns and rows of Cholesky factor", y.cols(), | ||
y.rows()); | ||
check_positive(function, "columns of Cholesky factor", y.cols()); | ||
const Eigen::Ref<const plain_type_t<EigMat>>& y_ref = y; | ||
auto&& y_ref = to_ref(value_of_rec(y)); | ||
check_lower_triangular(function, name, y_ref); | ||
check_positive(function, name, y_ref.diagonal()); | ||
} | ||
|
||
/** | ||
* Throw an exception if the specified matrix is not a valid Cholesky factor. A | ||
* Cholesky factor is a lower triangular matrix whose diagonal elements are all | ||
* positive. Note that Cholesky factors need not be square, but require at | ||
* least as many rows M as columns N (i.e., `M >= N`). | ||
* @tparam StdVec A standard vector with inner type either inheriting from | ||
* `MatrixBase` with neither rows or columns defined at compile time to be equal | ||
* to 1 or a `var_value` with the var's inner type inheriting from | ||
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to | ||
* be equal to 1 | ||
* @param function Function name (for error messages) | ||
* @param name Variable name (for error messages) | ||
* @param y Standard vector of matrices to test | ||
* @throw `std::domain_error` if y is not a valid Cholesky factor, if number of | ||
* rows is less than the number of columns, if there are 0 columns, or if any | ||
* element in matrix is `NaN` | ||
*/ | ||
template <typename StdVec, require_std_vector_t<StdVec>* = nullptr> | ||
void check_cholesky_factor(const char* function, const char* name, | ||
const StdVec& y) { | ||
for (size_t i = 0; i < y.size(); ++i) { | ||
check_cholesky_factor(function, internal::make_iter_name(name, i).c_str(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs a performance evaluation as it's going to proactively create string names for each entry, which is pretty expensive. [question] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with running a little performance check here, I def expect it to be slower for small matrices though hopefully not much.
I tried thinking about this but the only thing I could figure out is to change all of the checks to take in a lambda instead of a const char* that doesn't evaluate until a throw occurs. |
||
y[i]); | ||
} | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,35 +4,39 @@ | |
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/fun/to_ref.hpp> | ||
#include <stan/math/prim/fun/value_of_rec.hpp> | ||
#include <stan/math/prim/err/check_positive.hpp> | ||
#include <stan/math/prim/err/check_lower_triangular.hpp> | ||
#include <stan/math/prim/err/check_square.hpp> | ||
#include <stan/math/prim/err/check_unit_vector.hpp> | ||
#include <stan/math/prim/err/make_iter_name.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Check if the specified matrix is a valid Cholesky factor of a | ||
* correlation matrix. | ||
* A Cholesky factor is a lower triangular matrix whose diagonal | ||
* elements are all positive. Note that Cholesky factors need not | ||
* be square, but require at least as many rows M as columns N | ||
* (i.e., M >= N). | ||
* Tolerance is specified by <code>math::CONSTRAINT_TOLERANCE</code>. | ||
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and | ||
* columns. | ||
* Throw an exception if the specified matrix is not a valid Cholesky factor of | ||
* a correlation matrix. A Cholesky factor is a lower triangular matrix whose | ||
* diagonal elements are all positive and each row has unit Euclidean length. | ||
* Note that Cholesky factors need not be square, but require at least as many | ||
* rows M as columns N (i.e., `M >= N`). Tolerance is specified by | ||
* `math::CONSTRAINT_TOLERANCE`. Tolerance is specified by | ||
* `math::CONSTRAINT_TOLERANCE`. | ||
* @tparam Mat Type inheriting from `MatrixBase` with neither rows or columns | ||
* defined at compile time to be equal to 1 or a `var_value` with the var's | ||
* inner type inheriting from `Eigen::MatrixBase` with neither rows or columns | ||
* defined at compile time to be equal to 1 | ||
* @param function Function name (for error messages) | ||
* @param name Variable name (for error messages) | ||
* @param y Matrix to test | ||
* @throw <code>std::domain_error</code> if y is not a valid Cholesky | ||
* factor, if number of rows is less than the number of columns, | ||
* if there are 0 columns, or if any element in matrix is NaN | ||
* @throw `std::domain_error` if y is not a valid Cholesky factor, if number of | ||
* rows is less than the number of columns, if there are 0 columns, or if any | ||
* element in matrix is NaN | ||
*/ | ||
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr> | ||
template <typename Mat, require_matrix_t<Mat>* = nullptr> | ||
void check_cholesky_factor_corr(const char* function, const char* name, | ||
const EigMat& y) { | ||
const auto& y_ref = to_ref(y); | ||
const Mat& y) { | ||
const auto& y_ref = to_ref(value_of_rec(y)); | ||
check_square(function, name, y_ref); | ||
check_lower_triangular(function, name, y_ref); | ||
check_positive(function, name, y_ref.diagonal()); | ||
|
@@ -41,6 +45,34 @@ void check_cholesky_factor_corr(const char* function, const char* name, | |
} | ||
} | ||
|
||
/** | ||
* Throw an exception if the specified matrix is not a valid Cholesky factor of | ||
* a correlation matrix. A Cholesky factor is a lower triangular matrix whose | ||
* diagonal elements are all positive and each row has unit Euclidean length. | ||
* Note that Cholesky factors need not be square, but require at least as many | ||
* rows M as columns N (i.e., `M >= N`). Tolerance is specified by | ||
* `math::CONSTRAINT_TOLERANCE`. Tolerance is specified by | ||
* `math::CONSTRAINT_TOLERANCE`. | ||
* @tparam StdVec A standard vector with inner type either inheriting from | ||
* `MatrixBase` with neither rows or columns defined at compile time to be equal | ||
* to 1 or a `var_value` with the var's inner type inheriting from | ||
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to | ||
* be equal to 1 | ||
* @param function Function name (for error messages) | ||
* @param name Variable name (for error messages) | ||
* @param y Standard vector of matrics to test | ||
* @throw `std::domain_error` if y[i] is not a valid Cholesky factor, if number | ||
* of rows is less than the number of columns, if there are 0 columns, or if any | ||
* element in matrix is NaN | ||
*/ | ||
template <typename StdVec, require_std_vector_t<StdVec>* = nullptr> | ||
void check_cholesky_factor_corr(const char* function, const char* name, | ||
const StdVec& y) { | ||
for (size_t i = 0; i < y.size(); ++i) { | ||
check_cholesky_factor_corr(function, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question for all of these and efficiency. |
||
internal::make_iter_name(name, i).c_str(), y[i]); | ||
} | ||
} | ||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,11 @@ | |
|
||
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/fun/to_ref.hpp> | ||
#include <stan/math/prim/fun/value_of_rec.hpp> | ||
#include <stan/math/prim/err/throw_domain_error.hpp> | ||
#include <stan/math/prim/err/check_pos_definite.hpp> | ||
#include <stan/math/prim/err/check_square.hpp> | ||
#include <stan/math/prim/fun/to_ref.hpp> | ||
#include <sstream> | ||
#include <string> | ||
#include <cmath> | ||
|
@@ -15,44 +16,68 @@ namespace stan { | |
namespace math { | ||
|
||
/** | ||
* Check if the specified matrix is a valid correlation matrix. | ||
* A valid correlation matrix is symmetric, has a unit diagonal | ||
* (all 1 values), and has all values between -1 and 1 | ||
* (inclusive). | ||
* This function throws exceptions if the variable is not a valid | ||
* correlation matrix. | ||
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and | ||
* columns. | ||
* Throw an exception if the specified matrix is not a valid correlation matrix. | ||
* A valid correlation matrix is symmetric positive definite, has a unit | ||
* diagonal (all 1 values), and has all values between -1 and 1 (inclusive). | ||
* @tparam Mat Type inheriting from `MatrixBase` with neither rows or columns | ||
* defined at compile time to be equal to 1 or a `var_value` with the var's | ||
* inner type inheriting from `Eigen::MatrixBase` with neither rows or columns | ||
* defined at compile time to be equal to 1 | ||
* @param function Name of the function this was called from | ||
* @param name Name of the variable | ||
* @param y Matrix to test | ||
* @throw <code>std::invalid_argument</code> if the matrix is not square | ||
* @throw <code>std::domain_error</code> if the matrix is non-symmetric, | ||
* diagonals not near 1, not positive definite, or any of the | ||
* elements nan | ||
* @throw `std::invalid_argument` if the matrix is not square | ||
* @throw `std::domain_error` if the matrix is non-symmetric, diagonals not near | ||
* 1, not positive definite, or any of the elements are `NaN` | ||
*/ | ||
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr> | ||
template <typename Mat, require_matrix_t<Mat>* = nullptr> | ||
inline void check_corr_matrix(const char* function, const char* name, | ||
const EigMat& y) { | ||
const auto& y_ref = to_ref(y); | ||
const Mat& y) { | ||
auto&& y_ref = to_ref(value_of_rec(y)); | ||
check_square(function, name, y_ref); | ||
using std::fabs; | ||
if (y_ref.size() == 0) { | ||
return; | ||
} | ||
|
||
for (Eigen::Index k = 0; k < y.rows(); ++k) { | ||
if (!(fabs(y_ref(k, k) - 1.0) <= CONSTRAINT_TOLERANCE)) { | ||
std::ostringstream msg; | ||
msg << "is not a valid correlation matrix. " << name << "(" | ||
<< stan::error_index::value + k << "," << stan::error_index::value + k | ||
<< ") is "; | ||
std::string msg_str(msg.str()); | ||
throw_domain_error(function, name, y_ref(k, k), msg_str.c_str(), | ||
", but should be near 1.0"); | ||
if (!(fabs(y_ref.coeff(k, k) - 1.0) <= CONSTRAINT_TOLERANCE)) { | ||
[&y_ref, name, k, function]() STAN_COLD_PATH { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See issue #2249 for more info, but essentially STAN_COLD_PATH is equal to |
||
std::ostringstream msg; | ||
msg << "is not a valid correlation matrix. " << name << "(" | ||
<< stan::error_index::value + k << "," | ||
<< stan::error_index::value + k << ") is "; | ||
std::string msg_str(msg.str()); | ||
throw_domain_error(function, name, y_ref(k, k), msg_str.c_str(), | ||
", but should be near 1.0"); | ||
}(); | ||
} | ||
} | ||
check_pos_definite(function, "y", y_ref); | ||
check_pos_definite(function, name, y_ref); | ||
} | ||
|
||
/** | ||
* Throw an exception if the specified matrix is not a valid correlation matrix. | ||
* A valid correlation matrix is symmetric positive definite, has a unit | ||
* diagonal (all 1 values), and has all values between -1 and 1 (inclusive). | ||
* @tparam StdVec A standard vector with inner type either inheriting from | ||
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to | ||
* be equal to 1 or a `var_value` with the var's inner type inheriting from | ||
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to | ||
* be equal to 1. | ||
* @param function Name of the function this was called from | ||
* @param name Name of the variable | ||
* @param y Matrix to test | ||
* @throw `std::invalid_argument` if the matrix is not square | ||
* @throw `std::domain_error` if the matrix is non-symmetric, diagonals not near | ||
* 1, not positive definite, or any of the elements are `NaN` | ||
*/ | ||
template <typename StdVec, require_std_vector_t<StdVec>* = nullptr> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a public function and thus should have doc. |
||
void check_corr_matrix(const char* function, const char* name, | ||
const StdVec& y) { | ||
for (auto&& y_i : y) { | ||
check_corr_matrix(function, name, y_i); | ||
} | ||
} | ||
|
||
} // namespace math | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,31 +3,58 @@ | |
|
||
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/fun/to_ref.hpp> | ||
#include <stan/math/prim/fun/value_of_rec.hpp> | ||
#include <stan/math/prim/err/check_pos_definite.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
/** | ||
* Check if the specified matrix is a valid covariance matrix. | ||
* A valid covariance matrix is a square, symmetric matrix that is | ||
* positive definite. | ||
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and | ||
* columns. | ||
* Throw an exception if the specified matrix is not a valid covariance matrix. | ||
* A valid covariance matrix is a square, symmetric matrix that is positive | ||
* definite. | ||
* @tparam Mat Type inheriting from `MatrixBase` with neither rows or columns | ||
* defined at compile time to be equal to 1 or a `var_value` with the var's | ||
* inner type inheriting from `Eigen::MatrixBase` with neither rows or columns | ||
* defined at compile time to be equal to 1 | ||
* @param function Function name (for error messages) | ||
* @param name Variable name (for error messages) | ||
* @param y Matrix to test | ||
* @throw <code>std::invalid_argument</code> if the matrix is not square | ||
* or if the matrix is 0x0 | ||
* @throw <code>std::domain_error</code> if the matrix is not symmetric, | ||
* if the matrix is not positive definite, | ||
* or if any element of the matrix is nan | ||
* @throw `std::invalid_argument` if the matrix is not square or if the matrix | ||
* is 0x0 | ||
* @throw `std::domain_error` if the matrix is not symmetric, if the matrix is | ||
* not positive definite, or if any element of the matrix is `NaN` | ||
*/ | ||
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr> | ||
template <typename Mat, require_matrix_t<Mat>* = nullptr> | ||
inline void check_cov_matrix(const char* function, const char* name, | ||
const EigMat& y) { | ||
const Mat& y) { | ||
check_pos_definite(function, name, y); | ||
} | ||
|
||
/** | ||
* Throw an exception if the specified matrix is not a valid covariance matrix. | ||
* A valid covariance matrix is a square, symmetric matrix that is positive | ||
* definite. | ||
* @tparam StdVec A standard vector with inner type either inheriting from | ||
* `MatrixBase` with neither rows or columns defined at compile time to be equal | ||
* to 1 or a `var_value` with the var's inner type inheriting from | ||
* `Eigen::MatrixBase` with neither rows or columns defined at compile time to | ||
* be equal to 1 | ||
* @param function Function name (for error messages) | ||
* @param name Variable name (for error messages) | ||
* @param y standard vector of matrices to test. | ||
* @throw `std::invalid_argument` if the matrix is not square | ||
* or if the matrix is 0x0 | ||
* @throw `std::domain_error` if the matrix is not symmetric, if the matrix is | ||
* not positive definite, or if any element of the matrix is `NaN` | ||
*/ | ||
template <typename StdVec, require_std_vector_t<StdVec>* = nullptr> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs doc. |
||
void check_cov_matrix(const char* function, const char* name, const StdVec& y) { | ||
for (auto&& y_i : y) { | ||
check_cov_matrix(function, name, y_i); | ||
} | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional]
I'd rather keep shorter template parameters, like just
V
for standard vectors or maybeC
for generic containers.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For very generic functions I like using
V
orT
etc. but for functions with specific requirements I like that the template parameter's name gives an idea of what the requirement is to use the function.