Skip to content

Commit 63b6b42

Browse files
authored
Fix parameter detection for Contains method for Linq provider (#2520)
Fixes #2512, fixes #2514, fixes #2515
1 parent a1e709b commit 63b6b42

File tree

5 files changed

+185
-30
lines changed

5 files changed

+185
-30
lines changed

src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ public class AnotherEntityRequired
2222

2323
public virtual ISet<AnotherEntity> RelatedItems { get; set; } = new HashSet<AnotherEntity>();
2424

25+
public virtual ISet<AnotherEntityRequired> RequiredRelatedItems { get; set; } = new HashSet<AnotherEntityRequired>();
26+
2527
public virtual bool? NullableBool { get; set; }
2628
}
2729

src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml

+4
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,9 @@
1919
<key column="Id"/>
2020
<one-to-many class="AnotherEntity"/>
2121
</set>
22+
<set name="RequiredRelatedItems" lazy="true" inverse="true">
23+
<key column="Id"/>
24+
<one-to-many class="AnotherEntityRequired"/>
25+
</set>
2226
</class>
2327
</hibernate-mapping>

src/NHibernate.Test/Async/Linq/ParameterTests.cs

+74
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,80 @@ public async Task UsingEntityParameterTwiceAsync()
8181
1));
8282
}
8383

84+
[Test]
85+
public async Task UsingEntityParameterForCollectionAsync()
86+
{
87+
var item = await (db.OrderLines.FirstAsync());
88+
await (AssertTotalParametersAsync(
89+
db.Orders.Where(o => o.OrderLines.Contains(item)),
90+
1));
91+
}
92+
93+
[Test]
94+
public async Task UsingProxyParameterForCollectionAsync()
95+
{
96+
var item = await (session.LoadAsync<Order>(10248));
97+
Assert.That(NHibernateUtil.IsInitialized(item), Is.False);
98+
await (AssertTotalParametersAsync(
99+
db.Customers.Where(o => o.Orders.Contains(item)),
100+
1));
101+
}
102+
103+
[Test]
104+
public async Task UsingFieldProxyParameterForCollectionAsync()
105+
{
106+
var item = await (session.Query<AnotherEntityRequired>().FirstAsync());
107+
await (AssertTotalParametersAsync(
108+
session.Query<AnotherEntityRequired>().Where(o => o.RequiredRelatedItems.Contains(item)),
109+
1));
110+
}
111+
112+
[Test]
113+
public async Task UsingEntityParameterInSubQueryAsync()
114+
{
115+
var item = await (db.Customers.FirstAsync());
116+
var subQuery = db.Orders.Select(o => o.Customer).Where(o => o == item);
117+
await (AssertTotalParametersAsync(
118+
db.Orders.Where(o => subQuery.Contains(o.Customer)),
119+
1));
120+
}
121+
122+
[Test]
123+
public async Task UsingEntityParameterForCollectionSelectionAsync()
124+
{
125+
var item = await (db.OrderLines.FirstAsync());
126+
await (AssertTotalParametersAsync(
127+
db.Orders.SelectMany(o => o.OrderLines).Where(o => o == item),
128+
1));
129+
}
130+
131+
[Test]
132+
public async Task UsingFieldProxyParameterForCollectionSelectionAsync()
133+
{
134+
var item = await (session.Query<AnotherEntityRequired>().FirstAsync());
135+
await (AssertTotalParametersAsync(
136+
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => o == item),
137+
1));
138+
}
139+
140+
[Test]
141+
public async Task UsingEntityListParameterForCollectionSelectionAsync()
142+
{
143+
var items = new[] {await (db.OrderLines.FirstAsync())};
144+
await (AssertTotalParametersAsync(
145+
db.Orders.SelectMany(o => o.OrderLines).Where(o => items.Contains(o)),
146+
1));
147+
}
148+
149+
[Test]
150+
public async Task UsingFieldProxyListParameterForCollectionSelectionAsync()
151+
{
152+
var items = new[] {await (session.Query<AnotherEntityRequired>().FirstAsync())};
153+
await (AssertTotalParametersAsync(
154+
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => items.Contains(o)),
155+
1));
156+
}
157+
84158
[Test]
85159
public async Task UsingTwoEntityParametersAsync()
86160
{

src/NHibernate.Test/Linq/ParameterTests.cs

+74
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,80 @@ public void UsingEntityParameterTwice()
6969
1);
7070
}
7171

