Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go/binlog/gomysql_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func NewGoMySQLReader(migrationContext *base.MigrationContext) *GoMySQLReader {
UseDecimal: true,
TimestampStringLocation: time.UTC,
MaxReconnectAttempts: migrationContext.BinlogSyncerMaxReconnectAttempts,
Dialer: connectionConfig.Dialer,
}),
}
}
Expand Down
13 changes: 12 additions & 1 deletion go/mysql/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"os"
"strings"

gomysqlclient "github.com/go-mysql-org/go-mysql/client"
"github.com/go-sql-driver/mysql"
)

Expand All @@ -31,6 +32,10 @@ type ConnectionConfig struct {
Timeout float64
TransactionIsolation string
Charset string
// Network is the go-sql-driver network name. When empty, tcp is used.
Network string
// Dialer is used by go-mysql binlog connections. When nil, go-mysql uses net.Dialer.
Dialer gomysqlclient.Dialer

// use migrationContext.Uuid if useSSL
TLSKey string
Expand All @@ -54,6 +59,8 @@ func (this *ConnectionConfig) DuplicateCredentials(key InstanceKey) *ConnectionC
Timeout: this.Timeout,
TransactionIsolation: this.TransactionIsolation,
Charset: this.Charset,
Network: this.Network,
Dialer: this.Dialer,
TLSKey: this.TLSKey,
}

Expand Down Expand Up @@ -168,7 +175,11 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string {
fmt.Sprintf("writeTimeout=%fs", this.Timeout),
}

return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", this.User, this.Password, hostname, this.Key.Port, databaseName, strings.Join(connectionParams, "&"))
network := this.Network
if network == "" {
network = "tcp"
}
return fmt.Sprintf("%s:%s@%s(%s:%d)/%s?%s", this.User, this.Password, network, hostname, this.Key.Port, databaseName, strings.Join(connectionParams, "&"))
}

func GetDBTLSConfigKey(tlsKey string, tlsServerName string) string {
Expand Down
33 changes: 33 additions & 0 deletions go/mysql/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
package mysql

import (
"context"
"crypto/tls"
"errors"
"net"
"testing"

gomysqlclient "github.com/go-mysql-org/go-mysql/client"
"github.com/openark/golib/log"
"github.com/stretchr/testify/require"
)
Expand All @@ -34,6 +38,7 @@ func TestNewConnectionConfig(t *testing.T) {
}

func TestDuplicateCredentials(t *testing.T) {
dialerErr := errors.New("dialer called")
c := NewConnectionConfig()
c.Key = InstanceKey{Hostname: "myhost", Port: 3306}
c.User = "gromit"
Expand All @@ -44,6 +49,10 @@ func TestDuplicateCredentials(t *testing.T) {
}
c.TransactionIsolation = transactionIsolation
c.Charset = "utf8mb4"
c.Network = "mysql-tcp-12345678"
c.Dialer = func(context.Context, string, string) (net.Conn, error) {
return nil, dialerErr
}

dup := c.DuplicateCredentials(InstanceKey{Hostname: "otherhost", Port: 3310})
require.Equal(t, "otherhost", dup.Key.Hostname)
Expand All @@ -58,6 +67,9 @@ func TestDuplicateCredentials(t *testing.T) {
require.Equal(t, c.tlsConfig.InsecureSkipVerify, dup.tlsConfig.InsecureSkipVerify)
require.Equal(t, c.TransactionIsolation, dup.TransactionIsolation)
require.Equal(t, c.Charset, dup.Charset)
require.Equal(t, c.Network, dup.Network)
_, err := dup.Dialer(context.Background(), "tcp", "otherhost:3310")
require.ErrorIs(t, err, dialerErr)
}

func TestDuplicate(t *testing.T) {
Expand All @@ -67,6 +79,7 @@ func TestDuplicate(t *testing.T) {
c.Password = "penguin"
c.TransactionIsolation = transactionIsolation
c.Charset = "utf8mb4"
c.Network = "mysql-tcp-12345678"

dup := c.Duplicate()
require.Equal(t, "myhost", dup.Key.Hostname)
Expand All @@ -78,6 +91,13 @@ func TestDuplicate(t *testing.T) {
require.Equal(t, c.tlsConfig, dup.tlsConfig)
require.Equal(t, transactionIsolation, dup.TransactionIsolation)
require.Equal(t, "utf8mb4", dup.Charset)
require.Equal(t, c.Network, dup.Network)
}

func TestNewConnectionConfigHasNoCustomDialer(t *testing.T) {
c := NewConnectionConfig()
var dialer gomysqlclient.Dialer = c.Dialer
require.Nil(t, dialer)
}

func TestGetDBUri(t *testing.T) {
Expand Down Expand Up @@ -110,6 +130,19 @@ func TestGetDBUriWithTLSSetup(t *testing.T) {
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=uuidv4-myhost&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
}

func TestGetDBUriWithCustomNetwork(t *testing.T) {
c := NewConnectionConfig()
c.Key = InstanceKey{Hostname: "myhost", Port: 3306}
c.User = "gromit"
c.Password = "penguin"
c.Timeout = 1.2345
c.Charset = "utf8mb4,utf8,latin1"
c.Network = "mysql-tcp-12345678"

uri := c.GetDBUri("test")
require.Equal(t, `gromit:penguin@mysql-tcp-12345678(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4,utf8,latin1&tls=false&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
}

func TestGetDBTLSConfigKey(t *testing.T) {
configKey := GetDBTLSConfigKey("", "myhost")
require.Equal(t, "ghost-myhost", configKey)
Expand Down
Loading