Skip to content

Fix lbeta for large arguments #1612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4dce6ad
Failing test
martinmodrak Jan 14, 2020
2fedf57
Fixes #1611
martinmodrak Jan 14, 2020
b0394a4
Test against derivatives, identitites at around cutoff for Stirling
martinmodrak Jan 14, 2020
c89f18c
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 14, 2020
5ae5f93
Fixing lint errors
martinmodrak Jan 14, 2020
10bbde3
Lint errors
martinmodrak Jan 14, 2020
e67e7a2
Revert expect_near_rel.hpp
martinmodrak Jan 14, 2020
b74c0bc
Line end
martinmodrak Jan 14, 2020
6b87d37
Format and comment improvements as suggested in review.
martinmodrak Jan 14, 2020
4157942
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 14, 2020
c74b152
Fixed headers
martinmodrak Jan 14, 2020
abe6d58
Removed problematic constexpr
martinmodrak Jan 16, 2020
0a93fdc
Tighten test accuracy, use more terms of Stirling series in lgamma_st…
martinmodrak Jan 16, 2020
6778e62
Merge commit '45dce152f1fdb6fc079218d76b611d2664bf305e' into HEAD
yashikno Jan 16, 2020
571e727
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 16, 2020
ab16ea5
Ceased using expect_near_rel
martinmodrak Jan 16, 2020
06cf322
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Jan 16, 2020
1ced285
Slightly relaxed test to pass on Linux
martinmodrak Jan 17, 2020
1c96115
Suggestions from @mcol
martinmodrak Jan 17, 2020
bcf5d48
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 17, 2020
95aa753
Further tweak to test tolerance
martinmodrak Jan 17, 2020
a0ae6c3
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Jan 17, 2020
17f3680
Tweak test tolerance for Linux
martinmodrak Jan 17, 2020
084ddfb
Merge remote-tracking branch 'stan-dev/develop' into bugfix/1611-lbet…
martinmodrak Jan 18, 2020
9ce2093
Slightly better Mathematica code to generate tests
martinmodrak Jan 27, 2020
e731457
Merge remote-tracking branch 'origin/bugfix/1611-lbeta-large-argument…
martinmodrak Jan 27, 2020
c1cea9c
Merge remote-tracking branch 'stan-dev/develop' into bugfix/1611-lbet…
martinmodrak Jan 27, 2020
71b983e
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions stan/math/prim/fun/lbeta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
#define STAN_MATH_PRIM_FUN_LBETA_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/is_any_nan.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/lgamma_stirling.hpp>
#include <stan/math/prim/fun/lgamma_stirling_diff.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/multiply_log.hpp>

