@@ -29,6 +29,13 @@ import (
2929 "vitess.io/vitess/go/vt/vterrors"
3030)
3131
32+ // The states a pool can be in.
33+ const (
34+ UNINITIALIZED = iota
35+ OPENED
36+ CLOSED
37+ )
38+
3239var (
3340 // ErrTimeout is returned if a connection get times out.
3441 ErrTimeout = vterrors .New (vtrpcpb .Code_RESOURCE_EXHAUSTED , "connection pool timed out" )
@@ -124,8 +131,12 @@ type ConnPool[C Connection] struct {
124131 capacity atomic.Int64
125132
126133 // workers is a waitgroup for all the currently running worker goroutines
127- workers sync.WaitGroup
128- close chan struct {}
134+ workers sync.WaitGroup
135+ close chan struct {}
136+
137+ // state represents the state the pool is in: uninitialized, open, or closed.
138+ state atomic.Uint32
139+
129140 capacityMu sync.Mutex
130141
131142 config struct {
@@ -193,6 +204,19 @@ func (pool *ConnPool[C]) open() {
193204 // The expire worker takes care of removing from the waiter list any clients whose
194205 // context has been cancelled.
195206 pool .runWorker (pool .close , 100 * time .Millisecond , func (_ time.Time ) bool {
207+ if pool .IsClosed () {
208+ // Clean up any waiters that may have been added after the pool was closed
209+ pool .wait .expire (true )
210+
211+ // If there are no more active connections, we can close the channel and stop
212+ // the workers
213+ if pool .active .Load () == 0 {
214+ close (pool .close )
215+ }
216+
217+ return true
218+ }
219+
196220 maybeStarving := pool .wait .expire (false )
197221
198222 // Do not allow connections to starve; if there's waiters in the queue
@@ -234,8 +258,8 @@ func (pool *ConnPool[C]) open() {
234258// Open starts the background workers that manage the pool and gets it ready
235259// to start serving out connections.
236260func (pool * ConnPool [C ]) Open (connect Connector [C ], refresh RefreshCheck ) * ConnPool [C ] {
237- if pool .close != nil {
238- // already open
261+ if ! pool .state . CompareAndSwap ( UNINITIALIZED , OPENED ) {
262+ // already open or closed
239263 return pool
240264 }
241265
@@ -263,20 +287,41 @@ func (pool *ConnPool[C]) CloseWithContext(ctx context.Context) error {
263287 pool .capacityMu .Lock ()
264288 defer pool .capacityMu .Unlock ()
265289
266- if pool .close == nil || pool . capacity . Load () == 0 {
267- // already closed
290+ if ! pool .state . CompareAndSwap ( OPENED , CLOSED ) {
291+ // Already closed or uninitialized
268292 return nil
269293 }
270294
271- // close all the connections in the pool; if we time out while waiting for
272- // users to return our connections, we still want to finish the shutdown
273- // for the pool
274- err := pool .setCapacity (ctx , 0 )
295+ // close all the connections in the pool
275296
276- close (pool .close )
277- pool .workers .Wait ()
278- pool .close = nil
279- return err
297+ newcap := int64 (0 )
298+ oldcap := pool .capacity .Swap (newcap )
299+ if oldcap == newcap {
300+ return nil
301+ }
302+
303+ // close connections until we're under capacity
304+ for {
305+ // make sure there's no clients waiting for connections because they won't be returned in the future
306+ pool .wait .expire (true )
307+
308+ // try closing from connections which are currently idle in the stacks
309+ conn := pool .getFromSettingsStack (nil )
310+ if conn == nil {
311+ conn = pool .pop (& pool .clean )
312+ }
313+ if conn == nil {
314+ break
315+ }
316+ conn .Close ()
317+ pool .closedConn ()
318+ }
319+
320+ if pool .active .Load () == 0 {
321+ close (pool .close )
322+ }
323+
324+ return nil
280325}
281326
282327func (pool * ConnPool [C ]) reopen () {
@@ -305,7 +350,12 @@ func (pool *ConnPool[C]) reopen() {
305350
306351// IsOpen returns whether the pool is open
307352func (pool * ConnPool [C ]) IsOpen () bool {
308- return pool .close != nil
353+ return pool .state .Load () == OPENED
354+ }
355+
356+ // IsClosed returns whether the pool is closed
357+ func (pool * ConnPool [C ]) IsClosed () bool {
358+ return pool .state .Load () == CLOSED
309359}
310360
311361// Capacity returns the maximum amount of connections that this pool can maintain open
@@ -363,7 +413,7 @@ func (pool *ConnPool[C]) Get(ctx context.Context, setting *Setting) (*Pooled[C],
363413 if ctx .Err () != nil {
364414 return nil , ErrCtxTimeout
365415 }
366- if pool .capacity .Load () == 0 {
416+ if pool .state .Load () != OPENED {
367417 return nil , ErrConnPoolClosed
368418 }
369419 if setting == nil {
@@ -377,6 +427,16 @@ func (pool *ConnPool[C]) Get(ctx context.Context, setting *Setting) (*Pooled[C],
377427func (pool * ConnPool [C ]) put (conn * Pooled [C ]) {
378428 pool .borrowed .Add (- 1 )
379429
430+ // Close connection if pool is closed
431+ if pool .IsClosed () {
432+ if conn != nil {
433+ conn .Close ()
434+ pool .closedConn ()
435+ }
436+
437+ return
438+ }
439+
380440 if conn == nil {
381441 var err error
382442 // Using context.Background() is fine since MySQL connection already enforces
@@ -412,10 +472,24 @@ func (pool *ConnPool[C]) tryReturnConn(conn *Pooled[C]) bool {
412472 connSetting := conn .Conn .Setting ()
413473 if connSetting == nil {
414474 pool .clean .Push (conn )
475+
476+ // Close connection if pool is closed
477+ if pool .IsClosed () {
478+ conn = pool .pop (& pool .clean )
479+ conn .Close ()
480+ pool .closedConn ()
481+ }
415482 } else {
416483 stack := connSetting .bucket & stackMask
417484 pool .settings [stack ].Push (conn )
418485 pool .freshSettingsStack .Store (int64 (stack ))
486+
487+ // Close connection if pool is closed
488+ if pool .IsClosed () {
489+ conn = pool .pop (& pool .settings [stack ])
490+ conn .Close ()
491+ pool .closedConn ()
492+ }
419493 }
420494 return false
421495}
@@ -759,55 +833,133 @@ func (pool *ConnPool[C]) StatsJSON() map[string]any {
759833 }
760834}
761835
762- // RegisterStats registers this pool's metrics into a stats Exporter
763- func (pool * ConnPool [C ]) RegisterStats (stats * servenv.Exporter , name string ) {
764- if stats == nil || name == "" {
765- return
766- }
836+ type StatsExporter [C Connection ] struct {
837+ // The Pool for which this exporter is exporting stats.
838+ // It is an atomic pointer so that it can be updated safely.
839+ // The pointer is nil if the pool has not been registered yet.
840+ pool atomic.Pointer [ConnPool [C ]]
841+ }
767842
768- pool .Name = name
843+ func NewStatsExporter [C Connection ](stats * servenv.Exporter , name string ) * StatsExporter [C ] {
844+ se := & StatsExporter [C ]{}
769845
770846 stats .NewGaugeFunc (name + "Capacity" , "Tablet server conn pool capacity" , func () int64 {
847+ pool := se .pool .Load ()
848+ if pool == nil {
849+ return 0
850+ }
851+
771852 return pool .Capacity ()
772853 })
773854 stats .NewGaugeFunc (name + "Available" , "Tablet server conn pool available" , func () int64 {
855+ pool := se .pool .Load ()
856+ if pool == nil {
857+ return 0
858+ }
859+
774860 return pool .Available ()
775861 })
776862 stats .NewGaugeFunc (name + "Active" , "Tablet server conn pool active" , func () int64 {
863+ pool := se .pool .Load ()
864+ if pool == nil {
865+ return 0
866+ }
867+
777868 return pool .Active ()
778869 })
779870 stats .NewGaugeFunc (name + "InUse" , "Tablet server conn pool in use" , func () int64 {
871+ pool := se .pool .Load ()
872+ if pool == nil {
873+ return 0
874+ }
875+
780876 return pool .InUse ()
781877 })
782878 stats .NewGaugeFunc (name + "MaxCap" , "Tablet server conn pool max cap" , func () int64 {
879+ pool := se .pool .Load ()
880+ if pool == nil {
881+ return 0
882+ }
883+
783884 // the smartconnpool doesn't have a maximum capacity
784885 return pool .Capacity ()
785886 })
786887 stats .NewCounterFunc (name + "WaitCount" , "Tablet server conn pool wait count" , func () int64 {
888+ pool := se .pool .Load ()
889+ if pool == nil {
890+ return 0
891+ }
892+
787893 return pool .Metrics .WaitCount ()
788894 })
789895 stats .NewCounterDurationFunc (name + "WaitTime" , "Tablet server wait time" , func () time.Duration {
896+ pool := se .pool .Load ()
897+ if pool == nil {
898+ return 0
899+ }
900+
790901 return pool .Metrics .WaitTime ()
791902 })
792903 stats .NewGaugeDurationFunc (name + "IdleTimeout" , "Tablet server idle timeout" , func () time.Duration {
904+ pool := se .pool .Load ()
905+ if pool == nil {
906+ return 0
907+ }
908+
793909 return pool .IdleTimeout ()
794910 })
795911 stats .NewCounterFunc (name + "IdleClosed" , "Tablet server conn pool idle closed" , func () int64 {
912+ pool := se .pool .Load ()
913+ if pool == nil {
914+ return 0
915+ }
916+
796917 return pool .Metrics .IdleClosed ()
797918 })
798919 stats .NewCounterFunc (name + "MaxLifetimeClosed" , "Tablet server conn pool refresh closed" , func () int64 {
920+ pool := se .pool .Load ()
921+ if pool == nil {
922+ return 0
923+ }
924+
799925 return pool .Metrics .MaxLifetimeClosed ()
800926 })
801927 stats .NewCounterFunc (name + "Get" , "Tablet server conn pool get count" , func () int64 {
928+ pool := se .pool .Load ()
929+ if pool == nil {
930+ return 0
931+ }
932+
802933 return pool .Metrics .GetCount ()
803934 })
804935 stats .NewCounterFunc (name + "GetSetting" , "Tablet server conn pool get with setting count" , func () int64 {
936+ pool := se .pool .Load ()
937+ if pool == nil {
938+ return 0
939+ }
940+
805941 return pool .Metrics .GetSettingCount ()
806942 })
807943 stats .NewCounterFunc (name + "DiffSetting" , "Number of times pool applied different setting" , func () int64 {
944+ pool := se .pool .Load ()
945+ if pool == nil {
946+ return 0
947+ }
948+
808949 return pool .Metrics .DiffSettingCount ()
809950 })
810951 stats .NewCounterFunc (name + "ResetSetting" , "Number of times pool reset the setting" , func () int64 {
952+ pool := se .pool .Load ()
953+ if pool == nil {
954+ return 0
955+ }
956+
811957 return pool .Metrics .ResetSettingCount ()
812958 })
959+
960+ return se
961+ }
962+
963+ func (se * StatsExporter [C ]) SetPool (pool * ConnPool [C ]) {
964+ se .pool .Store (pool )
813965}
0 commit comments