Skip to content

Commit fe3a41c

Browse files
authored
Merge pull request #1614 from martinmodrak/bugfix/1592-binomial_coefficient_log
Improved numerical stability of binomial_coefficient_log
2 parents f2a3c1a + 6aed316 commit fe3a41c

File tree

5 files changed

+582
-31
lines changed

5 files changed

+582
-31
lines changed

stan/math/prim/fun/binomial_coefficient_log.hpp

+99-28
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
#define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/fun/inv.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/fun/digamma.hpp>
8+
#include <stan/math/prim/fun/is_any_nan.hpp>
9+
#include <stan/math/prim/fun/log1p.hpp>
10+
#include <stan/math/prim/fun/lbeta.hpp>
611
#include <stan/math/prim/fun/lgamma.hpp>
7-
#include <stan/math/prim/fun/multiply_log.hpp>
12+
#include <stan/math/prim/fun/value_of.hpp>
813

914
namespace stan {
1015
namespace math {
@@ -13,22 +18,24 @@ namespace math {
1318
* Return the log of the binomial coefficient for the specified
1419
* arguments.
1520
*
16-
* The binomial coefficient, \f${N \choose n}\f$, read "N choose n", is
17-
* defined for \f$0 \leq n \leq N\f$ by
21+
* The binomial coefficient, \f${n \choose k}\f$, read "n choose k", is
22+
* defined for \f$0 \leq k \leq n\f$ by
1823
*
19-
* \f${N \choose n} = \frac{N!}{n! (N-n)!}\f$.
24+
* \f${n \choose k} = \frac{n!}{k! (n-k)!}\f$.
2025
*
2126
* This function uses Gamma functions to define the log
22-
* and generalize the arguments to continuous N and n.
27+
* and generalize the arguments to continuous n and k.
28+
*
29+
* \f$ \log {n \choose k}
30+
* = \log \ \Gamma(n+1) - \log \Gamma(k+1) - \log \Gamma(n-k+1)\f$.
2331
*
24-
* \f$ \log {N \choose n}
25-
* = \log \ \Gamma(N+1) - \log \Gamma(n+1) - \log \Gamma(N-n+1)\f$.
2632
*
2733
\f[
2834
\mbox{binomial\_coefficient\_log}(x, y) =
2935
\begin{cases}
30-
\textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
31-
\ln\Gamma(x+1) & \mbox{if } 0\leq y \leq x \\
36+
\textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
37+
< -1\\
38+
\ln\Gamma(x+1) & \mbox{if } -1 < y < x + 1 \\
3239
\quad -\ln\Gamma(y+1)& \\
3340
\quad -\ln\Gamma(x-y+1)& \\[6pt]
3441
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
@@ -38,7 +45,8 @@ namespace math {
3845
\f[
3946
\frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial x} =
4047
\begin{cases}
41-
\textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
48+
\textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
49+
< -1\\
4250
\Psi(x+1) & \mbox{if } 0\leq y \leq x \\
4351
\quad -\Psi(x-y+1)& \\[6pt]
4452
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
@@ -48,32 +56,95 @@ namespace math {
4856
\f[
4957
\frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial y} =
5058
\begin{cases}
51-
\textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
59+
\textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
60+
< -1\\
5261
-\Psi(y+1) & \mbox{if } 0\leq y \leq x \\
5362
\quad +\Psi(x-y+1)& \\[6pt]
5463
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
5564
\end{cases}
5665
\f]
5766
*
58-
* @tparam T_N type of the first argument
59-
* @tparam T_n type of the second argument
60-
* @param N total number of objects.
61-
* @param n number of objects chosen.
62-
* @return log (N choose n).
67+
* This function is numerically more stable than naive evaluation via lgamma.
68+
*
69+
* @tparam T_n type of the first argument
70+
* @tparam T_k type of the second argument
71+
*
72+
* @param n total number of objects.
73+
* @param k number of objects chosen.
74+
* @return log (n choose k).
6375
*/
64-
template <typename T_N, typename T_n>
65-
inline return_type_t<T_N, T_n> binomial_coefficient_log(const T_N N,
66-
const T_n n) {
67-
const double CUTOFF = 1000;
68-
if (N - n < CUTOFF) {
69-
const T_N N_plus_1 = N + 1;
70-
return lgamma(N_plus_1) - lgamma(n + 1) - lgamma(N_plus_1 - n);
76+
77+
template <typename T_n, typename T_k>
78+
inline return_type_t<T_n, T_k> binomial_coefficient_log(const T_n n,
79+
const T_k k) {
80+
using T_partials_return = partials_return_t<T_n, T_k>;
81+
82+
if (is_any_nan(n, k)) {
83+
return NOT_A_NUMBER;
84+
}
85+
86+
// Choosing the more stable of the symmetric branches
87+
if (n > -1 && k > value_of_rec(n) / 2.0 + 1e-8) {
88+
return binomial_coefficient_log(n, n - k);
89+
}
90+
91+
const T_partials_return n_dbl = value_of(n);
92+
const T_partials_return k_dbl = value_of(k);
93+
const T_partials_return n_plus_1 = n_dbl + 1;
94+
const T_partials_return n_plus_1_mk = n_plus_1 - k_dbl;
95+
96+
static const char* function = "binomial_coefficient_log";
97+
check_greater_or_equal(function, "first argument", n, -1);
98+
check_greater_or_equal(function, "second argument", k, -1);
99+
check_greater_or_equal(function, "(first argument - second argument + 1)",
100+
n_plus_1_mk, 0.0);
101+
102+
operands_and_partials<T_n, T_k> ops_partials(n, k);
103+
104+
T_partials_return value;
105+
if (k_dbl == 0) {
106+
value = 0;
107+
} else if (n_plus_1 < lgamma_stirling_diff_useful) {
108+
value = lgamma(n_plus_1) - lgamma(k_dbl + 1) - lgamma(n_plus_1_mk);
71109
} else {
72-
return_type_t<T_N, T_n> N_minus_n = N - n;
73-
const double one_twelfth = inv(12);
74-
return multiply_log(n, N_minus_n) + multiply_log((N + 0.5), N / N_minus_n)
75-
+ one_twelfth / N - n - one_twelfth / N_minus_n - lgamma(n + 1);
110+
value = -lbeta(n_plus_1_mk, k_dbl + 1) - log1p(n_dbl);
76111
}
112+
113+
if (!is_constant_all<T_n, T_k>::value) {
114+
// Branching on all the edge cases.
115+
// In direct computation many of those would be NaN
116+
// But one-sided limits from within the domain exist, all of the below
117+
// follows from lim x->0 from above digamma(x) == -Inf
118+
//
119+
// Note that we have k < n / 2 (see the first branch in this function)
120+
// se we can ignore the n == k - 1 edge case.
121+
T_partials_return digamma_n_plus_1_mk = digamma(n_plus_1_mk);
122+
123+
if (!is_constant_all<T_n>::value) {
124+
if (n_dbl == -1.0) {
125+
if (k_dbl == 0) {
126+
ops_partials.edge1_.partials_[0] = 0;
127+
} else {
128+
ops_partials.edge1_.partials_[0] = NEGATIVE_INFTY;
129+
}
130+
} else {
131+
ops_partials.edge1_.partials_[0]
132+
= (digamma(n_plus_1) - digamma_n_plus_1_mk);
133+
}
134+
}
135+
if (!is_constant_all<T_k>::value) {
136+
if (k_dbl == 0 && n_dbl == -1.0) {
137+
ops_partials.edge2_.partials_[0] = NEGATIVE_INFTY;
138+
} else if (k_dbl == -1) {
139+
ops_partials.edge2_.partials_[0] = INFTY;
140+
} else {
141+
ops_partials.edge2_.partials_[0]
142+
= (digamma_n_plus_1_mk - digamma(k_dbl + 1));
143+
}
144+
}
145+
}
146+
147+
return ops_partials.build(value);
77148
}
78149

79150
} // namespace math

test/unit/math/mix/fun/binomial_coefficient_log_test.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,9 @@ TEST(mathMixScalFun, binomialCoefficientLog) {
66
};
77
stan::test::expect_ad(f, 3, 2);
88
stan::test::expect_ad(f, 24.0, 12.0);
9+
stan::test::expect_ad(f, 1.0, 0.0);
10+
stan::test::expect_ad(f, 0.0, 1.0);
11+
stan::test::expect_ad(f, -0.3, 0.5);
12+
913
stan::test::expect_common_nonzero_binary(f);
1014
}

test/unit/math/prim/fun/binomial_coefficient_log_test.cpp

+29-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#include <stan/math/prim.hpp>
2+
#include <test/unit/math/expect_near_rel.hpp>
23
#include <gtest/gtest.h>
34
#include <cmath>
4-
#include <limits>
55

66
template <typename T_N, typename T_n>
77
void test_binom_coefficient(const T_N& N, const T_n& n) {
88
using stan::math::binomial_coefficient_log;
99
EXPECT_FLOAT_EQ(lgamma(N + 1) - lgamma(n + 1) - lgamma(N - n + 1),
10-
binomial_coefficient_log(N, n));
10+
binomial_coefficient_log(N, n))
11+
<< "N = " << N << ", n = " << n;
1112
}
1213

1314
TEST(MathFunctions, binomial_coefficient_log) {
@@ -19,6 +20,13 @@ TEST(MathFunctions, binomial_coefficient_log) {
1920

2021
EXPECT_FLOAT_EQ(29979.16, binomial_coefficient_log(100000, 91116));
2122

23+
EXPECT_EQ(binomial_coefficient_log(-1, 0), 0); // Needed for neg_binomial_2
24+
EXPECT_EQ(binomial_coefficient_log(50, 0), 0);
25+
EXPECT_EQ(binomial_coefficient_log(10000, 0), 0);
26+
27+
EXPECT_EQ(binomial_coefficient_log(10, 11), stan::math::NEGATIVE_INFTY);
28+
EXPECT_EQ(binomial_coefficient_log(10, -1), stan::math::NEGATIVE_INFTY);
29+
2230
for (int n = 0; n < 1010; ++n) {
2331
test_binom_coefficient(1010, n);
2432
test_binom_coefficient(1010.0, n);
@@ -32,9 +40,27 @@ TEST(MathFunctions, binomial_coefficient_log) {
3240
}
3341

3442
TEST(MathFunctions, binomial_coefficient_log_nan) {
35-
double nan = std::numeric_limits<double>::quiet_NaN();
43+
double nan = stan::math::NOT_A_NUMBER;
3644

3745
EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(2.0, nan)));
3846
EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(nan, 2.0)));
3947
EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(nan, nan)));
4048
}
49+
50+
TEST(MathFunctions, binomial_coefficient_log_errors_edge_cases) {
51+
using stan::math::INFTY;
52+
using stan::math::binomial_coefficient_log;
53+
54+
EXPECT_NO_THROW(binomial_coefficient_log(10, 11));
55+
EXPECT_THROW(binomial_coefficient_log(10, 11.01), std::domain_error);
56+
EXPECT_THROW(binomial_coefficient_log(10, -1.1), std::domain_error);
57+
EXPECT_THROW(binomial_coefficient_log(-1, 0.3), std::domain_error);
58+
EXPECT_NO_THROW(binomial_coefficient_log(-0.5, 0.49));
59+
EXPECT_NO_THROW(binomial_coefficient_log(10, -0.9));
60+
61+
EXPECT_FLOAT_EQ(binomial_coefficient_log(0, -1), -INFTY);
62+
EXPECT_FLOAT_EQ(binomial_coefficient_log(-1, 0), 0);
63+
EXPECT_FLOAT_EQ(binomial_coefficient_log(-1, -0.3), INFTY);
64+
EXPECT_FLOAT_EQ(binomial_coefficient_log(0.3, -1), -INFTY);
65+
EXPECT_FLOAT_EQ(binomial_coefficient_log(5.0, 6.0), -INFTY);
66+
}

0 commit comments

Comments
 (0)