Skip to content

Commit 085a334

Browse files
authored
proxy, util: meter public IP addresses into public traffic (#1006)
1 parent 91fab67 commit 085a334

File tree

5 files changed

+155
-12
lines changed

5 files changed

+155
-12
lines changed

pkg/balance/router/group.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,14 @@ func (g *Group) Match(clientInfo ClientInfo) bool {
9595
if g.matchType == MatchClientCIDR {
9696
addr = clientInfo.ClientAddr
9797
}
98-
contains, err := netutil.CIDRContainsIP(g.cidrList, addr)
98+
ip, err := netutil.NetAddr2IP(addr)
9999
if err != nil {
100-
g.lg.Error("checking CIDR failed", zap.String("addr", addr.String()), zap.Error(err))
100+
g.lg.Error("checking CIDR failed", zap.Stringer("addr", addr), zap.Error(err))
101+
return false
102+
}
103+
contains, err := netutil.CIDRContainsIP(g.cidrList, ip)
104+
if err != nil {
105+
g.lg.Error("checking CIDR failed", zap.Stringer("addr", addr), zap.Error(err))
101106
}
102107
return contains
103108
}

pkg/proxy/proxy.go

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package proxy
66
import (
77
"context"
88
"net"
9+
"reflect"
910
"strings"
1011
"sync"
1112
"time"
@@ -166,11 +167,12 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) {
166167
zap.String("addr", addr))
167168
clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerSQLTLS(), s.certMgr.SQLTLS(),
168169
s.hsHandler, s.cpt, connID, addr, &backend.BCConfig{
169-
ProxyProtocol: s.mu.proxyProtocol,
170-
RequireBackendTLS: s.mu.requireBackendTLS,
171-
HealthyKeepAlive: s.mu.healthyKeepAlive,
172-
UnhealthyKeepAlive: s.mu.unhealthyKeepAlive,
173-
ConnBufferSize: s.mu.connBufferSize,
170+
ProxyProtocol: s.mu.proxyProtocol,
171+
RequireBackendTLS: s.mu.requireBackendTLS,
172+
HealthyKeepAlive: s.mu.healthyKeepAlive,
173+
UnhealthyKeepAlive: s.mu.unhealthyKeepAlive,
174+
ConnBufferSize: s.mu.connBufferSize,
175+
FromPublicEndpoints: s.fromPublicEndpoint,
174176
}, s.meter)
175177
s.mu.clients[connID] = clientConn
176178
logger.Debug("new connection", zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.Bool("require_backend_tls", s.mu.requireBackendTLS))
@@ -204,6 +206,31 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) {
204206
clientConn.Run(ctx)
205207
}
206208

209+
func (s *SQLServer) fromPublicEndpoint(addr net.Addr) bool {
210+
if addr == nil || reflect.ValueOf(addr).IsNil() {
211+
return false
212+
}
213+
s.mu.RLock()
214+
publicEndpoints := s.mu.publicEndpoints
215+
s.mu.RUnlock()
216+
ip, err := netutil.NetAddr2IP(addr)
217+
if err != nil {
218+
s.logger.Warn("failed to check public endpoint", zap.Any("addr", addr), zap.Error(err))
219+
return false
220+
}
221+
contains, err := netutil.CIDRContainsIP(publicEndpoints, ip)
222+
if err != nil {
223+
s.logger.Warn("failed to check public endpoint", zap.Any("ip", ip), zap.Error(err))
224+
return false
225+
}
226+
if contains {
227+
return true
228+
}
229+
// The public NLB may enable preserveIP, and the incoming address is the client address, which may be a public address.
230+
// Even if the private NLB enables preserveIP, the client address is still a private address.
231+
return !netutil.IsPrivate(ip)
232+
}
233+
207234
func (s *SQLServer) PreClose() {
208235
// Step 1: HTTP status returns unhealthy so that NLB takes this instance offline and then new connections won't come.
209236
s.mu.Lock()

pkg/proxy/proxy_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/pingcap/tiproxy/pkg/proxy/client"
2626
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
2727
"github.com/stretchr/testify/require"
28+
"go.uber.org/zap"
2829
)
2930

