11using System ;
2+ using System . Data ;
23using System . Dynamic ;
34using System . Linq ;
45using System . Linq . Expressions ;
@@ -240,10 +241,13 @@ constant.Value is CallSite site &&
240241 protected HqlTreeNode VisitNhAverage ( NhAverageExpression expression )
241242 {
242243 var hqlExpression = VisitExpression ( expression . Expression ) . AsExpression ( ) ;
243- if ( expression . Type != expression . Expression . Type )
244- hqlExpression = _hqlTreeBuilder . Cast ( hqlExpression , expression . Type ) ;
244+ hqlExpression = IsCastRequired ( expression . Expression , expression . Type )
245+ ? ( HqlExpression ) _hqlTreeBuilder . Cast ( hqlExpression , expression . Type )
246+ : _hqlTreeBuilder . TransparentCast ( hqlExpression , expression . Type ) ;
245247
246- return _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type ) ;
248+ return IsCastRequired ( expression . Type , "avg" )
249+ ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type )
250+ : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type ) ;
247251 }
248252
249253 protected HqlTreeNode VisitNhCount ( NhCountExpression expression )
@@ -263,17 +267,9 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
263267
264268 protected HqlTreeNode VisitNhSum ( NhSumExpression expression )
265269 {
266- var type = expression . Type . UnwrapIfNullable ( ) ;
267- var nhType = TypeFactory . GetDefaultTypeFor ( type ) ;
268- if ( nhType != null && _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( "sum" )
269- ? . ReturnType ( nhType , _parameters . SessionFactory ) ? . ReturnedClass == type )
270- {
271- return _hqlTreeBuilder . TransparentCast (
272- _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) ,
273- expression . Type ) ;
274- }
275-
276- return _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
270+ return IsCastRequired ( expression . Type , "sum" )
271+ ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type )
272+ : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
277273 }
278274
279275 protected HqlTreeNode VisitNhDistinct ( NhDistinctExpression expression )
@@ -487,15 +483,9 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
487483 case ExpressionType . Convert :
488484 case ExpressionType . ConvertChecked :
489485 case ExpressionType . TypeAs :
490- var operandType = expression . Operand . Type . UnwrapIfNullable ( ) ;
491- if ( ( operandType . IsPrimitive || operandType == typeof ( decimal ) ) &&
492- ( expression . Type . IsPrimitive || expression . Type == typeof ( decimal ) ) &&
493- expression . Type != operandType )
494- {
495- return _hqlTreeBuilder . Cast ( VisitExpression ( expression . Operand ) . AsExpression ( ) , expression . Type ) ;
496- }
497-
498- return VisitExpression ( expression . Operand ) ;
486+ return IsCastRequired ( expression . Operand , expression . Type )
487+ ? _hqlTreeBuilder . Cast ( VisitExpression ( expression . Operand ) . AsExpression ( ) , expression . Type )
488+ : VisitExpression ( expression . Operand ) ;
499489 }
500490
501491 throw new NotSupportedException ( expression . ToString ( ) ) ;
@@ -596,5 +586,96 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
596586 var expressionSubTree = expression . Expressions . ToArray ( exp => VisitExpression ( exp ) ) ;
597587 return _hqlTreeBuilder . ExpressionSubTreeHolder ( expressionSubTree ) ;
598588 }
589+
590+ private bool IsCastRequired ( Expression expression , System . Type toType )
591+ {
592+ return toType != typeof ( object ) && IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) ) ;
593+ }
594+
595+ private bool IsCastRequired ( IType type , IType toType )
596+ {
597+ // A type can be null when casting an entity into a base class, in that case we should not cast
598+ if ( type == null || toType == null || Equals ( type , toType ) )
599+ {
600+ return false ;
601+ }
602+
603+ var sqlTypes = type . SqlTypes ( _parameters . SessionFactory ) ;
604+ var toSqlTypes = toType . SqlTypes ( _parameters . SessionFactory ) ;
605+ if ( sqlTypes . Length != 1 || toSqlTypes . Length != 1 )
606+ {
607+ return false ; // Casting a multi-column type is not possible
608+ }
609+
610+ if ( type . ReturnedClass . IsEnum && sqlTypes [ 0 ] . DbType == DbType . String )
611+ {
612+ return false ; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value
613+ }
614+
615+ return sqlTypes [ 0 ] . DbType != toSqlTypes [ 0 ] . DbType ;
616+ }
617+
618+ private bool IsCastRequired ( System . Type type , string sqlFunctionName )
619+ {
620+ if ( type == typeof ( object ) )
621+ {
622+ return false ;
623+ }
624+
625+ var toType = TypeFactory . GetDefaultTypeFor ( type ) ;
626+ if ( toType == null )
627+ {
628+ return true ; // Fallback to the old behavior
629+ }
630+
631+ var sqlFunction = _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( sqlFunctionName ) ;
632+ if ( sqlFunction == null )
633+ {
634+ return true ; // Fallback to the old behavior
635+ }
636+
637+ var fnReturnType = sqlFunction . ReturnType ( toType , _parameters . SessionFactory ) ;
638+ return fnReturnType == null || IsCastRequired ( fnReturnType , toType ) ;
639+ }
640+
641+ private IType GetType ( Expression expression )
642+ {
643+ if ( ! ( expression is MemberExpression memberExpression ) )
644+ {
645+ return expression . Type != typeof ( object )
646+ ? TypeFactory . GetDefaultTypeFor ( expression . Type )
647+ : null ;
648+ }
649+
650+ // Try to get the mapped type for the member as it may be a non default one
651+ var entityName = TryGetEntityName ( memberExpression ) ;
652+ if ( entityName == null )
653+ {
654+ return TypeFactory . GetDefaultTypeFor ( expression . Type ) ; // Not mapped
655+ }
656+
657+ var persister = _parameters . SessionFactory . GetEntityPersister ( entityName ) ;
658+ var index = persister . EntityMetamodel . GetPropertyIndexOrNull ( memberExpression . Member . Name ) ;
659+ return ! index . HasValue
660+ ? TypeFactory . GetDefaultTypeFor ( expression . Type ) // Not mapped
661+ : persister . EntityMetamodel . PropertyTypes [ index . Value ] ;
662+ }
663+
664+ private string TryGetEntityName ( MemberExpression memberExpression )
665+ {
666+ System . Type entityType ;
667+ // Try to get the actual entity type from the query source if possbile as member can be declared
668+ // in a base type
669+ if ( memberExpression . Expression is QuerySourceReferenceExpression querySourceReferenceExpression )
670+ {
671+ entityType = querySourceReferenceExpression . Type ;
672+ }
673+ else
674+ {
675+ entityType = memberExpression . Member . ReflectedType ;
676+ }
677+
678+ return _parameters . SessionFactory . TryGetGuessEntityName ( entityType ) ;
679+ }
599680 }
600681}
0 commit comments