Skip to content

Commit 7991f40

Browse files
committed
Fixes #1592
1 parent 8a8ea67 commit 7991f40

File tree

1 file changed

+38
-13
lines changed

1 file changed

+38
-13
lines changed

stan/math/prim/fun/binomial_coefficient_log.hpp

+38-13
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
#ifndef STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
2-
#define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
1+
#ifndef STAN_MATH_PRIM_SCAL_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
2+
#define STAN_MATH_PRIM_SCAL_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
33

4+
#include <limits>
45
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/fun/inv.hpp>
66
#include <stan/math/prim/fun/lgamma.hpp>
7-
#include <stan/math/prim/fun/multiply_log.hpp>
7+
#include <stan/math/prim/fun/lbeta.hpp>
8+
#include <stan/math/prim/fun/constants.hpp>
9+
#include <stan/math/prim/err/check_nonnegative.hpp>
10+
#include <stan/math/prim/err/check_greater_or_equal.hpp>
811

912
namespace stan {
1013
namespace math {
14+
1115
/**
1216
* Return the log of the binomial coefficient for the specified
1317
* arguments.
@@ -23,11 +27,15 @@ namespace math {
2327
* \f$ \log {N \choose n}
2428
* = \log \ \Gamma(N+1) - \log \Gamma(n+1) - \log \Gamma(N-n+1)\f$.
2529
*
30+
*
31+
* TODO[martinmodrak] figure out the cases for x < 0 and for partials
2632
\f[
2733
\mbox{binomial\_coefficient\_log}(x, y) =
2834
\begin{cases}
29-
\textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
30-
\ln\Gamma(x+1) & \mbox{if } 0\leq y \leq x \\
35+
\textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
36+
< -1\\
37+
\textrm{-\infty} & \mbox{if } y = x + 1 \textrm{ or } y = -1\\
38+
\ln\Gamma(x+1) & \mbox{if } -1 < y < x + 1 \\
3139
\quad -\ln\Gamma(y+1)& \\
3240
\quad -\ln\Gamma(x-y+1)& \\[6pt]
3341
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
@@ -54,22 +62,39 @@ namespace math {
5462
\end{cases}
5563
\f]
5664
*
65+
* This function is numerically more stable than naive evaluation via lgamma
66+
*
5767
* @param N total number of objects.
5868
* @param n number of objects chosen.
5969
* @return log (N choose n).
6070
*/
71+
6172
template <typename T_N, typename T_n>
6273
inline return_type_t<T_N, T_n> binomial_coefficient_log(const T_N N,
6374
const T_n n) {
64-
const double CUTOFF = 1000;
65-
if (N - n < CUTOFF) {
66-
const T_N N_plus_1 = N + 1;
75+
if (is_nan(value_of_rec(N)) || is_nan(value_of_rec(n))) {
76+
return std::numeric_limits<double>::quiet_NaN();
77+
}
78+
79+
// For some uses it is important this works even when N < 0 and therefore
80+
// it is before checks
81+
if (n == 0) {
82+
return 0;
83+
}
84+
const T_N N_plus_1 = N + 1;
85+
86+
static const char* function = "binomial_coefficient_log";
87+
check_greater_or_equal(function, "first argument", N, -1);
88+
check_greater_or_equal(function, "second argument", n, -1);
89+
check_greater_or_equal(function, "(first argument - second argument + 1)",
90+
N - n + 1, 0.0);
91+
92+
if (N / 2 < n) {
93+
return binomial_coefficient_log(N, N - n);
94+
} else if (N_plus_1 < lgamma_stirling_diff_useful) {
6795
return lgamma(N_plus_1) - lgamma(n + 1) - lgamma(N_plus_1 - n);
6896
} else {
69-
return_type_t<T_N, T_n> N_minus_n = N - n;
70-
const double one_twelfth = inv(12);
71-
return multiply_log(n, N_minus_n) + multiply_log((N + 0.5), N / N_minus_n)
72-
+ one_twelfth / N - n - one_twelfth / N_minus_n - lgamma(n + 1);
97+
return -lbeta(N - n + 1, n + 1) - log(N_plus_1);
7398
}
7499
}
75100

0 commit comments

Comments
 (0)