Skip to content

Commit a98fe70

Browse files
committed
CSHARP-5481: ScalarDiscriminatorConvention class should implement IScalarDiscriminatorConvention interface.
1 parent b289310 commit a98fe70

File tree

15 files changed

+1037
-254
lines changed

15 files changed

+1037
-254
lines changed

src/MongoDB.Bson/Serialization/BsonSerializer.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,37 @@ internal static void EnsureKnownTypesAreRegistered(Type nominalType)
778778
}
779779
}
780780

781+
// internal static methods
782+
internal static BsonValue[] GetDiscriminatorsForTypeAndSubTypes(Type type)
783+
{
784+
// note: EnsureKnownTypesAreRegistered handles its own locking so call from outside any lock
785+
EnsureKnownTypesAreRegistered(type);
786+
787+
var discriminators = new List<BsonValue>();
788+
789+
__configLock.EnterReadLock();
790+
try
791+
{
792+
foreach (var entry in __discriminators)
793+
{
794+
var discriminator = entry.Key;
795+
var actualTypes = entry.Value;
796+
797+
var matchingType = actualTypes.SingleOrDefault(t => t == type || t.IsSubclassOf(type));
798+
if (matchingType != null)
799+
{
800+
discriminators.Add(discriminator);
801+
}
802+
}
803+
}
804+
finally
805+
{
806+
__configLock.ExitReadLock();
807+
}
808+
809+
return discriminators.OrderBy(x => x).ToArray();
810+
}
811+
781812
// private static methods
782813
private static void CreateSerializerRegistry()
783814
{

src/MongoDB.Bson/Serialization/Conventions/ScalarDiscriminatorConvention.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
*/
1515

1616
using System;
17+
using System.Collections.Concurrent;
1718

1819
namespace MongoDB.Bson.Serialization.Conventions
1920
{
2021
/// <summary>
2122
/// Represents a discriminator convention where the discriminator is provided by the class map of the actual type.
2223
/// </summary>
23-
public class ScalarDiscriminatorConvention : StandardDiscriminatorConvention
24+
public class ScalarDiscriminatorConvention : StandardDiscriminatorConvention, IScalarDiscriminatorConvention
2425
{
26+
private readonly ConcurrentDictionary<Type, BsonValue[]> _cachedTypeAndSubTypeDiscriminators = new();
27+
2528
// constructors
2629
/// <summary>
2730
/// Initializes a new instance of the ScalarDiscriminatorConvention class.
@@ -52,5 +55,11 @@ public override BsonValue GetDiscriminator(Type nominalType, Type actualType)
5255
return null;
5356
}
5457
}
58+
59+
/// <inheritdoc/>
60+
public BsonValue[] GetDiscriminatorsForTypeAndSubTypes(Type type)
61+
{
62+
return _cachedTypeAndSubTypeDiscriminators.GetOrAdd(type, BsonSerializer.GetDiscriminatorsForTypeAndSubTypes);
63+
}
5564
}
5665
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
using MongoDB.Bson;
1919
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
2020
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages;
2122
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Visitors;
23+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
2224

2325
namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers
2426
{
@@ -351,6 +353,33 @@ elemMatchOperation.Filter is AstFieldOperationFilter elemFilter &&
351353
}
352354
}
353355

356+
public override AstNode VisitFilterExpression(AstFilterExpression node)
357+
{
358+
var inputExpression = VisitAndConvert(node.Input);
359+
var condExpression = VisitAndConvert(node.Cond);
360+
var limitExpression = VisitAndConvert(node.Limit);
361+
362+
if (condExpression is AstConstantExpression condConstantExpression &&
363+
condConstantExpression.Value is BsonBoolean condBsonBoolean)
364+
{
365+
if (condBsonBoolean.Value)
366+
{
367+
// { $filter : { input : <input>, as : "x", cond : true } } => <input>
368+
if (limitExpression == null)
369+
{
370+
return inputExpression;
371+
}
372+
}
373+
else
374+
{
375+
// { $filter : { input : <input>, as : "x", cond : false, optional-limit } } => []
376+
return AstExpression.Constant(new BsonArray());
377+
}
378+
}
379+
380+
return node.Update(inputExpression, condExpression, limitExpression);
381+
}
382+
354383
public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
355384
{
356385
if (TrySimplifyAsFieldPath(node, out var simplified))
@@ -448,6 +477,26 @@ public override AstNode VisitNotFilterOperation(AstNotFilterOperation node)
448477
return base.VisitNotFilterOperation(node);
449478
}
450479

