Skip to content

Commit ecf6419

Browse files
committed
修复优化当表达式内嵌使用属性的情况下出现:Cannot use multiple context instances within a single query execution. Ensure the query use a single context instance.的错误
1 parent 6e4afa7 commit ecf6419

File tree

3 files changed

+39
-36
lines changed

3 files changed

+39
-36
lines changed

samples/Sample.MySql/Controllers/WeatherForecastController.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,37 @@ public class abc
2222
public string name { get; set; }
2323
public int count { get; set; }
2424
}
25+
26+
public class ABC
27+
{
28+
private readonly DefaultShardingDbContext _defaultTableDbContext;
29+
30+
public ABC(DefaultShardingDbContext defaultTableDbContext)
31+
{
32+
_defaultTableDbContext = defaultTableDbContext;
33+
}
34+
35+
public IQueryable<SysTest> GetAll()
36+
{
37+
return _defaultTableDbContext.Set<SysTest>();
38+
}
39+
40+
public virtual IQueryable<SysTest> Select => this.GetAll();
41+
}
2542
[ApiController]
2643
[Route("[controller]/[action]")]
2744
public class WeatherForecastController : ControllerBase
2845
{
2946

3047
private readonly DefaultShardingDbContext _defaultTableDbContext;
3148
private readonly IShardingRuntimeContext _shardingRuntimeContext;
49+
private readonly ABC _abc;
3250

3351
public WeatherForecastController(DefaultShardingDbContext defaultTableDbContext,IShardingRuntimeContext shardingRuntimeContext)
3452
{
3553
_defaultTableDbContext = defaultTableDbContext;
3654
_shardingRuntimeContext = shardingRuntimeContext;
55+
_abc=new ABC(_defaultTableDbContext);
3756
}
3857

3958
public IQueryable<SysTest> GetAll()
@@ -86,7 +105,12 @@ public async Task<IActionResult> Get()
86105
// var firstOrDefault = _defaultTableDbContext.Set<SysUserMod>().FromSqlRaw($"select * from {nameof(SysUserMod)}").FirstOrDefault();
87106

88107
var sysUserMods1 = _defaultTableDbContext.Set<SysTest>()
108+
.Select(o => new ssss(){ Id = o.Id, C = _abc.Select.Count(x => x.Id == o.Id) }).ToList();
109+
var sysUserMods2 = _defaultTableDbContext.Set<SysTest>()
89110
.Select(o => new ssss(){ Id = o.Id, C = GetAll().Count(x => x.Id == o.Id) }).ToList();
111+
var sysTests = GetAll();
112+
var sysUserMods3 = _defaultTableDbContext.Set<SysTest>()
113+
.Select(o => new ssss(){ Id = o.Id, C = sysTests.Count(x => x.Id == o.Id) }).ToList();
90114
var resultX = await _defaultTableDbContext.Set<SysUserMod>()
91115
.Where(o => o.Id == "2" || o.Id == "3").FirstOrDefaultAsync();
92116
var resultY = await _defaultTableDbContext.Set<SysUserMod>().FirstOrDefaultAsync(o => o.Id == "2" || o.Id == "3");

src/ShardingCore/Extensions/CommonExtension.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ public static bool IsMemberQueryable(this MemberExpression memberExpression)
176176
{
177177
if (memberExpression == null)
178178
throw new ArgumentNullException(nameof(memberExpression));
179-
return (memberExpression.Type.FullName?.StartsWith("System.Linq.IQueryable`1") ?? false) || typeof(DbContext).IsAssignableFrom(memberExpression.Type);
179+
return (memberExpression.Type.FullName?.StartsWith("System.Linq.IQueryable`1") ?? false) ||typeof(IQueryable).IsAssignableFrom(memberExpression.Type) || typeof(DbContext).IsAssignableFrom(memberExpression.Type);
180180
}
181181

182182
public static Type GetSequenceType(this Type type)

src/ShardingCore/Sharding/Visitors/DbContextReplaceQueryableVisitor.cs

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace ShardingCore.Core.Internal.Visitors
1818
* @Email: 326308290@qq.com
1919
*/
2020

21-
internal class DbContextInnerMemberReferenceReplaceQueryableVisitor : ExpressionVisitor
21+
internal class DbContextInnerMemberReferenceReplaceQueryableVisitor : ShardingExpressionVisitor
2222
{
2323
private readonly DbContext _dbContext;
2424
protected bool RootIsVisit = false;
@@ -28,50 +28,29 @@ public DbContextInnerMemberReferenceReplaceQueryableVisitor(DbContext dbContext)
2828
_dbContext = dbContext;
2929
}
3030

31+
// public override Expression Visit(Expression node)
32+
// {
33+
// Console.WriteLine("1");
34+
// return base.Visit(node);
35+
// }
36+
3137
protected override Expression VisitMember
3238
(MemberExpression memberExpression)
3339
{
3440
// Recurse down to see if we can simplify...
35-
//if (memberExpression.IsMemberQueryable()) //2x,3x 路由 单元测试 分表和不分表
36-
//{
37-
var expression = Visit(memberExpression.Expression);
38-
39-
// If we've ended up with a constant, and it's a property or a field,
40-
// we can simplify ourselves to a constant
41-
if (expression is ConstantExpression constantExpression)
41+
if (memberExpression.IsMemberQueryable()) //2x,3x 路由 单元测试 分表和不分表
4242
{
43-
object container = constantExpression.Value;
44-
var member = memberExpression.Member;
45-
if (member is FieldInfo fieldInfo)
43+
var expressionValue = GetExpressionValue(memberExpression);
44+
if (expressionValue is IQueryable queryable)
4645
{
47-
object value = fieldInfo.GetValue(container);
48-
if (value is IQueryable queryable)
49-
{
50-
return ReplaceMemberExpression(queryable);
51-
}
52-
53-
if (value is DbContext dbContext)
54-
{
55-
return ReplaceMemberExpression(dbContext);
56-
}
57-
//return Expression.Constant(value);
46+
return ReplaceMemberExpression(queryable);
5847
}
5948

60-
if (member is PropertyInfo propertyInfo)
49+
if (expressionValue is DbContext dbContext)
6150
{
62-
object value = propertyInfo.GetValue(container, null);
63-
if (value is IQueryable queryable)
64-
{
65-
return ReplaceMemberExpression(queryable);
66-
}
67-
68-
if (value is DbContext dbContext)
69-
{
70-
return ReplaceMemberExpression(dbContext);
71-
}
51+
return ReplaceMemberExpression(dbContext);
7252
}
73-
}
74-
//}
53+
}
7554

7655
return base.VisitMember(memberExpression);
7756
}

0 commit comments

Comments
 (0)