Skip to content

Commit bb3a5ca

Browse files
authored
Merge pull request #1728 from stan-dev/cleanup/912-beta-binomial-cdfs
Keep computations in log space and simplify expressions in beta binomial_*cdf
2 parents 9fa4e13 + 8544549 commit bb3a5ca

File tree

3 files changed

+44
-76
lines changed

3 files changed

+44
-76
lines changed

stan/math/prim/prob/beta_binomial_cdf.hpp

+15-24
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <stan/math/prim/fun/exp.hpp>
1010
#include <stan/math/prim/fun/F32.hpp>
1111
#include <stan/math/prim/fun/grad_F32.hpp>
12-
#include <stan/math/prim/fun/lgamma.hpp>
12+
#include <stan/math/prim/fun/lbeta.hpp>
1313
#include <stan/math/prim/fun/max_size.hpp>
1414
#include <stan/math/prim/fun/size.hpp>
1515
#include <stan/math/prim/fun/size_zero.hpp>
@@ -87,46 +87,37 @@ return_type_t<T_size1, T_size2> beta_binomial_cdf(const T_n& n, const T_N& N,
8787
const T_partials_return N_dbl = value_of(N_vec[i]);
8888
const T_partials_return alpha_dbl = value_of(alpha_vec[i]);
8989
const T_partials_return beta_dbl = value_of(beta_vec[i]);
90-
90+
const T_partials_return N_minus_n = N_dbl - n_dbl;
9191
const T_partials_return mu = alpha_dbl + n_dbl + 1;
92-
const T_partials_return nu = beta_dbl + N_dbl - n_dbl - 1;
92+
const T_partials_return nu = beta_dbl + N_minus_n - 1;
93+
const T_partials_return one = 1;
9394

9495
const T_partials_return F
95-
= F32((T_partials_return)1, mu, -N_dbl + n_dbl + 1, n_dbl + 2, 1 - nu,
96-
(T_partials_return)1);
97-
98-
T_partials_return C = lgamma(nu) - lgamma(N_dbl - n_dbl);
99-
C += lgamma(mu) - lgamma(n_dbl + 2);
100-
C += lgamma(N_dbl + 2) - lgamma(N_dbl + alpha_dbl + beta_dbl);
101-
C = exp(C);
96+
= F32(one, mu, 1 - N_minus_n, n_dbl + 2, 1 - nu, one);
10297

103-
C *= F / stan::math::beta(alpha_dbl, beta_dbl);
104-
C /= N_dbl + 1;
98+
T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
99+
- lbeta(N_minus_n, n_dbl + 2);
100+
C = F * exp(C) / (N_dbl + 1);
105101

106102
const T_partials_return Pi = 1 - C;
107103

108104
P *= Pi;
109105

110106
T_partials_return dF[6];
111-
T_partials_return digammaOne = 0;
112-
T_partials_return digammaTwo = 0;
107+
T_partials_return digammaDiff = 0;
113108

114109
if (!is_constant_all<T_size1, T_size2>::value) {
115-
digammaOne = digamma(mu + nu);
116-
digammaTwo = digamma(alpha_dbl + beta_dbl);
117-
grad_F32(dF, (T_partials_return)1, mu, -N_dbl + n_dbl + 1, n_dbl + 2,
118-
1 - nu, (T_partials_return)1);
110+
digammaDiff = digamma(mu + nu) - digamma(alpha_dbl + beta_dbl);
111+
grad_F32(dF, one, mu, 1 - N_minus_n, n_dbl + 2, 1 - nu, one);
119112
}
120113
if (!is_constant_all<T_size1>::value) {
121-
const T_partials_return g = -C
122-
* (digamma(mu) - digammaOne + dF[1] / F
123-
- digamma(alpha_dbl) + digammaTwo);
114+
const T_partials_return g
115+
= -C * (digamma(mu) - digamma(alpha_dbl) - digammaDiff + dF[1] / F);
124116
ops_partials.edge1_.partials_[i] += g / Pi;
125117
}
126118
if (!is_constant_all<T_size2>::value) {
127-
const T_partials_return g = -C
128-
* (digamma(nu) - digammaOne - dF[4] / F
129-
- digamma(beta_dbl) + digammaTwo);
119+
const T_partials_return g
120+
= -C * (digamma(nu) - digamma(beta_dbl) - digammaDiff - dF[4] / F);
130121
ops_partials.edge2_.partials_[i] += g / Pi;
131122
}
132123
}

