Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ jobs:
go-version-file: "go.mod"

- name: Build
run: go build -v -tags with_clash_api ./...
run: go build -v -tags with_clash_api,with_conntrack ./...
- name: Test
run: go test -v -tags with_clash_api ./...
run: go test -v -tags with_clash_api,with_conntrack ./...
110 changes: 110 additions & 0 deletions vpn/drain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package vpn

import (
"net"
"testing"
"time"

"github.com/sagernet/sing-box/common/conntrack"

"github.com/getlantern/radiance/vpn/ipc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDrainConnectionsNoConnections(t *testing.T) {
// With no active connections, drainConnections should return immediately.
origTimeout := DrainTimeout
DrainTimeout = 1 * time.Second
defer func() { DrainTimeout = origTimeout }()

start := time.Now()
drainConnections()
elapsed := time.Since(start)

assert.Less(t, elapsed, 100*time.Millisecond, "should return immediately with no connections")
}

func TestDrainConnectionsWaitsForConnections(t *testing.T) {
if !conntrack.Enabled {
t.Skip("conntrack not enabled (need build tag with_conntrack)")
}

origTimeout := DrainTimeout
origPoll := DrainPollInterval
DrainTimeout = 5 * time.Second
DrainPollInterval = 10 * time.Millisecond
defer func() {
DrainTimeout = origTimeout
DrainPollInterval = origPoll
}()

// Create a tracked connection using conntrack.
server, client := net.Pipe()
defer server.Close()

tracked, err := conntrack.NewConn(client)
require.NoError(t, err)
require.Equal(t, 1, conntrack.Count(), "should have 1 tracked connection")

// Close the connection after a short delay to simulate graceful drain.
drainDelay := 200 * time.Millisecond
go func() {
time.Sleep(drainDelay)
tracked.Close()
}()

start := time.Now()
drainConnections()
elapsed := time.Since(start)

assert.Equal(t, 0, conntrack.Count(), "all connections should be drained")
assert.GreaterOrEqual(t, elapsed, drainDelay, "should have waited for connection to close")
assert.Less(t, elapsed, DrainTimeout, "should not have waited the full timeout")
}

func TestDrainConnectionsTimeout(t *testing.T) {
if !conntrack.Enabled {
t.Skip("conntrack not enabled (need build tag with_conntrack)")
}

origTimeout := DrainTimeout
origPoll := DrainPollInterval
DrainTimeout = 200 * time.Millisecond
DrainPollInterval = 10 * time.Millisecond
defer func() {
DrainTimeout = origTimeout
DrainPollInterval = origPoll
}()

// Create a tracked connection that never closes.
server, client := net.Pipe()
defer server.Close()

tracked, err := conntrack.NewConn(client)
require.NoError(t, err)
defer tracked.Close()
Comment thread
myleshorton marked this conversation as resolved.
require.Equal(t, 1, conntrack.Count(), "should have 1 tracked connection")

start := time.Now()
drainConnections()
elapsed := time.Since(start)

assert.GreaterOrEqual(t, elapsed, DrainTimeout, "should have waited the full timeout")
assert.Less(t, elapsed, DrainTimeout+100*time.Millisecond, "should not overshoot timeout significantly")
assert.Equal(t, 1, conntrack.Count(), "connection should still be open after timeout")
}

func TestTunnelCloseCallsDrain(t *testing.T) {
// Verify that tunnel.close() invokes the drain phase by checking that closing a tunnel
// with no closers and no connections completes quickly.
tun := &tunnel{}
tun.status.Store(ipc.Disconnected)

start := time.Now()
err := tun.close()
elapsed := time.Since(start)

assert.NoError(t, err)
assert.Less(t, elapsed, 500*time.Millisecond, "close with no connections should be fast")
}
50 changes: 50 additions & 0 deletions vpn/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
lsync "github.com/getlantern/common/sync"
box "github.com/getlantern/lantern-box"

"github.com/sagernet/sing-box/common/conntrack"

lbA "github.com/getlantern/lantern-box/adapter"
"github.com/getlantern/lantern-box/adapter/groups"
lblog "github.com/getlantern/lantern-box/log"
Expand Down Expand Up @@ -226,7 +228,19 @@ func (t *tunnel) selectOutbound(group, tag string) error {
return nil
}

// DrainTimeout is the maximum time to wait for active connections to drain before forcibly closing
// the tunnel. Exported for testing.
var DrainTimeout = 5 * time.Second

// DrainPollInterval is how often to check for remaining connections during the drain phase.
var DrainPollInterval = 50 * time.Millisecond
Comment thread
myleshorton marked this conversation as resolved.

func (t *tunnel) close() error {
// Drain phase: wait for active connections to finish before cancelling context and
// tearing down the tunnel. This gives in-flight requests (e.g. video streams) a chance
// to complete gracefully instead of being killed by context cancellation or TUN teardown.
drainConnections()

if t.cancel != nil {
t.cancel()
}
Expand All @@ -253,6 +267,42 @@ func (t *tunnel) close() error {
return err
}

// drainConnections waits for active connections to close gracefully up to DrainTimeout.
// It polls conntrack.Count() and returns early if all connections have closed.
func drainConnections() {
initial := conntrack.Count()
if initial == 0 {
return
}
slog.Info("Draining active connections before tunnel close", "count", initial, "timeout", DrainTimeout)

pollInterval := DrainPollInterval
if pollInterval <= 0 {
pollInterval = 50 * time.Millisecond
}

timer := time.NewTimer(DrainTimeout)
defer timer.Stop()
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()

for {
select {
case <-timer.C:
remaining := conntrack.Count()
if remaining > 0 {
slog.Warn("Drain timeout reached, proceeding with tunnel teardown", "remaining", remaining)
}
return
case <-ticker.C:
if conntrack.Count() == 0 {
slog.Info("All connections drained successfully")
return
}
}
}
}

func (t *tunnel) Status() ipc.VPNStatus {
return t.status.Load().(ipc.VPNStatus)
}
Expand Down