Skip to content

Commit a0511be

Browse files
authored
Merge pull request #2737 from stan-dev/feature/complex-abs-vectorized
allows complex abs() to be vectorized
2 parents 8f73b3c + 1c28dd7 commit a0511be

File tree

2 files changed

+201
-59
lines changed

2 files changed

+201
-59
lines changed

stan/math/prim/fun/abs.hpp

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,10 @@ namespace math {
2222
* @return absolute value of argument
2323
*/
2424
template <typename T, require_arithmetic_t<T>* = nullptr>
25-
T abs(T x) {
25+
inline T abs(T x) {
2626
return std::abs(x);
2727
}
2828

29-
/*
30-
* Return the elementwise absolute value of the specified container.
31-
*
32-
* @tparam T type of elements in the vector
33-
* @param x vector argument
34-
* @return elementwise absolute value of argument
35-
*/
36-
template <typename T>
37-
std::vector<T> abs(const std::vector<T>& x) {
38-
std::vector<T> y(x.size());
39-
for (size_t n = 0; n < x.size(); ++n)
40-
y[n] = abs(x[n]);
41-
return y;
42-
}
43-
44-
/**
45-
* Return the elementwise absolute value of the specified matrix,
46-
* vector, or row vector.
47-
*
48-
* @tparam T type of scalar for matrix argument (real or complex)
49-
* @tparam R row specification (1 or -1)
50-
* @tparam C column specification (1 or -1)
51-
* @param x argument
52-
* @return elementwise absolute value of argument
53-
*/
54-
template <typename T, int R, int C>
55-
Eigen::Matrix<T, R, C> abs(const Eigen::Matrix<T, R, C>& x) {
56-
return fabs(x);
57-
}
58-
5929
/**
6030
* Return the absolute value (also known as the norm, modulus, or
6131
* magnitude) of the specified complex argument.
@@ -65,7 +35,7 @@ Eigen::Matrix<T, R, C> abs(const Eigen::Matrix<T, R, C>& x) {
6535
* @return absolute value of argument (a real number)
6636
*/
6737
template <typename T, require_complex_t<T>* = nullptr>
68-
auto abs(T x) {
38+
inline auto abs(T x) {
6939
return hypot(x.real(), x.imag());
7040
}
7141

@@ -80,7 +50,7 @@ auto abs(T x) {
8050
struct abs_fun {
8151
template <typename T>
8252
static inline T fun(const T& x) {
83-
return fabs(x);
53+
return abs(x);
8454
}
8555
};
8656

@@ -92,28 +62,28 @@ struct abs_fun {
9262
* @param x argument
9363
* @return Absolute value of each variable in the container.
9464
*/
95-
// template <typename Container,
96-
// require_not_container_st<std::is_arithmetic, Container>* = nullptr,
97-
// require_not_var_matrix_t<Container>* = nullptr,
98-
// require_not_stan_scalar_t<Container>* = nullptr>
99-
// inline auto abs(const Container& x) {
100-
// return apply_scalar_unary<abs_fun, Container>::apply(x);
101-
// }
65+
template <typename Container,
66+
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
67+
require_not_var_matrix_t<Container>* = nullptr,
68+
require_not_stan_scalar_t<Container>* = nullptr>
69+
inline auto abs(const Container& x) {
70+
return apply_scalar_unary<abs_fun, Container>::apply(x);
71+
}
10272

103-
// /**
104-
// * Version of `abs()` that accepts std::vectors, Eigen Matrix/Array objects
105-
// * or expressions, and containers of these.
106-
// *
107-
// * @tparam Container Type of x
108-
// * @param x argument
109-
// * @return Absolute value of each variable in the container.
110-
// */
111-
// template <typename Container,
112-
// require_container_st<std::is_arithmetic, Container>* = nullptr>
113-
// inline auto abs(const Container& x) {
114-
// return apply_vector_unary<Container>::apply(
115-
// x, [&](const auto& v) { return v.array().abs(); });
116-
// }
73+
/**
74+
* Version of `abs()` that accepts std::vectors, Eigen Matrix/Array objects
75+
* or expressions, and containers of these.
76+
*
77+
* @tparam Container Type of x
78+
* @param x argument
79+
* @return Absolute value of each variable in the container.
80+
*/
81+
template <typename Container,
82+
require_container_st<std::is_arithmetic, Container>* = nullptr>
83+
inline auto abs(const Container& x) {
84+
return apply_vector_unary<Container>::apply(
85+
x, [&](const auto& v) { return v.array().abs(); });
86+
}
11787

11888
namespace internal {
11989
/**

test/unit/math/mix/fun/abs_test.cpp

Lines changed: 177 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@ TEST(mixFun, absBasics) {
1919
}
2020

2121
TEST(mixFun, abs) {
22-
auto f = [](const auto& x) {
23-
using std::abs;
24-
return abs(x);
25-
};
22+
auto f = [](const auto& x) { return stan::math::abs(x); };
2623
stan::test::expect_common_nonzero_unary(f);
24+
// 0 (no derivative at 0)
2725
stan::test::expect_value(f, 0);
2826
stan::test::expect_value(f, 0.0);
2927

@@ -37,12 +35,186 @@ TEST(mixFun, abs) {
3735
stan::test::expect_ad(f, 2.0);
3836
stan::test::expect_ad(f, 4.0);
3937

40-
// not differentiable at zero
38+
// complex tests
4139
for (double re : std::vector<double>{-4, -2.5, -1.5, -0.3, 1.3, 2.1, 3.9}) {
4240
for (double im : std::vector<double>{-4, -2.5, -1.5, -0.3, 1.3, 2.1, 3.9}) {
4341
stan::test::expect_ad(f, std::complex<double>(re, im));
4442
}
4543
}
44+
45+
// vector<double>
46+
using svd_t = std::vector<double>;
47+
stan::test::expect_ad(f, svd_t{});
48+
stan::test::expect_ad(f, svd_t{1.0});
49+
stan::test::expect_ad(f, svd_t{1.9, -2.3});
50+
51+
// vector<vector<double>>
52+
using svvd_t = std::vector<svd_t>;
53+
stan::test::expect_ad(f, svvd_t{});
54+
stan::test::expect_ad(f, svvd_t{svd_t{}});
55+
stan::test::expect_ad(f, svvd_t{svd_t{1.9, 4.8}});
56+
stan::test::expect_ad(f, svvd_t{svd_t{1.9}, svd_t{-13.987}});
57+
stan::test::expect_ad(f, svvd_t{svd_t{1.9, -2.7}, svd_t{-13.987, 8.8}});
58+
59+
// vector<complex<double>>
60+
using c_t = std::complex<double>;
61+
using svc_t = std::vector<c_t>;
62+
stan::test::expect_ad(f, svc_t{});
63+
stan::test::expect_ad(f, svc_t{c_t{1.0, -1.9}});
64+
stan::test::expect_ad(f, svc_t{c_t{1.0, -1.9}, c_t{-9.3, -128.987654}});
65+
66+
// vector<vector<complex<double>>>
67+
using svvc_t = std::vector<svc_t>;
68+
stan::test::expect_ad(f, svvc_t{});
69+
stan::test::expect_ad(f, svvc_t{{}});
70+
stan::test::expect_ad(f, svvc_t{svc_t{c_t{1.2, -2.3}, c_t{-32.8, 1}}});
71+
stan::test::expect_ad(f, svvc_t{svc_t{c_t{1.2, -2.3}, c_t{-32.8, 1}},
72+
svc_t{c_t{9.3, 9.4}, c_t{182, -95}}});
73+
74+
// VectorXd
75+
using v_t = Eigen::VectorXd;
76+
v_t a0(0);
77+
stan::test::expect_ad(f, a0);
78+
stan::test::expect_ad_matvar(f, a0);
79+
v_t a1(1);
80+
a1 << 1.9;
81+
stan::test::expect_ad(f, a1);
82+
stan::test::expect_ad_matvar(f, a1);
83+
v_t a2(2);
84+
a2 << 1.9, -2.3;
85+
stan::test::expect_ad(f, a2);
86+
stan::test::expect_ad_matvar(f, a2);
87+
88+
// RowVectorXd
89+
using rv_t = Eigen::RowVectorXd;
90+
rv_t b0(0);
91+
stan::test::expect_ad(f, b0);
92+
stan::test::expect_ad_matvar(f, b0);
93+
rv_t b1(1);
94+
b1 << 1.9;
95+
stan::test::expect_ad(f, b1);
96+
stan::test::expect_ad_matvar(f, b1);
97+
rv_t b2(2);
98+
b2 << 1.9, -2.3;
99+
stan::test::expect_ad(f, b2);
100+
stan::test::expect_ad_matvar(f, b2);
101+
102+
// MatrixXd
103+
using m_t = Eigen::MatrixXd;
104+
m_t c0(0, 0);
105+
stan::test::expect_ad(f, c0);
106+
stan::test::expect_ad_matvar(f, c0);
107+
m_t c0i(0, 2);
108+
stan::test::expect_ad(f, c0i);
109+
stan::test::expect_ad_matvar(f, c0i);
110+
m_t c0ii(2, 0);
111+
stan::test::expect_ad(f, c0ii);
112+
stan::test::expect_ad_matvar(f, c0ii);
113+
m_t c2(2, 1);
114+
c2 << 1.3, -2.9;
115+
stan::test::expect_ad(f, c2);
116+
stan::test::expect_ad_matvar(f, c2);
117+
m_t c6(3, 2);
118+
c6 << 1.3, 2.9, -13.456, 1.898, -0.01, 1.87e21;
119+
stan::test::expect_ad(f, c6);
120+
stan::test::expect_ad_matvar(f, c6);
121+
122+
// vector<VectorXd>
123+
using av_t = std::vector<Eigen::VectorXd>;
124+
av_t d0;
125+
stan::test::expect_ad(f, d0);
126+
stan::test::expect_ad_matvar(f, d0);
127+
av_t d1{a0};
128+
stan::test::expect_ad(f, d1);
129+
stan::test::expect_ad_matvar(f, d1);
130+
av_t d2{a1, a2};
131+
stan::test::expect_ad(f, d2);
132+
stan::test::expect_ad_matvar(f, d2);
133+
134+
// vector<RowVectorXd>
135+
using arv_t = std::vector<Eigen::RowVectorXd>;
136+
arv_t e0;
137+
stan::test::expect_ad(f, e0);
138+
stan::test::expect_ad_matvar(f, e0);
139+
arv_t e1{b0};
140+
stan::test::expect_ad(f, e1);
141+
stan::test::expect_ad_matvar(f, e1);
142+
arv_t e2{b1, b2};
143+
stan::test::expect_ad(f, e2);
144+
stan::test::expect_ad_matvar(f, e2);
145+
146+
// vector<MatrixXd>
147+
using am_t = std::vector<Eigen::MatrixXd>;
148+
am_t g0;
149+
stan::test::expect_ad(f, g0);
150+
stan::test::expect_ad_matvar(f, g0);
151+
am_t g1{c0};
152+
stan::test::expect_ad(f, g1);
153+
stan::test::expect_ad_matvar(f, g1);
154+
am_t g2{c2, c6};
155+
stan::test::expect_ad(f, g2);
156+
stan::test::expect_ad_matvar(f, g2);
157+
158+
// VectorXcd
159+
using vc_t = Eigen::VectorXcd;
160+
vc_t h0(0);
161+
stan::test::expect_ad(f, h0);
162+
vc_t h1(1);
163+
h1 << c_t{1.9, -1.8};
164+
stan::test::expect_ad(f, h1);
165+
vc_t h2(2);
166+
h2 << c_t{1.9, -1.8}, c_t{-128.7, 1.3};
167+
stan::test::expect_ad(f, h2);
168+
169+
// RowVectorXcd
170+
using rvc_t = Eigen::RowVectorXcd;
171+
rvc_t j0(0);
172+
stan::test::expect_ad(f, j0);
173+
rvc_t j1(1);
174+
j1 << c_t{1.9, -1.8};
175+
stan::test::expect_ad(f, j1);
176+
rvc_t j2(2);
177+
j2 << c_t{1.9, -1.8}, c_t{-128.7, 1.3};
178+
stan::test::expect_ad(f, j2);
179+
180+
// MatrixXcd
181+
using mc_t = Eigen::MatrixXcd;
182+
mc_t k0(0, 0);
183+
stan::test::expect_ad(f, k0);
184+
mc_t k2(1, 2);
185+
k2 << c_t{1.9, -1.8}, c_t{128.735, 128.734};
186+
stan::test::expect_ad(f, k2);
187+
mc_t k6(3, 2);
188+
k6 << c_t{1.9, -1.8}, c_t{-128.7, 1.3}, c_t{1, 2}, c_t{0.3, -0.5},
189+
c_t{-13, 125.7}, c_t{-12.5, -10.5};
190+
stan::test::expect_ad(f, k6);
191+
192+
// vector<VectorXcd>
193+
using avc_t = std::vector<vc_t>;
194+
avc_t m0;
195+
stan::test::expect_ad(f, m0);
196+
avc_t m1{h1};
197+
stan::test::expect_ad(f, m1);
198+
avc_t m2{h1, h2};
199+
stan::test::expect_ad(f, m2);
200+
201+
// vector<RowVectorXcd>
202+
using arvc_t = std::vector<rvc_t>;
203+
arvc_t p0(0);
204+
stan::test::expect_ad(f, p0);
205+
arvc_t p1{j1};
206+
stan::test::expect_ad(f, p1);
207+
arvc_t p2{j1, j2};
208+
stan::test::expect_ad(f, p2);
209+
210+
// vector<MatrixXcd>
211+
using amc_t = std::vector<mc_t>;
212+
amc_t q0;
213+
stan::test::expect_ad(f, q0);
214+
amc_t q1{k2};
215+
stan::test::expect_ad(f, q1);
216+
amc_t q2{k2, k6};
217+
stan::test::expect_ad(f, q2);
46218
}
47219
TEST(mixFun, absReturnType) {
48220
// validate return types not overpromoted to complex by assignability

0 commit comments

Comments
 (0)