From b1971a1903a8b2ac2624c21b77ee900db800d9a3 Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 28 Jul 2020 20:58:58 +0200 Subject: [PATCH 1/2] Fix detecting parameter type for Contains method for Linq provider --- src/NHibernate.Test/Async/Linq/EnumTests.cs | 8 +++++ src/NHibernate.Test/Linq/EnumTests.cs | 8 +++++ .../Linq/ParameterTypeLocatorTests.cs | 16 ++++++++++ .../Linq/Visitors/ParameterTypeLocator.cs | 32 ++++++++++++++++--- 4 files changed, 60 insertions(+), 4 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/EnumTests.cs b/src/NHibernate.Test/Async/Linq/EnumTests.cs index 6e9355d294c..e08a2c90829 100644 --- a/src/NHibernate.Test/Async/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Async/Linq/EnumTests.cs @@ -62,6 +62,14 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async() Assert.AreEqual(expectedCount, query.Count); } + [Test] + public async Task CanQueryWithContainsOnEnumStoredAsString_Small_1Async() + { + var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium }; + var query = await (db.Users.Where(x => values.Contains(x.Enum1)).ToListAsync()); + Assert.AreEqual(3, query.Count); + } + [Test] public async Task ConditionalNavigationPropertyAsync() { diff --git a/src/NHibernate.Test/Linq/EnumTests.cs b/src/NHibernate.Test/Linq/EnumTests.cs index aeea060b51e..7f312de7e42 100644 --- a/src/NHibernate.Test/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Linq/EnumTests.cs @@ -49,6 +49,14 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo Assert.AreEqual(expectedCount, query.Count); } + [Test] + public void CanQueryWithContainsOnEnumStoredAsString_Small_1() + { + var values = new[] { EnumStoredAsString.Small, EnumStoredAsString.Medium }; + var query = db.Users.Where(x => values.Contains(x.Enum1)).ToList(); + Assert.AreEqual(3, query.Count); + } + [Test] public void ConditionalNavigationProperty() { diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs index 511e23f88cd..39cb2d22d74 100644 --- a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -84,6 +84,22 @@ public void EqualStringEnumTest() ); } + [Test] + public void ContainsStringEnumTest() + { + var values = new[] {EnumStoredAsString.Small}; + AssertResults( + new Dictionary> + { + {"value(NHibernate.DomainModel.Northwind.Entities.EnumStoredAsString[])", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => values.Contains(o.Enum1)), + db.Users.Where(o => values.Contains(o.NullableEnum1.Value)), + db.Users.Where(o => values.Contains(o.Name == o.Name ? o.Enum1 : o.NullableEnum1.Value)), + db.Timesheets.Where(o => o.Users.Any(u => values.Contains(u.Enum1))) + ); + } + [Test] public void EqualStringEnumTestWithFetch() { diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 34326640169..606bffcd55b 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -1,12 +1,15 @@ using System.Collections.Generic; using System.Dynamic; +using System.Linq; using System.Linq.Expressions; using NHibernate.Engine; using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; using Remotion.Linq; +using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -219,14 +222,35 @@ protected override Expression VisitConstant(ConstantExpression node) return node; } - public override Expression Visit(Expression node) + protected override Expression VisitSubQuery(SubQueryExpression node) { - if (node is SubQueryExpression subQueryExpression) + // ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of + // ContainsResultOperator where the constant expression is dislocated from the related expression, + // we have to manually link the related expressions. + var containsOperator = node.QueryModel.ResultOperators.OfType().FirstOrDefault(); + if (containsOperator != null && + node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference && + querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && + mainFromClause.FromExpression is ConstantExpression constantExpression) { - subQueryExpression.QueryModel.TransformExpressions(Visit); + VisitConstant(constantExpression); + AddRelatedExpression(constantExpression, Unwrap(Visit(containsOperator.Item))); + // Copy all found MemberExpressions to the constant expression + // (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2) + if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set)) + { + foreach (var nestedMemberExpression in set) + { + AddRelatedExpression(constantExpression, nestedMemberExpression); + } + } + } + else + { + node.QueryModel.TransformExpressions(Visit); } - return base.Visit(node); + return node; } private void VisitAssign(Expression leftNode, Expression rightNode) From 290047d31e98b83857d95f0cdb8374091140121a Mon Sep 17 00:00:00 2001 From: maca88 Date: Wed, 29 Jul 2020 16:57:06 +0200 Subject: [PATCH 2/2] Code review changes --- src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 606bffcd55b..40f3ec0d3d3 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -227,8 +227,8 @@ protected override Expression VisitSubQuery(SubQueryExpression node) // ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of // ContainsResultOperator where the constant expression is dislocated from the related expression, // we have to manually link the related expressions. - var containsOperator = node.QueryModel.ResultOperators.OfType().FirstOrDefault(); - if (containsOperator != null && + if (node.QueryModel.ResultOperators.Count == 1 && + node.QueryModel.ResultOperators[0] is ContainsResultOperator containsOperator && node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference && querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && mainFromClause.FromExpression is ConstantExpression constantExpression)