11
11
#include < stan/math/prim/fun/max_size.hpp>
12
12
#include < stan/math/prim/fun/size.hpp>
13
13
#include < stan/math/prim/fun/size_zero.hpp>
14
+ #include < stan/math/prim/fun/to_ref.hpp>
14
15
#include < stan/math/prim/fun/value_of.hpp>
15
16
#include < stan/math/prim/functor/operands_and_partials.hpp>
16
17
@@ -34,75 +35,65 @@ template <bool propto, typename T_n, typename T_N, typename T_prob>
34
35
return_type_t <T_prob> binomial_logit_lpmf (const T_n& n, const T_N& N,
35
36
const T_prob& alpha) {
36
37
using T_partials_return = partials_return_t <T_n, T_N, T_prob>;
37
- using std::log ;
38
+ using T_n_ref = ref_type_if_t <!is_constant<T_n>::value, T_n>;
39
+ using T_N_ref = ref_type_if_t <!is_constant<T_N>::value, T_N>;
40
+ using T_alpha_ref = ref_type_if_t <!is_constant<T_prob>::value, T_prob>;
38
41
static const char * function = " binomial_logit_lpmf" ;
39
- check_bounded (function, " Successes variable" , n, 0 , N);
40
- check_nonnegative (function, " Population size parameter" , N);
41
- check_finite (function, " Probability parameter" , alpha);
42
42
check_consistent_sizes (function, " Successes variable" , n,
43
43
" Population size parameter" , N,
44
44
" Probability parameter" , alpha);
45
45
46
- if (size_zero (n, N, alpha)) {
47
- return 0.0 ;
48
- }
49
- if (!include_summand<propto, T_prob>::value) {
50
- return 0.0 ;
51
- }
46
+ T_n_ref n_ref = n;
47
+ T_N_ref N_ref = N;
48
+ T_alpha_ref alpha_ref = alpha;
52
49
53
- T_partials_return logp = 0 ;
54
- operands_and_partials<T_prob> ops_partials (alpha);
50
+ const auto & n_col = as_column_vector_or_scalar (n_ref);
51
+ const auto & N_col = as_column_vector_or_scalar (N_ref);
52
+ const auto & alpha_col = as_column_vector_or_scalar (alpha_ref);
55
53
56
- scalar_seq_view<T_n> n_vec (n);
57
- scalar_seq_view<T_N> N_vec (N);
58
- scalar_seq_view<T_prob> alpha_vec (alpha);
59
- size_t size_alpha = stan::math::size (alpha);
60
- size_t max_size_seq_view = max_size (n, N, alpha);
54
+ const auto & n_arr = as_array_or_scalar (n_col);
55
+ const auto & N_arr = as_array_or_scalar (N_col);
56
+ const auto & alpha_arr = as_array_or_scalar (alpha_col);
61
57
62
- if (include_summand<propto>::value) {
63
- for (size_t i = 0 ; i < max_size_seq_view; ++i) {
64
- logp += binomial_coefficient_log (N_vec[i], n_vec[i]);
65
- }
66
- }
58
+ ref_type_t <decltype (value_of (n_arr))> n_val = value_of (n_arr);
59
+ ref_type_t <decltype (value_of (N_arr))> N_val = value_of (N_arr);
60
+ ref_type_t <decltype (value_of (alpha_arr))> alpha_val = value_of (alpha_arr);
67
61
68
- VectorBuilder<true , T_partials_return, T_prob> inv_logit_alpha (size_alpha);
69
- VectorBuilder<true , T_partials_return, T_prob> inv_logit_neg_alpha (
70
- size_alpha);
71
- VectorBuilder<true , T_partials_return, T_prob> log_inv_logit_alpha (
72
- size_alpha);
73
- VectorBuilder<true , T_partials_return, T_prob> log_inv_logit_neg_alpha (
74
- size_alpha);
62
+ check_bounded (function, " Successes variable" , n_val, 0 , N_val);
63
+ check_nonnegative (function, " Population size parameter" , N_val);
64
+ check_finite (function, " Probability parameter" , alpha_val);
75
65
76
- for (size_t i = 0 ; i < size_alpha; ++i) {
77
- const T_partials_return alpha_dbl = value_of (alpha_vec[i]);
78
- inv_logit_alpha[i] = inv_logit (alpha_dbl);
79
- inv_logit_neg_alpha[i] = inv_logit (-alpha_dbl);
80
- log_inv_logit_alpha[i] = log (inv_logit_alpha[i]);
81
- log_inv_logit_neg_alpha[i] = log (inv_logit_neg_alpha[i]);
66
+ if (size_zero (n, N, alpha)) {
67
+ return 0.0 ;
68
+ }
69
+ if (!include_summand<propto, T_prob>::value) {
70
+ return 0.0 ;
82
71
}
72
+ const auto & inv_logit_alpha
73
+ = to_ref_if<!is_constant_all<T_prob>::value>(inv_logit (alpha_val));
74
+ const auto & inv_logit_neg_alpha
75
+ = to_ref_if<!is_constant_all<T_prob>::value>(inv_logit (-alpha_val));
83
76
84
- for (size_t i = 0 ; i < max_size_seq_view; ++i) {
85
- logp += n_vec[i] * log_inv_logit_alpha[i]
86
- + (N_vec[i] - n_vec[i]) * log_inv_logit_neg_alpha[i];
77
+ size_t maximum_size = max_size (n, N, alpha);
78
+ const auto & log_inv_logit_alpha = log (inv_logit_alpha);
79
+ const auto & log_inv_logit_neg_alpha = log (inv_logit_neg_alpha);
80
+ T_partials_return logp = sum (n_val * log_inv_logit_alpha
81
+ + (N_val - n_val) * log_inv_logit_neg_alpha);
82
+ if (include_summand<propto, T_n, T_N>::value) {
83
+ logp += sum (binomial_coefficient_log (N_val, n_val)) * maximum_size
84
+ / max_size (n, N);
87
85
}
88
86
87
+ operands_and_partials<T_alpha_ref> ops_partials (alpha_ref);
89
88
if (!is_constant_all<T_prob>::value) {
90
- if (size_alpha == 1 ) {
91
- T_partials_return sum_n = 0 ;
92
- T_partials_return sum_N = 0 ;
93
- for (size_t i = 0 ; i < max_size_seq_view; ++i) {
94
- sum_n += n_vec[i];
95
- sum_N += N_vec[i];
96
- }
97
- ops_partials.edge1_ .partials_ [0 ]
98
- += sum_n * inv_logit_neg_alpha[0 ]
99
- - (sum_N - sum_n) * inv_logit_alpha[0 ];
89
+ if (is_vector<T_prob>::value) {
90
+ ops_partials.edge1_ .partials_
91
+ = n_val * inv_logit_neg_alpha - (N_val - n_val) * inv_logit_alpha;
100
92
} else {
101
- for (size_t i = 0 ; i < max_size_seq_view; ++i) {
102
- ops_partials.edge1_ .partials_ [i]
103
- += n_vec[i] * inv_logit_neg_alpha[i]
104
- - (N_vec[i] - n_vec[i]) * inv_logit_alpha[i];
105
- }
93
+ T_partials_return sum_n = sum (n_val) * maximum_size / size (n);
94
+ ops_partials.edge1_ .partials_ [0 ] = forward_as<T_partials_return>(
95
+ sum_n * inv_logit_neg_alpha
96
+ - (sum (N_val) * maximum_size / size (N) - sum_n) * inv_logit_alpha);
106
97
}
107
98
}
108
99
0 commit comments