480+
public override AstNode VisitPipeline(AstPipeline node)
481+
{
482+
var stages = VisitAndConvert(node.Stages);
483+
484+
// { $match : { } } => remove redundant stage
485+
if (stages.Any(stage => IsMatchEverythingStage(stage)))
486+
{
487+
stages = stages.Where(stage => !IsMatchEverythingStage(stage)).AsReadOnlyList();
488+
}
489+
490+
return node.Update(stages);
491+
492+
static bool IsMatchEverythingStage(AstStage stage)
493+
{
494+
return
495+
stage is AstMatchStage matchStage &&
496+
matchStage.Filter is AstMatchesEverythingFilter;
497+
}
498+
}
499+
451500
public override AstNode VisitSliceExpression(AstSliceExpression node)
452501
{
453502
node = (AstSliceExpression)base.VisitSliceExpression(node);
@@ -498,15 +547,34 @@ arrayConstant.Value is BsonArray bsonArrayConstant &&
498547

499548
public override AstNode VisitUnaryExpression(AstUnaryExpression node)
500549
{
550+
var arg = VisitAndConvert(node.Arg);
551+
501552
// { $first : <arg> } => { $arrayElemAt : [<arg>, 0] } (or -1 for $last)
502553
if (node.Operator == AstUnaryOperator.First || node.Operator == AstUnaryOperator.Last)
503554
{
504-
var simplifiedArg = VisitAndConvert(node.Arg);
505555
var index = node.Operator == AstUnaryOperator.First ? 0 : -1;
506-
return AstExpression.ArrayElemAt(simplifiedArg, index);
556+
return AstExpression.ArrayElemAt(arg, index);
557+
}
558+
559+
// { $not : booleanConstant } => !booleanConstant
560+
if (node.Operator is AstUnaryOperator.Not &&
561+
arg is AstConstantExpression argConstantExpression &&
562+
argConstantExpression.Value is BsonBoolean argBsonBoolean)
563+
{
564+
return AstExpression.Constant(!argBsonBoolean.Value);
565+
}
566+
567+
// { $not : { $eq : [expr1, expr2] } } => { $ne : [expr1, expr2] }
568+
// { $not : { $ne : [expr1, expr2] } } => { $eq : [expr1, expr2] }
569+
if (node.Operator is AstUnaryOperator.Not &&
570+
arg is AstBinaryExpression argBinaryExpression &&
571+
argBinaryExpression.Operator is AstBinaryOperator.Eq or AstBinaryOperator.Ne)
572+
{
573+
var oppositeComparisonOperator = argBinaryExpression.Operator == AstBinaryOperator.Eq ? AstBinaryOperator.Ne : AstBinaryOperator.Eq;
574+
return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2);
507575
}
508576

509-
return base.VisitUnaryExpression(node);
577+
return node.Update(arg);
510578
}
511579
}
512580
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/OfTypeMethodToAggregationExpressionTranslator.cs

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
/* Copyright 2010-present MongoDB Inc.
2-
*
3-
* Licensed under the Apache License, Version 2.0 (the "License");
4-
* you may not use this file except in compliance with the License.
5-
* You may obtain a copy of the License at
6-
*
7-
* http://www.apache.org/licenses/LICENSE-2.0
8-
*
9-
* Unless required by applicable law or agreed to in writing, software
10-
* distributed under the License is distributed on an "AS IS" BASIS,
11-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
* See the License for the specific language governing permissions and
13-
* limitations under the License.
14-
*/
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
1515