stan/math/prim/prob/beta_binomial_lccdf.hpp

+13-26
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <stan/math/prim/fun/exp.hpp>
1010
#include <stan/math/prim/fun/F32.hpp>
1111
#include <stan/math/prim/fun/grad_F32.hpp>
12-
#include <stan/math/prim/fun/lgamma.hpp>
12+
#include <stan/math/prim/fun/lbeta.hpp>
1313
#include <stan/math/prim/fun/log.hpp>
1414
#include <stan/math/prim/fun/max_size.hpp>
1515
#include <stan/math/prim/fun/size.hpp>
@@ -89,47 +89,34 @@ return_type_t<T_size1, T_size2> beta_binomial_lccdf(const T_n& n, const T_N& N,
8989
const T_partials_return N_dbl = value_of(N_vec[i]);
9090
const T_partials_return alpha_dbl = value_of(alpha_vec[i]);
9191
const T_partials_return beta_dbl = value_of(beta_vec[i]);
92-
9392
const T_partials_return mu = alpha_dbl + n_dbl + 1;
9493
const T_partials_return nu = beta_dbl + N_dbl - n_dbl - 1;
94+
const T_partials_return one = 1;
9595

9696
const T_partials_return F
97-
= F32((T_partials_return)1, mu, -N_dbl + n_dbl + 1, n_dbl + 2, 1 - nu,
98-
(T_partials_return)1);
99-
100-
T_partials_return C = lgamma(nu) - lgamma(N_dbl - n_dbl);
101-
C += lgamma(mu) - lgamma(n_dbl + 2);
102-
C += lgamma(N_dbl + 2) - lgamma(N_dbl + alpha_dbl + beta_dbl);
103-
C = exp(C);
104-
105-
C *= F / stan::math::beta(alpha_dbl, beta_dbl);
106-
C /= N_dbl + 1;
97+
= F32(one, mu, -N_dbl + n_dbl + 1, n_dbl + 2, 1 - nu, one);
98+
T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
99+
- lbeta(N_dbl - n_dbl, n_dbl + 2);
100+
C = F * exp(C) / (N_dbl + 1);
107101

108102
const T_partials_return Pi = C;
109103

110104
P += log(Pi);
111105

112106
T_partials_return dF[6];
113-
T_partials_return digammaOne = 0;
114-
T_partials_return digammaTwo = 0;
107+
T_partials_return digammaDiff = 0;
115108

116109
if (!is_constant_all<T_size1, T_size2>::value) {
117-
digammaOne = digamma(mu + nu);
118-
digammaTwo = digamma(alpha_dbl + beta_dbl);
119-
grad_F32(dF, (T_partials_return)1, mu, -N_dbl + n_dbl + 1, n_dbl + 2,
120-
1 - nu, (T_partials_return)1);
110+
digammaDiff = digamma(mu + nu) - digamma(alpha_dbl + beta_dbl);
111+
grad_F32(dF, one, mu, -N_dbl + n_dbl + 1, n_dbl + 2, 1 - nu, one);
121112
}
122113
if (!is_constant_all<T_size1>::value) {
123-
const T_partials_return g = -C
124-
* (digamma(mu) - digammaOne + dF[1] / F
125-
- digamma(alpha_dbl) + digammaTwo);
126-
ops_partials.edge1_.partials_[i] -= g / Pi;
114+
ops_partials.edge1_.partials_[i]
115+
+= digamma(mu) - digamma(alpha_dbl) - digammaDiff + dF[1] / F;
127116
}
128117
if (!is_constant_all<T_size2>::value) {
129-
const T_partials_return g = -C
130-
* (digamma(nu) - digammaOne - dF[4] / F
131-
- digamma(beta_dbl) + digammaTwo);
132-
ops_partials.edge2_.partials_[i] -= g / Pi;
118+
ops_partials.edge2_.partials_[i]
119+
+= digamma(nu) - digamma(beta_dbl) - digammaDiff - dF[4] / F;
133120
}
134121
}
135122

