Skip to content

Commit 5ee773f

Browse files
committed
Implement $accumulator
JAVA-3640
1 parent 67b48f5 commit 5ee773f

File tree

4 files changed

+387
-11
lines changed

4 files changed

+387
-11
lines changed

driver-core/src/main/com/mongodb/client/model/Accumulators.java

Lines changed: 135 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616

1717
package com.mongodb.client.model;
1818

19+
import com.mongodb.lang.Nullable;
20+
import org.bson.BsonArray;
21+
import org.bson.BsonDocument;
22+
import org.bson.BsonString;
23+
24+
import java.util.List;
25+
26+
import static java.util.stream.Collectors.toList;
27+
1928
/**
2029
* Builders for accumulators used in the group pipeline stage of an aggregation pipeline.
2130
*
@@ -38,7 +47,7 @@ public final class Accumulators {
3847
* @mongodb.driver.manual reference/operator/aggregation/sum/ $sum
3948
*/
4049
public static <TExpression> BsonField sum(final String fieldName, final TExpression expression) {
41-
return accumulator("$sum", fieldName, expression);
50+
return accumulatorOperator("$sum", fieldName, expression);
4251
}
4352

4453
/**
@@ -52,7 +61,7 @@ public static <TExpression> BsonField sum(final String fieldName, final TExpress
5261
* @mongodb.driver.manual reference/operator/aggregation/avg/ $avg
5362
*/
5463
public static <TExpression> BsonField avg(final String fieldName, final TExpression expression) {
55-
return accumulator("$avg", fieldName, expression);
64+
return accumulatorOperator("$avg", fieldName, expression);
5665
}
5766

5867
/**
@@ -66,7 +75,7 @@ public static <TExpression> BsonField avg(final String fieldName, final TExpress
6675
* @mongodb.driver.manual reference/operator/aggregation/first/ $first
6776
*/
6877
public static <TExpression> BsonField first(final String fieldName, final TExpression expression) {
69-
return accumulator("$first", fieldName, expression);
78+
return accumulatorOperator("$first", fieldName, expression);
7079
}
7180

7281
/**
@@ -80,7 +89,7 @@ public static <TExpression> BsonField first(final String fieldName, final TExpre
8089
* @mongodb.driver.manual reference/operator/aggregation/last/ $last
8190
*/
8291
public static <TExpression> BsonField last(final String fieldName, final TExpression expression) {
83-
return accumulator("$last", fieldName, expression);
92+
return accumulatorOperator("$last", fieldName, expression);
8493
}
8594

8695
/**
@@ -94,7 +103,7 @@ public static <TExpression> BsonField last(final String fieldName, final TExpres
94103
* @mongodb.driver.manual reference/operator/aggregation/max/ $max
95104
*/
96105
public static <TExpression> BsonField max(final String fieldName, final TExpression expression) {
97-
return accumulator("$max", fieldName, expression);
106+
return accumulatorOperator("$max", fieldName, expression);
98107
}
99108

100109
/**
@@ -108,7 +117,7 @@ public static <TExpression> BsonField max(final String fieldName, final TExpress
108117
* @mongodb.driver.manual reference/operator/aggregation/min/ $min
109118
*/
110119
public static <TExpression> BsonField min(final String fieldName, final TExpression expression) {
111-
return accumulator("$min", fieldName, expression);
120+
return accumulatorOperator("$min", fieldName, expression);
112121
}
113122

114123
/**
@@ -122,7 +131,7 @@ public static <TExpression> BsonField min(final String fieldName, final TExpress
122131
* @mongodb.driver.manual reference/operator/aggregation/push/ $push
123132
*/
124133
public static <TExpression> BsonField push(final String fieldName, final TExpression expression) {
125-
return accumulator("$push", fieldName, expression);
134+
return accumulatorOperator("$push", fieldName, expression);
126135
}
127136

128137
/**
@@ -136,7 +145,7 @@ public static <TExpression> BsonField push(final String fieldName, final TExpres
136145
* @mongodb.driver.manual reference/operator/aggregation/addToSet/ $addToSet
137146
*/
138147
public static <TExpression> BsonField addToSet(final String fieldName, final TExpression expression) {
139-
return accumulator("$addToSet", fieldName, expression);
148+
return accumulatorOperator("$addToSet", fieldName, expression);
140149
}
141150

142151
/**
@@ -155,7 +164,7 @@ public static <TExpression> BsonField addToSet(final String fieldName, final TEx
155164
* @since 3.2
156165
*/
157166
public static <TExpression> BsonField stdDevPop(final String fieldName, final TExpression expression) {
158-
return accumulator("$stdDevPop", fieldName, expression);
167+
return accumulatorOperator("$stdDevPop", fieldName, expression);
159168
}
160169

