Skip to content

Commit df0bc12

Browse files
committed
allows complex abs() to be vectorized
1 parent 43ec11b commit df0bc12

File tree

2 files changed

+182
-57
lines changed

2 files changed

+182
-57
lines changed

stan/math/prim/fun/abs.hpp

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,10 @@ namespace math {
2222
* @return absolute value of argument
2323
*/
2424
template <typename T, require_arithmetic_t<T>* = nullptr>
25-
T abs(T x) {
25+
inline T abs(T x) {
2626
return std::abs(x);
2727
}
2828

29-
/*
30-
* Return the elementwise absolute value of the specified container.
31-
*
32-
* @tparam T type of elements in the vector
33-
* @param x vector argument
34-
* @return elementwise absolute value of argument
35-
*/
36-
template <typename T>
37-
std::vector<T> abs(const std::vector<T>& x) {
38-
std::vector<T> y(x.size());
39-
for (size_t n = 0; n < x.size(); ++n)
40-
y[n] = abs(x[n]);
41-
return y;
42-
}
43-
44-
/**
45-
* Return the elementwise absolute value of the specified matrix,
46-
* vector, or row vector.
47-
*
48-
* @tparam T type of scalar for matrix argument (real or complex)
49-
* @tparam R row specification (1 or -1)
50-
* @tparam C column specification (1 or -1)
51-
* @param x argument
52-
* @return elementwise absolute value of argument
53-
*/
54-
template <typename T, int R, int C>
55-
Eigen::Matrix<T, R, C> abs(const Eigen::Matrix<T, R, C>& x) {
56-
return fabs(x);
57-
}
58-
5929
/**
6030
* Return the absolute value (also known as the norm, modulus, or
6131
* magnitude) of the specified complex argument.
@@ -65,7 +35,7 @@ Eigen::Matrix<T, R, C> abs(const Eigen::Matrix<T, R, C>& x) {
6535
* @return absolute value of argument (a real number)
6636
*/
6737
template <typename T, require_complex_t<T>* = nullptr>
68-
auto abs(T x) {
38+
inline auto abs(T x) {
6939
return hypot(x.real(), x.imag());
7040
}
7141

@@ -80,7 +50,7 @@ auto abs(T x) {
8050
struct abs_fun {
8151
template <typename T>
8252
static inline T fun(const T& x) {
83-
return fabs(x);
53+
return abs(x);
8454
}
8555
};
8656

@@ -92,28 +62,28 @@ struct abs_fun {
9262
* @param x argument
9363
* @return Absolute value of each variable in the container.
9464
*/
95-
// template <typename Container,
96-
// require_not_container_st<std::is_arithmetic, Container>* = nullptr,
97-
// require_not_var_matrix_t<Container>* = nullptr,
98-
// require_not_stan_scalar_t<Container>* = nullptr>
99-
// inline auto abs(const Container& x) {
100-
// return apply_scalar_unary<abs_fun, Container>::apply(x);
101-
// }
65+
template <typename Container,
66+
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
67+
require_not_var_matrix_t<Container>* = nullptr,
68+
require_not_stan_scalar_t<Container>* = nullptr>
69+
inline auto abs(const Container& x) {
70+
return apply_scalar_unary<abs_fun, Container>::apply(x);
71+
}
10272

103-
// /**
104-
// * Version of `abs()` that accepts std::vectors, Eigen Matrix/Array objects
105-
// * or expressions, and containers of these.
106-
// *
107-
// * @tparam Container Type of x
108-
// * @param x argument
109-
// * @return Absolute value of each variable in the container.
110-
// */
111-
// template <typename Container,
112-
// require_container_st<std::is_arithmetic, Container>* = nullptr>
113-
// inline auto abs(const Container& x) {
114-
// return apply_vector_unary<Container>::apply(
115-
// x, [&](const auto& v) { return v.array().abs(); });
116-
// }
73+
/**
74+
* Version of `abs()` that accepts std::vectors, Eigen Matrix/Array objects
75+
* or expressions, and containers of these.
76+
*
77+
* @tparam Container Type of x
78+
* @param x argument
79+
* @return Absolute value of each variable in the container.
80+
*/
81+
template <typename Container,
82+
require_container_st<std::is_arithmetic, Container>* = nullptr>
83+
inline auto abs(const Container& x) {
84+
return apply_vector_unary<Container>::apply(
85+
x, [&](const auto& v) { return v.array().abs(); });
86+
}
11787

11888
namespace internal {
11989
/**

test/unit/math/mix/fun/abs_test.cpp

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ TEST(mixFun, absBasics) {
2020

2121
TEST(mixFun, abs) {
2222
auto f = [](const auto& x) {
23-
using std::abs;
24-
return abs(x);
23+
return stan::math::abs(x);
2524
};
2625
stan::test::expect_common_nonzero_unary(f);
26+
// 0 (no derivative at 0)
2727
stan::test::expect_value(f, 0);
2828
stan::test::expect_value(f, 0.0);
2929

@@ -37,12 +37,167 @@ TEST(mixFun, abs) {
3737
stan::test::expect_ad(f, 2.0);
3838
stan::test::expect_ad(f, 4.0);
3939

40-
// not differentiable at zero
40+
// complex tests
4141
for (double re : std::vector<double>{-4, -2.5, -1.5, -0.3, 1.3, 2.1, 3.9}) {
4242
for (double im : std::vector<double>{-4, -2.5, -1.5, -0.3, 1.3, 2.1, 3.9}) {
4343
stan::test::expect_ad(f, std::complex<double>(re, im));
4444
}
4545
}
46+
47+
48+
// vector<double>
49+
using svd_t = std::vector<double>;
50+
stan::test::expect_ad(f, svd_t{});
51+
stan::test::expect_ad(f, svd_t{1.0});
52+
stan::test::expect_ad(f, svd_t{1.9, -2.3});
53+
54+
// vector<vector<double>>
55+
using svvd_t = std::vector<svd_t>;
56+
stan::test::expect_ad(f, svvd_t{});
57+
stan::test::expect_ad(f, svvd_t{svd_t{}});
58+
stan::test::expect_ad(f, svvd_t{svd_t{1.9, 4.8}});
59+
stan::test::expect_ad(f, svvd_t{svd_t{1.9}, svd_t{-13.987}});
60+
stan::test::expect_ad(f, svvd_t{svd_t{1.9, -2.7}, svd_t{-13.987, 8.8}});
61+
62+
// vector<complex<double>>
63+
using c_t = std::complex<double>;
64+
using svc_t = std::vector<c_t>;
65+
stan::test::expect_ad(f, svc_t{});
66+
stan::test::expect_ad(f, svc_t{c_t{1.0, -1.9}});
67+
stan::test::expect_ad(f, svc_t{c_t{1.0, -1.9}, c_t{-9.3, -128.987654}});
68+
69+
// vector<vector<complex<double>>>
70+
using svvc_t = std::vector<svc_t>;
71+
stan::test::expect_ad(f, svvc_t{});
72+
stan::test::expect_ad(f, svvc_t{{}});
73+
stan::test::expect_ad(f, svvc_t{svc_t{c_t{1.2, -2.3}, c_t{-32.8, 1}}});
74+
stan::test::expect_ad(f, svvc_t{svc_t{c_t{1.2, -2.3}, c_t{-32.8, 1}},
75+
svc_t{c_t{9.3, 9.4}, c_t{182, -95}}});
76+
77+
// VectorXd
78+
using v_t = Eigen::VectorXd;
79+
v_t a0(0);
80+
stan::test::expect_ad(f, a0);
81+
v_t a1(1);
82+
a1 << 1.9;
83+
stan::test::expect_ad(f, a1);
84+
v_t a2(2);
85+
a2 << 1.9, -2.3;
86+
stan::test::expect_ad(f, a2);
87+
88+
// RowVectorXd
89+
using rv_t = Eigen::RowVectorXd;
90+
rv_t b0(0);
91+
stan::test::expect_ad(f, b0);
92+
rv_t b1(1);
93+
b1 << 1.9;
94+
stan::test::expect_ad(f, b1);
95+
rv_t b2(2);
96+
b2 << 1.9, -2.3;
97+
stan::test::expect_ad(f, b2);
98+
99+
// MatrixXd
100+
using m_t = Eigen::MatrixXd;
101+
m_t c0(0, 0);
102+
stan::test::expect_ad(f, c0);
103+
m_t c0i(0, 2);
104+
stan::test::expect_ad(f, c0i);
105+
m_t c0ii(2, 0);
106+
stan::test::expect_ad(f, c0ii);
107+
m_t c2(2, 1);
108+
c2 << 1.3, -2.9;
109+
stan::test::expect_ad(f, c2);
110+
m_t c6(3, 2);
111+
c6 << 1.3, 2.9, -13.456, 1.898, -0.01, 1.87e21;
112+
stan::test::expect_ad(f, c6);
113+
114+
// vector<VectorXd>
115+
using av_t = std::vector<Eigen::VectorXd>;
116+
av_t d0;
117+
stan::test::expect_ad(f, d0);
118+
av_t d1{a0};
119+
stan::test::expect_ad(f, d1);
120+
av_t d2{a1, a2};
121+
stan::test::expect_ad(f, d2);
122+
123+
// vector<RowVectorXd>
124+
using arv_t = std::vector<Eigen::RowVectorXd>;
125+
arv_t e0;
126+
stan::test::expect_ad(f, e0);
127+
arv_t e1{b0};
128+
stan::test::expect_ad(f, e1);
129+
arv_t e2{b1, b2};
130+
stan::test::expect_ad(f, e2);
131+
132+
// vector<MatrixXd>
133+
using am_t = std::vector<Eigen::MatrixXd>;
134+
am_t g0;
135+
stan::test::expect_ad(f, g0);
136+
am_t g1{c0};
137+
stan::test::expect_ad(f, g1);
138+
am_t g2{c2, c6};
139+
stan::test::expect_ad(f, g2);
140+
141+
// VectorXcd
142+
using vc_t = Eigen::VectorXcd;
143+
vc_t h0(0);
144+
stan::test::expect_ad(f, h0);
145+
vc_t h1(1);
146+
h1 << c_t{1.9, -1.8};
147+
stan::test::expect_ad(f, h1);
148+
vc_t h2(2);
149+
h2 << c_t{1.9, -1.8}, c_t{-128.7, 1.3};
150+
stan::test::expect_ad(f, h2);
151+
152+
// RowVectorXcd
153+
using rvc_t = Eigen::RowVectorXcd;
154+
rvc_t j0(0);
155+
stan::test::expect_ad(f, j0);
156+
rvc_t j1(1);
157+
j1 << c_t{1.9, -1.8};
158+
stan::test::expect_ad(f, j1);
159+
rvc_t j2(2);
160+
j2 << c_t{1.9, -1.8}, c_t{-128.7, 1.3};
161+
stan::test::expect_ad(f, j2);
162+
163+
// MatrixXcd
164+
using mc_t = Eigen::MatrixXcd;
165+
mc_t k0(0, 0);
166+
stan::test::expect_ad(f, k0);
167+
mc_t k2(1, 2);
168+
k2 << c_t{1.9, -1.8}, c_t{128.735, 128.734};
169+
stan::test::expect_ad(f, k2);
170+
mc_t k6(3, 2);
171+
k6 << c_t{1.9, -1.8}, c_t{-128.7, 1.3}, c_t{1, 2}, c_t{0.3, -0.5},
172+
c_t{-13, 125.7}, c_t{-12.5, -10.5};
173+
stan::test::expect_ad(f, k6);
174+
175+
// vector<VectorXcd>
176+
using avc_t = std::vector<vc_t>;
177+
avc_t m0;
178+
stan::test::expect_ad(f, m0);
179+
avc_t m1{h1};
180+
stan::test::expect_ad(f, m1);
181+
avc_t m2{h1, h2};
182+
stan::test::expect_ad(f, m2);
183+
184+
// vector<RowVectorXcd>
185+
using arvc_t = std::vector<rvc_t>;
186+
arvc_t p0(0);
187+
stan::test::expect_ad(f, p0);
188+
arvc_t p1{j1};
189+
stan::test::expect_ad(f, p1);
190+
arvc_t p2{j1, j2};
191+
stan::test::expect_ad(f, p2);
192+
193+
// vector<MatrixXcd>
194+
using amc_t = std::vector<mc_t>;
195+
amc_t q0;
196+
stan::test::expect_ad(f, q0);
197+
amc_t q1{k2};
198+
stan::test::expect_ad(f, q1);
199+
amc_t q2{k2, k6};
200+
stan::test::expect_ad(f, q2);
46201
}
47202
TEST(mixFun, absReturnType) {
48203
// validate return types not overpromoted to complex by assignability

0 commit comments

Comments
 (0)