@@ -240,6 +240,9 @@ constant.Value is CallSite site &&
240240
241241 protected HqlTreeNode VisitNhAverage ( NhAverageExpression expression )
242242 {
243+ // We need to cast the argument when its type is different from Average method return type,
244+ // otherwise the result may be incorrect. In SQL Server avg always returns int
245+ // when the argument is int.
243246 var hqlExpression = VisitExpression ( expression . Expression ) . AsExpression ( ) ;
244247 hqlExpression = IsCastRequired ( expression . Expression , expression . Type , out _ )
245248 ? ( HqlExpression ) _hqlTreeBuilder . Cast ( hqlExpression , expression . Type )
@@ -267,7 +270,7 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
267270
268271 protected HqlTreeNode VisitNhSum ( NhSumExpression expression )
269272 {
270- return IsCastRequired ( expression . Type , "sum" , out _ )
273+ return IsCastRequired ( "sum" , expression . Expression , expression . Type )
271274 ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type )
272275 : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
273276 }
@@ -593,7 +596,8 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
593596 private bool IsCastRequired ( Expression expression , System . Type toType , out bool existType )
594597 {
595598 existType = false ;
596- return toType != typeof ( object ) && IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) , out existType ) ;
599+ return toType != typeof ( object ) &&
600+ IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) , out existType ) ;
597601 }
598602
599603 private bool IsCastRequired ( IType type , IType toType , out bool existType )
@@ -635,59 +639,38 @@ private bool IsCastRequired(IType type, IType toType, out bool existType)
635639 return castTypeName != toCastTypeName ;
636640 }
637641
638- private bool IsCastRequired ( System . Type type , string sqlFunctionName , out bool existType )
642+ private bool IsCastRequired ( string sqlFunctionName , Expression argumentExpression , System . Type returnType )
639643 {
640- if ( type == typeof ( object ) )
644+ var argumentType = GetType ( argumentExpression ) ;
645+ if ( argumentType == null || returnType == typeof ( object ) )
641646 {
642- existType = false ;
643647 return false ;
644648 }
645649
646- var toType = TypeFactory . GetDefaultTypeFor ( type ) ;
647- if ( toType == null )
650+ var returnNhType = TypeFactory . GetDefaultTypeFor ( returnType ) ;
651+ if ( returnNhType == null )
648652 {
649- existType = false ;
650653 return true ; // Fallback to the old behavior
651654 }
652655
653- existType = true ;
654656 var sqlFunction = _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( sqlFunctionName ) ;
655657 if ( sqlFunction == null )
656658 {
657659 return true ; // Fallback to the old behavior
658660 }
659661
660- var fnReturnType = sqlFunction . ReturnType ( toType , _parameters . SessionFactory ) ;
661- return fnReturnType == null || IsCastRequired ( fnReturnType , toType , out existType ) ;
662+ var fnReturnType = sqlFunction . ReturnType ( argumentType , _parameters . SessionFactory ) ;
663+ return fnReturnType == null || IsCastRequired ( fnReturnType , returnNhType , out _ ) ;
662664 }
663665
664666 private IType GetType ( Expression expression )
665667 {
666- if ( ! ( expression is MemberExpression memberExpression ) )
667- {
668- return expression . Type != typeof ( object )
669- ? TypeFactory . GetDefaultTypeFor ( expression . Type )
670- : null ;
671- }
672-
673668 // Try to get the mapped type for the member as it may be a non default one
674- var entityName = ExpressionsHelper . TryGetEntityName ( _parameters . SessionFactory , memberExpression , out var memberPath ) ;
675- if ( entityName == null )
676- {
677- return TypeFactory . GetDefaultTypeFor ( expression . Type ) ; // Not mapped
678- }
679-
680- var persister = _parameters . SessionFactory . GetEntityPersister ( entityName ) ;
681- var type = persister . EntityMetamodel . GetIdentifierPropertyType ( memberPath ) ;
682- if ( type != null )
683- {
684- return type ;
685- }
686-
687- var index = persister . EntityMetamodel . GetPropertyIndexOrNull ( memberPath ) ;
688- return ! index . HasValue
689- ? TypeFactory . GetDefaultTypeFor ( expression . Type ) // Not mapped
690- : persister . EntityMetamodel . PropertyTypes [ index . Value ] ;
669+ ExpressionsHelper . TryGetEntityName ( _parameters . SessionFactory , expression , out _ , out var type ) ;
670+ return type ??
671+ ( expression . Type != typeof ( object )
672+ ? TypeFactory . GetDefaultTypeFor ( expression . Type )
673+ : null ) ;
691674 }
692675 }
693676}
0 commit comments