Skip to content

Commit a46427f

Browse files
authored
Merge pull request #1612 from martinmodrak/bugfix/1611-lbeta-large-arguments
Fix lbeta for large arguments
2 parents 59c2bff + 71b983e commit a46427f

File tree

6 files changed

+587
-3
lines changed

6 files changed

+587
-3
lines changed

stan/math/prim/fun/lbeta.hpp

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22
#define STAN_MATH_PRIM_FUN_LBETA_HPP
33

44
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/fun/inv.hpp>
8+
#include <stan/math/prim/fun/is_any_nan.hpp>
59
#include <stan/math/prim/fun/lgamma.hpp>
10+
#include <stan/math/prim/fun/lgamma_stirling.hpp>
11+
#include <stan/math/prim/fun/lgamma_stirling_diff.hpp>
12+
#include <stan/math/prim/fun/log_sum_exp.hpp>
13+
#include <stan/math/prim/fun/log1m.hpp>
14+
#include <stan/math/prim/fun/multiply_log.hpp>
615

716
namespace stan {
817
namespace math {
@@ -22,7 +31,7 @@ namespace math {
2231
*
2332
* See stan::math::lgamma() for the double-based and stan::math for the
2433
* variable-based log Gamma function.
25-
*
34+
* This function is numerically more stable than naive evaluation via lgamma.
2635
*
2736
\f[
2837
\mbox{lbeta}(\alpha, \beta) =
@@ -54,8 +63,58 @@ namespace math {
5463
* @tparam T2 Type of second value.
5564
*/
5665
template <typename T1, typename T2>
57-
inline return_type_t<T1, T2> lbeta(const T1 a, const T2 b) {
58-
return lgamma(a) + lgamma(b) - lgamma(a + b);
66+
return_type_t<T1, T2> lbeta(const T1 a, const T2 b) {
67+
using T_ret = return_type_t<T1, T2>;
68+
69+
if (is_any_nan(a, b)) {
70+
return NOT_A_NUMBER;
71+
}
72+
73+
static const char* function = "lbeta";
74+
check_nonnegative(function, "first argument", a);
75+
check_nonnegative(function, "second argument", b);
76+
T_ret x; // x is the smaller of the two
77+
T_ret y;
78+
if (a < b) {
79+
x = a;
80+
y = b;
81+
} else {
82+
x = b;
83+
y = a;
84+
}
85+
86+
// Special cases
87+
if (x == 0) {
88+
return INFTY;
89+
}
90+
if (is_inf(y)) {
91+
return NEGATIVE_INFTY;
92+
}
93+
94+
// For large x or y, separate the lgamma values into Stirling approximations
95+
// and appropriate corrections. The Stirling approximations allow for
96+
// analytic simplification and the corrections are added later.
97+
//
98+
// The overall approach is inspired by the code in R, where the algorithm is
99+
// credited to W. Fullerton of Los Alamos Scientific Laboratory
100+
if (y < lgamma_stirling_diff_useful) {
101+
// both small
102+
return lgamma(x) + lgamma(y) - lgamma(x + y);
103+
}
104+
T_ret x_over_xy = x / (x + y);
105+
if (x < lgamma_stirling_diff_useful) {
106+
// y large, x small
107+
T_ret stirling_diff = lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y);
108+
T_ret stirling = (y - 0.5) * log1m(x_over_xy) + x * (1 - log(x + y));
109+
return stirling + lgamma(x) + stirling_diff;
110+
}
111+
112+
// both large
113+
T_ret stirling_diff = lgamma_stirling_diff(x) + lgamma_stirling_diff(y)
114+
- lgamma_stirling_diff(x + y);
115+
T_ret stirling = (x - 0.5) * log(x_over_xy) + y * log1m(x_over_xy)
116+
+ HALF_LOG_TWO_PI - 0.5 * log(y);
117+
return stirling + stirling_diff;
59118
}
60119

61120
} // namespace math
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_HPP
2+
#define STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/fun/constants.hpp>
6+
#include <stan/math/prim/fun/lgamma.hpp>
7+
#include <cmath>
8+
9+
namespace stan {
10+
namespace math {
11+
12+
/**
13+
* Return the Stirling approximation to the lgamma function.
14+
*
15+
16+
\f[
17+
\mbox{lgamma_stirling}(x) =
18+
\frac{1}{2} \log(2\pi) + (x-\frac{1}{2})*\log(x) - x
19+
\f]
20+
21+
*
22+
* @tparam T Type of value.
23+
* @param x value
24+
* @return Stirling's approximation to lgamma(x).
25+
*/
26+
template <typename T>
27+
return_type_t<T> lgamma_stirling(const T x) {
28+
return HALF_LOG_TWO_PI + (x - 0.5) * log(x) - x;
29+
}
30+
31+
} // namespace math
32+
} // namespace stan
33+
34+
#endif
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#ifndef STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_DIFF_HPP
2+
#define STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_DIFF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/fun/inv.hpp>
8+
#include <stan/math/prim/fun/lgamma.hpp>
9+
#include <stan/math/prim/fun/lgamma_stirling.hpp>
10+
#include <stan/math/prim/fun/square.hpp>
11+
#include <stan/math/prim/fun/value_of.hpp>
12+
#include <cmath>
13+
14+
namespace stan {
15+
namespace math {
16+
17+
constexpr double lgamma_stirling_diff_useful = 10;
18+
19+
/**
20+
* Return the difference between log of the gamma function and its Stirling
21+
* approximation.
22+
* This is useful to stably compute log of ratios of gamma functions with large
23+
* arguments where the Stirling approximation allows for analytic solution
24+
* and the (small) differences can be added afterwards.
25+
* This is for example used in the implementation of lbeta.
26+
*
27+
* The function will return correct value for all arguments, but using it can
28+
* lead to a loss of precision when x < lgamma_stirling_diff_useful.
29+
*
30+
\f[
31+
\mbox{lgamma_stirling_diff}(x) =
32+
\log(\Gamma(x)) - \frac{1}{2} \log(2\pi) +
33+
(x-\frac{1}{2})*\log(x) - x
34+
\f]
35+
36+
*
37+
* @tparam T Type of value.
38+
* @param x value
39+
* @return Difference between lgamma(x) and its Stirling approximation.
40+
*/
41+
template <typename T>
42+
return_type_t<T> lgamma_stirling_diff(const T x) {
43+
using T_ret = return_type_t<T>;
44+
45+
if (is_nan(value_of_rec(x))) {
46+
return NOT_A_NUMBER;
47+
}
48+
check_nonnegative("lgamma_stirling_diff", "argument", x);
49+
50+
if (x == 0) {
51+
return INFTY;
52+
}
53+
if (value_of(x) < lgamma_stirling_diff_useful) {
54+
return lgamma(x) - lgamma_stirling(x);
55+
}
56+
57+
// Using the Stirling series as expressed in formula 5.11.1. at
58+
// https://dlmf.nist.gov/5.11
59+
constexpr double stirling_series[]{
60+
0.0833333333333333333333333, -0.00277777777777777777777778,
61+
0.000793650793650793650793651, -0.000595238095238095238095238,
62+
0.000841750841750841750841751, -0.00191752691752691752691753,
63+
0.00641025641025641025641026, -0.0295506535947712418300654};
64+
65+
constexpr int n_stirling_terms = 6;
66+
T_ret result(0.0);
67+
T_ret multiplier = inv(x);
68+
T_ret inv_x_squared = square(multiplier);
69+
for (int n = 0; n < n_stirling_terms; n++) {
70+
if (n > 0) {
71+
multiplier *= inv_x_squared;
72+
}
73+
result += stirling_series[n] * multiplier;
74+
}
75+
return result;
76+
}
77+
78+
} // namespace math
79+
} // namespace stan
80+
81+
#endif

test/unit/math/prim/fun/lbeta_test.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include <gtest/gtest.h>
33
#include <cmath>
44
#include <limits>
5+
#include <string>
6+
#include <vector>
7+
#include <algorithm>
58

69
TEST(MathFunctions, lbeta) {
710
using stan::math::lbeta;
@@ -21,3 +24,89 @@ TEST(MathFunctions, lbeta_nan) {
2124

2225
EXPECT_TRUE(std::isnan(stan::math::lbeta(nan, nan)));
2326
}
27+
28+
TEST(MathFunctions, lbeta_extremes_errors) {
29+
double inf = std::numeric_limits<double>::infinity();
30+
double after_stirling
31+
= std::nextafter(stan::math::lgamma_stirling_diff_useful, inf);
32+
using stan::math::lbeta;
33+
34+
EXPECT_FLOAT_EQ(lbeta(0.0, 1.0), inf);
35+
EXPECT_FLOAT_EQ(lbeta(1.0, 0.0), inf);
36+
EXPECT_FLOAT_EQ(lbeta(0.0, after_stirling), inf);
37+
EXPECT_FLOAT_EQ(lbeta(after_stirling, 0.0), inf);
38+
EXPECT_FLOAT_EQ(lbeta(0.0, 0.0), inf);
39+
40+
EXPECT_FLOAT_EQ(lbeta(inf, 0.0), inf);
41+
EXPECT_FLOAT_EQ(lbeta(0.0, inf), inf);
42+
EXPECT_FLOAT_EQ(lbeta(inf, 1), -inf);
43+
EXPECT_FLOAT_EQ(lbeta(1e8, inf), -inf);
44+
EXPECT_FLOAT_EQ(lbeta(inf, inf), -inf);
45+
}
46+
47+
TEST(MathFunctions, lbeta_identities) {
48+
using stan::math::lbeta;
49+
using stan::math::pi;
50+
51+
std::vector<double> to_test
52+
= {1e-100, 1e-8, 1e-1, 1, 1 + 1e-6, 1e3, 1e30, 1e100};
53+
auto tol = [](double x, double y) {
54+
return std::max(1e-15 * (0.5 * (fabs(x) + fabs(y))), 1e-15);
55+
};
56+
57+
for (double x : to_test) {
58+
for (double y : to_test) {
59+
std::stringstream msg;
60+
msg << std::setprecision(22) << "successors: x = " << x << "; y = " << y;
61+
double lh = lbeta(x, y);
62+
double rh = stan::math::log_sum_exp(lbeta(x + 1, y), lbeta(x, y + 1));
63+
EXPECT_NEAR(lh, rh, tol(lh, rh)) << msg.str();
64+
}
65+
}
66+
67+
for (double x : to_test) {
68+
if (x < 1) {
69+
std::stringstream msg;
70+
msg << std::setprecision(22) << "sin: x = " << x;
71+
double lh = lbeta(x, 1.0 - x);
72+
double rh = log(pi()) - log(sin(pi() * x));
73+
EXPECT_NEAR(lh, rh, tol(lh, rh)) << msg.str();
74+
}
75+
}
76+
77+
for (double x : to_test) {
78+
std::stringstream msg;
79+
msg << std::setprecision(22) << "inv: x = " << x;
80+
double lh = lbeta(x, 1.0);
81+
double rh = -log(x);
82+
EXPECT_NEAR(lh, rh, tol(lh, rh)) << msg.str();
83+
}
84+
}
85+
86+
TEST(MathFunctions, lbeta_stirling_cutoff) {
87+
using stan::math::lgamma_stirling_diff_useful;
88+
89+
double after_stirling
90+
= std::nextafter(lgamma_stirling_diff_useful, stan::math::INFTY);
91+
double before_stirling = std::nextafter(lgamma_stirling_diff_useful, 0);
92+
using stan::math::lbeta;
93+
94+
std::vector<double> to_test
95+
= {1e-100, 1e-8, 1e-1, 1, 1 + 1e-6, 1e3, 1e30, 1e100,
96+
before_stirling, after_stirling};
97+
for (const double x : to_test) {
98+
double before = lbeta(x, before_stirling);
99+
double at = lbeta(x, lgamma_stirling_diff_useful);
100+
double after = lbeta(x, after_stirling);
101+
102+
double diff_before = at - before;
103+
double diff_after = after - at;
104+
double tol
105+
= std::max(1e-15 * (0.5 * (fabs(diff_before) + fabs(diff_after))),
106+
1e-14 * fabs(at));
107+
108+
EXPECT_NEAR(diff_before, diff_after, tol)
109+
<< "diff before and after cutoff: x = " << x << "; before = " << before
110+
<< "; at = " << at << "; after = " << after;
111+
}
112+
}

0 commit comments

Comments
 (0)