Skip to content

Commit a87459f

Browse files
authored
Merge pull request #2256 from stan-dev/feature/varmat-a-to-c-unary
var matrix specializations for a-c unary functions in rev
2 parents 3d03131 + c5caa59 commit a87459f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+739
-207
lines changed

stan/math/prim/fun/abs.hpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,57 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/hypot.hpp>
6+
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
7+
#include <stan/math/prim/functor/apply_vector_unary.hpp>
68
#include <cmath>
79

810
namespace stan {
911
namespace math {
1012

1113
/**
12-
* Return floating-point absolute value.
14+
* Structure to wrap `abs()` so it can be vectorized.
1315
*
14-
* Delegates to <code>fabs(double)</code> rather than
15-
* <code>std::abs(int)</code>.
16+
* @tparam T type of variable
17+
* @param x argument
18+
* @return Absolute value of variable.
19+
*/
20+
struct abs_fun {
21+
template <typename T>
22+
static inline T fun(const T& x) {
23+
using std::fabs;
24+
return fabs(x);
25+
}
26+
};
27+
28+
/**
29+
* Returns the elementwise `abs()` of the input,
30+
* which may be a scalar or any Stan container of numeric scalars.
1631
*
17-
* @param x scalar
18-
* @return absolute value of scalar
32+
* @tparam Container type of container
33+
* @param x argument
34+
* @return Absolute value of each variable in the container.
1935
*/
20-
inline double abs(double x) { return std::fabs(x); }
36+
template <typename Container,
37+
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
38+
require_not_var_matrix_t<Container>* = nullptr>
39+
inline auto abs(const Container& x) {
40+
return apply_scalar_unary<abs_fun, Container>::apply(x);
41+
}
42+
43+
/**
44+
* Version of `abs()` that accepts std::vectors, Eigen Matrix/Array objects
45+
* or expressions, and containers of these.
46+
*
47+
* @tparam Container Type of x
48+
* @param x argument
49+
* @return Absolute value of each variable in the container.
50+
*/
51+
template <typename Container,
52+
require_container_st<std::is_arithmetic, Container>* = nullptr>
53+
inline auto abs(const Container& x) {
54+
return apply_vector_unary<Container>::apply(
55+
x, [](const auto& v) { return v.array().abs(); });
56+
}
2157

2258
namespace internal {
2359
/**

stan/math/prim/fun/acos.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct acos_fun {
4343
*/
4444
template <typename Container,
4545
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
46+
require_not_var_matrix_t<Container>* = nullptr,
4647
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
4748
Container>* = nullptr>
4849
inline auto acos(const Container& x) {

stan/math/prim/fun/acosh.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ struct acosh_fun {
8383
* @param x container
8484
* @return Elementwise acosh of members of container.
8585
*/
86-
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
87-
T>* = nullptr>
86+
template <
87+
typename T, require_not_var_matrix_t<T>* = nullptr,
88+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
8889
inline auto acosh(const T& x) {
8990
return apply_scalar_unary<acosh_fun, T>::apply(x);
9091
}

stan/math/prim/fun/asin.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct asin_fun {
4141
*/
4242
template <typename Container,
4343
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
44+
require_not_var_matrix_t<Container>* = nullptr,
4445
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
4546
Container>* = nullptr>
4647
inline auto asin(const Container& x) {

stan/math/prim/fun/asinh.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ struct asinh_fun {
4343
* @param x container
4444
* @return Inverse hyperbolic sine of each value in the container.
4545
*/
46-
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
47-
T>* = nullptr>
46+
template <
47+
typename T, require_not_var_matrix_t<T>* = nullptr,
48+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
4849
inline auto asinh(const T& x) {
4950
return apply_scalar_unary<asinh_fun, T>::apply(x);
5051
}

stan/math/prim/fun/atan.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct atan_fun {
3939
*/
4040
template <typename Container,
4141
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
42+
require_not_var_matrix_t<Container>* = nullptr,
4243
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
4344
Container>* = nullptr>
4445
inline auto atan(const Container& x) {

stan/math/prim/fun/atanh.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ struct atanh_fun {
7272
* @param x container
7373
* @return Elementwise atanh of members of container.
7474
*/
75-
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
76-
T>* = nullptr>
75+
template <
76+
typename T, require_not_var_matrix_t<T>* = nullptr,
77+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
7778
inline auto atanh(const T& x) {
7879
return apply_scalar_unary<atanh_fun, T>::apply(x);
7980
}

stan/math/prim/fun/block.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace math {
1818
* @param ncols Number of columns in block.
1919
* @throw std::out_of_range if either index is out of range.
2020
*/
21-
template <typename T, require_eigen_t<T>* = nullptr>
21+
template <typename T, require_matrix_t<T>* = nullptr>
2222
inline auto block(const T& m, size_t i, size_t j, size_t nrows, size_t ncols) {
2323
check_row_index("block", "i", m, i);
2424
check_row_index("block", "i+nrows-1", m, i + nrows - 1);

stan/math/prim/fun/cbrt.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct cbrt_fun {
3131
* @param x container
3232
* @return Cube root of each value in x.
3333
*/
34-
template <typename T>
34+
template <typename T, require_not_var_matrix_t<T>* = nullptr>
3535
inline auto cbrt(const T& x) {
3636
return apply_scalar_unary<cbrt_fun, T>::apply(x);
3737
}

stan/math/prim/fun/cos.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct cos_fun {
3838
*/
3939
template <typename Container,
4040
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
41+
require_not_var_matrix_t<Container>* = nullptr,
4142
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
4243
Container>* = nullptr>
4344
inline auto cos(const Container& x) {

stan/math/prim/fun/cosh.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct cosh_fun {
3737
*/
3838
template <typename Container,
3939
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
40+
require_not_var_matrix_t<Container>* = nullptr,
4041
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
4142
Container>* = nullptr>
4243
inline auto cosh(const Container& x) {

stan/math/prim/fun/fabs.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ struct fabs_fun {
3434
* @return Absolute value of each value in x.
3535
*/
3636
template <typename Container,
37-
require_not_container_st<std::is_arithmetic, Container>* = nullptr>
37+
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
38+
require_not_var_matrix_t<Container>* = nullptr>
3839
inline auto fabs(const Container& x) {
3940
return apply_scalar_unary<fabs_fun, Container>::apply(x);
4041
}

stan/math/prim/fun/sin.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct sin_fun {
3737
*/
3838
template <
3939
typename T, require_not_container_st<std::is_arithmetic, T>* = nullptr,
40+
require_not_var_matrix_t<T>* = nullptr,
4041
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
4142
inline auto sin(const T& x) {
4243
return apply_scalar_unary<sin_fun, T>::apply(x);

stan/math/prim/fun/sinh.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct sinh_fun {
3535
*/
3636
template <typename Container,
3737
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
38+
require_not_var_matrix_t<Container>* = nullptr,
3839
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
3940
Container>* = nullptr>
4041
inline auto sinh(const Container& x) {

stan/math/prim/fun/tan.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct tan_fun {
3737
*/
3838
template <typename Container,
3939
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
40+
require_not_var_matrix_t<Container>* = nullptr,
4041
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
4142
Container>* = nullptr>
4243
inline auto tan(const Container& x) {

stan/math/prim/fun/tanh.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct tanh_fun {
3737
*/
3838
template <typename Container,
3939
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
40+
require_not_var_matrix_t<Container>* = nullptr,
4041
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
4142
Container>* = nullptr>
4243
inline auto tanh(const Container& x) {

stan/math/prim/functor/apply_vector_unary.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ struct apply_vector_unary<T, require_std_vector_vt<is_stan_scalar, T>> {
159159
}
160160
};
161161

162+
namespace internal {
163+
template <typename T>
164+
using is_container_or_var_matrix
165+
= disjunction<is_container<T>, is_var_matrix<T>>;
166+
}
167+
162168
/**
163169
* Specialisation for use with nested containers (std::vectors).
164170
* For each of the member functions, an std::vector with the appropriate
@@ -170,7 +176,8 @@ struct apply_vector_unary<T, require_std_vector_vt<is_stan_scalar, T>> {
170176
*
171177
*/
172178
template <typename T>
173-
struct apply_vector_unary<T, require_std_vector_vt<is_container, T>> {
179+
struct apply_vector_unary<
180+
T, require_std_vector_vt<internal::is_container_or_var_matrix, T>> {
174181
using T_vt = value_type_t<T>;
175182

176183
/**

stan/math/prim/meta/is_container.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ namespace stan {
2020
*/
2121
template <typename Container>
2222
using is_container = bool_constant<
23-
math::disjunction<is_eigen<Container>, is_std_vector<Container>,
24-
is_var_matrix<Container>>::value>;
23+
math::disjunction<is_eigen<Container>, is_std_vector<Container>>::value>;
2524

2625
STAN_ADD_REQUIRE_UNARY(container, is_container, general_types);
2726
STAN_ADD_REQUIRE_CONTAINER(container, is_container, general_types);

stan/math/rev/fun/abs.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@ namespace math {
3333
\end{cases}
3434
\f]
3535
*
36+
* @tparam T A floating point type or an Eigen type with floating point scalar.
3637
* @param a Variable input.
3738
* @return Absolute value of variable.
3839
*/
39-
inline var abs(const var& a) { return fabs(a); }
40+
template <typename T>
41+
inline auto abs(const var_value<T>& a) {
42+
return fabs(a);
43+
}
4044

4145
/**
4246
* Return the absolute value of the complex argument.

stan/math/rev/fun/acos.hpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,6 @@
1616
namespace stan {
1717
namespace math {
1818

19-
namespace internal {
20-
class acos_vari : public op_v_vari {
21-
public:
22-
explicit acos_vari(vari* avi) : op_v_vari(std::acos(avi->val_), avi) {}
23-
void chain() {
24-
avi_->adj_ -= adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
25-
}
26-
};
27-
} // namespace internal
28-
2919
/**
3020
* Return the principal value of the arc cosine of a variable,
3121
* in radians (cmath).
@@ -62,7 +52,27 @@ class acos_vari : public op_v_vari {
6252
* @param x argument
6353
* @return Arc cosine of variable, in radians.
6454
*/
65-
inline var acos(const var& x) { return var(new internal::acos_vari(x.vi_)); }
55+
inline var acos(const var& x) {
56+
return make_callback_var(std::acos(x.val()), [x](const auto& vi) mutable {
57+
x.adj() -= vi.adj() / std::sqrt(1.0 - (x.val() * x.val()));
58+
});
59+
}
60+
61+
/**
62+
* Return the principal value of the arc cosine of a variable,
63+
* in radians (cmath).
64+
*
65+
* @param x a `var_value` with inner Eigen type
66+
* @return Arc cosine of variable, in radians.
67+
*/
68+
template <typename VarMat, require_var_matrix_t<VarMat>* = nullptr>
69+
inline auto acos(const VarMat& x) {
70+
return make_callback_var(
71+
x.val().array().acos().matrix(), [x](const auto& vi) mutable {
72+
x.adj().array()
73+
-= vi.adj().array() / (1.0 - (x.val().array().square())).sqrt();
74+
});
75+
}
6676

6777
/**
6878
* Return the arc cosine of the complex argument.

stan/math/rev/fun/acosh.hpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,6 @@
2121
namespace stan {
2222
namespace math {
2323

24-
namespace internal {
25-
class acosh_vari : public op_v_vari {
26-
public:
27-
acosh_vari(double val, vari* avi) : op_v_vari(val, avi) {}
28-
void chain() {
29-
avi_->adj_ += adj_ / std::sqrt(avi_->val_ * avi_->val_ - 1.0);
30-
}
31-
};
32-
} // namespace internal
33-
3424
/**
3525
* The inverse hyperbolic cosine function for variables (C99).
3626
*
@@ -67,11 +57,31 @@ class acosh_vari : public op_v_vari {
6757
\frac{\partial \, \cosh^{-1}(x)}{\partial x} = \frac{1}{\sqrt{x^2-1}}
6858
\f]
6959
*
70-
* @param a The variable.
60+
* @param x The variable.
61+
* @return Inverse hyperbolic cosine of the variable.
62+
*/
63+
inline var acosh(const var& x) {
64+
return make_callback_var(acosh(x.val()), [x](const auto& vi) mutable {
65+
x.adj() += vi.adj() / std::sqrt(x.val() * x.val() - 1.0);
66+
});
67+
}
68+
/**
69+
* The inverse hyperbolic cosine function for variables (C99).
70+
*
71+
* For non-variable function, see ::acosh().
72+
*
73+
* @tparam Varmat a `var_value` with inner Eigen type
74+
* @param x The variable
7175
* @return Inverse hyperbolic cosine of the variable.
7276
*/
73-
inline var acosh(const var& a) {
74-
return var(new internal::acosh_vari(acosh(a.val()), a.vi_));
77+
template <typename VarMat, require_var_matrix_t<VarMat>* = nullptr>
78+
inline auto acosh(const VarMat& x) {
79+
return make_callback_var(
80+
x.val().unaryExpr([](const auto x) { return acosh(x); }),
81+
[x](const auto& vi) mutable {
82+
x.adj().array()
83+
+= vi.adj().array() / (x.val().array().square() - 1.0).sqrt();
84+
});
7585
}
7686

7787
/**

stan/math/rev/fun/asin.hpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,6 @@
1414
namespace stan {
1515
namespace math {
1616

17-
namespace internal {
18-
class asin_vari : public op_v_vari {
19-
public:
20-
explicit asin_vari(vari* avi) : op_v_vari(std::asin(avi->val_), avi) {}
21-
void chain() {
22-
avi_->adj_ += adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
23-
}
24-
};
25-
} // namespace internal
26-
2717
/**
2818
* Return the principal value of the arc sine, in radians, of the
2919
* specified variable (cmath).
@@ -57,10 +47,31 @@ class asin_vari : public op_v_vari {
5747
\frac{\partial \, \arcsin(x)}{\partial x} = \frac{1}{\sqrt{1-x^2}}
5848
\f]
5949
*
60-
* @param a Variable in range [-1, 1].
50+
* @param x Variable in range [-1, 1].
51+
* @return Arc sine of variable, in radians.
52+
*/
53+
inline var asin(const var& x) {
54+
return make_callback_var(std::asin(x.val()), [x](const auto& vi) mutable {
55+
x.adj() += vi.adj() / std::sqrt(1.0 - (x.val() * x.val()));
56+
});
57+
}
58+
59+
/**
60+
* Return the principal value of the arc sine, in radians, of the
61+
* specified variable (cmath).
62+
*
63+
* @tparam Varmat a `var_value` with inner Eigen type
64+
* @param x Variable with cells in range [-1, 1].
6165
* @return Arc sine of variable, in radians.
6266
*/
63-
inline var asin(const var& a) { return var(new internal::asin_vari(a.vi_)); }
67+
template <typename VarMat, require_var_matrix_t<VarMat>* = nullptr>
68+
inline auto asin(const VarMat& x) {
69+
return make_callback_var(
70+
x.val().array().asin().matrix(), [x](const auto& vi) mutable {
71+
x.adj().array()
72+
+= vi.adj().array() / (1.0 - (x.val().array().square())).sqrt();
73+
});
74+
}
6475

6576
/**
6677
* Return the arc sine of the complex argument.

0 commit comments

Comments
 (0)