Skip to content

Feature/2729 vectorize abs #2734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions stan/math/fwd/fun/abs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
namespace stan {
namespace math {

/**
* Return the absolute value of the forward-mode autodiff argument.
*
* @tparam T value type for autodiff variable
* @param[in] x argument
* @return absolute value of argument
*/
template <typename T>
inline fvar<T> abs(const fvar<T>& x) {
if (x.val_ > 0.0) {
if (x.val_ > 0) {
return x;
} else if (x.val_ < 0.0) {
} else if (x.val_ < 0) {
return fvar<T>(-x.val_, -x.d_);
} else if (x.val_ == 0.0) {
} else if (x.val_ == 0) {
return fvar<T>(0, 0);
} else {
return fvar<T>(abs(x.val_), NOT_A_NUMBER);
Expand Down
153 changes: 68 additions & 85 deletions stan/math/prim/fun/abs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,121 +13,104 @@
namespace stan {
namespace math {

namespace internal {
/**
* Return the absolute value of the specified arithmetic argument.
* The return type is the same as the argument type.
* Return the absolute value of the complex argument.
*
* @tparam T type of argument (must be arithmetic)
* @param x argument
* @return absolute value of argument
* @tparam V value type of argument
* @param[in] z argument
* @return absolute value of the argument
*/
template <typename T, require_arithmetic_t<T>* = nullptr>
T abs(T x) {
return std::abs(x);
template <typename V>
inline V complex_abs(const std::complex<V>& z) {
return hypot(z.real(), z.imag());
}
} // namespace internal

/*
* Return the elementwise absolute value of the specified container.
/**
* Metaprogram calculating return type for applying `abs` to an
* argument of the specified template type. This struct defines the
* default case, which will be used for arithmetic types, with
* specializations dealing with other cases.
*
* @tparam T type of elements in the vector
* @param x vector argument
* @return elementwise absolute value of argument
* @tparam T argument type
*/
template <typename T>
std::vector<T> abs(const std::vector<T>& x) {
std::vector<T> y(x.size());
for (size_t n = 0; n < x.size(); ++n)
y[n] = abs(x[n]);
return y;
}
struct abs_return {
/**
* Return type of `abs(T)`.
*/
using type = T;
};

/**
* Return the elementwise absolute value of the specified matrix,
* vector, or row vector.
* Helper typedef for abs_return. With this definition,
* `typename abs_return<T>::type` can be abbreviated to `abs_return_t<T>`.
*
* @tparam T type of scalar for matrix argument (real or complex)
* @tparam R row specification (1 or -1)
* @tparam C column specification (1 or -1)
* @param x argument
* @return elementwise absolute value of argument
* @tparam T type of argument
*/
template <typename T>
using abs_return_t = typename abs_return<T>::type;

template <typename T>
struct abs_return<std::complex<T>> {
using type = T;
};

template <typename T, int R, int C>
Eigen::Matrix<T, R, C> abs(const Eigen::Matrix<T, R, C>& x) {
return fabs(x);
}
struct abs_return<Eigen::Matrix<T, R, C>> {
using type = Eigen::Matrix<abs_return_t<T>, R, C>;
};

/**
* Return the absolute value (also known as the norm, modulus, or
* magnitude) of the specified complex argument.
*
* @tparam T type of argument (must be complex)
* @param x argument
* @return absolute value of argument (a real number)
*/
template <typename T, require_complex_t<T>* = nullptr>
auto abs(T x) {
return hypot(x.real(), x.imag());
}
template <typename T>
struct abs_return<std::vector<T>> {
using type = std::vector<abs_return_t<T>>;
};

/**
* Return elementwise absolute value of the specified real-valued
* container.
* Return the absolute value of the specified arithmetic argument.
*
* @tparam T type of argument
* @tparam T type of argument (must be arithmetic)
* @param x argument
* @return absolute value of argument
*/
struct abs_fun {
template <typename T>
static inline T fun(const T& x) {
return fabs(x);
}
};
template <typename T>
inline auto abs(T x) {
return std::abs(x);
}

/**
* Returns the elementwise `abs()` of the input,
* which may be a scalar or any Stan container of numeric scalars.
* Return the elementwise absolute value of the specified matrix or vector
* argument.
*
* @tparam Container type of container
* @tparam T type of matrix elements
* @param x argument
* @return Absolute value of each variable in the container.
* @return elementwise absolute value of argument
*/
// template <typename Container,
// require_not_container_st<std::is_arithmetic, Container>* = nullptr,
// require_not_var_matrix_t<Container>* = nullptr,
// require_not_stan_scalar_t<Container>* = nullptr>
// inline auto abs(const Container& x) {
// return apply_scalar_unary<abs_fun, Container>::apply(x);
// }

// /**
// * Version of `abs()` that accepts std::vectors, Eigen Matrix/Array objects
// * or expressions, and containers of these.
// *
// * @tparam Container Type of x
// * @param x argument
// * @return Absolute value of each variable in the container.
// */
// template <typename Container,
// require_container_st<std::is_arithmetic, Container>* = nullptr>
// inline auto abs(const Container& x) {
// return apply_vector_unary<Container>::apply(
// x, [&](const auto& v) { return v.array().abs(); });
// }
template <typename T, int R, int C>
inline auto abs(const Eigen::Matrix<T, R, C>& x) {
Eigen::Matrix<abs_return_t<T>, R, C> y(x.rows(), x.cols());
for (int i = 0; i < x.size(); ++i)
y(i) = abs(x(i));
return y;
}

namespace internal {
/**
* Return the absolute value of the complex argument.
* Return the elementwise absolute value of the specified standard vector
* argument.
*
* @tparam V value type of argument
* @param[in] z argument
* @return absolute value of the argument
* @tparam T type of vector elements
* @param x argument
* @return elementwise absolute value of argument
*/
template <typename V>
inline V complex_abs(const std::complex<V>& z) {
return hypot(z.real(), z.imag());
template <typename T>
inline auto abs(const std::vector<T>& x) {
std::vector<abs_return_t<T>> y;
y.reserve(x.size());
for (const auto& xi : x)
y.push_back(abs(xi));
return y;
}
} // namespace internal

} // namespace math
} // namespace stan
Expand Down
34 changes: 7 additions & 27 deletions stan/math/rev/fun/abs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,15 @@ namespace stan {
namespace math {

/**
* Return the absolute value of the variable (std).
* Return the absolute value of the argument.
*
* Delegates to <code>fabs()</code> (see for doc).
*
\f[
\mbox{abs}(x) =
\begin{cases}
|x| & \mbox{if } -\infty\leq x\leq \infty \\[6pt]
\textrm{NaN} & \mbox{if } x = \textrm{NaN}
\end{cases}
\f]

\f[
\frac{\partial\, \mbox{abs}(x)}{\partial x} =
\begin{cases}
-1 & \mbox{if } x < 0 \\
0 & \mbox{if } x = 0 \\
1 & \mbox{if } x > 0 \\[6pt]
\textrm{NaN} & \mbox{if } x = \textrm{NaN}
\end{cases}
\f]
*
* @tparam T A floating point type or an Eigen type with floating point scalar.
* @param a Variable input.
* @return Absolute value of variable.
* @tparam T floating point or var_mat type
* @param[in] x argument
* @return absolute value of argument
*/
template <typename T>
inline auto abs(const var_value<T>& a) {
return fabs(a);
inline auto abs(const var_value<T>& x) {
return fabs(x);
}

/**
Expand All @@ -48,7 +28,7 @@ inline auto abs(const var_value<T>& a) {
* @param[in] z argument
* @return absolute value of the argument
*/
inline var abs(const std::complex<var>& z) { return internal::complex_abs(z); }
inline auto abs(const std::complex<var>& z) { return internal::complex_abs(z); }

} // namespace math
} // namespace stan
Expand Down
Loading