19
19
using MongoDB . Driver . Linq . Linq3Implementation . Ast . Expressions ;
20
20
using MongoDB . Driver . Linq . Linq3Implementation . Misc ;
21
21
using MongoDB . Driver . Linq . Linq3Implementation . Reflection ;
22
+ using MongoDB . Driver . Linq . Linq3Implementation . Serializers ;
22
23
23
24
namespace MongoDB . Driver . Linq . Linq3Implementation . Translators . ExpressionToAggregationExpressionTranslators . MethodTranslators
24
25
{
@@ -34,39 +35,46 @@ internal static class AggregateMethodToAggregationExpressionTranslator
34
35
QueryableMethod . AggregateWithSeedFuncAndResultSelector
35
36
} ;
36
37
37
- private static readonly MethodInfo [ ] __aggregateWithoutSeedMethods =
38
+ private static readonly MethodInfo [ ] __aggregateWithFuncMethods =
38
39
{
39
40
EnumerableMethod . AggregateWithFunc ,
40
41
QueryableMethod . AggregateWithFunc
41
42
} ;
42
43
43
- private static readonly MethodInfo [ ] __aggregateWithSeedMethods =
44
+ private static readonly MethodInfo [ ] __aggregateWithSeedAndFuncMethods =
44
45
{
45
46
EnumerableMethod . AggregateWithSeedAndFunc ,
47
+ QueryableMethod . AggregateWithSeedAndFunc
48
+ } ;
49
+
50
+ private static readonly MethodInfo [ ] __aggregateWithSeedAndFuncAndResultSelectorMethods =
51
+ {
46
52
EnumerableMethod . AggregateWithSeedFuncAndResultSelector ,
47
- QueryableMethod . AggregateWithSeedAndFunc ,
48
53
QueryableMethod . AggregateWithSeedFuncAndResultSelector
49
54
} ;
50
55
51
- private static readonly MethodInfo [ ] __aggregateWithSeedFuncAndResultSelectorMethods =
52
- {
56
+ private static readonly MethodInfo [ ] __aggregateIncludingSeedMethods =
57
+ {
58
+ EnumerableMethod . AggregateWithSeedAndFunc ,
53
59
EnumerableMethod . AggregateWithSeedFuncAndResultSelector ,
60
+ QueryableMethod . AggregateWithSeedAndFunc ,
54
61
QueryableMethod . AggregateWithSeedFuncAndResultSelector
55
62
} ;
56
63
57
- public static AggregationExpression Translate ( TranslationContext context , MethodCallExpression expression )
64
+ public static AggregationExpression Translate ( TranslationContext context , MethodCallExpression expression , IBsonSerializer targetSerializer )
58
65
{
59
66
var method = expression . Method ;
60
67
var arguments = expression . Arguments ;
61
68
62
69
if ( method . IsOneOf ( __aggregateMethods ) )
63
70
{
64
71
var sourceExpression = arguments [ 0 ] ;
65
- var sourceTranslation = ExpressionToAggregationExpressionTranslator . TranslateEnumerable ( context , sourceExpression ) ;
72
+ var sourceTargetSerializer = GetSourceTargetSerializer ( method , targetSerializer ) ;
73
+ var sourceTranslation = ExpressionToAggregationExpressionTranslator . TranslateEnumerable ( context , sourceExpression , sourceTargetSerializer ) ;
66
74
NestedAsQueryableHelper . EnsureQueryableMethodHasNestedAsQueryableSource ( expression , sourceTranslation ) ;
67
75
var itemSerializer = ArraySerializerHelper . GetItemSerializer ( sourceTranslation . Serializer ) ;
68
76
69
- if ( method . IsOneOf ( __aggregateWithoutSeedMethods ) )
77
+ if ( method . IsOneOf ( __aggregateWithFuncMethods ) )
70
78
{
71
79
var funcLambda = ExpressionHelper . UnquoteLambdaIfQueryableMethod ( method , arguments [ 1 ] ) ;
72
80
var funcParameters = funcLambda . Parameters ;
@@ -75,7 +83,8 @@ public static AggregationExpression Translate(TranslationContext context, Method
75
83
var itemParameter = funcParameters [ 1 ] ;
76
84
var itemSymbol = context . CreateSymbolWithVarName ( itemParameter , varName : "this" , itemSerializer ) ; // note: MQL uses $$this for the item being processed
77
85
var funcContext = context . WithSymbols ( accumulatorSymbol , itemSymbol ) ;
78
- var funcTranslation = ExpressionToAggregationExpressionTranslator . Translate ( funcContext , funcLambda . Body ) ;
86
+ var funcTargetSerializer = GetFuncTargetSerializer ( method , targetSerializer ) ;
87
+ var funcTranslation = ExpressionToAggregationExpressionTranslator . Translate ( funcContext , funcLambda . Body , funcTargetSerializer ) ;
79
88
80
89
var ( sourceVarBinding , sourceAst ) = AstExpression . UseVarIfNotSimple ( "source" , sourceTranslation . Ast ) ;
81
90
var seedVar = AstExpression . Var ( "seed" ) ;
@@ -95,10 +104,11 @@ public static AggregationExpression Translate(TranslationContext context, Method
95
104
96
105
return new AggregationExpression ( expression , ast , itemSerializer ) ;
97
106
}
98
- else if ( method . IsOneOf ( __aggregateWithSeedMethods ) )
107
+ else if ( method . IsOneOf ( __aggregateIncludingSeedMethods ) )
99
108
{
100
109
var seedExpression = arguments [ 1 ] ;
101
- var seedTranslation = ExpressionToAggregationExpressionTranslator . Translate ( context , seedExpression ) ;
110
+ var seedTargetSerializer = GetSeedTargetSerializer ( method , targetSerializer ) ;
111
+ var seedTranslation = ExpressionToAggregationExpressionTranslator . Translate ( context , seedExpression , seedTargetSerializer ) ;
102
112
103
113
var funcLambda = ExpressionHelper . UnquoteLambdaIfQueryableMethod ( method , arguments [ 2 ] ) ;
104
114
var funcParameters = funcLambda . Parameters ;
@@ -108,21 +118,23 @@ public static AggregationExpression Translate(TranslationContext context, Method
108
118
var itemParameter = funcParameters [ 1 ] ;
109
119
var itemSymbol = context . CreateSymbolWithVarName ( itemParameter , varName : "this" , itemSerializer ) ; // note: MQL uses $$this for the item being processed
110
120
var funcContext = context . WithSymbols ( accumulatorSymbol , itemSymbol ) ;
111
- var funcTranslation = ExpressionToAggregationExpressionTranslator . Translate ( funcContext , funcLambda . Body ) ;
121
+ var funcTargetSerializer = GetFuncTargetSerializer ( method , targetSerializer ) ;
122
+ var funcTranslation = ExpressionToAggregationExpressionTranslator . Translate ( funcContext , funcLambda . Body , funcTargetSerializer ) ;
112
123
113
124
var ast = AstExpression . Reduce (
114
125
input : sourceTranslation . Ast ,
115
126
initialValue : seedTranslation . Ast ,
116
127
@in : funcTranslation . Ast ) ;
117
128
var serializer = accumulatorSerializer ;
118
129
119
- if ( method . IsOneOf ( __aggregateWithSeedFuncAndResultSelectorMethods ) )
130
+ if ( method . IsOneOf ( __aggregateWithSeedAndFuncAndResultSelectorMethods ) )
120
131
{
121
132
var resultSelectorLambda = ExpressionHelper . UnquoteLambdaIfQueryableMethod ( method , arguments [ 3 ] ) ;
122
133
var resultSelectorAccumulatorParameter = resultSelectorLambda . Parameters [ 0 ] ;
123
134
var resultSelectorAccumulatorSymbol = context . CreateSymbol ( resultSelectorAccumulatorParameter , accumulatorSerializer ) ;
124
135
var resultSelectorContext = context . WithSymbol ( resultSelectorAccumulatorSymbol ) ;
125
- var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator . Translate ( resultSelectorContext , resultSelectorLambda . Body ) ;
136
+ var resultSelectorTargetSerializer = GetResultSelectorTargetSerializer ( method , targetSerializer ) ;
137
+ var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator . Translate ( resultSelectorContext , resultSelectorLambda . Body , resultSelectorTargetSerializer ) ;
126
138
127
139
ast = AstExpression . Let (
128
140
var : AstExpression . VarBinding ( resultSelectorAccumulatorSymbol . Var , ast ) ,
@@ -136,5 +148,57 @@ public static AggregationExpression Translate(TranslationContext context, Method
136
148
137
149
throw new ExpressionNotSupportedException ( expression ) ;
138
150
}
151
+
152
+ private static IBsonSerializer GetFuncTargetSerializer ( MethodInfo method , IBsonSerializer targetSerializer )
153
+ {
154
+ if ( method . IsOneOf ( __aggregateWithFuncMethods , __aggregateWithSeedAndFuncMethods ) )
155
+ {
156
+ return targetSerializer ;
157
+ }
158
+
159
+ return null ;
160
+ }
161
+
162
+ private static IBsonSerializer GetResultSelectorTargetSerializer ( MethodInfo method , IBsonSerializer targetSerializer )
163
+ {
164
+ if ( method . IsOneOf ( __aggregateWithSeedAndFuncAndResultSelectorMethods ) )
165
+ {
166
+ return targetSerializer ;
167
+ }
168
+
169
+ return null ;
170
+ }
171
+
172
+ private static IBsonSerializer GetSeedTargetSerializer ( MethodInfo method , IBsonSerializer targetSerializer )
173
+ {
174
+ if ( method . IsOneOf ( __aggregateWithSeedAndFuncMethods ) )
175
+ {
176
+ return targetSerializer ;
177
+ }
178
+
179
+ return null ;
180
+ }
181
+
182
+ private static IBsonSerializer GetSourceTargetSerializer ( MethodInfo method , IBsonSerializer targetSerializer )
183
+ {
184
+ IBsonSerializer itemSerializer = null ;
185
+ if ( method . IsOneOf ( __aggregateWithFuncMethods ) )
186
+ {
187
+ itemSerializer = targetSerializer ;
188
+ }
189
+
190
+ if ( method . IsOneOf ( __aggregateWithSeedAndFuncMethods ) )
191
+ {
192
+ var genericArguments = method . GetGenericArguments ( ) ;
193
+ var sourceType = genericArguments [ 0 ] ;
194
+ var accumulateType = genericArguments [ 1 ] ;
195
+ if ( sourceType == accumulateType )
196
+ {
197
+ itemSerializer = targetSerializer ;
198
+ }
199
+ }
200
+
201
+ return itemSerializer == null ? null : IEnumerableSerializer . Create ( itemSerializer ) ;
202
+ }
139
203
}
140
204
}
0 commit comments