1
- #ifndef STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
2
- #define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
1
+ #ifndef STAN_MATH_PRIM_SCAL_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
2
+ #define STAN_MATH_PRIM_SCAL_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
3
3
4
+ #include < limits>
4
5
#include < stan/math/prim/meta.hpp>
5
- #include < stan/math/prim/fun/inv.hpp>
6
6
#include < stan/math/prim/fun/lgamma.hpp>
7
- #include < stan/math/prim/fun/multiply_log.hpp>
7
+ #include < stan/math/prim/fun/lbeta.hpp>
8
+ #include < stan/math/prim/fun/constants.hpp>
9
+ #include < stan/math/prim/err/check_nonnegative.hpp>
10
+ #include < stan/math/prim/err/check_greater_or_equal.hpp>
8
11
9
12
namespace stan {
10
13
namespace math {
14
+
11
15
/* *
12
16
* Return the log of the binomial coefficient for the specified
13
17
* arguments.
@@ -23,11 +27,15 @@ namespace math {
23
27
* \f$ \log {N \choose n}
24
28
* = \log \ \Gamma(N+1) - \log \Gamma(n+1) - \log \Gamma(N-n+1)\f$.
25
29
*
30
+ *
31
+ * TODO[martinmodrak] figure out the cases for x < 0 and for partials
26
32
\f[
27
33
\mbox{binomial\_coefficient\_log}(x, y) =
28
34
\begin{cases}
29
- \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
30
- \ln\Gamma(x+1) & \mbox{if } 0\leq y \leq x \\
35
+ \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
36
+ < -1\\
37
+ \textrm{-\infty} & \mbox{if } y = x + 1 \textrm{ or } y = -1\\
38
+ \ln\Gamma(x+1) & \mbox{if } -1 < y < x + 1 \\
31
39
\quad -\ln\Gamma(y+1)& \\
32
40
\quad -\ln\Gamma(x-y+1)& \\[6pt]
33
41
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
@@ -54,22 +62,39 @@ namespace math {
54
62
\end{cases}
55
63
\f]
56
64
*
65
+ * This function is numerically more stable than naive evaluation via lgamma
66
+ *
57
67
* @param N total number of objects.
58
68
* @param n number of objects chosen.
59
69
* @return log (N choose n).
60
70
*/
71
+
61
72
template <typename T_N, typename T_n>
62
73
inline return_type_t <T_N, T_n> binomial_coefficient_log (const T_N N,
63
74
const T_n n) {
64
- const double CUTOFF = 1000 ;
65
- if (N - n < CUTOFF) {
66
- const T_N N_plus_1 = N + 1 ;
75
+ if (is_nan (value_of_rec (N)) || is_nan (value_of_rec (n))) {
76
+ return std::numeric_limits<double >::quiet_NaN ();
77
+ }
78
+
79
+ // For some uses it is important this works even when N < 0 and therefore
80
+ // it is before checks
81
+ if (n == 0 ) {
82
+ return 0 ;
83
+ }
84
+ const T_N N_plus_1 = N + 1 ;
85
+
86
+ static const char * function = " binomial_coefficient_log" ;
87
+ check_greater_or_equal (function, " first argument" , N, -1 );
88
+ check_greater_or_equal (function, " second argument" , n, -1 );
89
+ check_greater_or_equal (function, " (first argument - second argument + 1)" ,
90
+ N - n + 1 , 0.0 );
91
+
92
+ if (N / 2 < n) {
93
+ return binomial_coefficient_log (N, N - n);
94
+ } else if (N_plus_1 < lgamma_stirling_diff_useful) {
67
95
return lgamma (N_plus_1) - lgamma (n + 1 ) - lgamma (N_plus_1 - n);
68
96
} else {
69
- return_type_t <T_N, T_n> N_minus_n = N - n;
70
- const double one_twelfth = inv (12 );
71
- return multiply_log (n, N_minus_n) + multiply_log ((N + 0.5 ), N / N_minus_n)
72
- + one_twelfth / N - n - one_twelfth / N_minus_n - lgamma (n + 1 );
97
+ return -lbeta (N - n + 1 , n + 1 ) - log (N_plus_1);
73
98
}
74
99
}
75
100
0 commit comments