72+
[Test]
73+
public void UsingEntityParameterForCollection()
74+
{
75+
var item = db.OrderLines.First();
76+
AssertTotalParameters(
77+
db.Orders.Where(o => o.OrderLines.Contains(item)),
78+
1);
79+
}
80+
81+
[Test]
82+
public void UsingProxyParameterForCollection()
83+
{
84+
var item = session.Load<Order>(10248);
85+
Assert.That(NHibernateUtil.IsInitialized(item), Is.False);
86+
AssertTotalParameters(
87+
db.Customers.Where(o => o.Orders.Contains(item)),
88+
1);
89+
}
90+
91+
[Test]
92+
public void UsingFieldProxyParameterForCollection()
93+
{
94+
var item = session.Query<AnotherEntityRequired>().First();
95+
AssertTotalParameters(
96+
session.Query<AnotherEntityRequired>().Where(o => o.RequiredRelatedItems.Contains(item)),
97+
1);
98+
}
99+
100+
[Test]
101+
public void UsingEntityParameterInSubQuery()
102+
{
103+
var item = db.Customers.First();
104+
var subQuery = db.Orders.Select(o => o.Customer).Where(o => o == item);
105+
AssertTotalParameters(
106+
db.Orders.Where(o => subQuery.Contains(o.Customer)),
107+
1);
108+
}
109+
110+
[Test]
111+
public void UsingEntityParameterForCollectionSelection()
112+
{
113+
var item = db.OrderLines.First();
114+
AssertTotalParameters(
115+
db.Orders.SelectMany(o => o.OrderLines).Where(o => o == item),
116+
1);
117+
}
118+
119+
[Test]
120+
public void UsingFieldProxyParameterForCollectionSelection()
121+
{
122+
var item = session.Query<AnotherEntityRequired>().First();
123+
AssertTotalParameters(
124+
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => o == item),
125+
1);
126+
}
127+
128+
[Test]
129+
public void UsingEntityListParameterForCollectionSelection()
130+
{
131+
var items = new[] {db.OrderLines.First()};
132+
AssertTotalParameters(
133+
db.Orders.SelectMany(o => o.OrderLines).Where(o => items.Contains(o)),
134+
1);
135+
}
136+
137+
[Test]
138+
public void UsingFieldProxyListParameterForCollectionSelection()
139+
{
140+
var items = new[] {session.Query<AnotherEntityRequired>().First()};
141+
AssertTotalParameters(
142+
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => items.Contains(o)),
143+
1);
144+
}
145+
72146
[Test]
73147
public void UsingTwoEntityParameters()
74148
{

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

+31-30
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ private static IType GetCandidateType(
117117
if (!ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var mappedType, out _, out _, out _))
118118
continue;
119119

120-
if (mappedType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression))
120+
if (mappedType.IsCollectionType)
121121
{
122122
var collection = (IQueryableCollection) ((IAssociationType) mappedType).GetAssociatedJoinable(sessionFactory);
123123
mappedType = collection.ElementType;
@@ -176,7 +176,6 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
176176
new Dictionary<NamedParameter, HashSet<ConstantExpression>>();
177177
public readonly Dictionary<Expression, HashSet<Expression>> RelatedExpressions =
178178
new Dictionary<Expression, HashSet<Expression>>();
179-
public readonly HashSet<Expression> SequenceSelectorExpressions = new HashSet<Expression>();
180179

181180
public ConstantTypeLocatorVisitor(
182181
bool removeMappedAsCalls,
@@ -282,41 +281,43 @@ protected override Expression VisitConstant(ConstantExpression node)
282281
}
283282

284283
protected override Expression VisitSubQuery(SubQueryExpression node)
284+
{
285+
if (!TryLinkContainsMethod(node.QueryModel))
286+
{
287+
node.QueryModel.TransformExpressions(Visit);
288+
}
289+
290+
return node;
291+
}
292+
293+
private bool TryLinkContainsMethod(QueryModel queryModel)
285294
{
286295
// ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
287296
// ContainsResultOperator where the constant expression is dislocated from the related expression,
288297
// we have to manually link the related expressions.
289-
if (node.QueryModel.ResultOperators.Count == 1 &&
290-
node.QueryModel.ResultOperators[0] is ContainsResultOperator containsOperator &&
291-
node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference &&
292-
querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause &&
293-
mainFromClause.FromExpression is ConstantExpression constantExpression)
298+
if (queryModel.ResultOperators.Count != 1 ||
299+
!(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator) ||
300+
!(queryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference) ||
301+
!(querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause))
294302
{
295-
VisitConstant(constantExpression);
296-
AddRelatedExpression(constantExpression, UnwrapUnary(Visit(containsOperator.Item)));
297-
// Copy all found MemberExpressions to the constant expression
298-
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
299-
if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set))
300-
{
301-
foreach (var nestedMemberExpression in set)
302-
{
303-
AddRelatedExpression(constantExpression, nestedMemberExpression);
304-
}
305-
}
303+
return false;
306304
}
307-
else
308-
{
309-
// In case a parameter is related to a sequence selector we will have to get the underlying item type
310-
// (e.g. q.Where(o => o.Users.Any(u => u == user)))
311-
if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase))
312-
{
313-
SequenceSelectorExpressions.Add(node.QueryModel.SelectClause.Selector);
314-
}
315305

316-
node.QueryModel.TransformExpressions(Visit);
306+
var left = UnwrapUnary(Visit(mainFromClause.FromExpression));
307+
var right = UnwrapUnary(Visit(containsOperator.Item));
308+
// The constant is on the left side (e.g. db.Users.Where(o => users.Contains(o)))
309+
// The constant is on the right side (e.g. db.Customers.Where(o => o.Orders.Contains(item)))
310+
if (left.NodeType != ExpressionType.Constant && right.NodeType != ExpressionType.Constant)
311+
{
312+
return false;
317313
}
318314

319-
return node;
315+
// Copy all found MemberExpressions to the constant expression
316+
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
317+
AddRelatedExpression(null, left, right);
318+
AddRelatedExpression(null, right, left);
319+
320+
return true;
320321
}
321322

322323
private void VisitAssign(Expression leftNode, Expression rightNode)
@@ -346,7 +347,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r
346347
left is QuerySourceReferenceExpression)
347348
{
348349
AddRelatedExpression(right, left);
349-
if (NonVoidOperators.Contains(node.NodeType))
350+
if (node != null && NonVoidOperators.Contains(node.NodeType))
350351
{
351352
AddRelatedExpression(node, left);
352353
}
@@ -359,7 +360,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r
359360
foreach (var nestedMemberExpression in set)
360361
{
361362
AddRelatedExpression(right, nestedMemberExpression);
362-
if (NonVoidOperators.Contains(node.NodeType))
363+
if (node != null && NonVoidOperators.Contains(node.NodeType))
363364
{
364365
AddRelatedExpression(node, nestedMemberExpression);
365366
}

0 commit comments

Comments
 (0)