|
| 1 | +package pool_test |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "net" |
| 6 | + "sync" |
| 7 | + "testing" |
| 8 | + "time" |
| 9 | + |
| 10 | + "github.com/redis/go-redis/v9/internal/pool" |
| 11 | +) |
| 12 | + |
| 13 | +// TestRaceConditionFreeTurn tests the race condition where: |
| 14 | +// 1. Dial completes and tryDeliver succeeds |
| 15 | +// 2. Waiter's context times out before receiving from result channel |
| 16 | +// 3. Waiter's defer retrieves connection via cancel() and delivers to another waiter |
| 17 | +// 4. Turn must be freed by the defer, not by dial goroutine or new waiter |
| 18 | +func TestRaceConditionFreeTurn(t *testing.T) { |
| 19 | + // Create a pool with PoolSize=2 to make the race easier to trigger |
| 20 | + opt := &pool.Options{ |
| 21 | + Dialer: func(ctx context.Context) (net.Conn, error) { |
| 22 | + // Slow dial to allow context timeout to race with delivery |
| 23 | + time.Sleep(50 * time.Millisecond) |
| 24 | + return dummyDialer(ctx) |
| 25 | + }, |
| 26 | + PoolSize: 2, |
| 27 | + MaxConcurrentDials: 2, |
| 28 | + DialTimeout: 1 * time.Second, |
| 29 | + PoolTimeout: 1 * time.Second, |
| 30 | + } |
| 31 | + |
| 32 | + connPool := pool.NewConnPool(opt) |
| 33 | + defer connPool.Close() |
| 34 | + |
| 35 | + // Run multiple iterations to increase chance of hitting the race |
| 36 | + for iteration := 0; iteration < 10; iteration++ { |
| 37 | + // Request 1: Will timeout quickly |
| 38 | + ctx1, cancel1 := context.WithTimeout(context.Background(), 30*time.Millisecond) |
| 39 | + defer cancel1() |
| 40 | + |
| 41 | + var wg sync.WaitGroup |
| 42 | + wg.Add(2) |
| 43 | + |
| 44 | + // Goroutine 1: Request with short timeout |
| 45 | + go func() { |
| 46 | + defer wg.Done() |
| 47 | + cn, err := connPool.Get(ctx1) |
| 48 | + if err == nil { |
| 49 | + // If we got a connection, put it back |
| 50 | + connPool.Put(ctx1, cn) |
| 51 | + } |
| 52 | + // Expected: context deadline exceeded |
| 53 | + }() |
| 54 | + |
| 55 | + // Goroutine 2: Request with longer timeout, should receive the orphaned connection |
| 56 | + go func() { |
| 57 | + defer wg.Done() |
| 58 | + time.Sleep(20 * time.Millisecond) // Start slightly after first request |
| 59 | + ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) |
| 60 | + defer cancel2() |
| 61 | + |
| 62 | + cn, err := connPool.Get(ctx2) |
| 63 | + if err != nil { |
| 64 | + t.Logf("Request 2 error: %v", err) |
| 65 | + return |
| 66 | + } |
| 67 | + // Got connection, put it back |
| 68 | + connPool.Put(ctx2, cn) |
| 69 | + }() |
| 70 | + |
| 71 | + wg.Wait() |
| 72 | + |
| 73 | + // Give some time for all operations to complete |
| 74 | + time.Sleep(100 * time.Millisecond) |
| 75 | + |
| 76 | + // Check QueueLen - should be 0 (all turns freed) |
| 77 | + queueLen := connPool.QueueLen() |
| 78 | + if queueLen != 0 { |
| 79 | + t.Errorf("Iteration %d: QueueLen = %d, expected 0 (turn leak detected!)", iteration, queueLen) |
| 80 | + } |
| 81 | + } |
| 82 | +} |
| 83 | + |
| 84 | +// TestRaceConditionFreeTurnStress is a stress test for the race condition |
| 85 | +func TestRaceConditionFreeTurnStress(t *testing.T) { |
| 86 | + var dialCount int32 |
| 87 | + opt := &pool.Options{ |
| 88 | + Dialer: func(ctx context.Context) (net.Conn, error) { |
| 89 | + // Variable dial time to create more race opportunities |
| 90 | + count := dialCount |
| 91 | + dialCount++ |
| 92 | + time.Sleep(time.Duration(10+count%40) * time.Millisecond) |
| 93 | + return dummyDialer(ctx) |
| 94 | + }, |
| 95 | + PoolSize: 10, |
| 96 | + MaxConcurrentDials: 10, |
| 97 | + DialTimeout: 1 * time.Second, |
| 98 | + PoolTimeout: 500 * time.Millisecond, |
| 99 | + } |
| 100 | + |
| 101 | + connPool := pool.NewConnPool(opt) |
| 102 | + defer connPool.Close() |
| 103 | + |
| 104 | + const numRequests = 50 |
| 105 | + |
| 106 | + var wg sync.WaitGroup |
| 107 | + wg.Add(numRequests) |
| 108 | + |
| 109 | + // Launch many concurrent requests with varying timeouts |
| 110 | + for i := 0; i < numRequests; i++ { |
| 111 | + go func(i int) { |
| 112 | + defer wg.Done() |
| 113 | + |
| 114 | + // Varying timeouts to create race conditions |
| 115 | + timeout := time.Duration(20+i%80) * time.Millisecond |
| 116 | + ctx, cancel := context.WithTimeout(context.Background(), timeout) |
| 117 | + defer cancel() |
| 118 | + |
| 119 | + cn, err := connPool.Get(ctx) |
| 120 | + if err == nil { |
| 121 | + // Simulate some work |
| 122 | + time.Sleep(time.Duration(i%20) * time.Millisecond) |
| 123 | + connPool.Put(ctx, cn) |
| 124 | + } |
| 125 | + }(i) |
| 126 | + } |
| 127 | + |
| 128 | + wg.Wait() |
| 129 | + |
| 130 | + // Give time for all cleanup to complete |
| 131 | + time.Sleep(200 * time.Millisecond) |
| 132 | + |
| 133 | + // Check for turn leaks |
| 134 | + queueLen := connPool.QueueLen() |
| 135 | + if queueLen != 0 { |
| 136 | + t.Errorf("QueueLen = %d, expected 0 (turn leak detected!)", queueLen) |
| 137 | + t.Errorf("This indicates that some turns were never freed") |
| 138 | + } |
| 139 | + |
| 140 | + // Also check pool stats |
| 141 | + stats := connPool.Stats() |
| 142 | + t.Logf("Pool stats: Hits=%d, Misses=%d, Timeouts=%d, TotalConns=%d, IdleConns=%d", |
| 143 | + stats.Hits, stats.Misses, stats.Timeouts, stats.TotalConns, stats.IdleConns) |
| 144 | +} |
| 145 | + |
0 commit comments