@@ -115,62 +115,39 @@ internal static void SetParameterTypes(
115115 }
116116 }
117117
118- private static HashSet < IType > GetCandidateTypes (
118+ private static IType GetCandidateType (
119119 ISessionFactoryImplementor sessionFactory ,
120120 IEnumerable < ConstantExpression > constantExpressions ,
121121 ConstantTypeLocatorVisitor visitor )
122122 {
123- var candidateTypes = new HashSet < IType > ( ) ;
123+ IType candidateType = null ;
124124 foreach ( var expression in constantExpressions )
125125 {
126126 // In order to get the actual type we have to check first the related member expressions, as
127127 // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string.
128128 // By getting the type from a related member expression we also get the correct length in case of StringType
129129 // or precision when having a DecimalType.
130- if ( visitor . RelatedExpressions . TryGetValue ( expression , out var relatedExpressions ) )
130+ if ( ! visitor . RelatedExpressions . TryGetValue ( expression , out var relatedExpressions ) )
131+ continue ;
132+ foreach ( var relatedExpression in relatedExpressions )
131133 {
132- foreach ( var relatedExpression in relatedExpressions )
133- {
134- if ( ExpressionsHelper . TryGetMappedType ( sessionFactory , relatedExpression , out var candidateType , out _ , out _ , out _ ) )
135- {
136- if ( candidateType . IsAssociationType && visitor . SequenceSelectorExpressions . Contains ( relatedExpression ) )
137- {
138- var collection = ( IQueryableCollection ) ( ( IAssociationType ) candidateType ) . GetAssociatedJoinable ( sessionFactory ) ;
139- candidateType = collection . ElementType ;
140- }
134+ if ( ! ExpressionsHelper . TryGetMappedType ( sessionFactory , relatedExpression , out var mappedType , out _ , out _ , out _ ) )
135+ continue ;
141136
142- candidateTypes . Add ( candidateType ) ;
143- }
137+ if ( mappedType . IsAssociationType && visitor . SequenceSelectorExpressions . Contains ( relatedExpression ) )
138+ {
139+ var collection = ( IQueryableCollection ) ( ( IAssociationType ) mappedType ) . GetAssociatedJoinable ( sessionFactory ) ;
140+ mappedType = collection . ElementType ;
144141 }
145- }
146- }
147-
148- return candidateTypes ;
149- }
150-
151- private static bool GetCandidateType (
152- ISessionFactoryImplementor sessionFactory ,
153- IEnumerable < ConstantExpression > constantExpressions ,
154- ConstantTypeLocatorVisitor visitor ,
155- System . Type constantType ,
156- out IType candidateType )
157- {
158- var candidateTypes = GetCandidateTypes ( sessionFactory , constantExpressions , visitor ) ;
159- if ( candidateTypes . Count == 1 )
160- {
161- candidateType = candidateTypes . First ( ) ;
162142
163- // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
164- // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
165- if ( ! IntegralNumericTypes . Contains ( candidateType . ReturnedClass ) ||
166- ! FloatingPointNumericTypes . Contains ( constantType ) )
167- {
168- return true ;
143+ if ( candidateType == null )
144+ candidateType = mappedType ;
145+ else if ( ! candidateType . Equals ( mappedType ) )
146+ return null ;
169147 }
170148 }
171149
172- candidateType = null ;
173- return false ;
150+ return candidateType ;
174151 }
175152
176153 private static IType GetParameterType (
@@ -183,7 +160,11 @@ private static IType GetParameterType(
183160 // All constant expressions have the same type/value
184161 var constantExpression = constantExpressions . First ( ) ;
185162 var constantType = constantExpression . Type . UnwrapIfNullable ( ) ;
186- if ( GetCandidateType ( sessionFactory , constantExpressions , visitor , constantType , out var candidateType ) )
163+ var candidateType = GetCandidateType ( sessionFactory , constantExpressions , visitor ) ;
164+ if ( candidateType != null &&
165+ // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
166+ // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
167+ ! ( FloatingPointNumericTypes . Contains ( constantType ) && IntegralNumericTypes . Contains ( candidateType . ReturnedClass ) ) )
187168 {
188169 return candidateType ;
189170 }
0 commit comments