Skip to content

Commit 2fedf57

Browse files
committed
Fixes #1611
1 parent 4dce6ad commit 2fedf57

File tree

4 files changed

+268
-4
lines changed

4 files changed

+268
-4
lines changed

stan/math/prim/fun/lbeta.hpp

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
#ifndef STAN_MATH_PRIM_FUN_LBETA_HPP
2-
#define STAN_MATH_PRIM_FUN_LBETA_HPP
1+
#ifndef STAN_MATH_PRIM_SCAL_FUN_LBETA_HPP
2+
#define STAN_MATH_PRIM_SCAL_FUN_LBETA_HPP
33

4+
#include <limits>
45
#include <stan/math/prim/meta.hpp>
56
#include <stan/math/prim/fun/lgamma.hpp>
7+
#include <stan/math/prim/fun/lgamma_stirling.hpp>
8+
#include <stan/math/prim/fun/log_sum_exp.hpp>
9+
#include <stan/math/prim/fun/lgamma_stirling_diff.hpp>
10+
#include <stan/math/prim/fun/multiply_log.hpp>
11+
#include <stan/math/prim/fun/inv.hpp>
12+
#include <stan/math/prim/err/check_nonnegative.hpp>
613

714
namespace stan {
815
namespace math {
@@ -22,7 +29,7 @@ namespace math {
2229
*
2330
* See stan::math::lgamma() for the double-based and stan::math for the
2431
* variable-based log Gamma function.
25-
*
32+
* This function is numerically more stable than naive evaluation via lgamma
2633
*
2734
\f[
2835
\mbox{lbeta}(\alpha, \beta) =
@@ -55,7 +62,48 @@ namespace math {
5562
*/
5663
template <typename T1, typename T2>
5764
inline return_type_t<T1, T2> lbeta(const T1 a, const T2 b) {
58-
return lgamma(a) + lgamma(b) - lgamma(a + b);
65+
typedef return_type_t<T1, T2> T_ret;
66+
67+
if (is_nan(value_of_rec(a)) || is_nan(value_of_rec(b))) {
68+
return std::numeric_limits<double>::quiet_NaN();
69+
}
70+
71+
static const char* function = "lbeta";
72+
check_nonnegative(function, "first argument", a);
73+
check_nonnegative(function, "second argument", b);
74+
T_ret x; // x is the smaller of the two
75+
T_ret y;
76+
if (a < b) {
77+
x = a;
78+
y = b;
79+
} else {
80+
x = b;
81+
y = a;
82+
}
83+
84+
// For large x or y, separate the lgamma values into Stirling approximations
85+
// and appropriate corrections. The Stirling approximations allow for
86+
// analytic simplifaction and the corrections are added later.
87+
//
88+
// The overall approach is inspired by the code in R, where the algorithm is
89+
// credited to W. Fullerton of Los Alamos Scientific Laboratory
90+
if (y < lgamma_stirling_diff_useful) {
91+
// both small
92+
return lgamma(x) + lgamma(y) - lgamma(x + y);
93+
} else if (x < lgamma_stirling_diff_useful) {
94+
// y large, x small
95+
T_ret stirling_diff = lgamma_stirling_diff(y) - lgamma_stirling_diff(x + y);
96+
T_ret log_x_y = log(x + y);
97+
T_ret stirling = (y - 0.5) * log1p(-x / (x + y)) + x * (1 - log_x_y);
98+
return stirling + lgamma(x) + stirling_diff;
99+
} else {
100+
// both large
101+
T_ret stirling_diff = lgamma_stirling_diff(x) + lgamma_stirling_diff(y)
102+
- lgamma_stirling_diff(x + y);
103+
T_ret stirling = (x - 0.5) * log(x / (x + y)) + y * log1p(-x / (x + y))
104+
+ 0.5 * (LOG_TWO_PI - log(y));
105+
return stirling + stirling_diff;
106+
}
59107
}
60108

61109
} // namespace math
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_HPP
2+
#define STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_HPP
3+
4+
#include <cmath>
5+
#include <stan/math/prim/fun/constants.hpp>
6+
#include <stan/math/prim/fun/lgamma.hpp>
7+
#include <stan/math/prim/meta/return_type.hpp>
8+
9+
namespace stan {
10+
namespace math {
11+
12+
/**
13+
* Return the Stirling approximation to the gamma function.
14+
*
15+
16+
\f[
17+
\mbox{lgamma_stirling}(x) =
18+
\frac{1}{2} \log(2\pi) + (x-\frac{1}{2})*\log(x) - x
19+
\f]
20+
21+
*
22+
* @param x value
23+
* @return Stirling's approximation to lgamma(x).
24+
* @tparam T Type of value.
25+
*/
26+
template <typename T>
27+
return_type_t<T> lgamma_stirling(const T x) {
28+
return 0.5 * LOG_TWO_PI + (x - 0.5) * log(x) - x;
29+
}
30+
31+
} // namespace math
32+
} // namespace stan
33+
34+
#endif
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#ifndef STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_DIFF_HPP
2+
#define STAN_MATH_PRIM_FUN_LGAMMA_STIRLING_DIFF_HPP
3+
4+
#include <cmath>
5+
#include <stan/math/prim/fun/value_of.hpp>
6+
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/fun/lgamma.hpp>
8+
#include <stan/math/prim/fun/inv.hpp>
9+
#include <stan/math/prim/fun/square.hpp>
10+
#include <stan/math/prim/fun/lgamma_stirling.hpp>
11+
#include <stan/math/prim/err/check_nonnegative.hpp>
12+
#include <stan/math/prim/meta/return_type.hpp>
13+
14+
namespace stan {
15+
namespace math {
16+
17+
constexpr double lgamma_stirling_diff_useful = 10;
18+
19+
/**
20+
* Return the difference between log of the gamma function and it's Stirling
21+
* approximation.
22+
* This is useful to stably compute log of ratios of gamma functions with large
23+
* arguments where the Stirling approximation allows for analytic solution
24+
* and the (small) differences can be added afterwards.
25+
* This is for example used in the implementation of lbeta.
26+
*
27+
* The function will return correct value for all arguments, but the can add
28+
* precision only when x >= lgamma_stirling_diff_useful.
29+
*
30+
\f[
31+
\mbox{lgamma_stirling_diff}(x) =
32+
\log(\Gamma(x)) - \frac{1}{2} \log(2\pi) +
33+
(x-\frac{1}{2})*\log(x) - x
34+
\f]
35+
36+
*
37+
* @param x value
38+
* @return Difference between lgamma(x) and it's Stirling approximation.
39+
* @tparam T Type of value.
40+
*/
41+
42+
template <typename T>
43+
return_type_t<T> lgamma_stirling_diff(const T x) {
44+
if (is_nan(value_of_rec(x))) {
45+
return std::numeric_limits<double>::quiet_NaN();
46+
}
47+
typedef return_type_t<T> T_Ret;
48+
49+
static const char* function = "lgamma_stirling_diff";
50+
check_nonnegative(function, "argument", x);
51+
52+
if(x == 0) {
53+
return std::numeric_limits<double>::infinity();
54+
} else if (value_of(x) < lgamma_stirling_diff_useful) {
55+
return lgamma(x) - lgamma_stirling(x);
56+
} else {
57+
// Using the Stirling series as expressed in formula 5.11.1. at
58+
// https://dlmf.nist.gov/5.11
59+
constexpr double stirling_series[] = {
60+
0.0833333333333333333333333,
61+
-0.00277777777777777777777778,
62+
0.000793650793650793650793651,
63+
-0.000595238095238095238095238,
64+
};
65+
constexpr int n_stirling_terms = 3;
66+
T_Ret inv_x = inv(x);
67+
T_Ret inv_x_squared = square(inv_x);
68+
T_Ret inv_x_cubed = inv_x * inv_x_squared;
69+
T_Ret inv_x_fifth = inv_x_cubed * inv_x_squared;
70+
return stirling_series[0] * inv_x + stirling_series[1] * inv_x_cubed
71+
+ stirling_series[2] * inv_x_fifth;
72+
}
73+
}
74+
75+
} // namespace math
76+
} // namespace stan
77+
78+
#endif
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#include <cmath>
2+
#include <limits>
3+
#include <vector>
4+
#include <gtest/gtest.h>
5+
#include <test/unit/math/expect_near_rel.hpp>
6+
#include <stan/math/prim.hpp>
7+
#include <stan/math/prim/fun/lgamma_stirling_diff.hpp>
8+
9+
TEST(MathFunctions, lgamma_stirling_diff_errors_special_cases) {
10+
using stan::math::lgamma_stirling_diff;
11+
12+
double nan = std::numeric_limits<double>::quiet_NaN();
13+
double inf = std::numeric_limits<double>::infinity();
14+
15+
EXPECT_TRUE(std::isnan(lgamma_stirling_diff(nan)));
16+
EXPECT_FLOAT_EQ(lgamma_stirling_diff(inf), 0);
17+
EXPECT_THROW(std::isnan(lgamma_stirling_diff(-1.0)), std::domain_error);
18+
EXPECT_TRUE(std::isinf(lgamma_stirling_diff(0.0)));
19+
EXPECT_TRUE(lgamma_stirling_diff(0) > 0);
20+
}
21+
22+
TEST(MathFunctions, lgamma_stirling_diff_accuracy) {
23+
using stan::math::lgamma_stirling_diff;
24+
using stan::math::lgamma_stirling_diff_useful;
25+
using stan::test::expect_near_rel;
26+
27+
double start = std::nextafter(10, 11);
28+
for (double x = start; x < 1e8; x *= 1.5) {
29+
long double x_l = static_cast<long double>(x);
30+
long double stirling
31+
= 0.5 * std::log(2 * static_cast<long double>(stan::math::pi()))
32+
+ (x_l - 0.5) * std::log(x_l) - x_l;
33+
long double lgamma_res = std::lgamma(x_l);
34+
double diff_actual = static_cast<double>(lgamma_res - stirling);
35+
double diff = lgamma_stirling_diff(x);
36+
37+
std::ostringstream msg;
38+
msg << "x = " << x << "; lgamma = " << lgamma_res
39+
<< "; stirling = " << stirling;
40+
expect_near_rel(msg.str(), diff, diff_actual, 1e-4);
41+
}
42+
43+
double before_big = std::nextafter(lgamma_stirling_diff_useful, 0);
44+
double after_big = std::nextafter(lgamma_stirling_diff_useful,
45+
stan::math::positive_infinity());
46+
expect_near_rel("big cutoff", lgamma_stirling_diff(before_big),
47+
lgamma_stirling_diff(after_big));
48+
}
49+
50+
namespace lgamma_stirling_diff_test_internal {
51+
struct TestValue {
52+
double x;
53+
double val;
54+
};
55+
56+
std::vector<TestValue> testValues = {
57+
{1.049787068367863943, 0.077388806767834476832},
58+
{1.1353352832366126919, 0.071790358566585005482},
59+
{1.3678794411714423216, 0.059960812482712981438},
60+
{2., 0.041340695955409294094},
61+
{3.7182818284590452354, 0.022358812123082674471},
62+
{8.3890560989306502272, 0.0099288907523535997267},
63+
{21.085536923187667741, 0.0039518599801395734578},
64+
{55.598150033144239078, 0.0014988346688724404687},
65+
{149.41315910257660342, 0.0005577367442531155476},
66+
{404.42879349273512261, 0.00020605188772717995062},
67+
{1097.6331584284585993, 0.000075920930766205666598},
68+
{2981.9579870417282747, 0.00002794584410078046085},
69+
{8104.0839275753840077, 0.000010282881326966996581},
70+
{22027.465794806716517, 3.7831557249429676373e-6},
71+
{59875.141715197818455, 1.3917851539949910276e-6},
72+
{162755.79141900392081, 5.1201455018391551878e-7},
73+
{442414.39200892050333, 1.8836035815859912686e-7},
74+
{1.2026052841647767777e6, 6.9294002305342748064e-8},
75+
{3.2690183724721106393e6, 2.5491852243801980915e-8},
76+
{8.8861115205078726368e6, 9.3779301712579119232e-9},
77+
{2.4154953753575298215e7, 3.4499479561619419703e-9},
78+
{6.5659970137330511139e7, 1.2691649593966957356e-9},
79+
{1.7848230196318726084e8, 4.6689970051216164913e-10},
80+
{4.8516519640979027797e8, 1.7176280151585029843e-10},
81+
{1.3188157354832146972e9, 6.3188003518019870566e-11},
82+
{3.5849128471315915617e9, 2.3245567434090096532e-11},
83+
{9.7448034472489026e9, 8.5515663588740238069e-12},
84+
{2.6489122130843472294e10, 3.1459454534471511195e-12},
85+
{7.2004899338385872524e10, 1.1573286553975954675e-12},
86+
{1.9572960942983876427e11, 4.2575741900310182743e-13},
87+
{5.3204824060279861668e11, 1.5662740137796255552e-13},
88+
{1.4462570642924751737e12, 5.7620000891128517638e-14},
89+
};
90+
91+
} // namespace lgamma_stirling_diff_test_internal
92+
93+
TEST(MathFunctions, lgamma_stirling_diff_precomputed) {
94+
using stan::math::lgamma_stirling_diff;
95+
using stan::test::expect_near_rel;
96+
using lgamma_stirling_diff_test_internal::TestValue;
97+
using lgamma_stirling_diff_test_internal::testValues;
98+
99+
for (TestValue t : testValues) {
100+
std::ostringstream msg;
101+
msg << "x = " << t.x;
102+
expect_near_rel(msg.str(), lgamma_stirling_diff(t.x), t.val, 1e-10);
103+
}
104+
}

0 commit comments

Comments
 (0)