2
2
#define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
3
3
4
4
#include < stan/math/prim/meta.hpp>
5
- #include < stan/math/prim/fun/inv.hpp>
5
+ #include < stan/math/prim/err.hpp>
6
+ #include < stan/math/prim/fun/constants.hpp>
7
+ #include < stan/math/prim/fun/digamma.hpp>
8
+ #include < stan/math/prim/fun/is_any_nan.hpp>
9
+ #include < stan/math/prim/fun/log1p.hpp>
10
+ #include < stan/math/prim/fun/lbeta.hpp>
6
11
#include < stan/math/prim/fun/lgamma.hpp>
7
- #include < stan/math/prim/fun/multiply_log .hpp>
12
+ #include < stan/math/prim/fun/value_of .hpp>
8
13
9
14
namespace stan {
10
15
namespace math {
@@ -13,22 +18,24 @@ namespace math {
13
18
* Return the log of the binomial coefficient for the specified
14
19
* arguments.
15
20
*
16
- * The binomial coefficient, \f${N \choose n }\f$, read "N choose n ", is
17
- * defined for \f$0 \leq n \leq N \f$ by
21
+ * The binomial coefficient, \f${n \choose k }\f$, read "n choose k ", is
22
+ * defined for \f$0 \leq k \leq n \f$ by
18
23
*
19
- * \f${N \choose n } = \frac{N !}{n ! (N-n )!}\f$.
24
+ * \f${n \choose k } = \frac{n !}{k ! (n-k )!}\f$.
20
25
*
21
26
* This function uses Gamma functions to define the log
22
- * and generalize the arguments to continuous N and n.
27
+ * and generalize the arguments to continuous n and k.
28
+ *
29
+ * \f$ \log {n \choose k}
30
+ * = \log \ \Gamma(n+1) - \log \Gamma(k+1) - \log \Gamma(n-k+1)\f$.
23
31
*
24
- * \f$ \log {N \choose n}
25
- * = \log \ \Gamma(N+1) - \log \Gamma(n+1) - \log \Gamma(N-n+1)\f$.
26
32
*
27
33
\f[
28
34
\mbox{binomial\_coefficient\_log}(x, y) =
29
35
\begin{cases}
30
- \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
31
- \ln\Gamma(x+1) & \mbox{if } 0\leq y \leq x \\
36
+ \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
37
+ < -1\\
38
+ \ln\Gamma(x+1) & \mbox{if } -1 < y < x + 1 \\
32
39
\quad -\ln\Gamma(y+1)& \\
33
40
\quad -\ln\Gamma(x-y+1)& \\[6pt]
34
41
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
@@ -38,7 +45,8 @@ namespace math {
38
45
\f[
39
46
\frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial x} =
40
47
\begin{cases}
41
- \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
48
+ \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
49
+ < -1\\
42
50
\Psi(x+1) & \mbox{if } 0\leq y \leq x \\
43
51
\quad -\Psi(x-y+1)& \\[6pt]
44
52
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
@@ -48,32 +56,95 @@ namespace math {
48
56
\f[
49
57
\frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial y} =
50
58
\begin{cases}
51
- \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
59
+ \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
60
+ < -1\\
52
61
-\Psi(y+1) & \mbox{if } 0\leq y \leq x \\
53
62
\quad +\Psi(x-y+1)& \\[6pt]
54
63
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
55
64
\end{cases}
56
65
\f]
57
66
*
58
- * @tparam T_N type of the first argument
59
- * @tparam T_n type of the second argument
60
- * @param N total number of objects.
61
- * @param n number of objects chosen.
62
- * @return log (N choose n).
67
+ * This function is numerically more stable than naive evaluation via lgamma.
68
+ *
69
+ * @tparam T_n type of the first argument
70
+ * @tparam T_k type of the second argument
71
+ *
72
+ * @param n total number of objects.
73
+ * @param k number of objects chosen.
74
+ * @return log (n choose k).
63
75
*/
64
- template <typename T_N, typename T_n>
65
- inline return_type_t <T_N, T_n> binomial_coefficient_log (const T_N N,
66
- const T_n n) {
67
- const double CUTOFF = 1000 ;
68
- if (N - n < CUTOFF) {
69
- const T_N N_plus_1 = N + 1 ;
70
- return lgamma (N_plus_1) - lgamma (n + 1 ) - lgamma (N_plus_1 - n);
76
+
77
+ template <typename T_n, typename T_k>
78
+ inline return_type_t <T_n, T_k> binomial_coefficient_log (const T_n n,
79
+ const T_k k) {
80
+ using T_partials_return = partials_return_t <T_n, T_k>;
81
+
82
+ if (is_any_nan (n, k)) {
83
+ return NOT_A_NUMBER;
84
+ }
85
+
86
+ // Choosing the more stable of the symmetric branches
87
+ if (n > -1 && k > value_of_rec (n) / 2.0 + 1e-8 ) {
88
+ return binomial_coefficient_log (n, n - k);
89
+ }
90
+
91
+ const T_partials_return n_dbl = value_of (n);
92
+ const T_partials_return k_dbl = value_of (k);
93
+ const T_partials_return n_plus_1 = n_dbl + 1 ;
94
+ const T_partials_return n_plus_1_mk = n_plus_1 - k_dbl;
95
+
96
+ static const char * function = " binomial_coefficient_log" ;
97
+ check_greater_or_equal (function, " first argument" , n, -1 );
98
+ check_greater_or_equal (function, " second argument" , k, -1 );
99
+ check_greater_or_equal (function, " (first argument - second argument + 1)" ,
100
+ n_plus_1_mk, 0.0 );
101
+
102
+ operands_and_partials<T_n, T_k> ops_partials (n, k);
103
+
104
+ T_partials_return value;
105
+ if (k_dbl == 0 ) {
106
+ value = 0 ;
107
+ } else if (n_plus_1 < lgamma_stirling_diff_useful) {
108
+ value = lgamma (n_plus_1) - lgamma (k_dbl + 1 ) - lgamma (n_plus_1_mk);
71
109
} else {
72
- return_type_t <T_N, T_n> N_minus_n = N - n;
73
- const double one_twelfth = inv (12 );
74
- return multiply_log (n, N_minus_n) + multiply_log ((N + 0.5 ), N / N_minus_n)
75
- + one_twelfth / N - n - one_twelfth / N_minus_n - lgamma (n + 1 );
110
+ value = -lbeta (n_plus_1_mk, k_dbl + 1 ) - log1p (n_dbl);
76
111
}
112
+
113
+ if (!is_constant_all<T_n, T_k>::value) {
114
+ // Branching on all the edge cases.
115
+ // In direct computation many of those would be NaN
116
+ // But one-sided limits from within the domain exist, all of the below
117
+ // follows from lim x->0 from above digamma(x) == -Inf
118
+ //
119
+ // Note that we have k < n / 2 (see the first branch in this function)
120
+ // se we can ignore the n == k - 1 edge case.
121
+ T_partials_return digamma_n_plus_1_mk = digamma (n_plus_1_mk);
122
+
123
+ if (!is_constant_all<T_n>::value) {
124
+ if (n_dbl == -1.0 ) {
125
+ if (k_dbl == 0 ) {
126
+ ops_partials.edge1_ .partials_ [0 ] = 0 ;
127
+ } else {
128
+ ops_partials.edge1_ .partials_ [0 ] = NEGATIVE_INFTY;
129
+ }
130
+ } else {
131
+ ops_partials.edge1_ .partials_ [0 ]
132
+ = (digamma (n_plus_1) - digamma_n_plus_1_mk);
133
+ }
134
+ }
135
+ if (!is_constant_all<T_k>::value) {
136
+ if (k_dbl == 0 && n_dbl == -1.0 ) {
137
+ ops_partials.edge2_ .partials_ [0 ] = NEGATIVE_INFTY;
138
+ } else if (k_dbl == -1 ) {
139
+ ops_partials.edge2_ .partials_ [0 ] = INFTY;
140
+ } else {
141
+ ops_partials.edge2_ .partials_ [0 ]
142
+ = (digamma_n_plus_1_mk - digamma (k_dbl + 1 ));
143
+ }
144
+ }
145
+ }
146
+
147
+ return ops_partials.build (value);
77
148
}
78
149
79
150
} // namespace math
0 commit comments