1313using ShardingCore . Core . VirtualRoutes . TableRoutes . RouteTails . Abstractions ;
1414using ShardingCore . Core . DbContextCreator ;
1515using ShardingCore . Core . RuntimeContexts ;
16+ using ShardingCore . Core . ShardingConfigurations ;
1617using ShardingCore . Core . VirtualRoutes . Abstractions ;
1718using ShardingCore . Extensions ;
1819using ShardingCore . Sharding . Abstractions ;
@@ -38,6 +39,7 @@ public class ShardingDbContextExecutor : IShardingDbContextExecutor
3839 //private readonly ConcurrentDictionary<string, ConcurrentDictionary<string, DbContext>> _dbContextCaches = new ConcurrentDictionary<string, ConcurrentDictionary<string, DbContext>>();
3940 private readonly ConcurrentDictionary < string , IDataSourceDbContext > _dbContextCaches = new ConcurrentDictionary < string , IDataSourceDbContext > ( ) ;
4041 private readonly IShardingRuntimeContext _shardingRuntimeContext ;
42+ private readonly ShardingConfigOptions _shardingConfigOptions ;
4143 private readonly IVirtualDataSource _virtualDataSource ;
4244 private readonly IDataSourceRouteManager _dataSourceRouteManager ;
4345 private readonly ITableRouteManager _tableRouteManager ;
@@ -66,6 +68,7 @@ public ShardingDbContextExecutor(DbContext shardingDbContext)
6668 //初始化
6769 _shardingRuntimeContext = shardingDbContext . GetShardingRuntimeContext ( ) ;
6870 _shardingRuntimeContext . GetOrCreateShardingRuntimeModel ( shardingDbContext ) ;
71+ _shardingConfigOptions = _shardingRuntimeContext . GetShardingConfigOptions ( ) ;
6972 _virtualDataSource = _shardingRuntimeContext . GetVirtualDataSource ( ) ;
7073 _dataSourceRouteManager = _shardingRuntimeContext . GetDataSourceRouteManager ( ) ;
7174 _tableRouteManager = _shardingRuntimeContext . GetTableRouteManager ( ) ;
@@ -154,6 +157,7 @@ private string GetTableTail<TEntity>(string dataSourceName,TEntity entity) where
154157 i += await dbContextCache . Value . SaveChangesAsync ( acceptAllChangesOnSuccess , cancellationToken ) ;
155158 }
156159
160+ AutoUseWriteConnectionString ( ) ;
157161 return i ;
158162 }
159163
@@ -165,6 +169,7 @@ public int SaveChanges(bool acceptAllChangesOnSuccess)
165169 i += dbContextCache . Value . SaveChanges ( acceptAllChangesOnSuccess ) ;
166170 }
167171
172+ AutoUseWriteConnectionString ( ) ;
168173 return i ;
169174 }
170175
@@ -182,6 +187,8 @@ public void Rollback()
182187 {
183188 dbContextCache . Value . Rollback ( ) ;
184189 }
190+
191+ AutoUseWriteConnectionString ( ) ;
185192 }
186193
187194 public void Commit ( )
@@ -190,6 +197,8 @@ public void Commit()
190197 {
191198 dbContextCache . Value . Commit ( _dbContextCaches . Count ) ;
192199 }
200+
201+ AutoUseWriteConnectionString ( ) ;
193202 }
194203
195204 public IDictionary < string , IDataSourceDbContext > GetCurrentDbContexts ( )
@@ -215,6 +224,8 @@ public void Dispose()
215224 {
216225 await dbContextCache . Value . RollbackAsync ( cancellationToken ) ;
217226 }
227+
228+ AutoUseWriteConnectionString ( ) ;
218229 }
219230
220231 public async Task CommitAsync ( CancellationToken cancellationToken = new CancellationToken ( ) )
@@ -223,6 +234,8 @@ public void Dispose()
223234 {
224235 await dbContextCache . Value . CommitAsync ( _dbContextCaches . Count , cancellationToken ) ;
225236 }
237+
238+ AutoUseWriteConnectionString ( ) ;
226239 }
227240 public async ValueTask DisposeAsync ( )
228241 {
@@ -232,5 +245,16 @@ public async ValueTask DisposeAsync()
232245 }
233246 }
234247#endif
248+
249+ /// <summary>
250+ /// 自动切换成写库连接
251+ /// </summary>
252+ private void AutoUseWriteConnectionString ( )
253+ {
254+ if ( _shardingConfigOptions . AutoUseWriteConnectionStringAfterWriteDb )
255+ {
256+ ( ( IShardingDbContext ) _shardingDbContext ) . ReadWriteSeparationWriteOnly ( ) ;
257+ }
258+ }
235259 }
236260}
0 commit comments