3031
func TestCreateConn(t *testing.T) {
@@ -285,6 +286,44 @@ func TestRecoverPanic(t *testing.T) {
285286
certManager.Close()
286287
}
287288

289+
func TestPublicEndpoint(t *testing.T) {
290+
tests := []struct {
291+
publicEndpoints []string
292+
publicIps []string
293+
privateIps []string
294+
}{
295+
{
296+
publicIps: []string{"137.84.2.178"},
297+
privateIps: []string{"10.10.10.10"},
298+
},
299+
{
300+
publicEndpoints: []string{"10.10.10.0/24"},
301+
publicIps: []string{"137.84.2.178", "10.10.10.10"},
302+
privateIps: []string{"10.10.20.10"},
303+
},
304+
{
305+
publicEndpoints: []string{"10.10.10.0/24", "10.10.20.10"},
306+
publicIps: []string{"137.84.2.178", "10.10.10.10", "10.10.20.10"},
307+
privateIps: []string{"10.10.20.11"},
308+
},
309+
}
310+
311+
server, err := NewSQLServer(zap.NewNop(), &config.Config{}, nil, id.NewIDManager(), nil, nil, backend.NewDefaultHandshakeHandler(nil))
312+
require.NoError(t, err)
313+
for i, test := range tests {
314+
cfg := &config.Config{}
315+
cfg.Proxy.PublicEndpoints = test.publicEndpoints
316+
server.reset(cfg)
317+
for j, ip := range test.publicIps {
318+
require.True(t, server.fromPublicEndpoint(&net.TCPAddr{IP: net.ParseIP(ip), Port: 1000}), "test %d %d", i, j)
319+
}
320+
for j, ip := range test.privateIps {
321+
require.False(t, server.fromPublicEndpoint(&net.TCPAddr{IP: net.ParseIP(ip), Port: 1000}), "test %d %d", i, j)
322+
}
323+
require.False(t, server.fromPublicEndpoint(nil))
324+
}
325+
}
326+
288327
type mockHsHandler struct {
289328
backend.DefaultHandshakeHandler
290329
handshakeResp func(ctx backend.ConnContext, _ *pnet.HandshakeResp) error

pkg/util/netutil/netutil.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,25 @@ func ParseCIDRList(strList []string) ([]*net.IPNet, error) {
2828
return cidrList, parseErr
2929
}
3030

31-
func CIDRContainsIP(cidrList []*net.IPNet, addr net.Addr) (bool, error) {
31+
func NetAddr2IP(addr net.Addr) (net.IP, error) {
3232
if addr == nil || reflect.ValueOf(addr).IsNil() {
33-
return false, errors.New("address is nil")
33+
return nil, errors.New("address is nil")
3434
}
3535
value := addr.String()
3636
ipStr, _, err := net.SplitHostPort(value)
3737
if err != nil {
38-
return false, errors.Wrapf(err, "failed to parse address '%s'", value)
38+
return nil, errors.Wrapf(err, "failed to parse address '%s'", value)
3939
}
4040
ip := net.ParseIP(ipStr)
4141
if ip == nil {
42-
return false, errors.Errorf("failed to parse IP '%s'", value)
42+
return nil, errors.Errorf("failed to parse IP '%s'", value)
43+
}
44+
return ip, nil
45+
}
46+
47+
func CIDRContainsIP(cidrList []*net.IPNet, ip net.IP) (bool, error) {
48+
if ip == nil || reflect.ValueOf(ip).IsNil() {
49+
return false, errors.New("address is nil")
4350
}
4451
for _, cidr := range cidrList {
4552
if cidr.Contains(ip) {
@@ -48,3 +55,15 @@ func CIDRContainsIP(cidrList []*net.IPNet, addr net.Addr) (bool, error) {
4855
}
4956
return false, nil
5057
}
58+
59+
func IsPrivate(ip net.IP) bool {
60+
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
61+
return true
62+
}
63+
if ipv4 := ip.To4(); ipv4 != nil {
64+
if ipv4[0] == 100 && ipv4[1] >= 64 && ipv4[1] <= 127 {
65+
return true
66+
}
67+
}
68+
return ip.IsPrivate()
69+
}

pkg/util/netutil/netutil_test.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,60 @@ func TestContainIP(t *testing.T) {
8989
for _, test := range tests {
9090
list, err := ParseCIDRList(test.cidrs)
9191
require.NoError(t, err, "ip: %s, cidrs: %v", test.ip, test.cidrs)
92-
contain, _ := CIDRContainsIP(list, &net.TCPAddr{IP: net.ParseIP(test.ip), Port: 1000})
92+
contain, _ := CIDRContainsIP(list, net.ParseIP(test.ip))
9393
require.Equal(t, test.success, contain, "ip: %s, cidrs: %v", test.ip, test.cidrs)
9494
}
9595
}
96+
97+
func TestIsPrivate(t *testing.T) {
98+
tests := []struct {
99+
ip string
100+
private bool
101+
}{
102+
{
103+
ip: "8.8.8.8",
104+
private: false,
105+
},
106+
{
107+
ip: "192.168.1.1",
108+
private: true,
109+
},
110+
{
111+
ip: "172.16.0.1",
112+
private: true,
113+
},
114+
{
115+
ip: "10.0.0.1",
116+
private: true,
117+
},
118+
{
119+
ip: "127.0.0.1",
120+
private: true,
121+
},
122+
{
123+
ip: "::1",
124+
private: true,
125+
},
126+
{
127+
ip: "169.254.1.1",
128+
private: true,
129+
},
130+
{
131+
ip: "100.64.1.1",
132+
private: true,
133+
},
134+
{
135+
ip: "2001:4860:4860::8888",
136+
private: false,
137+
},
138+
{
139+
ip: "fc00::1",
140+
private: true,
141+
},
142+
}
143+
144+
for _, test := range tests {
145+
ip := net.ParseIP(test.ip)
146+
require.Equal(t, test.private, IsPrivate(ip), "ip: %s", test.ip)
147+
}
148+
}

0 commit comments

Comments
 (0)