stan/math/prim/prob/beta_binomial_lcdf.hpp

+16-26
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <stan/math/prim/fun/exp.hpp>
1010
#include <stan/math/prim/fun/F32.hpp>
1111
#include <stan/math/prim/fun/grad_F32.hpp>
12-
#include <stan/math/prim/fun/lgamma.hpp>
12+
#include <stan/math/prim/fun/lbeta.hpp>
1313
#include <stan/math/prim/fun/log.hpp>
1414
#include <stan/math/prim/fun/max_size.hpp>
1515
#include <stan/math/prim/fun/size.hpp>
@@ -89,46 +89,36 @@ return_type_t<T_size1, T_size2> beta_binomial_lcdf(const T_n& n, const T_N& N,
8989
const T_partials_return N_dbl = value_of(N_vec[i]);
9090
const T_partials_return alpha_dbl = value_of(alpha_vec[i]);
9191
const T_partials_return beta_dbl = value_of(beta_vec[i]);
92-
92+
const T_partials_return N_minus_n = N_dbl - n_dbl;
9393
const T_partials_return mu = alpha_dbl + n_dbl + 1;
94-
const T_partials_return nu = beta_dbl + N_dbl - n_dbl - 1;
95-
96-
T_partials_return F;
97-
F = F32((T_partials_return)1, mu, -N_dbl + n_dbl + 1, n_dbl + 2, 1 - nu,
98-
(T_partials_return)1);
99-
100-
T_partials_return C = lgamma(nu) - lgamma(N_dbl - n_dbl);
101-
C += lgamma(mu) - lgamma(n_dbl + 2);
102-
C += lgamma(N_dbl + 2) - lgamma(N_dbl + alpha_dbl + beta_dbl);
103-
C = exp(C);
94+
const T_partials_return nu = beta_dbl + N_minus_n - 1;
95+
const T_partials_return one = 1;
10496

105-
C *= F / stan::math::beta(alpha_dbl, beta_dbl);
106-
C /= N_dbl + 1;
97+
const T_partials_return F
98+
= F32(one, mu, 1 - N_minus_n, n_dbl + 2, 1 - nu, one);
99+
T_partials_return C = lbeta(nu, mu) - lbeta(alpha_dbl, beta_dbl)
100+
- lbeta(N_minus_n, n_dbl + 2);
101+
C = F * exp(C) / (N_dbl + 1);
107102

108103
const T_partials_return Pi = 1 - C;
109104

110105
P += log(Pi);
111106

112107
T_partials_return dF[6];
113-
T_partials_return digammaOne = 0;
114-
T_partials_return digammaTwo = 0;
108+
T_partials_return digammaDiff = 0;
115109

116110
if (!is_constant_all<T_size1, T_size2>::value) {
117-
digammaOne = digamma(mu + nu);
118-
digammaTwo = digamma(alpha_dbl + beta_dbl);
119-
grad_F32(dF, (T_partials_return)1, mu, -N_dbl + n_dbl + 1, n_dbl + 2,
120-
1 - nu, (T_partials_return)1);
111+
digammaDiff = digamma(mu + nu) - digamma(alpha_dbl + beta_dbl);
112+
grad_F32(dF, one, mu, 1 - N_minus_n, n_dbl + 2, 1 - nu, one);
121113
}
122114
if (!is_constant_all<T_size1>::value) {
123-
const T_partials_return g = -C
124-
* (digamma(mu) - digammaOne + dF[1] / F
125-
- digamma(alpha_dbl) + digammaTwo);
115+
const T_partials_return g
116+
= -C * (digamma(mu) - digamma(alpha_dbl) - digammaDiff + dF[1] / F);
126117
ops_partials.edge1_.partials_[i] += g / Pi;
127118
}
128119
if (!is_constant_all<T_size2>::value) {
129-
const T_partials_return g = -C
130-
* (digamma(nu) - digammaOne - dF[4] / F
131-
- digamma(beta_dbl) + digammaTwo);
120+
const T_partials_return g
121+
= -C * (digamma(nu) - digamma(beta_dbl) - digammaDiff - dF[4] / F);
132122
ops_partials.edge2_.partials_[i] += g / Pi;
133123
}
134124
}

0 commit comments

Comments
 (0)