diff --git a/go/base/context.go b/go/base/context.go index 2fed82486..59550e045 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -276,12 +276,13 @@ type MigrationContext struct { // move tables: MoveTables struct { - TableNames []string // List of table names to be moved. - TargetHost string // Target hostname for the move. This must be a primary/writable host. - TargetPort int // Target MySQL port for the move. - TargetUser string // Target username for the move. If not specified, it will default to the source user. - TargetPass string // Target password for the move. If not specified, it will default to the source password. - TargetDatabase string // Target database name for the move. If not specified, it will default to the source database name. + TableNames []string // List of table names to be moved. + TargetHost string // Target hostname for the move. This must be a primary/writable host. + TargetPort int // Target MySQL port for the move. + TargetUser string // Target username for the move. If not specified, it will default to the source user. + TargetPass string // Target password for the move. If not specified, it will default to the source password. + TargetDatabase string // Target database name for the move. If not specified, it will default to the source database name. + ConnectionConfig *mysql.ConnectionConfig } Log Logger @@ -356,6 +357,9 @@ func (mctx *MigrationContext) SetConnectionConfig(storageEngine string) error { } mctx.InspectorConnectionConfig.TransactionIsolation = transactionIsolation mctx.ApplierConnectionConfig.TransactionIsolation = transactionIsolation + if mctx.MoveTables.ConnectionConfig != nil { + mctx.MoveTables.ConnectionConfig.TransactionIsolation = transactionIsolation + } return nil } @@ -366,6 +370,9 @@ func (mctx *MigrationContext) SetConnectionCharset(charset string) { mctx.InspectorConnectionConfig.Charset = charset mctx.ApplierConnectionConfig.Charset = charset + if mctx.MoveTables.ConnectionConfig != nil { + mctx.MoveTables.ConnectionConfig.Charset = charset + } } func getSafeTableName(baseName string, suffix string) string { diff --git a/go/logic/applier.go b/go/logic/applier.go index f3474b3ef..24b47dc84 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -91,6 +91,12 @@ type Applier struct { migrationLockName string migrationLockStop chan struct{} migrationLockDone chan struct{} + + moveTablesTargetDB *gosql.DB + moveTablesConnectionConfig *mysql.ConnectionConfig + moveTablesCopySelectFirstQueryBuilder *sql.MoveTableCopySelectQueryBuilder + moveTablesCopySelectNextQueryBuilder *sql.MoveTableCopySelectQueryBuilder + moveTablesCopyInsertQueryBuilder *sql.MoveTableCopyInsertQueryBuilder } func NewApplier(migrationContext *base.MigrationContext) *Applier { @@ -99,6 +105,8 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier { migrationContext: migrationContext, finishedMigrating: 0, name: "applier", + + moveTablesConnectionConfig: migrationContext.MoveTables.ConnectionConfig, } } @@ -149,6 +157,15 @@ func (apl *Applier) InitDBConnections() (err error) { if err := apl.readTableColumns(); err != nil { return err } + if apl.moveTablesConnectionConfig != nil { + moveTablesURI := apl.moveTablesConnectionConfig.GetDBUri(apl.migrationContext.MoveTables.TargetDatabase) + "&multiStatements=true" + if apl.moveTablesTargetDB, _, err = mysql.GetDB(apl.migrationContext.Uuid, moveTablesURI); err != nil { + return err + } + if _, err := base.ValidateConnection(apl.moveTablesTargetDB, apl.moveTablesConnectionConfig, apl.migrationContext, apl.name); err != nil { + return err + } + } apl.migrationContext.Log.Infof("Applier initiated on %+v, version %+v", apl.connectionConfig.ImpliedKey, apl.migrationContext.ApplierMySQLVersion) return nil } @@ -297,17 +314,24 @@ func (apl *Applier) releaseMigrationLock() { } func (apl *Applier) prepareQueries() (err error) { + targetDatabaseName := apl.migrationContext.DatabaseName + targetTableName := apl.migrationContext.GetGhostTableName() + if apl.migrationContext.IsMoveTablesMode() { + targetDatabaseName = apl.migrationContext.MoveTables.TargetDatabase + targetTableName = apl.migrationContext.OriginalTableName + } + if apl.dmlDeleteQueryBuilder, err = sql.NewDMLDeleteQueryBuilder( - apl.migrationContext.DatabaseName, - apl.migrationContext.GetGhostTableName(), + targetDatabaseName, + targetTableName, apl.migrationContext.OriginalTableColumns, &apl.migrationContext.UniqueKey.Columns, ); err != nil { return err } if apl.dmlInsertQueryBuilder, err = sql.NewDMLInsertQueryBuilder( - apl.migrationContext.DatabaseName, - apl.migrationContext.GetGhostTableName(), + targetDatabaseName, + targetTableName, apl.migrationContext.OriginalTableColumns, apl.migrationContext.SharedColumns, apl.migrationContext.MappedSharedColumns, @@ -315,8 +339,8 @@ func (apl *Applier) prepareQueries() (err error) { return err } if apl.dmlUpdateQueryBuilder, err = sql.NewDMLUpdateQueryBuilder( - apl.migrationContext.DatabaseName, - apl.migrationContext.GetGhostTableName(), + targetDatabaseName, + targetTableName, apl.migrationContext.OriginalTableColumns, apl.migrationContext.SharedColumns, apl.migrationContext.MappedSharedColumns, @@ -333,6 +357,35 @@ func (apl *Applier) prepareQueries() (err error) { return err } } + if apl.migrationContext.IsMoveTablesMode() { + if apl.moveTablesCopySelectFirstQueryBuilder, err = sql.NewMoveTableCopySelectQueryBuilder( + apl.migrationContext.DatabaseName, + apl.migrationContext.OriginalTableName, + apl.migrationContext.OriginalTableColumns, + apl.migrationContext.UniqueKey.Name, + &apl.migrationContext.UniqueKey.Columns, + true, // <-- include start range values for first select query + ); err != nil { + return err + } + if apl.moveTablesCopySelectNextQueryBuilder, err = sql.NewMoveTableCopySelectQueryBuilder( + apl.migrationContext.DatabaseName, + apl.migrationContext.OriginalTableName, + apl.migrationContext.OriginalTableColumns, + apl.migrationContext.UniqueKey.Name, + &apl.migrationContext.UniqueKey.Columns, + false, + ); err != nil { + return err + } + if apl.moveTablesCopyInsertQueryBuilder, err = sql.NewMoveTableCopyInsertQueryBuilder( + targetDatabaseName, + targetTableName, + apl.migrationContext.OriginalTableColumns, + ); err != nil { + return err + } + } return nil } @@ -1164,6 +1217,130 @@ func (apl *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected i return chunkSize, rowsAffected, duration, nil } +// ApplyIterationMoveTableCopyQueries issues a SELECT query on the original table and an INSERT query on the target table, +// copying a chunk of rows. It is used when `--move-tables` is specified, instead of ApplyIterationInsertQuery. +func (apl *Applier) ApplyIterationMoveTableCopyQueries() (chunkSize int64, rowsAffected int64, duration time.Duration, err error) { + startTime := time.Now() + chunkSize = atomic.LoadInt64(&apl.migrationContext.ChunkSize) + + // First, select data from the source database: + rows, err := func() ([]*sql.ColumnValues, error) { + var qb *sql.MoveTableCopySelectQueryBuilder + if apl.migrationContext.GetIteration() == 0 { + qb = apl.moveTablesCopySelectFirstQueryBuilder + } else { + qb = apl.moveTablesCopySelectNextQueryBuilder + } + query, explodedArgs, err := qb.BuildQuery( + apl.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), + apl.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), + ) + if err != nil { + return nil, err + } + sqlRows, err := apl.db.Query(query, explodedArgs...) + if err != nil { + return nil, err + } + defer sqlRows.Close() + chunkRows := make([]*sql.ColumnValues, 0, chunkSize) + for sqlRows.Next() { + row := sql.NewColumnValues(apl.migrationContext.SharedColumns.Len()) + err := sqlRows.Scan(row.ValuesPointers...) + if err != nil { + return nil, err + } + chunkRows = append(chunkRows, row) + } + if rowsErr := sqlRows.Err(); rowsErr != nil { + return nil, rowsErr + } + return chunkRows, nil + }() + if err != nil { + return chunkSize, rowsAffected, duration, err + } + + // no need to INSERT if there are no rows to copy: + if len(rows) == 0 { + duration = time.Since(startTime) + return chunkSize, 0, duration, nil + } + + // Then, insert data into the destination database: + sqlResult, err := func() (gosql.Result, error) { + query, explodedArgs, err := apl.moveTablesCopyInsertQueryBuilder.BuildQuery(rows) + if err != nil { + return nil, err + } + tx, err := apl.moveTablesTargetDB.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s', %s`, + apl.migrationContext.ApplierTimeZone, + apl.generateSqlModeQuery()) + if _, err := tx.Exec(sessionQuery); err != nil { + return nil, err + } + + sqlResult, err := tx.Exec(query, explodedArgs...) + if err != nil { + return nil, err + } + + if apl.migrationContext.PanicOnWarnings { + rows, err := tx.Query("SHOW WARNINGS") + if err != nil { + return nil, err + } + defer rows.Close() + if err = rows.Err(); err != nil { + return nil, err + } + migrationKeyRegex, err := apl.compileMigrationKeyWarningRegex() + if err != nil { + return nil, err + } + var sqlWarnings []string + for rows.Next() { + var level, message string + var code int + if err := rows.Scan(&level, &code, &message); err != nil { + apl.migrationContext.Log.Warningf("Failed to read SHOW WARNINGS row") + continue + } + if strings.Contains(message, "Duplicate entry") && migrationKeyRegex.MatchString(message) { + continue + } + sqlWarnings = append(sqlWarnings, fmt.Sprintf("%s: %s (%d)", level, message, code)) + } + apl.migrationContext.MigrationLastInsertSQLWarnings = sqlWarnings + } + + if err := tx.Commit(); err != nil { + return nil, err + } + return sqlResult, nil + }() + if err != nil { + return chunkSize, rowsAffected, duration, err + } + rowsAffected, _ = sqlResult.RowsAffected() + duration = time.Since(startTime) + apl.migrationContext.Log.Debugf( + "Issued SELECT+INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d", + apl.migrationContext.MigrationIterationRangeMinValues, + apl.migrationContext.MigrationIterationRangeMaxValues, + apl.migrationContext.GetIteration(), + chunkSize, + ) + + return chunkSize, rowsAffected, duration, nil +} + // LockOriginalTable places a write lock on the original table func (apl *Applier) LockOriginalTable() error { query := fmt.Sprintf(`lock /* gh-ost */ tables %s.%s write`, @@ -1783,7 +1960,11 @@ func (apl *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) e ctx := context.Background() err := func() error { - conn, err := apl.db.Conn(ctx) + db := apl.db + if apl.migrationContext.IsMoveTablesMode() { + db = apl.moveTablesTargetDB + } + conn, err := db.Conn(ctx) if err != nil { return err } @@ -1888,6 +2069,9 @@ func (apl *Applier) Teardown() { apl.releaseMigrationLock() apl.db.Close() apl.singletonDB.Close() + if apl.moveTablesTargetDB != nil { + apl.moveTablesTargetDB.Close() + } atomic.StoreInt64(&apl.finishedMigrating, 1) } diff --git a/go/logic/applier_test.go b/go/logic/applier_test.go index 85a5a01d3..2f1676617 100644 --- a/go/logic/applier_test.go +++ b/go/logic/applier_test.go @@ -9,6 +9,7 @@ import ( "context" gosql "database/sql" "errors" + "net" "strings" "testing" "time" @@ -271,6 +272,7 @@ type ApplierTestSuite struct { mysqlContainer testcontainers.Container db *gosql.DB + otherDB *gosql.DB } func (suite *ApplierTestSuite) SetupSuite() { @@ -291,12 +293,30 @@ func (suite *ApplierTestSuite) SetupSuite() { db, err := gosql.Open("mysql", dsn) suite.Require().NoError(err) - suite.db = db + + containerHost, err := mysqlContainer.Host(ctx) + suite.Require().NoError(err) + containerPort, err := mysqlContainer.MappedPort(ctx, "3306/tcp") + suite.Require().NoError(err) + + // Second database & connection for move-tables tests: + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", testMysqlDatabaseOther)) + suite.Require().NoError(err) + otherConf := drivermysql.NewConfig() + otherConf.DBName = testMysqlDatabaseOther + otherConf.User = testMysqlUser + otherConf.Passwd = testMysqlPass + otherConf.Net = "tcp" + otherConf.Addr = net.JoinHostPort(containerHost, containerPort.Port()) + otherDB, err := gosql.Open("mysql", otherConf.FormatDSN()) + suite.Require().NoError(err) + suite.otherDB = otherDB } func (suite *ApplierTestSuite) TeardownSuite() { suite.Assert().NoError(suite.db.Close()) + suite.Assert().NoError(suite.otherDB.Close()) suite.Assert().NoError(testcontainers.TerminateContainer(suite.mysqlContainer)) } @@ -313,6 +333,8 @@ func (suite *ApplierTestSuite) TearDownTest() { suite.Require().NoError(err) _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestGhostTableName()) suite.Require().NoError(err) + _, err = suite.otherDB.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestOtherTableName()) + suite.Require().NoError(err) } func (suite *ApplierTestSuite) TestInitDBConnections() { @@ -1619,6 +1641,208 @@ func (suite *ApplierTestSuite) TestMultipleDMLEventsInBatch() { // Critically: id=2 (bob@example.com) is NOT present, proving event #3 was rolled back } +func (suite *ApplierTestSuite) TestApplyDMLEventQueriesMoveTablesMode() { + ctx := context.Background() + var err error + + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT, item_id INT);", getTestTableName())) + suite.Require().NoError(err) + _, err = suite.otherDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT, item_id INT);", getTestOtherTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.MoveTables.ConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "item_id"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "item_id"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "item_id"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "primary_key", + Columns: *sql.NewColumnList([]string{"id"}), + } + migrationContext.MoveTables.TableNames = []string{testMysqlTableName} + migrationContext.MoveTables.TargetDatabase = testMysqlDatabaseOther + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{123456, 42}), + }, + } + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().NoError(err) + + // Check that the row was inserted into the ghost table via moveTablesTargetDB + rows, err := suite.otherDB.Query("SELECT * FROM " + getTestOtherTableName()) + suite.Require().NoError(err) + defer rows.Close() + + var count, id, item_id int + for rows.Next() { + err = rows.Scan(&id, &item_id) + suite.Require().NoError(err) + count += 1 + } + suite.Require().NoError(rows.Err()) + + suite.Require().Equal(1, count) + suite.Require().Equal(123456, id) + suite.Require().Equal(42, item_id) + + suite.Require().Equal(int64(1), migrationContext.TotalDMLEventsApplied) + suite.Require().Equal(int64(0), migrationContext.RowsDeltaEstimate) +} + +func (suite *ApplierTestSuite) TestApplyIterationMoveTableCopyQueries() { + ctx := context.Background() + var err error + + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL, name VARCHAR(50), created_at DATETIME NOT NULL, PRIMARY KEY(id));", getTestTableName())) + suite.Require().NoError(err) + _, err = suite.otherDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL, name VARCHAR(50), created_at DATETIME NOT NULL, PRIMARY KEY(id));", getTestOtherTableName())) + suite.Require().NoError(err) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name, created_at) VALUES (1, 'alice', '2024-01-15 10:30:00'), (2, 'bob', '2024-06-20 14:45:00'), (3, 'carol', '2025-12-31 23:59:59');", getTestTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.MoveTables.ConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "name", "created_at"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "name", "created_at"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "name", "created_at"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + migrationContext.MoveTables.TableNames = []string{testMysqlTableName} + migrationContext.MoveTables.TargetDatabase = testMysqlDatabaseOther + + applier := NewApplier(migrationContext) + applier.prepareQueries() + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + err = applier.CreateChangelogTable() + suite.Require().NoError(err) + + err = applier.ReadMigrationRangeValues() + suite.Require().NoError(err) + + migrationContext.SetNextIterationRangeMinValues() + hasFurtherRange, err := applier.CalculateNextIterationRangeEndValues() + suite.Require().NoError(err) + suite.Require().True(hasFurtherRange) + + chunkSize, rowsAffected, duration, err := applier.ApplyIterationMoveTableCopyQueries() + suite.Require().NoError(err) + suite.Require().Equal(int64(3), rowsAffected) + suite.Require().Equal(int64(1000), chunkSize) + suite.Require().Greater(duration, time.Duration(0)) + + // Verify rows were copied to the other table + rows, err := suite.otherDB.QueryContext(ctx, "SELECT id, name, created_at FROM "+getTestOtherTableName()+" ORDER BY id") + suite.Require().NoError(err) + defer rows.Close() + + type row struct { + id int + name string + createdAt string + } + var results []row + for rows.Next() { + var r row + err = rows.Scan(&r.id, &r.name, &r.createdAt) + suite.Require().NoError(err) + results = append(results, r) + } + suite.Require().NoError(rows.Err()) + + suite.Require().Len(results, 3) + suite.Require().Equal(1, results[0].id) + suite.Require().Equal("alice", results[0].name) + suite.Require().Equal("2024-01-15 10:30:00", results[0].createdAt) + suite.Require().Equal(2, results[1].id) + suite.Require().Equal("bob", results[1].name) + suite.Require().Equal("2024-06-20 14:45:00", results[1].createdAt) + suite.Require().Equal(3, results[2].id) + suite.Require().Equal("carol", results[2].name) + suite.Require().Equal("2025-12-31 23:59:59", results[2].createdAt) +} + +func (suite *ApplierTestSuite) TestApplyIterationMoveTableCopyQueriesNoRows() { + ctx := context.Background() + var err error + + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL, name VARCHAR(50), created_at DATETIME NOT NULL, PRIMARY KEY(id));", getTestTableName())) + suite.Require().NoError(err) + _, err = suite.otherDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL, name VARCHAR(50), created_at DATETIME NOT NULL, PRIMARY KEY(id));", getTestOtherTableName())) + suite.Require().NoError(err) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name, created_at) VALUES (1, 'alice', '2024-01-15 10:30:00'), (2, 'bob', '2024-06-20 14:45:00'), (3, 'carol', '2025-12-31 23:59:59');", getTestTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.MoveTables.ConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "name", "created_at"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "name", "created_at"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "name", "created_at"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + migrationContext.MoveTables.TableNames = []string{testMysqlTableName} + migrationContext.MoveTables.TargetDatabase = testMysqlDatabaseOther + + applier := NewApplier(migrationContext) + applier.prepareQueries() + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Point the iteration range at a key range that contains no rows so the + // SELECT returns an empty result set and the INSERT is skipped. + migrationContext.MigrationIterationRangeMinValues = sql.ToColumnValues([]interface{}{100}) + migrationContext.MigrationIterationRangeMaxValues = sql.ToColumnValues([]interface{}{200}) + + chunkSize, rowsAffected, duration, err := applier.ApplyIterationMoveTableCopyQueries() + suite.Require().NoError(err) + suite.Require().Equal(int64(0), rowsAffected) + suite.Require().Equal(int64(1000), chunkSize) + suite.Require().Greater(duration, time.Duration(0)) + + // Verify no rows were copied to the target table. + var count int + err = suite.otherDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+getTestOtherTableName()).Scan(&count) + suite.Require().NoError(err) + suite.Require().Equal(0, count) +} + func TestApplier(t *testing.T) { if testing.Short() { t.Skip("skipping applier test suite in short mode") diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 352d02ba2..c820d074d 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -1724,7 +1724,12 @@ func (mgtr *Migrator) iterateChunks() error { // _ghost_ table, which no longer exists. So, bothering error messages and all, but no damage. return nil } - _, rowsAffected, _, err := mgtr.applier.ApplyIterationInsertQuery() + var rowsAffected int64 + if mgtr.migrationContext.IsMoveTablesMode() { + _, rowsAffected, _, err = mgtr.applier.ApplyIterationMoveTableCopyQueries() + } else { + _, rowsAffected, _, err = mgtr.applier.ApplyIterationInsertQuery() + } if err != nil { return err // wrapping call will retry } diff --git a/go/logic/test_utils.go b/go/logic/test_utils.go index cdcfcee84..49bca0859 100644 --- a/go/logic/test_utils.go +++ b/go/logic/test_utils.go @@ -17,6 +17,7 @@ var ( testMysqlUser = "root" testMysqlPass = "root-password" testMysqlDatabase = "test" + testMysqlDatabaseOther = "test_other" testMysqlTableName = "testing" ) @@ -36,6 +37,10 @@ func getTestOldTableName() string { return fmt.Sprintf("`%s`.`_%s_del`", testMysqlDatabase, testMysqlTableName) } +func getTestOtherTableName() string { + return fmt.Sprintf("`%s`.`%s`", testMysqlDatabaseOther, testMysqlTableName) +} + func getTestConnectionConfig(ctx context.Context, container testcontainers.Container) (*mysql.ConnectionConfig, error) { host, err := container.Host(ctx) if err != nil { diff --git a/go/sql/builder.go b/go/sql/builder.go index 7d0864601..1c3c612fa 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -7,6 +7,7 @@ package sql import ( "fmt" + "slices" "strconv" "strings" ) @@ -425,6 +426,146 @@ func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableNa return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, mappedSharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable, noWait) } +type MoveTableCopySelectQueryBuilder struct { + preparedStatement string + argsMapping []int + argsCount int +} + +func NewMoveTableCopySelectQueryBuilder(sourceDatabaseName, sourceTableName string, columns *ColumnList, uniqueKey string, uniqueKeyColumns *ColumnList, includeRangeStartValues bool) (*MoveTableCopySelectQueryBuilder, error) { + sourceDatabaseName = EscapeName(sourceDatabaseName) + sourceTableName = EscapeName(sourceTableName) + columnNames := columns.Names() + for i := range columnNames { + columnNames[i] = EscapeName(columnNames[i]) + } + sharedColumnsListing := strings.Join(columnNames, ", ") + uniqueKey = EscapeName(uniqueKey) + var minRangeComparisonSign = GreaterThanComparisonSign + if includeRangeStartValues { + minRangeComparisonSign = GreaterThanOrEqualsComparisonSign + } + rangeStartValues := buildColumnsPreparedValues(uniqueKeyColumns) + rangeEndValues := buildColumnsPreparedValues(uniqueKeyColumns) + dummyArgs := make([]any, len(uniqueKeyColumns.Columns())) + for i := range dummyArgs { + dummyArgs[i] = i + } + var argsMapping []int + + rangeStartComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns.Names(), rangeStartValues, dummyArgs, minRangeComparisonSign) + if err != nil { + return nil, err + } + for _, a := range rangeExplodedArgs { + idx := slices.Index(dummyArgs, a) + if idx == -1 { + return nil, fmt.Errorf("failed to build args mapping, missing argument pointer %v", a) + } + argsMapping = append(argsMapping, idx) + } + + rangeEndComparison, rangeExplodedArgs, err := BuildRangeComparison(uniqueKeyColumns.Names(), rangeEndValues, dummyArgs, LessThanOrEqualsComparisonSign) + if err != nil { + return nil, err + } + for _, a := range rangeExplodedArgs { + idx := slices.Index(dummyArgs, a) + if idx == -1 { + return nil, fmt.Errorf("failed to build args mapping, missing argument pointer %v", a) + } + argsMapping = append(argsMapping, idx+len(dummyArgs)) + } + + stmt := fmt.Sprintf(` + select /* gh-ost %s.%s */ %s + from + %s.%s + force index (%s) + where + (%s and %s) + `, + sourceDatabaseName, sourceTableName, sharedColumnsListing, + sourceDatabaseName, sourceTableName, + uniqueKey, + rangeStartComparison, rangeEndComparison, + ) + return &MoveTableCopySelectQueryBuilder{ + preparedStatement: stmt, + argsMapping: argsMapping, + argsCount: len(dummyArgs) * 2, + }, nil +} + +func (b *MoveTableCopySelectQueryBuilder) BuildQuery(rangeStartArgs, rangeEndArgs []any) (string, []any, error) { + if len(rangeStartArgs)+len(rangeEndArgs) != b.argsCount { + return "", nil, fmt.Errorf("got %d args but expected %d", len(rangeStartArgs)+len(rangeEndArgs), b.argsCount) + } + if len(rangeStartArgs) != len(rangeEndArgs) { + return "", nil, fmt.Errorf("mismatched number of start and end args: %d != %d", len(rangeStartArgs), len(rangeEndArgs)) + } + explodedArgs := make([]any, 0, len(b.argsMapping)) + for _, idx := range b.argsMapping { + if idx < len(rangeStartArgs) { + explodedArgs = append(explodedArgs, rangeStartArgs[idx]) + } else { + explodedArgs = append(explodedArgs, rangeEndArgs[idx-len(rangeStartArgs)]) + } + } + return b.preparedStatement, explodedArgs, nil +} + +type MoveTableCopyInsertQueryBuilder struct { + preparedStatement string + valueListPlaceholder string + valueListSize int +} + +func NewMoveTableCopyInsertQueryBuilder(targetDatabaseName, targetTableName string, columns *ColumnList) (*MoveTableCopyInsertQueryBuilder, error) { + targetDatabaseName = EscapeName(targetDatabaseName) + targetTableName = EscapeName(targetTableName) + columnsNames := columns.Names() + for i := range columnsNames { + columnsNames[i] = EscapeName(columnsNames[i]) + } + sharedColumnsListing := strings.Join(columnsNames, ", ") + valueListPlaceholder := "(" + strings.Join(buildColumnsPreparedValues(columns), ", ") + ")" + valueListSize := len(columnsNames) + stmt := fmt.Sprintf(` + insert /* gh-ost %s.%s */ ignore + into + %s.%s + (%s) + values + `, + targetDatabaseName, targetTableName, + targetDatabaseName, targetTableName, + sharedColumnsListing, + ) + return &MoveTableCopyInsertQueryBuilder{ + preparedStatement: stmt, + valueListPlaceholder: valueListPlaceholder, + valueListSize: valueListSize, + }, nil +} + +func (b *MoveTableCopyInsertQueryBuilder) BuildQuery(values []*ColumnValues) (string, []any, error) { + var explodedArgs []any + var builder strings.Builder + builder.WriteString(b.preparedStatement) + for i, value := range values { + if len(value.AbstractValues()) != b.valueListSize { + return "", nil, fmt.Errorf("got %d column values but expected %d", len(value.AbstractValues()), b.valueListSize) + } + if i > 0 { + builder.WriteString(",\n") + } + builder.WriteString(b.valueListPlaceholder) + explodedArgs = append(explodedArgs, value.AbstractValues()...) + } + return builder.String(), explodedArgs, nil +} + func BuildUniqueKeyRangeEndPreparedQueryViaOffset(databaseName, tableName string, uniqueKeyColumns *ColumnList, rangeStartArgs, rangeEndArgs []interface{}, chunkSize int64, includeRangeStartValues bool, hint string) (result string, explodedArgs []interface{}, err error) { if uniqueKeyColumns.Len() == 0 { return "", explodedArgs, fmt.Errorf("got 0 columns in BuildUniqueKeyRangeEndPreparedQuery") diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index be7075927..0fcf31441 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -1102,6 +1102,246 @@ func TestBuildDMLUpdateQuerySignedUnsigned(t *testing.T) { } } +func TestMoveTableCopySelectQueryBuilder(t *testing.T) { + t.Run("single column unique key", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + uniqueKeyColumns := NewColumnList([]string{"id"}) + + builder, err := NewMoveTableCopySelectQueryBuilder("mydb", "tbl", sharedColumns, "PRIMARY", uniqueKeyColumns, true) + require.NoError(t, err) + + query, args, err := builder.BuildQuery([]any{3}, []any{103}) + require.NoError(t, err) + + expected := ` + select /* gh-ost mydb.tbl */ id, name, position + from + mydb.tbl + force index (PRIMARY) + where + (((id > ?) or ((id = ?))) and ((id < ?) or ((id = ?)))) + ` + require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) + require.Equal(t, []any{3, 3, 103, 103}, args) + }) + + t.Run("single column unique key without range start", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + uniqueKeyColumns := NewColumnList([]string{"id"}) + + builder, err := NewMoveTableCopySelectQueryBuilder("mydb", "tbl", sharedColumns, "PRIMARY", uniqueKeyColumns, false) + require.NoError(t, err) + + query, args, err := builder.BuildQuery([]any{3}, []any{103}) + require.NoError(t, err) + + expected := ` + select /* gh-ost mydb.tbl */ id, name, position + from + mydb.tbl + force index (PRIMARY) + where + (((id > ?)) and ((id < ?) or ((id = ?)))) + ` + require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) + require.Equal(t, []any{3, 103, 103}, args) + }) + + t.Run("compound unique key", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) + + builder, err := NewMoveTableCopySelectQueryBuilder("mydb", "tbl", sharedColumns, "name_position_uidx", uniqueKeyColumns, true) + require.NoError(t, err) + + query, args, err := builder.BuildQuery([]any{3, 17}, []any{103, 117}) + require.NoError(t, err) + + expected := ` + select /* gh-ost mydb.tbl */ id, name, position + from + mydb.tbl + force index (name_position_uidx) + where + (((name > ?) or (((name = ?)) AND (position > ?)) or ((name = ?) and (position = ?))) + and ((name < ?) or (((name = ?)) AND (position < ?)) or ((name = ?) and (position = ?)))) + ` + require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) + require.Equal(t, []any{3, 3, 17, 3, 17, 103, 103, 117, 103, 117}, args) + }) + + t.Run("reuses prepared statement across calls", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name"}) + uniqueKeyColumns := NewColumnList([]string{"id"}) + + builder, err := NewMoveTableCopySelectQueryBuilder("mydb", "tbl", sharedColumns, "PRIMARY", uniqueKeyColumns, true) + require.NoError(t, err) + + query1, args1, err := builder.BuildQuery([]any{1}, []any{10}) + require.NoError(t, err) + query2, args2, err := builder.BuildQuery([]any{11}, []any{20}) + require.NoError(t, err) + + require.Equal(t, query1, query2) + require.Equal(t, []any{1, 1, 10, 10}, args1) + require.Equal(t, []any{11, 11, 20, 20}, args2) + }) + + t.Run("wrong args count", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name"}) + uniqueKeyColumns := NewColumnList([]string{"id"}) + + builder, err := NewMoveTableCopySelectQueryBuilder("mydb", "tbl", sharedColumns, "PRIMARY", uniqueKeyColumns, true) + require.NoError(t, err) + + _, _, err = builder.BuildQuery([]any{1, 2}, []any{10}) + require.Error(t, err) + }) + + t.Run("mismatched start and end args count", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) + + builder, err := NewMoveTableCopySelectQueryBuilder("mydb", "tbl", sharedColumns, "name_position_uidx", uniqueKeyColumns, true) + require.NoError(t, err) + + // Total args count matches argsCount (4), but start and end counts differ. + _, _, err = builder.BuildQuery([]any{1, 2, 3}, []any{10}) + require.Error(t, err) + require.Contains(t, err.Error(), "mismatched number of start and end args") + }) +} + +func BenchmarkMoveTableCopySelectQueryBuilderBuildQuery(b *testing.B) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) + + builder, err := NewMoveTableCopySelectQueryBuilder("mydb", "tbl", sharedColumns, "name_position_uidx", uniqueKeyColumns, true) + if err != nil { + b.Fatal(err) + } + + rangeStartArgs := []any{3, 17} + rangeEndArgs := []any{103, 117} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := builder.BuildQuery(rangeStartArgs, rangeEndArgs) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMoveTableCopyInsertQueryBuilder(t *testing.T) { + t.Run("single row", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + + builder, err := NewMoveTableCopyInsertQueryBuilder("mydb", "ghost", sharedColumns) + require.NoError(t, err) + + values := []*ColumnValues{ + ToColumnValues([]interface{}{1, "alice", 10}), + } + query, args, err := builder.BuildQuery(values) + require.NoError(t, err) + + expected := ` + insert /* gh-ost mydb.ghost */ ignore + into + mydb.ghost + (id, name, position) + values + (?, ?, ?) + ` + require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) + require.Equal(t, []any{1, "alice", 10}, args) + }) + + t.Run("multiple rows", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + + builder, err := NewMoveTableCopyInsertQueryBuilder("mydb", "ghost", sharedColumns) + require.NoError(t, err) + + values := []*ColumnValues{ + ToColumnValues([]interface{}{1, "alice", 10}), + ToColumnValues([]interface{}{2, "bob", 20}), + ToColumnValues([]interface{}{3, "carol", 30}), + } + query, args, err := builder.BuildQuery(values) + require.NoError(t, err) + + expected := ` + insert /* gh-ost mydb.ghost */ ignore + into + mydb.ghost + (id, name, position) + values + (?, ?, ?), + (?, ?, ?), + (?, ?, ?) + ` + require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) + require.Equal(t, []any{1, "alice", 10, 2, "bob", 20, 3, "carol", 30}, args) + }) + + t.Run("wrong column count", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + + builder, err := NewMoveTableCopyInsertQueryBuilder("mydb", "ghost", sharedColumns) + require.NoError(t, err) + + values := []*ColumnValues{ + ToColumnValues([]interface{}{1, "alice"}), + } + _, _, err = builder.BuildQuery(values) + require.Error(t, err) + }) + + t.Run("reuses prepared statement", func(t *testing.T) { + sharedColumns := NewColumnList([]string{"id", "name"}) + + builder, err := NewMoveTableCopyInsertQueryBuilder("mydb", "ghost", sharedColumns) + require.NoError(t, err) + + values1 := []*ColumnValues{ToColumnValues([]interface{}{1, "a"})} + values2 := []*ColumnValues{ToColumnValues([]interface{}{2, "b"})} + + query1, args1, err := builder.BuildQuery(values1) + require.NoError(t, err) + query2, args2, err := builder.BuildQuery(values2) + require.NoError(t, err) + + require.Equal(t, query1, query2) + require.Equal(t, []any{1, "a"}, args1) + require.Equal(t, []any{2, "b"}, args2) + }) +} + +func BenchmarkMoveTableCopyInsertQueryBuilderBuildQuery(b *testing.B) { + sharedColumns := NewColumnList([]string{"id", "name", "position"}) + + builder, err := NewMoveTableCopyInsertQueryBuilder("mydb", "ghost", sharedColumns) + if err != nil { + b.Fatal(err) + } + + values := []*ColumnValues{ + ToColumnValues([]interface{}{1, "alice", 10}), + ToColumnValues([]interface{}{2, "bob", 20}), + ToColumnValues([]interface{}{3, "carol", 30}), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, err := builder.BuildQuery(values) + if err != nil { + b.Fatal(err) + } + } +} + func TestCheckpointQueryBuilder(t *testing.T) { databaseName := "mydb" tableName := "_tbl_ghk"