Skip to content

Revert "Feature/elementwise checks revert revert" #2147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions stan/math/prim/err.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <stan/math/prim/err/check_corr_matrix.hpp>
#include <stan/math/prim/err/check_cov_matrix.hpp>
#include <stan/math/prim/err/check_finite.hpp>
#include <stan/math/prim/err/check_finite_screen.hpp>
#include <stan/math/prim/err/check_flag_sundials.hpp>
#include <stan/math/prim/err/check_greater.hpp>
#include <stan/math/prim/err/check_greater_or_equal.hpp>
Expand All @@ -27,7 +26,6 @@
#include <stan/math/prim/err/check_nonnegative.hpp>
#include <stan/math/prim/err/check_nonzero_size.hpp>
#include <stan/math/prim/err/check_not_nan.hpp>
#include <stan/math/prim/err/check_not_nan_screen.hpp>
#include <stan/math/prim/err/check_ordered.hpp>
#include <stan/math/prim/err/check_sorted.hpp>
#include <stan/math/prim/err/check_pos_definite.hpp>
Expand Down
136 changes: 126 additions & 10 deletions stan/math/prim/err/check_finite.hpp
Original file line number Diff line number Diff line change
@@ -1,27 +1,143 @@
#ifndef STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
#define STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP

#include <stan/math/prim/err/elementwise_check.hpp>
#include <stan/math/prim/err/check_finite_screen.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/is_scal_finite.hpp>
#include <stan/math/prim/err/throw_domain_error.hpp>
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/get.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <cmath>