namespace stan {
namespace math {
Expand All @@ -22,7 +31,7 @@ namespace math {
*
* See stan::math::lgamma() for the double-based and stan::math for the
* variable-based log Gamma function.
*
* This function is numerically more stable than naive evaluation via lgamma.
*
\f[
\mbox{lbeta}(\alpha, \beta) =
Expand Down Expand Up @@ -54,8 +63,58 @@ namespace math {
* @tparam T2 Type of second value.
*/
template <typename T1, typename T2>
inline return_type_t<T1, T2> lbeta(const T1 a, const T2 b) {
return lgamma(a) + lgamma(b) - lgamma(a + b);
return_type_t<T1, T2> lbeta(const T1 a, const T2 b) {
using T_ret = return_type_t<T1, T2>;

if (is_any_nan(a, b)) {
return NOT_A_NUMBER;
}

static const char* function = "lbeta";
check_nonnegative(function, "first argument", a);
check_nonnegative(function, "second argument", b);
T_ret x; // x is the smaller of the two
T_ret y;
if (a < b) {
x = a;
y = b;
} else {
x = b;
y = a;
}

// Special cases
if (x == 0) {
return INFTY;
}
if (is_inf(y)) {
return NEGATIVE_INFTY;
}

// For large x or y, separate the lgamma values into Stirling approximations
// and appropriate corrections. The Stirling approximations allow for
// analytic simplification and the corrections are added later.
//
// The overall approach is inspired by the code in R, where the algorithm is
// credited to W. Fullerton of Los Alamos Scientific Laboratory
if (y < lgamma_stirling_diff_useful) {
// both small
return lgamma(x) + lgamma(y) - lgamma(x + y);
}
T_ret x_over_xy = x / (x + y);
if (x < lgamma_stirling_diff_useful) {
// y large, x small
T_ret stirling_diff = lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y);
T_ret stirling = (y - 0.5) * log1m(x_over_xy) + x * (1 - log(x + y));
return stirling + lgamma(x) + stirling_diff;
}

// both large
T_ret stirling_diff = lgamma_stirling_diff(x) + lgamma_stirling_diff(y)
- lgamma_stirling_diff(x + y);
T_ret stirling = (x - 0.5) * log(x_over_xy) + y * log1m(x_over_xy)
+ HALF_LOG_TWO_PI - 0.5 * log(y);
return stirling + stirling_diff;
}

} // namespace math
Expand Down
34 changes: 34 additions & 0 deletions stan/math/prim/fun/lgamma_stirling.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_HPP
#define STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <cmath>

namespace stan {
namespace math {

/**
* Return the Stirling approximation to the lgamma function.
*

\f[
\mbox{lgamma_stirling}(x) =
\frac{1}{2} \log(2\pi) + (x-\frac{1}{2})*\log(x) - x
\f]

*
* @tparam T Type of value.
* @param x value
* @return Stirling's approximation to lgamma(x).
*/
template <typename T>
return_type_t<T> lgamma_stirling(const T x) {
return HALF_LOG_TWO_PI + (x - 0.5) * log(x) - x;
}

} // namespace math
} // namespace stan

#endif
81 changes: 81 additions & 0 deletions stan/math/prim/fun/lgamma_stirling_diff.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#ifndef STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_DIFF_HPP
#define STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_DIFF_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/lgamma_stirling.hpp>
#include <stan/math/prim/fun/square.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <cmath>

namespace stan {
namespace math {

constexpr double lgamma_stirling_diff_useful = 10;

/**
* Return the difference between log of the gamma function and its Stirling
* approximation.
* This is useful to stably compute log of ratios of gamma functions with large
* arguments where the Stirling approximation allows for analytic solution
* and the (small) differences can be added afterwards.
* This is for example used in the implementation of lbeta.
*
* The function will return correct value for all arguments, but using it can
* lead to a loss of precision when x < lgamma_stirling_diff_useful.
*
\f[
\mbox{lgamma_stirling_diff}(x) =
\log(\Gamma(x)) - \frac{1}{2} \log(2\pi) +
(x-\frac{1}{2})*\log(x) - x
\f]

*
* @tparam T Type of value.
* @param x value
* @return Difference between lgamma(x) and its Stirling approximation.
*/
template <typename T>
return_type_t<T> lgamma_stirling_diff(const T x) {
using T_ret = return_type_t<T>;

if (is_nan(value_of_rec(x))) {
return NOT_A_NUMBER;
}
check_nonnegative("lgamma_stirling_diff", "argument", x);

if (x == 0) {
return INFTY;
}
if (value_of(x) < lgamma_stirling_diff_useful) {
return lgamma(x) - lgamma_stirling(x);
}

// Using the Stirling series as expressed in formula 5.11.1. at
// https://dlmf.nist.gov/5.11
constexpr double stirling_series[]{
0.0833333333333333333333333, -0.00277777777777777777777778,
0.000793650793650793650793651, -0.000595238095238095238095238,
0.000841750841750841750841751, -0.00191752691752691752691753,
0.00641025641025641025641026, -0.0295506535947712418300654};

constexpr int n_stirling_terms = 6;
T_ret result(0.0);
T_ret multiplier = inv(x);
T_ret inv_x_squared = square(multiplier);
for (int n = 0; n < n_stirling_terms; n++) {
if (n > 0) {
multiplier *= inv_x_squared;
}
result += stirling_series[n] * multiplier;
}
return result;
}

} // namespace math
} // namespace stan

#endif
89 changes: 89 additions & 0 deletions test/unit/math/prim/fun/lbeta_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <gtest/gtest.h>
#include <cmath>
#include <limits>
#include <string>
#include <vector>
#include <algorithm>

TEST(MathFunctions, lbeta) {
using stan::math::lbeta;
Expand All @@ -21,3 +24,89 @@ TEST(MathFunctions, lbeta_nan) {

EXPECT_TRUE(std::isnan(stan::math::lbeta(nan, nan)));
}

TEST(MathFunctions, lbeta_extremes_errors) {
double inf = std::numeric_limits<double>::infinity();
double after_stirling
= std::nextafter(stan::math::lgamma_stirling_diff_useful, inf);
using stan::math::lbeta;

EXPECT_FLOAT_EQ(lbeta(0.0, 1.0), inf);
EXPECT_FLOAT_EQ(lbeta(1.0, 0.0), inf);
EXPECT_FLOAT_EQ(lbeta(0.0, after_stirling), inf);
EXPECT_FLOAT_EQ(lbeta(after_stirling, 0.0), inf);
EXPECT_FLOAT_EQ(lbeta(0.0, 0.0), inf);

EXPECT_FLOAT_EQ(lbeta(inf, 0.0), inf);
EXPECT_FLOAT_EQ(lbeta(0.0, inf), inf);
EXPECT_FLOAT_EQ(lbeta(inf, 1), -inf);
EXPECT_FLOAT_EQ(lbeta(1e8, inf), -inf);
EXPECT_FLOAT_EQ(lbeta(inf, inf), -inf);
}

TEST(MathFunctions, lbeta_identities) {
using stan::math::lbeta;
using stan::math::pi;

std::vector<double> to_test
= {1e-100, 1e-8, 1e-1, 1, 1 + 1e-6, 1e3, 1e30, 1e100};
auto tol = [](double x, double y) {
return std::max(1e-15 * (0.5 * (fabs(x) + fabs(y))), 1e-15);
};

for (double x : to_test) {
for (double y : to_test) {
std::stringstream msg;
msg << std::setprecision(22) << "successors: x = " << x << "; y = " << y;
double lh = lbeta(x, y);
double rh = stan::math::log_sum_exp(lbeta(x + 1, y), lbeta(x, y + 1));
EXPECT_NEAR(lh, rh, tol(lh, rh)) << msg.str();
}
}

for (double x : to_test) {
if (x < 1) {
std::stringstream msg;
msg << std::setprecision(22) << "sin: x = " << x;
double lh = lbeta(x, 1.0 - x);
double rh = log(pi()) - log(sin(pi() * x));
EXPECT_NEAR(lh, rh, tol(lh, rh)) << msg.str();
}
}

for (double x : to_test) {
std::stringstream msg;
msg << std::setprecision(22) << "inv: x = " << x;
double lh = lbeta(x, 1.0);
double rh = -log(x);
EXPECT_NEAR(lh, rh, tol(lh, rh)) << msg.str();
}
}

TEST(MathFunctions, lbeta_stirling_cutoff) {
using stan::math::lgamma_stirling_diff_useful;

double after_stirling
= std::nextafter(lgamma_stirling_diff_useful, stan::math::INFTY);
double before_stirling = std::nextafter(lgamma_stirling_diff_useful, 0);
using stan::math::lbeta;

std::vector<double> to_test
= {1e-100, 1e-8, 1e-1, 1, 1 + 1e-6, 1e3, 1e30, 1e100,
before_stirling, after_stirling};
for (const double x : to_test) {
double before = lbeta(x, before_stirling);
double at = lbeta(x, lgamma_stirling_diff_useful);
double after = lbeta(x, after_stirling);

double diff_before = at - before;
double diff_after = after - at;
double tol
= std::max(1e-15 * (0.5 * (fabs(diff_before) + fabs(diff_after))),
1e-14 * fabs(at));

EXPECT_NEAR(diff_before, diff_after, tol)
<< "diff before and after cutoff: x = " << x << "; before = " << before
<< "; at = " << at << "; after = " << after;
}
}
Loading