-
-
Notifications
You must be signed in to change notification settings - Fork 196
Incomplete Beta Function Inverse #2637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
f3b1276
Initial implementation, start tests
andrjohns 6aeb64b
Update tests
andrjohns a444d9c
cpplint
andrjohns 96b9790
Merge commit 'a43562ea29ef1bb892cb7942787d682f002dfc7c' into HEAD
yashikno be10487
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot 2b88773
Update doc & latex
andrjohns 9be9695
Merge commit '83dbd82c5b431d22c6e5899b2dfc5830a4cf95ed' into HEAD
yashikno 39a7421
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot 085f8ea
Missing doc updates
andrjohns 1eef220
Update inv naming
andrjohns 010c1c8
Merge branch 'develop' into feature/ibeta_inv
andrjohns 5803eea
Missed test naming
andrjohns 6587219
Merge branch 'stan-dev:develop' into feature/ibeta_inv
andrjohns b33a57b
Trigger CI
andrjohns 488befa
Add constraint checks
andrjohns 5ebcebf
Merge commit '359f742a43cde1292885280c68d7228c1b29abd8' into HEAD
yashikno 7715e90
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#ifndef STAN_MATH_FWD_FUN_INC_BETA_INV_HPP | ||
#define STAN_MATH_FWD_FUN_INC_BETA_INV_HPP | ||
|
||
#include <stan/math/fwd/meta.hpp> | ||
#include <stan/math/fwd/core.hpp> | ||
#include <stan/math/prim/fun/inc_beta_inv.hpp> | ||
#include <stan/math/prim/fun/inc_beta.hpp> | ||
#include <stan/math/prim/fun/exp.hpp> | ||
#include <stan/math/prim/fun/log.hpp> | ||
#include <stan/math/prim/fun/log_diff_exp.hpp> | ||
#include <stan/math/prim/fun/lbeta.hpp> | ||
#include <stan/math/prim/fun/lgamma.hpp> | ||
#include <stan/math/prim/fun/digamma.hpp> | ||
#include <stan/math/prim/fun/F32.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* The inverse of the normalized incomplete beta function of a, b, with | ||
* probability p. | ||
* | ||
* Used to compute the cumulative density function for the beta | ||
* distribution. | ||
* | ||
* @param a Shape parameter a >= 0; a and b can't both be 0 | ||
* @param b Shape parameter b >= 0 | ||
* @param p Random variate. 0 <= p <= 1 | ||
* @throws if constraints are violated or if any argument is NaN | ||
* @return The inverse of the normalized incomplete beta function. | ||
*/ | ||
template <typename T1, typename T2, typename T3, | ||
require_all_stan_scalar_t<T1, T2, T3>* = nullptr, | ||
require_any_fvar_t<T1, T2, T3>* = nullptr> | ||
inline fvar<partials_return_t<T1, T2, T3>> inc_beta_inv(const T1& a, | ||
const T2& b, | ||
const T3& p) { | ||
using T_return = partials_return_t<T1, T2, T3>; | ||
auto a_val = value_of(a); | ||
auto b_val = value_of(b); | ||
auto p_val = value_of(p); | ||
T_return w = inc_beta_inv(a_val, b_val, p_val); | ||
T_return log_w = log(w); | ||
T_return log1m_w = log1m(w); | ||
auto one_m_a = 1 - a_val; | ||
auto one_m_b = 1 - b_val; | ||
T_return one_m_w = 1 - w; | ||
auto ap1 = a_val + 1; | ||
auto bp1 = b_val + 1; | ||
auto lbeta_ab = lbeta(a_val, b_val); | ||
auto digamma_apb = digamma(a_val + b_val); | ||
|
||
T_return inv_d_(0); | ||
|
||
if (is_fvar<T1>::value) { | ||
auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w); | ||
auto da2 | ||
= exp(a_val * log_w + 2 * lgamma(a_val) | ||
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) - 2 * lgamma(ap1)); | ||
auto da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab) | ||
* (log_w - digamma(a_val) + digamma_apb); | ||
inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3); | ||
} | ||
|
||
if (is_fvar<T2>::value) { | ||
auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w); | ||
auto db2 = 2 * lgamma(b_val) | ||
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w)) | ||
- 2 * lgamma(bp1) + b_val * log1m_w; | ||
|
||
auto db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab) | ||
* (log1m_w - digamma(b_val) + digamma_apb); | ||
|
||
inv_d_ += forward_as<fvar<T_return>>(b).d_ * db1 * (exp(db2) - db3); | ||
} | ||
|
||
if (is_fvar<T3>::value) { | ||
inv_d_ += forward_as<fvar<T_return>>(p).d_ | ||
* exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab); | ||
} | ||
|
||
return fvar<T_return>(w, inv_d_); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_INC_BETA_INV_HPP | ||
#define STAN_MATH_PRIM_FUN_INC_BETA_INV_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/err.hpp> | ||
#include <stan/math/prim/fun/boost_policy.hpp> | ||
#include <boost/math/special_functions/beta.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* The inverse of the normalized incomplete beta function of a, b, with | ||
* probability p. | ||
* | ||
* Used to compute the cumulative density function for the beta | ||
* distribution. | ||
* | ||
* @param a Shape parameter a >= 0; a and b can't both be 0 | ||
* @param b Shape parameter b >= 0 | ||
* @param p Random variate. 0 <= p <= 1 | ||
* @throws if constraints are violated or if any argument is NaN | ||
* @return The inverse of the normalized incomplete beta function. | ||
*/ | ||
inline double inc_beta_inv(double a, double b, double p) { | ||
check_not_nan("inc_beta", "a", a); | ||
check_not_nan("inc_beta", "b", b); | ||
check_not_nan("inc_beta", "p", p); | ||
return boost::math::ibeta_inv(a, b, p, boost_policy_t<>()); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#ifndef STAN_MATH_REV_FUN_INC_BETA_INV_HPP | ||
#define STAN_MATH_REV_FUN_INC_BETA_INV_HPP | ||
|
||
#include <stan/math/rev/meta.hpp> | ||
#include <stan/math/rev/core.hpp> | ||
#include <stan/math/prim/fun/constants.hpp> | ||
#include <stan/math/prim/fun/inc_beta_inv.hpp> | ||
#include <stan/math/prim/fun/inc_beta.hpp> | ||
#include <stan/math/prim/fun/exp.hpp> | ||
#include <stan/math/prim/fun/log.hpp> | ||
#include <stan/math/prim/fun/log_diff_exp.hpp> | ||
#include <stan/math/prim/fun/lbeta.hpp> | ||
#include <stan/math/prim/fun/lgamma.hpp> | ||
#include <stan/math/prim/fun/digamma.hpp> | ||
#include <stan/math/prim/fun/F32.hpp> | ||
#include <stan/math/prim/fun/is_any_nan.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* The inverse of the normalized incomplete beta function of a, b, with | ||
* probability p. | ||
* | ||
* Used to compute the cumulative density function for the beta | ||
* distribution. | ||
* | ||
* @param a Shape parameter a >= 0; a and b can't both be 0 | ||
* @param b Shape parameter b >= 0 | ||
* @param p Random variate. 0 <= p <= 1 | ||
* @throws if constraints are violated or if any argument is NaN | ||
* @return The inverse of the normalized incomplete beta function. | ||
*/ | ||
template <typename T1, typename T2, typename T3, | ||
require_all_stan_scalar_t<T1, T2, T3>* = nullptr, | ||
require_any_var_t<T1, T2, T3>* = nullptr> | ||
inline var inc_beta_inv(const T1& a, const T2& b, const T3& p) { | ||
double a_val = value_of(a); | ||
double b_val = value_of(b); | ||
double p_val = value_of(p); | ||
double w = inc_beta_inv(a_val, b_val, p_val); | ||
return make_callback_var(w, [a, b, p, a_val, b_val, p_val, w](auto& vi) { | ||
double log_w = log(w); | ||
double log1m_w = log1m(w); | ||
double one_m_a = 1 - a_val; | ||
double one_m_b = 1 - b_val; | ||
double one_m_w = 1 - w; | ||
double ap1 = a_val + 1; | ||
double bp1 = b_val + 1; | ||
double lbeta_ab = lbeta(a_val, b_val); | ||
double digamma_apb = digamma(a_val + b_val); | ||
|
||
if (!is_constant_all<T1>::value) { | ||
double da1 = exp(one_m_b * log1m_w + one_m_a * log_w); | ||
double da2 = a_val * log_w + 2 * lgamma(a_val) | ||
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) | ||
- 2 * lgamma(ap1); | ||
double da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab) | ||
* (log_w - digamma(a_val) + digamma_apb); | ||
|
||
forward_as<var>(a).adj() += vi.adj() * da1 * (exp(da2) - da3); | ||
} | ||
|
||
if (!is_constant_all<T2>::value) { | ||
double db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w); | ||
double db2 = 2 * lgamma(b_val) | ||
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w)) | ||
- 2 * lgamma(bp1) + b_val * log1m_w; | ||
|
||
double db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab) | ||
* (log1m_w - digamma(b_val) + digamma_apb); | ||
|
||
forward_as<var>(b).adj() += vi.adj() * db1 * (exp(db2) - db3); | ||
} | ||
|
||
if (!is_constant_all<T3>::value) { | ||
forward_as<var>(p).adj() | ||
+= vi.adj() * exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab); | ||
} | ||
}); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#include <stan/math/fwd.hpp> | ||
#include <gtest/gtest.h> | ||
|
||
TEST(AgradFwdMatrixIncBetaInv, fd_scalar) { | ||
using stan::math::fvar; | ||
using stan::math::inc_beta_inv; | ||
fvar<double> a = 6; | ||
fvar<double> b = 2; | ||
fvar<double> p = 0.9; | ||
a.d_ = 1.0; | ||
b.d_ = 1.0; | ||
p.d_ = 1.0; | ||
|
||
fvar<double> res = inc_beta_inv(a, b, p); | ||
|
||
EXPECT_FLOAT_EQ(res.d_, 0.0117172527399 - 0.0680999818473 + 0.455387298585); | ||
} | ||
|
||
TEST(AgradFwdMatrixIncBetaInv, ffd_scalar) { | ||
using stan::math::fvar; | ||
using stan::math::inc_beta_inv; | ||
fvar<fvar<double>> a = 7; | ||
fvar<fvar<double>> b = 4; | ||
fvar<fvar<double>> p = 0.15; | ||
a.val_.d_ = 1.0; | ||
b.val_.d_ = 1.0; | ||
p.val_.d_ = 1.0; | ||
|
||
fvar<fvar<double>> res = inc_beta_inv(a, b, p); | ||
|
||
EXPECT_FLOAT_EQ(res.val_.d_, | ||
0.0428905418857 - 0.0563420377808 + 0.664919819507); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#include <stan/math/mix.hpp> | ||
#include <gtest/gtest.h> | ||
#include <test/unit/math/rev/fun/util.hpp> | ||
|
||
TEST(ProbInternalMath, inc_beta_inv_fv1) { | ||
using stan::math::fvar; | ||
using stan::math::inc_beta_inv; | ||
using stan::math::var; | ||
double a_d = 1; | ||
double b_d = 2; | ||
double p_d = 0.5; | ||
fvar<var> a_v = a_d; | ||
fvar<var> b_v = b_d; | ||
fvar<var> p_v = p_d; | ||
a_v.d_ = 1.0; | ||
b_v.d_ = 1.0; | ||
p_v.d_ = 1.0; | ||
|
||
fvar<var> res = inc_beta_inv(a_v, b_v, p_v); | ||
res.val_.grad(); | ||
|
||
EXPECT_FLOAT_EQ(a_v.val_.adj(), 0.287698278597); | ||
EXPECT_FLOAT_EQ(b_v.val_.adj(), -0.122532267934); | ||
EXPECT_FLOAT_EQ(p_v.val_.adj(), 0.707106781187); | ||
|
||
a_v = a_d; | ||
b_v = b_d; | ||
p_v = p_d; | ||
a_v.d_ = 1.0; | ||
b_v.d_ = 1.0; | ||
p_v.d_ = 1.0; | ||
|
||
res = inc_beta_inv(a_d, b_v, p_v); | ||
res.val_.grad(); | ||
|
||
EXPECT_FLOAT_EQ(b_v.val_.adj(), -0.122532267934); | ||
EXPECT_FLOAT_EQ(p_v.val_.adj(), 0.707106781187); | ||
|
||
b_v = b_d; | ||
p_v = p_d; | ||
b_v.d_ = 1.0; | ||
p_v.d_ = 1.0; | ||
|
||
res = inc_beta_inv(a_v, b_d, p_v); | ||
res.val_.grad(); | ||
|
||
EXPECT_FLOAT_EQ(a_v.val_.adj(), 0.287698278597); | ||
EXPECT_FLOAT_EQ(p_v.val_.adj(), 0.707106781187); | ||
|
||
a_v = a_d; | ||
p_v = p_d; | ||
a_v.d_ = 1.0; | ||
p_v.d_ = 1.0; | ||
|
||
res = inc_beta_inv(a_v, b_v, p_d); | ||
res.val_.grad(); | ||
|
||
EXPECT_FLOAT_EQ(a_v.val_.adj(), 0.287698278597); | ||
EXPECT_FLOAT_EQ(b_v.val_.adj(), -0.122532267934); | ||
} | ||
|
||
TEST(ProbInternalMath, inc_beta_inv_fv2) { | ||
using stan::math::fvar; | ||
using stan::math::inc_beta_inv; | ||
using stan::math::var; | ||
fvar<fvar<var>> a = 2; | ||
fvar<fvar<var>> b = 5; | ||
fvar<fvar<var>> p = 0.1; | ||
a.d_ = 1.0; | ||
b.d_ = 1.0; | ||
p.d_ = 1.0; | ||
|
||
fvar<fvar<var>> res = inc_beta_inv(a, b, p); | ||
res.val_.val_.grad(); | ||
|
||
EXPECT_FLOAT_EQ(a.val_.val_.adj(), 0.0783025374798); | ||
EXPECT_FLOAT_EQ(b.val_.val_.adj(), -0.0161882044585); | ||
EXPECT_FLOAT_EQ(p.val_.val_.adj(), 0.530989359806); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.