@@ -153,26 +153,44 @@ private static HashSet<IType> GetCandidateTypes(
153153 return candidateTypes ;
154154 }
155155
156+ private static bool GetCandidateType (
157+ ISessionFactoryImplementor sessionFactory ,
158+ IEnumerable < ConstantExpression > constantExpressions ,
159+ ConstantTypeLocatorVisitor visitor ,
160+ System . Type constantType ,
161+ out IType candidateType )
162+ {
163+ var candidateTypes = GetCandidateTypes ( sessionFactory , constantExpressions , visitor ) ;
164+ if ( candidateTypes . Count == 1 )
165+ {
166+ candidateType = candidateTypes . First ( ) ;
167+
168+ // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
169+ // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
170+ if ( ! IntegralNumericTypes . Contains ( candidateType . ReturnedClass ) ||
171+ ! FloatingPointNumericTypes . Contains ( constantType ) )
172+ {
173+ return true ;
174+ }
175+ }
176+
177+ candidateType = null ;
178+ return false ;
179+ }
180+
156181 private static IType GetParameterType (
157182 ISessionFactoryImplementor sessionFactory ,
158183 HashSet < ConstantExpression > constantExpressions ,
159184 ConstantTypeLocatorVisitor visitor ,
160185 NamedParameter namedParameter )
161186 {
162- var candidateTypes = GetCandidateTypes ( sessionFactory , constantExpressions , visitor ) ;
163187
164188 // All constant expressions have the same type/value
165189 var constantExpression = constantExpressions . First ( ) ;
166190 var constantType = constantExpression . Type . UnwrapIfNullable ( ) ;
167- if (
168- candidateTypes . Count == 1 &&
169- // When comparing an integral column with a floating-point parameter, the parameter type must remain floating-point type
170- // and the column needs to be casted in order to prevent invalid results (e.g. Where(o => o.Integer >= 2.2d)).
171- ! ( candidateTypes . Any ( t => IntegralNumericTypes . Contains ( t . ReturnedClass ) ) &&
172- FloatingPointNumericTypes . Contains ( constantType ) )
173- )
191+ if ( GetCandidateType ( sessionFactory , constantExpressions , visitor , constantType , out var candidateType ) )
174192 {
175- return candidateTypes . First ( ) ;
193+ return candidateType ;
176194 }
177195
178196 // No related MemberExpressions was found, guess the type by value or its type when null.
0 commit comments