16+
using System.Linq;
1617
using System.Linq.Expressions;
1718
using System.Reflection;
1819
using MongoDB.Bson.Serialization;
@@ -26,7 +27,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg
2627
{
2728
internal static class OfTypeMethodToAggregationExpressionTranslator
2829
{
29-
private static readonly MethodInfo[] __ofTypeMethods =
30+
private static MethodInfo[] __ofTypeMethods =
3031
{
3132
EnumerableMethod.OfType,
3233
QueryableMethod.OfType
@@ -42,35 +43,46 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
4243
var sourceExpression = arguments[0];
4344
var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
4445
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
45-
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
4646

47-
var nominalType = sourceTranslation.Serializer.ValueType;
48-
var actualType = method.GetGenericArguments()[0];
47+
var sourceAst = sourceTranslation.Ast;
48+
var sourceSerializer = sourceTranslation.Serializer;
49+
if (sourceSerializer is IWrappedValueSerializer wrappedValueSerializer)
50+
{
51+
sourceAst = AstExpression.GetField(sourceAst, wrappedValueSerializer.FieldName);
52+
sourceSerializer = wrappedValueSerializer.ValueSerializer;
53+
}
54+
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer);
4955

56+
var nominalType = itemSerializer.ValueType;
57+
var nominalTypeSerializer = itemSerializer;
58+
var actualType = method.GetGenericArguments().Single();
59+
var actualTypeSerializer = BsonSerializer.LookupSerializer(actualType);
60+
61+
AstExpression ast;
5062
if (nominalType == actualType)
5163
{
52-
return sourceTranslation;
64+
ast = sourceAst;
5365
}
66+
else
67+
{
68+
var discriminatorConvention = nominalTypeSerializer.GetDiscriminatorConvention();
69+
var itemVar = AstExpression.Var("item");
70+
var discriminatorField = AstExpression.GetField(itemVar, discriminatorConvention.ElementName);
5471

55-
var discriminatorConvention = itemSerializer.GetDiscriminatorConvention();
56-
var discriminatorElementName = discriminatorConvention.ElementName;
72+
var ofTypeExpression = discriminatorConvention switch
73+
{
74+
IHierarchicalDiscriminatorConvention hierarchicalDiscriminatorConvention => DiscriminatorAstExpression.TypeIs(discriminatorField, hierarchicalDiscriminatorConvention, nominalType, actualType),
75+
IScalarDiscriminatorConvention scalarDiscriminatorConvention => DiscriminatorAstExpression.TypeIs(discriminatorField, scalarDiscriminatorConvention, nominalType, actualType),
76+
_ => throw new ExpressionNotSupportedException(expression, because: "OfType is not supported with the configured discriminator convention")
77+
};
5778

58-
var itemVar = AstExpression.Var("this");
59-
var discriminatorField = AstExpression.GetField(itemVar, discriminatorElementName);
60-
var ofTypePredicate = discriminatorConvention switch
61-
{
62-
IHierarchicalDiscriminatorConvention hierarchicalDiscriminatorConvention => DiscriminatorAstExpression.TypeIs(discriminatorField, hierarchicalDiscriminatorConvention, nominalType, actualType),
63-
IScalarDiscriminatorConvention scalarDiscriminatorConvention => DiscriminatorAstExpression.TypeIs(discriminatorField, scalarDiscriminatorConvention, nominalType, actualType),
64-
_ => throw new ExpressionNotSupportedException(expression, because: "is operator is not supported with the configured discriminator convention")
65-
};
79+
ast = AstExpression.Filter(
80+
input: sourceAst,
81+
cond: ofTypeExpression,
82+
@as: "item");
83+
}
6684

67-
var ast = AstExpression.Filter(
68-
input: sourceTranslation.Ast,
69-
@as: itemVar.Name,
70-
cond: ofTypePredicate);
71-
var actualTypeSerializer = BsonSerializer.LookupSerializer(actualType);
7285
var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, actualTypeSerializer);
73-
7486
return new TranslatedExpression(expression, ast, resultSerializer);
7587
}
7688

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
4141
var sourceExpression = arguments[0];
4242
var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
4343
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
44-
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
44+
45+
var sourceAst = sourceTranslation.Ast;
46+
var sourceSerializer = sourceTranslation.Serializer;
47+
if (sourceSerializer is IWrappedValueSerializer wrappedValueSerializer)
48+
{
49+
sourceAst = AstExpression.GetField(sourceAst, wrappedValueSerializer.FieldName);
50+
sourceSerializer = wrappedValueSerializer.ValueSerializer;
51+
}
52+
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceSerializer);
4553

4654
var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]);
4755
var predicateParameter = predicateLambda.Parameters[0];
@@ -57,7 +65,7 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
5765
}
5866

5967
var ast = AstExpression.Filter(
60-
sourceTranslation.Ast,
68+
sourceAst,
6169
predicateTranslation.Ast,
6270
@as: predicateSymbol.Var.Name,
6371
limitTranslation?.Ast);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/TypeIsExpressionToAggregationExpressionTranslator.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
* limitations under the License.
1414
*/
1515

16-
using System.Linq;
1716
using System.Linq.Expressions;
18-
using MongoDB.Bson;
1917
using MongoDB.Bson.Serialization;
2018
using MongoDB.Bson.Serialization.Conventions;
2119
using MongoDB.Bson.Serialization.Serializers;

0 commit comments

Comments
 (0)