Skip to content

Commit 9207233

Browse files
Merge pull request #2771 from andrjohns/feature/as-int-fun
Add to_int function and tests
2 parents 57fe69b + 242901f commit 9207233

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@
338338
#include <stan/math/prim/fun/to_array_1d.hpp>
339339
#include <stan/math/prim/fun/to_array_2d.hpp>
340340
#include <stan/math/prim/fun/to_complex.hpp>
341+
#include <stan/math/prim/fun/to_int.hpp>
341342
#include <stan/math/prim/fun/to_matrix.hpp>
342343
#include <stan/math/prim/fun/to_ref.hpp>
343344
#include <stan/math/prim/fun/to_row_vector.hpp>

stan/math/prim/fun/to_int.hpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#ifndef STAN_MATH_PRIM_FUN_TO_INT_HPP
2+
#define STAN_MATH_PRIM_FUN_TO_INT_HPP
3+
4+
#include <stan/math/prim/err/check_bounded.hpp>
5+
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/**
11+
* Returns the input scalar as an integer type. Specialisation for integral
12+
* types which do not need conversion, reduces to a no-op.
13+
*
14+
* @tparam T type of integral argument
15+
* @param x argument
16+
* @return Input argument unchanged
17+
*/
18+
template <typename T, require_integral_t<T>* = nullptr>
19+
inline T to_int(T x) {
20+
return std::forward<T>(x);
21+
}
22+
23+
/**
24+
* Returns the input scalar as an integer type. This function performs no
25+
* rounding and simply truncates the decimal to return only the signficand as an
26+
* integer.
27+
*
28+
* Casting NaN and Inf values to integers is considered undefined behavior as
29+
* NaN and Inf cannot be represented as an integer and most implementations
30+
* simply overflow, as such this function throws for these inputs.
31+
*
32+
* The function also throws for floating-point values that are too large to be
33+
* represented as an integer.
34+
*
35+
* @tparam T type of argument (must be arithmetic)
36+
* @param x argument
37+
* @return Integer value of argument
38+
* @throw std::domain_error for NaN, Inf, or floating point values not in range
39+
* to be represented as int
40+
*/
41+
template <typename T, require_floating_point_t<T>* = nullptr>
42+
inline int to_int(T x) {
43+
static const char* function = "to_int";
44+
check_bounded(function, "x", x, std::numeric_limits<int>::min(),
45+
std::numeric_limits<int>::max());
46+
return static_cast<int>(x);
47+
}
48+
49+
/**
50+
* Return elementwise integer value of the specified real-valued
51+
* container.
52+
*
53+
* @tparam T type of argument
54+
* @param x argument
55+
* @return Integer value of argument
56+
*/
57+
struct to_int_fun {
58+
template <typename T>
59+
static inline auto fun(const T& x) {
60+
return to_int(x);
61+
}
62+
};
63+
64+
/**
65+
* Returns the elementwise `to_int()` of the input,
66+
* which may be a scalar or any Stan container of numeric scalars.
67+
*
68+
* @tparam Container type of container
69+
* @param x argument
70+
* @return Integer value of each variable in the container.
71+
*/
72+
template <typename Container,
73+
require_std_vector_st<std::is_arithmetic, Container>* = nullptr>
74+
inline auto to_int(const Container& x) {
75+
return apply_scalar_unary<to_int_fun, Container>::apply(x);
76+
}
77+
78+
} // namespace math
79+
} // namespace stan
80+
81+
#endif
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#include <stan/math/prim.hpp>
2+
#include <test/unit/util.hpp>
3+
#include <gtest/gtest.h>
4+
#include <cmath>
5+
#include <limits>
6+
7+
TEST(MathFunctions, to_int_values) {
8+
using stan::math::to_int;
9+
10+
EXPECT_EQ(2, to_int(2.0));
11+
EXPECT_EQ(2, to_int(2.1));
12+
EXPECT_EQ(2, to_int(2.9));
13+
EXPECT_EQ(2, to_int(2.999999999));
14+
15+
EXPECT_EQ(-36574, to_int(-36574.0));
16+
EXPECT_EQ(-36574, to_int(-36574.1));
17+
EXPECT_EQ(-36574, to_int(-36574.9));
18+
EXPECT_EQ(-36574, to_int(-36574.999999999));
19+
}
20+
21+
TEST(MathFunctions, to_int_errors) {
22+
using stan::math::INFTY;
23+
using stan::math::NEGATIVE_INFTY;
24+
using stan::math::NOT_A_NUMBER;
25+
using stan::math::to_int;
26+
27+
EXPECT_THROW(to_int(std::numeric_limits<int>::max() + 1.0),
28+
std::domain_error);
29+
EXPECT_THROW(to_int(std::numeric_limits<int>::min() - 1.0),
30+
std::domain_error);
31+
32+
EXPECT_THROW(to_int(NOT_A_NUMBER), std::domain_error);
33+
EXPECT_THROW(to_int(INFTY), std::domain_error);
34+
EXPECT_THROW(to_int(NEGATIVE_INFTY), std::domain_error);
35+
}
36+
37+
TEST(MathFunctions, to_int_vec) {
38+
using stan::math::to_int;
39+
40+
std::vector<double> inputs_1{2.1, -34.64, 10.89, 1000000};
41+
std::vector<double> inputs_2{-409831.987, 403.1, 10.61, -0.00001};
42+
std::vector<std::vector<double>> inputs{inputs_1, inputs_2};
43+
44+
std::vector<int> target_result_1{2, -34, 10, 1000000};
45+
std::vector<int> target_result_2{-409831, 403, 10, 0};
46+
std::vector<std::vector<int>> target_result{target_result_1, target_result_2};
47+
48+
EXPECT_STD_VECTOR_EQ(to_int(inputs), target_result);
49+
50+
inputs[0][2] = std::numeric_limits<int>::min() - 1.0;
51+
EXPECT_THROW(to_int(inputs), std::domain_error);
52+
53+
std::vector<double> inputs_empty;
54+
std::vector<double> inputs_size_one{1.5};
55+
56+
EXPECT_NO_THROW(to_int(inputs_empty));
57+
EXPECT_STD_VECTOR_EQ(to_int(inputs_size_one), std::vector<int>{1});
58+
}

0 commit comments

Comments
 (0)