namespace stan {
namespace math {
namespace internal {
/**
* Return true if y is finite
*
* @tparam T_y type of y
* @param y parameter to check
* @return boolean
*/
template <typename T_y>
bool is_finite(const T_y& y) {
return is_scal_finite(y);
}

/**
* Return true if every element of the matrix y is finite
*
* @tparam T_y type of elements y
* @param y matrix to check
* @return boolean
*/
template <typename T_y, int R, int C>
bool is_finite(const Eigen::Matrix<T_y, R, C>& y) {
bool all = true;
for (size_t n = 0; n < y.size(); ++n) {
all &= is_finite(y(n));
}
return all;
}

/**
* Return true if every element of the vector y is finite
*
* @tparam T_y type of elements y
* @param y vector to check
* @return boolean
*/
template <typename T_y>
bool is_finite(const std::vector<T_y>& y) {
bool all = true;
for (size_t n = 0; n < stan::math::size(y); ++n) {
all &= is_finite(y[n]);
}
return all;
}
} // namespace internal

/**
* Check if <code>y</code> is finite.
* This function is vectorized and will check each element of
* <code>y</code>.
* @tparam T_y type of y
* @param function function name (for error messages)
* @param name variable name (for error messages)
* @param y variable to check
* @tparam T_y Type of y
* @param function Function name (for error messages)
* @param name Variable name (for error messages)
* @param y Variable to check
* @throw <code>domain_error</code> if y is infinity, -infinity, or NaN
*/
template <typename T_y>
template <typename T_y, require_stan_scalar_t<T_y>* = nullptr>
inline void check_finite(const char* function, const char* name, const T_y& y) {
if (check_finite_screen(y)) {
auto is_good = [](const auto& y) { return std::isfinite(y); };
elementwise_check(is_good, function, name, y, ", but must be finite!");
if (!internal::is_finite(y)) {
throw_domain_error(function, name, y, "is ", ", but must be finite!");
}
}

/**
* Return <code>true</code> if all values in the std::vector are finite.
*
* @tparam T_y type of elements in the std::vector
*
* @param function name of function (for error messages)
* @param name variable name (for error messages)
* @param y std::vector to test
* @return <code>true</code> if all values are finite
**/
template <typename T_y, require_stan_scalar_t<T_y>* = nullptr>
inline void check_finite(const char* function, const char* name,
const std::vector<T_y>& y) {
for (size_t n = 0; n < stan::math::size(y); n++) {
if (!internal::is_finite(stan::get(y, n))) {
throw_domain_error_vec(function, name, y, n, "is ",
", but must be finite!");
}
}
}

/**
* Return <code>true</code> is the specified matrix is finite.
*
* @tparam Derived Eigen derived type
*
* @param function name of function (for error messages)
* @param name variable name (for error messages)
* @param y matrix to test
* @return <code>true</code> if the matrix is finite
**/
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
inline void check_finite(const char* function, const char* name,
const EigMat& y) {
if (!value_of(y).allFinite()) {
for (int n = 0; n < y.size(); ++n) {
if (!std::isfinite(value_of_rec(y(n)))) {
throw_domain_error_vec(function, name, y, n, "is ",
", but must be finite!");
}
}
}
}

/**
* Return <code>true</code> if all values in the std::vector are finite.
*
* @tparam T_y type of elements in the std::vector
*
* @param function name of function (for error messages)
* @param name variable name (for error messages)
* @param y std::vector to test
* @return <code>true</code> if all values are finite
**/
template <typename T_y, require_not_stan_scalar_t<T_y>* = nullptr>
inline void check_finite(const char* function, const char* name,
const std::vector<T_y>& y) {
for (size_t n = 0; n < stan::math::size(y); n++) {
if (!internal::is_finite(stan::get(y, n))) {
throw_domain_error(function, name, "", "", "is not finite!");
}
}
}

Expand Down
43 changes: 0 additions & 43 deletions stan/math/prim/err/check_finite_screen.hpp

This file was deleted.

38 changes: 34 additions & 4 deletions stan/math/prim/err/check_nonnegative.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,42 @@
#ifndef STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP
#define STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP

#include <stan/math/prim/err/elementwise_check.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/throw_domain_error.hpp>
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
#include <stan/math/prim/fun/get.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <type_traits>

namespace stan {
namespace math {

namespace internal {
template <typename T_y, bool is_vec>
struct nonnegative {
static void check(const char* function, const char* name, const T_y& y) {
// have to use not is_unsigned. is_signed will be false for
// floating point types that have no unsigned versions.
if (!std::is_unsigned<T_y>::value && !(y >= 0)) {
throw_domain_error(function, name, y, "is ", ", but must be >= 0!");
}
}
};

template <typename T_y>
struct nonnegative<T_y, true> {
static void check(const char* function, const char* name, const T_y& y) {
for (size_t n = 0; n < stan::math::size(y); n++) {
if (!std::is_unsigned<typename value_type<T_y>::type>::value
&& !(stan::get(y, n) >= 0)) {
throw_domain_error_vec(function, name, y, n, "is ",
", but must be >= 0!");
}
}
}
};
} // namespace internal

/**
* Check if <code>y</code> is non-negative.
* This function is vectorized and will check each element of <code>y</code>.
Expand All @@ -19,10 +50,9 @@ namespace math {
template <typename T_y>
inline void check_nonnegative(const char* function, const char* name,
const T_y& y) {
auto is_good = [](const auto& y) { return y >= 0; };
elementwise_check(is_good, function, name, y, ", but must be >= 0!");
internal::nonnegative<T_y, is_vector_like<T_y>::value>::check(function, name,
y);
}

} // namespace math
} // namespace stan
#endif
37 changes: 31 additions & 6 deletions stan/math/prim/err/check_not_nan.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
#ifndef STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP
#define STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP

#include <stan/math/prim/err/elementwise_check.hpp>
#include <stan/math/prim/err/check_not_nan_screen.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/throw_domain_error.hpp>
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
#include <stan/math/prim/fun/get.hpp>
#include <stan/math/prim/fun/is_nan.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>

namespace stan {
namespace math {

namespace internal {
template <typename T_y, bool is_vec>
struct not_nan {
static void check(const char* function, const char* name, const T_y& y) {
if (is_nan(value_of_rec(y))) {
throw_domain_error(function, name, y, "is ", ", but must not be nan!");
}
}
};

template <typename T_y>
struct not_nan<T_y, true> {
static void check(const char* function, const char* name, const T_y& y) {
for (size_t n = 0; n < stan::math::size(y); n++) {
if (is_nan(value_of_rec(stan::get(y, n)))) {
throw_domain_error_vec(function, name, y, n, "is ",
", but must not be nan!");
}
}
}
};
} // namespace internal

/**
* Check if <code>y</code> is not <code>NaN</code>.
* This function is vectorized and will check each element of
Expand All @@ -21,10 +49,7 @@ namespace math {
template <typename T_y>
inline void check_not_nan(const char* function, const char* name,
const T_y& y) {
if (check_not_nan_screen(y)) {
auto is_good = [](const auto& y) { return !std::isnan(y); };
elementwise_check(is_good, function, name, y, ", but must not be nan!");
}
internal::not_nan<T_y, is_vector_like<T_y>::value>::check(function, name, y);
}

} // namespace math
Expand Down
43 changes: 0 additions & 43 deletions stan/math/prim/err/check_not_nan_screen.hpp

This file was deleted.

Loading