Skip to content

Commit 6b1a5af

Browse files
authored
Merge pull request #19 from stan-dev/bugfix/reduce_sum_arg_reorder
Reorder the reduce_sum arguments (this has been approved and implemented which is why I'm merging it)
2 parents 12d7f0c + 66a5c74 commit 6b1a5af

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

designs/0017-reduce_sum.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ real reduce_sum(F func, T[] x, int grainsize, T1 s1, T2 s2, ...)
7474
The user-defined partial sum functions have the signature:
7575

7676
```
77-
real func(int start, int end, T[] x_subset, T1 arg1, T2 arg2, ...)
77+
real func(T[] x_subset, int start, int end, T1 arg1, T2 arg2, ...)
7878
```
7979

8080
and take the arguments:
81-
1. ```start``` - An integer specifying the first term in the partial sum
82-
2. ```end``` - An integer specifying the last term in the partial sum (inclusive)
83-
3. ```x_subset``` - The subset of ```x``` (from ```reduce_sum```) for which this partial sum is responsible (```x[start:end]```)
81+
1. ```x_subset``` - The subset of ```x``` (from ```reduce_sum```) for which this partial sum is responsible (```x[start:end]```)
82+
2. ```start``` - An integer specifying the first term in the partial sum
83+
3. ```end``` - An integer specifying the last term in the partial sum (inclusive)
8484
4-. ```arg1, arg2, ...``` Arguments shared in every term (passed on without modification from the reduce_sum call)
8585

8686
The user-provided function ```func``` is expect to compute the ```start``` through ```end``` terms of the overall sum, accumulate them, and return that value. The user function is passed the subset ```x[start:end]``` as ```x_subset```. ```start``` and ```end``` are passed so that ```func``` can index any of the tailing ```sM``` arguments as necessary. The trailing ```sM``` arguments are passed without modification to every call of ```func```.
@@ -94,15 +94,15 @@ real sum = reduce_sum(func, x, grainsize, s1, s2, ...)
9494
can be replaced by either:
9595

9696
```
97-
real sum = func(1, size(x), x, s1, s2, ...)
97+
real sum = func(x, 1, size(x), s1, s2, ...)
9898
```
9999

100100
or the code:
101101

102102
```
103103
real sum = 0.0;
104104
for(i in 1:size(x)) {
105-
sum = sum + func(i, i, { x[i] }, s1, s2, ...);
105+
sum = sum + func({ x[i] }, i, i, s1, s2, ...);
106106
}
107107
```
108108

@@ -144,10 +144,10 @@ updating the model block to use ```reduce_sum``` gives:
144144

145145
```
146146
functions {
147-
real partial_sum(int start, int end,
148-
int[] y_subset,
149-
vector x,
150-
vector beta) {
147+
real partial_sum(int[] y_subset,
148+
int start, int end,
149+
vector x,
150+
vector beta) {
151151
return bernoulli_logit_lpmf(y_subset | beta[1] + beta[2] * x[start:end]);
152152
}
153153
}
@@ -165,7 +165,7 @@ parameters {
165165
model {
166166
int grainsize = 100;
167167
beta ~ std_normal();
168-
target += reduce_sum(reducer_func, y,
168+
target += reduce_sum(partial_sum, y,
169169
grainsize,
170170
x, beta);
171171
}
@@ -233,7 +233,7 @@ real sum = reduce_sum(func, x, s1, s2, ...)
233233
can be replaced by:
234234

235235
```
236-
real sum = func(1, size(x), x, s1, s2, ...)
236+
real sum = func(x, 1, size(x), s1, s2, ...)
237237
```
238238

239239
where ```func``` was always provided by the user.

0 commit comments

Comments
 (0)