@@ -21,6 +21,7 @@ namespace ShardingCore.Core.Internal.Visitors
2121 internal class DbContextInnerMemberReferenceReplaceQueryableVisitor : ExpressionVisitor
2222 {
2323 private readonly DbContext _dbContext ;
24+ protected bool RootIsVisit = false ;
2425
2526 public DbContextInnerMemberReferenceReplaceQueryableVisitor ( DbContext dbContext )
2627 {
@@ -86,18 +87,6 @@ private MemberExpression ReplaceMemberExpression(IQueryable queryable)
8687 Expression . Property ( ConstantExpression . Constant ( tempVariable ) , nameof ( TempVariable < object > . Queryable ) ) ;
8788 return queryableMemberReplaceExpression ;
8889 }
89- private MethodCallExpression ReplaceMethodCallExpression ( IQueryable queryable )
90- {
91- var dbContextReplaceQueryableVisitor = new DbContextReplaceQueryableVisitor ( _dbContext ) ;
92- var newExpression = dbContextReplaceQueryableVisitor . Visit ( queryable . Expression ) ;
93- var newQueryable = dbContextReplaceQueryableVisitor . Source . Provider . CreateQuery ( newExpression ) ;
94- var tempVariableGenericType = typeof ( TempVariable < > ) . GetGenericType0 ( queryable . ElementType ) ;
95- var tempVariable = Activator . CreateInstance ( tempVariableGenericType , newQueryable ) ;
96- // MemberExpression queryableMemberReplaceExpression =
97- // Expression.Property(, nameof(TempVariable<object>.Queryable));
98-
99- return Expression . Call ( ConstantExpression . Constant ( tempVariable ) , tempVariableGenericType . GetMethod ( nameof ( TempVariable < object > . GetQueryable ) ) , new Expression [ 0 ] ) ;
100- }
10190
10291 private MemberExpression ReplaceMemberExpression ( DbContext dbContext )
10392 {
@@ -110,7 +99,7 @@ private MemberExpression ReplaceMemberExpression(DbContext dbContext)
11099 }
111100 protected override Expression VisitMethodCall ( MethodCallExpression node )
112101 {
113- if ( node . Method . ReturnType . IsMethodReturnTypeQueryableType ( ) && node . Method . ReturnType . IsGenericType )
102+ if ( RootIsVisit && node . Method . ReturnType . IsMethodReturnTypeQueryableType ( ) && node . Method . ReturnType . IsGenericType )
114103 {
115104#if EFCORE2 || EFCORE3
116105 var notRoot = node . Arguments . All ( o => ! ( o is ConstantExpression constantExpression && constantExpression . Value is IQueryable ) ) ;
@@ -120,31 +109,27 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
120109#endif
121110 if ( notRoot )
122111 {
123- var objQueryable = Expression . Lambda ( node ) . Compile ( ) . DynamicInvoke ( ) ;
124- if ( objQueryable != null && objQueryable is IQueryable queryable )
125- {
126- return ReplaceMethodCallExpression ( queryable ) ;
127- // var whereCallExpression = ReplaceMethodCallExpression(replaceMemberExpression);
128- // return base.VisitMethodCall(whereCallExpression);;
129- // Console.WriteLine("1");
130- }
112+ var entityType = node . Method . ReturnType . GenericTypeArguments [ 0 ] ;
113+
114+ var whereCallExpression = ReplaceMethodCallExpression ( node , entityType ) ;
115+ return whereCallExpression ;
131116 }
132117 }
133118
134119 return base . VisitMethodCall ( node ) ;
135120 }
136121
137- // private MethodCallExpression ReplaceMethodCallExpression(MemberExpression memberExpression)
138- // {
139- // var lambdaExpression = GetType().GetMethod(nameof(WhereTrueExpression)).MakeGenericMethod(new Type[] { queryable.ElementType }).Invoke(this,new object[]{});
140- // MethodCallExpression whereCallExpression = Expression.Call(
141- // typeof(Queryable ),
142- // nameof(Queryable.Where ),
143- // new Type[] { queryable.ElementType },
144- // queryable.Expression, (LambdaExpression)lambdaExpression
145- // );
146- // return whereCallExpression;
147- // }
122+ private MethodCallExpression ReplaceMethodCallExpression ( MethodCallExpression methodCallExpression ,
123+ Type entityType )
124+ {
125+ MethodCallExpression whereCallExpression = Expression . Call (
126+ typeof ( IShardingQueryableExtension ) ,
127+ nameof ( IShardingQueryableExtension . ReplaceDbContextQueryableWithType ) ,
128+ new Type [ ] { entityType } ,
129+ methodCallExpression , Expression . Constant ( _dbContext )
130+ ) ;
131+ return whereCallExpression ;
132+ }
148133
149134 public Expression < Func < T , bool > > WhereTrueExpression < T > ( )
150135 {
@@ -166,6 +151,20 @@ public IQueryable<T1> GetQueryable()
166151 return Queryable ;
167152 }
168153 }
154+ internal sealed class TempMethodVariable < T1 >
155+ {
156+ public IQueryable < T1 > Queryable { get ; }
157+
158+ public TempMethodVariable ( Func < IQueryable < T1 > > func )
159+ {
160+ Queryable = func ( ) ;
161+ }
162+
163+ public IQueryable < T1 > GetQueryable ( )
164+ {
165+ return Queryable ;
166+ }
167+ }
169168
170169 internal sealed class TempDbVariable < T1 >
171170 {
@@ -207,6 +206,7 @@ protected override Expression VisitConstant(ConstantExpression node)
207206 if ( Source == null )
208207 Source = newQueryable ;
209208 // return base.Visit(Expression.Constant(newQueryable));
209+ RootIsVisit = true ;
210210 return Expression . Constant ( newQueryable ) ;
211211 }
212212
@@ -245,6 +245,7 @@ protected override Expression VisitExtension(Expression node)
245245 //如何替换ef5的set
246246 var replaceQueryRoot = new ReplaceSingleQueryRootExpressionVisitor ( ) ;
247247 replaceQueryRoot . Visit ( newQueryable . Expression ) ;
248+ RootIsVisit = true ;
248249 return base . VisitExtension ( replaceQueryRoot . QueryRootExpression ) ;
249250 }
250251
0 commit comments