diff --git a/doh-client/client.go b/doh-client/client.go index e306d2ec..f6e15810 100644 --- a/doh-client/client.go +++ b/doh-client/client.go @@ -90,6 +90,29 @@ func NewClient(conf *config.Config) (c *Client, err error) { Net: "tcp", Timeout: time.Duration(conf.Other.Timeout) * time.Second, } + + if c.conf.Other.Interface != "" { + localV4, localV6, err := c.getInterfaceIPs() + if err != nil { + return nil, fmt.Errorf("failed to get interface IPs for %s: %v", c.conf.Other.Interface, err) + } + var localAddr net.IP + if localV4 != nil { + localAddr = localV4 + } else { + localAddr = localV6 + } + + c.udpClient.Dialer = &net.Dialer{ + Timeout: time.Duration(conf.Other.Timeout) * time.Second, + LocalAddr: &net.UDPAddr{IP: localAddr}, + } + c.tcpClient.Dialer = &net.Dialer{ + Timeout: time.Duration(conf.Other.Timeout) * time.Second, + LocalAddr: &net.TCPAddr{IP: localAddr}, + } + } + for _, addr := range conf.Listen { c.udpServers = append(c.udpServers, &dns.Server{ Addr: addr, @@ -120,6 +143,38 @@ func NewClient(conf *config.Config) (c *Client, err error) { PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { var d net.Dialer + if c.conf.Other.Interface != "" { + localV4, localV6, err := c.getInterfaceIPs() + if err != nil { + log.Printf("Bootstrap dial warning: %v", err) + } else { + numServers := len(c.bootstrap) + bootstrap := c.bootstrap[rand.Intn(numServers)] + host, _, _ := net.SplitHostPort(bootstrap) + ip := net.ParseIP(host) + if ip != nil { + if ip.To4() != nil { + if localV4 != nil { + if strings.HasPrefix(network, "udp") { + d.LocalAddr = &net.UDPAddr{IP: localV4} + } else { + d.LocalAddr = &net.TCPAddr{IP: localV4} + } + } + } else { + if localV6 != nil { + if strings.HasPrefix(network, "udp") { + d.LocalAddr = &net.UDPAddr{IP: localV6} + } else { + d.LocalAddr = &net.TCPAddr{IP: localV6} + } + } + } + } + conn, err := d.DialContext(ctx, network, bootstrap) + return conn, err + } + } numServers := len(c.bootstrap) bootstrap := c.bootstrap[rand.Intn(numServers)] conn, err := d.DialContext(ctx, network, bootstrap) @@ -235,14 +290,72 @@ func (c *Client) newHTTPClient() error { if c.httpTransport != nil { c.httpTransport.CloseIdleConnections() } - dialer := &net.Dialer{ + + localV4, localV6, err := c.getInterfaceIPs() + if err != nil { + log.Printf("Interface binding error: %v", err) + return err + } + + baseDialer := &net.Dialer{ Timeout: time.Duration(c.conf.Other.Timeout) * time.Second, KeepAlive: 30 * time.Second, - // DualStack: true, - Resolver: c.bootstrapResolver, + Resolver: c.bootstrapResolver, } + c.httpTransport = &http.Transport{ - DialContext: dialer.DialContext, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if c.conf.Other.Interface == "" { + return baseDialer.DialContext(ctx, network, addr) + } + + if network == "tcp4" && localV4 != nil { + d := *baseDialer + d.LocalAddr = &net.TCPAddr{IP: localV4} + return d.DialContext(ctx, network, addr) + } + if network == "tcp6" && localV6 != nil { + d := *baseDialer + d.LocalAddr = &net.TCPAddr{IP: localV6} + return d.DialContext(ctx, network, addr) + } + + // Manual Dual-Stack: Resolve host and try compatible families sequentially + host, port, _ := net.SplitHostPort(addr) + ips, err := c.bootstrapResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + + var lastErr error + for _, ip := range ips { + d := *baseDialer + targetAddr := net.JoinHostPort(ip.String(), port) + + if ip.IP.To4() != nil { + if localV4 == nil { + continue + } + d.LocalAddr = &net.TCPAddr{IP: localV4} + } else { + if localV6 == nil { + continue + } + d.LocalAddr = &net.TCPAddr{IP: localV6} + } + + conn, err := d.DialContext(ctx, "tcp", targetAddr) + if err == nil { + return conn, nil + } + lastErr = err + } + + if lastErr != nil { + return nil, lastErr + } + return nil, fmt.Errorf("connection to %s failed: no matching local/remote IP families on interface %s", addr, c.conf.Other.Interface) + }, ExpectContinueTimeout: 1 * time.Second, IdleConnTimeout: 90 * time.Second, MaxIdleConns: 100, @@ -251,15 +364,18 @@ func (c *Client) newHTTPClient() error { TLSHandshakeTimeout: time.Duration(c.conf.Other.Timeout) * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: c.conf.Other.TLSInsecureSkipVerify}, } + if c.conf.Other.NoIPv6 { + originalDial := c.httpTransport.DialContext c.httpTransport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { if strings.HasPrefix(network, "tcp") { network = "tcp4" } - return dialer.DialContext(ctx, network, address) + return originalDial(ctx, network, address) } } - err := http2.ConfigureTransport(c.httpTransport) + + err = http2.ConfigureTransport(c.httpTransport) if err != nil { return err } @@ -485,3 +601,38 @@ func (c *Client) findClientIP(w dns.ResponseWriter, r *dns.Msg) (ednsClientAddre } return } + +// getInterfaceIPs returns the first valid IPv4 and IPv6 addresses found on the interface +func (c *Client) getInterfaceIPs() (v4, v6 net.IP, err error) { + if c.conf.Other.Interface == "" { + return nil, nil, nil + } + ifi, err := net.InterfaceByName(c.conf.Other.Interface) + if err != nil { + return nil, nil, err + } + addrs, err := ifi.Addrs() + if err != nil { + return nil, nil, err + } + + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + if err != nil { + continue + } + if ip4 := ip.To4(); ip4 != nil { + if v4 == nil { + v4 = ip4 + } + } else { + if v6 == nil && !c.conf.Other.NoIPv6 { + v6 = ip + } + } + } + if v4 == nil && v6 == nil { + return nil, nil, fmt.Errorf("no valid IP addresses found on interface %s", c.conf.Other.Interface) + } + return v4, v6, nil +} diff --git a/doh-client/config/config.go b/doh-client/config/config.go index e57e22b3..78207a3e 100644 --- a/doh-client/config/config.go +++ b/doh-client/config/config.go @@ -50,6 +50,7 @@ type others struct { Bootstrap []string `toml:"bootstrap"` Passthrough []string `toml:"passthrough"` Timeout uint `toml:"timeout"` + Interface string `toml:"interface"` NoCookies bool `toml:"no_cookies"` NoECS bool `toml:"no_ecs"` NoIPv6 bool `toml:"no_ipv6"` diff --git a/doh-client/doh-client.conf b/doh-client/doh-client.conf index 01d12912..20b2f41e 100644 --- a/doh-client/doh-client.conf +++ b/doh-client/doh-client.conf @@ -97,6 +97,11 @@ passthrough = [ # Timeout for upstream request in seconds timeout = 30 +# Interface to bind to for outgoing connections. +# If empty, the system default route is used (usually eth0 or wlan0). +# Example: "eth1", "wlan0" +interface = "" + # Disable HTTP Cookies # # Cookies may be useful if your upstream resolver is protected by some