|
| 1 | +#ifndef STAN_MATH_FWD_FUN_HYPERGEOMETRIC_2F1_HPP |
| 2 | +#define STAN_MATH_FWD_FUN_HYPERGEOMETRIC_2F1_HPP |
| 3 | + |
| 4 | +#include <stan/math/prim/meta.hpp> |
| 5 | +#include <stan/math/fwd/core.hpp> |
| 6 | +#include <stan/math/prim/fun/hypergeometric_2F1.hpp> |
| 7 | +#include <stan/math/prim/fun/grad_2F1.hpp> |
| 8 | + |
| 9 | +namespace stan { |
| 10 | +namespace math { |
| 11 | + |
| 12 | +/** |
| 13 | + * Returns the Gauss hypergeometric function applied to the |
| 14 | + * input arguments: |
| 15 | + * \f$_2F_1(a_1,a_2;b;z)\f$ |
| 16 | + * |
| 17 | + * See 'grad_2F1.hpp' for the derivatives wrt each parameter |
| 18 | + * |
| 19 | + * @tparam Ta1 Type of scalar first 'a' argument |
| 20 | + * @tparam Ta2 Type of scalar second 'a' argument |
| 21 | + * @tparam Tb Type of scalar 'b' argument |
| 22 | + * @tparam Tz Type of scalar 'z' argument |
| 23 | + * @param[in] a1 First of 'a' arguments to function |
| 24 | + * @param[in] a2 Second of 'a' arguments to function |
| 25 | + * @param[in] b 'b' argument to function |
| 26 | + * @param[in] z Scalar z argument |
| 27 | + * @return Gauss hypergeometric function |
| 28 | + */ |
| 29 | +template <typename Ta1, typename Ta2, typename Tb, typename Tz, |
| 30 | + require_all_stan_scalar_t<Ta1, Ta2, Tb, Tz>* = nullptr, |
| 31 | + require_any_fvar_t<Ta1, Ta2, Tb, Tz>* = nullptr> |
| 32 | +inline return_type_t<Ta1, Ta1, Tb, Tz> hypergeometric_2F1(const Ta1& a1, |
| 33 | + const Ta2& a2, |
| 34 | + const Tb& b, |
| 35 | + const Tz& z) { |
| 36 | + using fvar_t = return_type_t<Ta1, Ta1, Tb, Tz>; |
| 37 | + |
| 38 | + auto a1_val = value_of(a1); |
| 39 | + auto a2_val = value_of(a2); |
| 40 | + auto b_val = value_of(b); |
| 41 | + auto z_val = value_of(z); |
| 42 | + |
| 43 | + auto grad_tuple = grad_2F1(a1, a2, b, z); |
| 44 | + |
| 45 | + typename fvar_t::Scalar grad = 0; |
| 46 | + |
| 47 | + if (!is_constant<Ta1>::value) { |
| 48 | + grad += forward_as<fvar_t>(a1).d() * std::get<0>(grad_tuple); |
| 49 | + } |
| 50 | + if (!is_constant<Ta2>::value) { |
| 51 | + grad += forward_as<fvar_t>(a2).d() * std::get<1>(grad_tuple); |
| 52 | + } |
| 53 | + if (!is_constant<Tb>::value) { |
| 54 | + grad += forward_as<fvar_t>(b).d() * std::get<2>(grad_tuple); |
| 55 | + } |
| 56 | + if (!is_constant<Tz>::value) { |
| 57 | + grad += forward_as<fvar_t>(z).d() * std::get<3>(grad_tuple); |
| 58 | + } |
| 59 | + |
| 60 | + return fvar_t(hypergeometric_2F1(a1_val, a2_val, b_val, z_val), grad); |
| 61 | +} |
| 62 | + |
| 63 | +} // namespace math |
| 64 | +} // namespace stan |
| 65 | +#endif |
0 commit comments