161170
/**
@@ -173,10 +182,125 @@ public static <TExpression> BsonField stdDevPop(final String fieldName, final TE
173182
* @since 3.2
174183
*/
175184
public static <TExpression> BsonField stdDevSamp(final String fieldName, final TExpression expression) {
176-
return accumulator("$stdDevSamp", fieldName, expression);
185+
return accumulatorOperator("$stdDevSamp", fieldName, expression);
186+
}
187+
188+
/**
189+
* Creates an $accumulator pipeline stage
190+
*
191+
* @param fieldName the field name
192+
* @param initFunction a function used to initialize the state
193+
* @param accumulateFunction a function used to accumulate documents
194+
* @param mergeFunction a function used to merge two internal states, e.g. accumulated on different shards or threads. It
195+
* returns the resulting state of the accumulator.
196+
* @return the $accumulator pipeline stage
197+
* @mongodb.driver.manual reference/operator/aggregation/accumulator/ $accumulator
198+
* @mongodb.server.release 4.4
199+
* @since 4.1
200+
*/
201+
public static BsonField accumulator(final String fieldName, final String initFunction, final String accumulateFunction,
202+
final String mergeFunction) {
203+
return accumulator(fieldName, initFunction, null, accumulateFunction, null, mergeFunction, null, "js");
204+
}
205+
206+
/**
207+
* Creates an $accumulator pipeline stage
208+
*
209+
* @param fieldName the field name
210+
* @param initFunction a function used to initialize the state
211+
* @param accumulateFunction a function used to accumulate documents
212+
* @param mergeFunction a function used to merge two internal states, e.g. accumulated on different shards or threads. It
213+
* returns the resulting state of the accumulator.
214+
* @param finalizeFunction a function used to finalize the state and return the result (may be null)
215+
* @return the $accumulator pipeline stage
216+
* @mongodb.driver.manual reference/operator/aggregation/accumulator/ $accumulator
217+
* @mongodb.server.release 4.4
218+
* @since 4.1
219+
*/
220+
public static BsonField accumulator(final String fieldName, final String initFunction, final String accumulateFunction,
221+
final String mergeFunction, @Nullable final String finalizeFunction) {
222+
return accumulator(fieldName, initFunction, null, accumulateFunction, null, mergeFunction, finalizeFunction, "js");
223+
}
224+
225+
/**
226+
* Creates an $accumulator pipeline stage
227+
*
228+
* @param fieldName the field name
229+
* @param initFunction a function used to initialize the state
230+
* @param initArgs init function’s arguments (may be null)
231+
* @param accumulateFunction a function used to accumulate documents
232+
* @param accumulateArgs additional accumulate function’s arguments (may be null). The first argument to the function
233+
* is ‘state’.
234+
* @param mergeFunction a function used to merge two internal states, e.g. accumulated on different shards or threads. It
235+
* returns the resulting state of the accumulator.
236+
* @param finalizeFunction a function used to finalize the state and return the result (may be null)
237+
* @return the $accumulator pipeline stage
238+
* @mongodb.driver.manual reference/operator/aggregation/accumulator/ $accumulator
239+
* @mongodb.server.release 4.4
240+
* @since 4.1
241+
*/
242+
public static BsonField accumulator(final String fieldName, final String initFunction, @Nullable final List<String> initArgs,
243+
final String accumulateFunction, @Nullable final List<String> accumulateArgs,
244+
final String mergeFunction, @Nullable final String finalizeFunction) {
245+
return accumulator(fieldName, initFunction, initArgs, accumulateFunction, accumulateArgs, mergeFunction, finalizeFunction, "js");
246+
}
247+
248+
/**
249+
* Creates an $accumulator pipeline stage
250+
*
251+
* @param fieldName the field name
252+
* @param initFunction a function used to initialize the state
253+
* @param accumulateFunction a function used to accumulate documents
254+
* @param mergeFunction a function used to merge two internal states, e.g. accumulated on different shards or threads. It
255+
* returns the resulting state of the accumulator.
256+
* @param finalizeFunction a function used to finalize the state and return the result (may be null)
257+
* @param lang a language specifier
258+
* @return the $accumulator pipeline stage
259+
* @mongodb.driver.manual reference/operator/aggregation/accumulator/ $accumulator
260+
* @mongodb.server.release 4.4
261+
* @since 4.1
262+
*/
263+
public static BsonField accumulator(final String fieldName, final String initFunction, final String accumulateFunction,
264+
final String mergeFunction, @Nullable final String finalizeFunction, final String lang) {
265+
return accumulator(fieldName, initFunction, null, accumulateFunction, null, mergeFunction, finalizeFunction, lang);
266+
}
267+
268+
/**
269+
* Creates an $accumulator pipeline stage
270+
*
271+
* @param fieldName the field name
272+
* @param initFunction a function used to initialize the state
273+
* @param initArgs init function’s arguments (may be null)
274+
* @param accumulateFunction a function used to accumulate documents
275+
* @param accumulateArgs additional accumulate function’s arguments (may be null). The first argument to the function
276+
* is ‘state’.
277+
* @param mergeFunction a function used to merge two internal states, e.g. accumulated on different shards or threads. It
278+
* returns the resulting state of the accumulator.
279+
* @param finalizeFunction a function used to finalize the state and return the result (may be null)
280+
* @param lang a language specifier
281+
* @return the $accumulator pipeline stage
282+
* @mongodb.driver.manual reference/operator/aggregation/accumulator/ $accumulator
283+
* @mongodb.server.release 4.4
284+
* @since 4.1
285+
*/
286+
public static BsonField accumulator(final String fieldName, final String initFunction, @Nullable final List<String> initArgs,
287+
final String accumulateFunction, @Nullable final List<String> accumulateArgs,
288+
final String mergeFunction, @Nullable final String finalizeFunction, final String lang) {
289+
BsonDocument accumulatorStage = new BsonDocument("init", new BsonString(initFunction))
290+
.append("initArgs", initArgs != null ? new BsonArray(initArgs.stream().map(initArg ->
291+
new BsonString(initArg)).collect(toList())) : new BsonArray())
292+
.append("accumulate", new BsonString(accumulateFunction))
293+
.append("accumulateArgs", accumulateArgs != null ? new BsonArray(accumulateArgs.stream().map(accumulateArg ->
294+
new BsonString(accumulateArg)).collect(toList())) : new BsonArray())
295+
.append("merge", new BsonString(mergeFunction))
296+
.append("lang", new BsonString(lang));
297+
if (finalizeFunction != null) {
298+
accumulatorStage.append("finalize", new BsonString(finalizeFunction));
299+
}
300+
return accumulatorOperator("$accumulator", fieldName, accumulatorStage);
177301
}
178302

