diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index 77b36e5b1..39bdf74b1 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -55,6 +55,7 @@ func NewGoMySQLReader(migrationContext *base.MigrationContext) *GoMySQLReader { UseDecimal: true, TimestampStringLocation: time.UTC, MaxReconnectAttempts: migrationContext.BinlogSyncerMaxReconnectAttempts, + Dialer: connectionConfig.Dialer, }), } } diff --git a/go/mysql/connection.go b/go/mysql/connection.go index f116b22fe..45ebdec76 100644 --- a/go/mysql/connection.go +++ b/go/mysql/connection.go @@ -14,6 +14,7 @@ import ( "os" "strings" + gomysqlclient "github.com/go-mysql-org/go-mysql/client" "github.com/go-sql-driver/mysql" ) @@ -31,6 +32,10 @@ type ConnectionConfig struct { Timeout float64 TransactionIsolation string Charset string + // Network is the go-sql-driver network name. When empty, tcp is used. + Network string + // Dialer is used by go-mysql binlog connections. When nil, go-mysql uses net.Dialer. + Dialer gomysqlclient.Dialer // use migrationContext.Uuid if useSSL TLSKey string @@ -54,6 +59,8 @@ func (this *ConnectionConfig) DuplicateCredentials(key InstanceKey) *ConnectionC Timeout: this.Timeout, TransactionIsolation: this.TransactionIsolation, Charset: this.Charset, + Network: this.Network, + Dialer: this.Dialer, TLSKey: this.TLSKey, } @@ -168,7 +175,11 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string { fmt.Sprintf("writeTimeout=%fs", this.Timeout), } - return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", this.User, this.Password, hostname, this.Key.Port, databaseName, strings.Join(connectionParams, "&")) + network := this.Network + if network == "" { + network = "tcp" + } + return fmt.Sprintf("%s:%s@%s(%s:%d)/%s?%s", this.User, this.Password, network, hostname, this.Key.Port, databaseName, strings.Join(connectionParams, "&")) } func GetDBTLSConfigKey(tlsKey string, tlsServerName string) string { diff --git a/go/mysql/connection_test.go b/go/mysql/connection_test.go index 73f5d9bd4..594dab062 100644 --- a/go/mysql/connection_test.go +++ b/go/mysql/connection_test.go @@ -6,9 +6,13 @@ package mysql import ( + "context" "crypto/tls" + "errors" + "net" "testing" + gomysqlclient "github.com/go-mysql-org/go-mysql/client" "github.com/openark/golib/log" "github.com/stretchr/testify/require" ) @@ -34,6 +38,7 @@ func TestNewConnectionConfig(t *testing.T) { } func TestDuplicateCredentials(t *testing.T) { + dialerErr := errors.New("dialer called") c := NewConnectionConfig() c.Key = InstanceKey{Hostname: "myhost", Port: 3306} c.User = "gromit" @@ -44,6 +49,10 @@ func TestDuplicateCredentials(t *testing.T) { } c.TransactionIsolation = transactionIsolation c.Charset = "utf8mb4" + c.Network = "mysql-tcp-12345678" + c.Dialer = func(context.Context, string, string) (net.Conn, error) { + return nil, dialerErr + } dup := c.DuplicateCredentials(InstanceKey{Hostname: "otherhost", Port: 3310}) require.Equal(t, "otherhost", dup.Key.Hostname) @@ -58,6 +67,9 @@ func TestDuplicateCredentials(t *testing.T) { require.Equal(t, c.tlsConfig.InsecureSkipVerify, dup.tlsConfig.InsecureSkipVerify) require.Equal(t, c.TransactionIsolation, dup.TransactionIsolation) require.Equal(t, c.Charset, dup.Charset) + require.Equal(t, c.Network, dup.Network) + _, err := dup.Dialer(context.Background(), "tcp", "otherhost:3310") + require.ErrorIs(t, err, dialerErr) } func TestDuplicate(t *testing.T) { @@ -67,6 +79,7 @@ func TestDuplicate(t *testing.T) { c.Password = "penguin" c.TransactionIsolation = transactionIsolation c.Charset = "utf8mb4" + c.Network = "mysql-tcp-12345678" dup := c.Duplicate() require.Equal(t, "myhost", dup.Key.Hostname) @@ -78,6 +91,13 @@ func TestDuplicate(t *testing.T) { require.Equal(t, c.tlsConfig, dup.tlsConfig) require.Equal(t, transactionIsolation, dup.TransactionIsolation) require.Equal(t, "utf8mb4", dup.Charset) + require.Equal(t, c.Network, dup.Network) +} + +func TestNewConnectionConfigHasNoCustomDialer(t *testing.T) { + c := NewConnectionConfig() + var dialer gomysqlclient.Dialer = c.Dialer + require.Nil(t, dialer) } func TestGetDBUri(t *testing.T) { @@ -110,6 +130,19 @@ func TestGetDBUriWithTLSSetup(t *testing.T) { require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=uuidv4-myhost&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri) } +func TestGetDBUriWithCustomNetwork(t *testing.T) { + c := NewConnectionConfig() + c.Key = InstanceKey{Hostname: "myhost", Port: 3306} + c.User = "gromit" + c.Password = "penguin" + c.Timeout = 1.2345 + c.Charset = "utf8mb4,utf8,latin1" + c.Network = "mysql-tcp-12345678" + + uri := c.GetDBUri("test") + require.Equal(t, `gromit:penguin@mysql-tcp-12345678(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4,utf8,latin1&tls=false&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri) +} + func TestGetDBTLSConfigKey(t *testing.T) { configKey := GetDBTLSConfigKey("", "myhost") require.Equal(t, "ghost-myhost", configKey)