@@ -117,7 +117,7 @@ private static IType GetCandidateType(
117117 if (!ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var mappedType, out _, out _, out _))
118118 continue;
119119
120- if (mappedType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression) )
120+ if (mappedType.IsCollectionType )
121121 {
122122 var collection = (IQueryableCollection) ((IAssociationType) mappedType).GetAssociatedJoinable(sessionFactory);
123123 mappedType = collection.ElementType;
@@ -176,7 +176,6 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
176176 new Dictionary<NamedParameter, HashSet<ConstantExpression>>();
177177 public readonly Dictionary<Expression, HashSet<Expression>> RelatedExpressions =
178178 new Dictionary<Expression, HashSet<Expression>>();
179- public readonly HashSet<Expression> SequenceSelectorExpressions = new HashSet<Expression>();
180179
181180 public ConstantTypeLocatorVisitor(
182181 bool removeMappedAsCalls,
@@ -282,41 +281,53 @@ protected override Expression VisitConstant(ConstantExpression node)
282281 }
283282
284283 protected override Expression VisitSubQuery(SubQueryExpression node)
284+ {
285+ if (!TryLinkContainsMethod(node.QueryModel))
286+ {
287+ node.QueryModel.TransformExpressions(Visit);
288+ }
289+
290+ return node;
291+ }
292+
293+ private bool TryLinkContainsMethod(QueryModel queryModel)
285294 {
286295 // ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
287296 // ContainsResultOperator where the constant expression is dislocated from the related expression,
288297 // we have to manually link the related expressions.
289- if (node.QueryModel.ResultOperators.Count == 1 &&
290- node.QueryModel.ResultOperators[0] is ContainsResultOperator containsOperator &&
291- node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference &&
292- querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause &&
293- mainFromClause.FromExpression is ConstantExpression constantExpression)
298+ if (queryModel.ResultOperators.Count != 1 ||
299+ !(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator) ||
300+ !(queryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference) ||
301+ !(querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause))
294302 {
295- VisitConstant(constantExpression);
296- AddRelatedExpression(constantExpression, UnwrapUnary(Visit(containsOperator.Item)));
297- // Copy all found MemberExpressions to the constant expression
298- // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
299- if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set))
300- {
301- foreach (var nestedMemberExpression in set)
302- {
303- AddRelatedExpression(constantExpression, nestedMemberExpression);
304- }
305- }
303+ return false;
304+ }
305+
306+ var left = UnwrapUnary(mainFromClause.FromExpression);
307+ var right = UnwrapUnary(containsOperator.Item);
308+ if (left.NodeType == ExpressionType.Constant)
309+ {
310+ // The constant is on the left side (e.g. db.Users.Where(o => users.Contains(o)))
311+ VisitConstant((ConstantExpression) left);
312+ right = UnwrapUnary(Visit(containsOperator.Item));
313+ }
314+ else if (right.NodeType == ExpressionType.Constant)
315+ {
316+ // The constant is on the right side (e.g. db.Customers.Where(o => o.Orders.Contains(item)))
317+ VisitConstant((ConstantExpression) right);
318+ left = UnwrapUnary(Visit(mainFromClause.FromExpression));
306319 }
307320 else
308321 {
309- // In case a parameter is related to a sequence selector we will have to get the underlying item type
310- // (e.g. q.Where(o => o.Users.Any(u => u == user)))
311- if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase))
312- {
313- SequenceSelectorExpressions.Add(node.QueryModel.SelectClause.Selector);
314- }
315-
316- node.QueryModel.TransformExpressions(Visit);
322+ return false;
317323 }
318324
319- return node;
325+ // Copy all found MemberExpressions to the constant expression
326+ // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
327+ AddRelatedExpression(null, left, right);
328+ AddRelatedExpression(null, right, left);
329+
330+ return true;
320331 }
321332
322333 private void VisitAssign(Expression leftNode, Expression rightNode)
@@ -346,7 +357,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r
346357 left is QuerySourceReferenceExpression)
347358 {
348359 AddRelatedExpression(right, left);
349- if (NonVoidOperators.Contains(node.NodeType))
360+ if (node != null && NonVoidOperators.Contains(node.NodeType))
350361 {
351362 AddRelatedExpression(node, left);
352363 }
@@ -359,7 +370,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r
359370 foreach (var nestedMemberExpression in set)
360371 {
361372 AddRelatedExpression(right, nestedMemberExpression);
362- if (NonVoidOperators.Contains(node.NodeType))
373+ if (node != null && NonVoidOperators.Contains(node.NodeType))
363374 {
364375 AddRelatedExpression(node, nestedMemberExpression);
365376 }
0 commit comments