179-
private static <TExpression> BsonField accumulator(final String name, final String fieldName, final TExpression expression) {
303+
private static <TExpression> BsonField accumulatorOperator(final String name, final String fieldName, final TExpression expression) {
180304
return new BsonField(fieldName, new SimpleExpression<TExpression>(name, expression));
181305
}
182306

driver-core/src/test/functional/com/mongodb/client/model/AggregatesFunctionalSpecification.groovy

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.bson.conversions.Bson
2626
import spock.lang.IgnoreIf
2727

2828
import static com.mongodb.ClusterFixture.serverVersionAtLeast
29+
import static com.mongodb.client.model.Accumulators.accumulator
2930
import static com.mongodb.client.model.Accumulators.addToSet
3031
import static com.mongodb.client.model.Accumulators.avg
3132
import static com.mongodb.client.model.Accumulators.first
@@ -787,6 +788,50 @@ class AggregatesFunctionalSpecification extends OperationFunctionalSpecification
787788
helper?.drop()
788789
}
789790

791+
@IgnoreIf({ !serverVersionAtLeast(4, 3) })
792+
def '$accumulator'() {
793+
given:
794+
def helper = getCollectionHelper()
795+
796+
when:
797+
helper.drop()
798+
helper.insertDocuments(Document.parse('{_id: 1, x: "string"}'))
799+
def init = 'function() { return { x: "test string" } }'
800+
def accumulate = 'function(state) { return state }'
801+
def merge = 'function(state1, state2) { return state1 }'
802+
def accumulatorExpr = accumulator('testString', init, accumulate, merge);
803+
def results1 = helper.aggregate([group('$x', asList(accumulatorExpr))])
804+
805+
then:
806+
results1.size() == 1
807+
results1.contains(Document.parse('{ _id: "string", testString: { x: "test string" } }'))
808+
809+
when:
810+
helper.drop()
811+
helper.insertDocuments(Document.parse('{_id: 8751, title: "The Banquet", author: "Dante", copies: 2}'),
812+
Document.parse('{_id: 8752, title: "Divine Comedy", author: "Dante", copies: 1}'),
813+
Document.parse('{_id: 8645, title: "Eclogues", author: "Dante", copies: 2}'),
814+
Document.parse('{_id: 7000, title: "The Odyssey", author: "Homer", copies: 10}'),
815+
Document.parse('{_id: 7020, title: "Iliad", author: "Homer", copies: 10}'))
816+
def initFunction = 'function(initCount, initSum) { return { count: parseInt(initCount), sum: parseInt(initSum) } }';
817+
def accumulateFunction = 'function(state, numCopies) { return { count : state.count + 1, sum : state.sum + numCopies } }';
818+
def mergeFunction = 'function(state1, state2) { return { count : state1.count + state2.count, sum : state1.sum + state2.sum } }';
819+
def finalizeFunction = 'function(state) { return (state.sum / state.count) }';
820+
def accumulatorExpression = accumulator('avgCopies', initFunction, [ '0', '0' ], accumulateFunction,
821+
[ '$copies' ], mergeFunction, finalizeFunction)
822+
def results2 = helper.aggregate([group('$author', asList(
823+
new BsonField('minCopies', new Document('$min', '$copies')), accumulatorExpression,
824+
new BsonField('maxCopies', new Document('$max', '$copies'))))])
825+
826+
then:
827+
results2.size() == 2
828+
results2.contains(Document.parse('{_id: "Dante", minCopies: 1, avgCopies: 1.6666666666666667, maxCopies : 2}'))
829+
results2.contains(Document.parse('{_id: "Homer", minCopies: 10, avgCopies: 10.0, maxCopies : 10}'))
830+
831+
cleanup:
832+
helper?.drop()
833+
}
834+
790835
@IgnoreIf({ !serverVersionAtLeast(3, 4) })
791836
def '$addFields'() {
792837
given:

0 commit comments

Comments
 (0)