Skip to content

Commit 6db1e57

Browse files
Merge pull request #1497 from martinmodrak/bugfix/1496-poisson-phi-cutoff
Fixing negative binomial phi cutoff
2 parents 57e469a + 4b2f032 commit 6db1e57

File tree

6 files changed

+632
-77
lines changed

6 files changed

+632
-77
lines changed

stan/math/opencl/kernel_generator/load.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class load_
4949
* Creates a deep copy of this expression.
5050
* @return copy of \c *this
5151
*/
52-
inline load_<T&> deep_copy() const & { return load_<T&>(a_); }
52+
inline load_<T&> deep_copy() const& { return load_<T&>(a_); }
5353
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }
5454

5555
/**

stan/math/prim/prob/neg_binomial_2_lpmf.hpp

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
67
#include <stan/math/prim/fun/digamma.hpp>
7-
#include <stan/math/prim/fun/lgamma.hpp>
88
#include <stan/math/prim/fun/log.hpp>
99
#include <stan/math/prim/fun/max_size.hpp>
1010
#include <stan/math/prim/fun/multiply_log.hpp>
1111
#include <stan/math/prim/fun/size.hpp>
1212
#include <stan/math/prim/fun/size_zero.hpp>
1313
#include <stan/math/prim/fun/value_of.hpp>
14-
#include <stan/math/prim/prob/poisson_lpmf.hpp>
1514
#include <cmath>
1615

1716
namespace stan {
@@ -47,7 +46,7 @@ return_type_t<T_location, T_precision> neg_binomial_2_lpmf(
4746
size_t size_phi = stan::math::size(phi);
4847
size_t size_mu_phi = max_size(mu, phi);
4948
size_t size_n_phi = max_size(n, phi);
50-
size_t max_size_seq_view = max_size(n, mu, phi);
49+
size_t size_all = max_size(n, mu, phi);
5150

5251
VectorBuilder<true, T_partials_return, T_location> mu_val(size_mu);
5352
for (size_t i = 0; i < size_mu; ++i) {
@@ -76,39 +75,30 @@ return_type_t<T_location, T_precision> neg_binomial_2_lpmf(
7675
n_plus_phi[i] = n_vec[i] + phi_val[i];
7776
}
7877

79-
for (size_t i = 0; i < max_size_seq_view; i++) {
80-
// if phi is large we probably overflow, defer to Poisson:
81-
if (phi_val[i] > 1e5) {
82-
// TODO(martinmodrak) This is wrong (doesn't pass propto information),
83-
// and inaccurate for n = 0, but shouldn't break most models.
84-
// Also the 1e5 cutoff is too small.
85-
// Will be addressed better in PR #1497
86-
logp += poisson_lpmf(n_vec[i], mu_val[i]);
87-
} else {
88-
if (include_summand<propto>::value) {
89-
logp -= lgamma(n_vec[i] + 1.0);
90-
}
91-
if (include_summand<propto, T_precision>::value) {
92-
logp += multiply_log(phi_val[i], phi_val[i]) - lgamma(phi_val[i]);
93-
}
94-
if (include_summand<propto, T_location>::value) {
95-
logp += multiply_log(n_vec[i], mu_val[i]);
96-
}
97-
if (include_summand<propto, T_precision>::value) {
98-
logp += lgamma(n_plus_phi[i]);
99-
}
100-
logp -= n_plus_phi[i] * log_mu_plus_phi[i];
78+
for (size_t i = 0; i < size_all; i++) {
79+
if (include_summand<propto, T_precision>::value) {
80+
logp += binomial_coefficient_log(n_plus_phi[i] - 1, n_vec[i]);
81+
}
82+
if (include_summand<propto, T_location>::value) {
83+
logp += multiply_log(n_vec[i], mu_val[i]);
10184
}
85+
logp += -phi_val[i] * (log1p(mu_val[i] / phi_val[i]))
86+
- n_vec[i] * log_mu_plus_phi[i];
10287

10388
if (!is_constant_all<T_location>::value) {
10489
ops_partials.edge1_.partials_[i]
105-
+= n_vec[i] / mu_val[i] - n_plus_phi[i] / mu_plus_phi[i];
90+
+= n_vec[i] / mu_val[i] - (n_vec[i] + phi_val[i]) / (mu_plus_phi[i]);
10691
}
10792
if (!is_constant_all<T_precision>::value) {
108-
ops_partials.edge2_.partials_[i] += 1.0 - n_plus_phi[i] / mu_plus_phi[i]
109-
+ log_phi[i] - log_mu_plus_phi[i]
110-
- digamma(phi_val[i])
111-
+ digamma(n_plus_phi[i]);
93+
T_partials_return log_term;
94+
if (mu_val[i] < phi_val[i]) {
95+
log_term = log1p(-mu_val[i] / (mu_plus_phi[i]));
96+
} else {
97+
log_term = log_phi[i] - log_mu_plus_phi[i];
98+
}
99+
ops_partials.edge2_.partials_[i]
100+
+= (mu_val[i] - n_vec[i]) / (mu_plus_phi[i]) + log_term
101+
- (digamma(phi_val[i]) - digamma(n_plus_phi[i]));
112102
}
113103
}
114104
return ops_partials.build(logp);

test/unit/math/prim/prob/neg_binomial_2_log_test.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,10 @@ TEST(ProbNegBinomial2, log_matches_lpmf) {
212212
TEST(ProbDistributionsNegBinomial2Log, neg_binomial_2_log_grid_test) {
213213
std::vector<double> mu_log_to_test
214214
= {-101, -27, -3, -1, -0.132, 0, 4, 10, 87};
215-
std::vector<double> phi_to_test = {2e-5, 0.36, 1, 2.3e5, 1.8e10, 6e16};
215+
// TODO(martinmodrak) Reducing the span of the test, should be fixed
216+
// along with #1495
217+
// std::vector<double> phi_to_test = {2e-5, 0.36, 1, 10, 2.3e5, 1.8e10, 6e16};
218+
std::vector<double> phi_to_test = {0.36, 1, 10};
216219
std::vector<int> n_to_test = {0, 1, 10, 39, 101, 3048, 150054};
217220

218221
// TODO(martinmdorak) Only weak tolerance for this quick fix

test/unit/math/prim/prob/neg_binomial_2_test.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <test/unit/math/prim/prob/vector_rng_test_helper.hpp>
33
#include <test/unit/math/prim/prob/NegativeBinomial2LogTestRig.hpp>
44
#include <test/unit/math/prim/prob/VectorIntRNGTestRig.hpp>
5+
#include <test/unit/math/expect_near_rel.hpp>
56
#include <gtest/gtest.h>
67
#include <boost/random/mersenne_twister.hpp>
78
#include <boost/math/distributions.hpp>
@@ -238,27 +239,40 @@ TEST(ProbDistributionsNegBinomial2, chiSquareGoodnessFitTest4) {
238239
}
239240

240241
TEST(ProbDistributionsNegBinomial2, extreme_values) {
241-
int N = 100;
242-
double mu = 8;
243-
double phi = 1e12;
244-
for (int n = 0; n < 10; ++n) {
245-
phi *= 10;
246-
double logp = stan::math::neg_binomial_2_log<false>(N, mu, phi);
247-
EXPECT_LT(logp, 0);
242+
std::vector<int> n_to_test = {0, 1, 5, 100, 12985, 1968422};
243+
std::vector<double> mu_to_test = {1e-5, 0.1, 8, 713, 28311, 19850054};
244+
for (double mu : mu_to_test) {
245+
for (int n : n_to_test) {
246+
// Test across a range of phi
247+
for (double phi = 1e12; phi < 1e22; phi *= 10) {
248+
double logp = stan::math::neg_binomial_2_log<false>(n, mu, phi);
249+
EXPECT_LT(logp, 0) << "n = " << n << ", mu = " << mu
250+
<< ", phi = " << phi;
251+
}
252+
}
248253
}
249254
}
250255

251-
TEST(ProbDistributionsNegBinomial2, vectorAroundCutoff) {
252-
int y = 10;
253-
double mu = 9.36;
254-
std::vector<double> phi;
255-
phi.push_back(1);
256-
phi.push_back(1e15);
257-
double vector_value = stan::math::neg_binomial_2_lpmf(y, mu, phi);
258-
double scalar_value = stan::math::neg_binomial_2_lpmf(y, mu, phi[0])
259-
+ stan::math::neg_binomial_2_lpmf(y, mu, phi[1]);
260-
261-
EXPECT_FLOAT_EQ(vector_value, scalar_value);
256+
TEST(ProbDistributionsNegBinomial2, zeroOne) {
257+
using stan::test::expect_near_rel;
258+
259+
std::vector<double> mu_to_test = {2.345e-5, 0.2, 13, 150, 1621, 18432, 1e10};
260+
double phi_start = 1e-8;
261+
double phi_max = 1e22;
262+
for (double mu : mu_to_test) {
263+
for (double phi = phi_start; phi < phi_max; phi *= stan::math::pi()) {
264+
std::stringstream msg;
265+
msg << ", mu = " << mu << ", phi = " << phi;
266+
267+
double expected_value_0 = phi * (-log1p(mu / phi));
268+
double value_0 = stan::math::neg_binomial_2_lpmf(0, mu, phi);
269+
expect_near_rel("n = 0 " + msg.str(), value_0, expected_value_0);
270+
271+
double expected_value_1 = (phi + 1) * (-log1p(mu / phi)) + log(mu);
272+
double value_1 = stan::math::neg_binomial_2_lpmf(1, mu, phi);
273+
expect_near_rel("n = 1 " + msg.str(), value_1, expected_value_1);
274+
}
275+
}
262276
}
263277

264278
TEST(ProbDistributionsNegativeBinomial2Log, distributionCheck) {

test/unit/math/rev/fun/lbeta_test.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,6 @@ TEST(MathFunctions, lbeta_identities_gradient) {
7474
// Successors: beta(a,b) = beta(a + 1, b) + beta(a, b + 1)
7575
for (double x : to_test) {
7676
for (double y : to_test) {
77-
// TODO(martinmodrak) this restriction on testing should be lifted once
78-
// the log_sum_exp bug (#1679) is resolved
79-
if (x > 1e10 || y > 1e10) {
80-
continue;
81-
}
8277
auto rh = [](const var& a, const var& b) {
8378
return stan::math::log_sum_exp(lbeta(a + 1, b), lbeta(a, b + 1));
8479
};

0 commit comments

Comments
 (0)