5151 checkErr bool
5252 pathBR string
5353 pathDumpling string
54+ pathCDC string
55+ downstream string
56+
57+ downStreamHost string
58+ downStreamPort string
59+ downStreamUser string
60+ downStreamPassword string
61+ downStreamDB string
5462)
5563
5664func init () {
@@ -67,8 +75,10 @@ func init() {
6775 flag .IntVar (& retryConnCount , "retry-connection-count" , 120 , "The max number to retry to connect to the database." )
6876 flag .BoolVar (& checkErr , "check-error" , false , "if --error ERR does not match, return error instead of just warn" )
6977 flag .BoolVar (& collationDisable , "collation-disable" , false , "run collation related-test with new-collation disabled" )
70- flag .StringVar (& pathBR , "path-br" , "" , "Path of BR" )
71- flag .StringVar (& pathDumpling , "path-dumpling" , "" , "Path of Dumpling" )
78+ flag .StringVar (& pathBR , "path-br" , "" , "Path of BR binary" )
79+ flag .StringVar (& pathDumpling , "path-dumpling" , "" , "Path of Dumpling binary" )
80+ flag .StringVar (& pathCDC , "path-cdc" , "" , "Path of TiCDC binary" )
81+ flag .StringVar (& downstream , "downstream" , "" , "Connection string of downstream TiDB cluster" )
7282}
7383
7484const (
@@ -165,6 +175,12 @@ type tester struct {
165175
166176 // dump and import context through --dump_and_import $SOURCE_TABLE as $TARGET_TABLE'
167177 dumpAndImport * SourceAndTarget
178+
179+ // replication checkpoint database name
180+ replicationCheckpointDB string
181+
182+ // replication checkpoint ID
183+ replicationCheckpointID int
168184}
169185
170186func newTester (name string ) * tester {
@@ -179,6 +195,8 @@ func newTester(name string) *tester {
179195 t .enableConcurrent = false
180196 t .enableInfo = false
181197
198+ t .replicationCheckpointDB = "checkpoint-" + uuid .NewString ()
199+ t .replicationCheckpointID = 0
182200 return t
183201}
184202
@@ -219,7 +237,7 @@ func isTiDB(db *sql.DB) bool {
219237 return true
220238}
221239
222- func (t * tester ) addConnection (connName , hostName , userName , password , db string ) {
240+ func (t * tester ) addConnection (connName , hostName , port , userName , password , db string ) {
223241 var (
224242 mdb * sql.DB
225243 err error
@@ -285,6 +303,64 @@ func (t *tester) disconnect(connName string) {
285303 t .currConnName = default_connection
286304}
287305
306+ func parseUserInfo (userInfo string ) (string , string , error ) {
307+ colonIndex := strings .Index (userInfo , ":" )
308+ if colonIndex == - 1 {
309+ return "" , "" , fmt .Errorf ("missing password in userinfo" )
310+ }
311+ return userInfo [:colonIndex ], userInfo [colonIndex + 1 :], nil
312+ }
313+
314+ func parseHostPort (hostPort string ) (string , string , error ) {
315+ colonIndex := strings .Index (hostPort , ":" )
316+ if colonIndex == - 1 {
317+ return "" , "" , fmt .Errorf ("missing port in host:port" )
318+ }
319+ return hostPort [:colonIndex ], hostPort [colonIndex + 1 :], nil
320+ }
321+
322+ func parseDownstream (connStr string ) (dbname string , host string , port string , user string , password string ) {
323+ // Splitting into userinfo and network/database parts
324+ parts := strings .SplitN (connStr , "@" , 2 )
325+ if len (parts ) != 2 {
326+ fmt .Println ("Invalid connection string format" )
327+ return
328+ }
329+
330+ // Parsing userinfo
331+ userInfo := parts [0 ]
332+ user , password , err := parseUserInfo (userInfo )
333+ if err != nil {
334+ fmt .Println ("Error parsing userinfo:" , err )
335+ return
336+ }
337+
338+ // Splitting network type and database part
339+ networkAndDB := parts [1 ]
340+ networkTypeIndex := strings .Index (networkAndDB , "(" )
341+ if networkTypeIndex == - 1 {
342+ fmt .Println ("Invalid connection string format: missing network type" )
343+ return
344+ }
345+
346+ // Extracting host, port, and database name
347+ hostPortDB := networkAndDB [networkTypeIndex + 1 :]
348+ hostPortDBParts := strings .SplitN (hostPortDB , ")/" , 2 )
349+ if len (hostPortDBParts ) != 2 {
350+ fmt .Println ("Invalid connection string format" )
351+ return
352+ }
353+
354+ host , port , err = parseHostPort (hostPortDBParts [0 ])
355+ if err != nil {
356+ fmt .Println ("Error parsing host and port:" , err )
357+ return
358+ }
359+
360+ dbname = hostPortDBParts [1 ]
361+ return
362+ }
363+
288364func (t * tester ) preProcess () {
289365 dbName := "test"
290366 mdb , err := OpenDBWithRetry ("mysql" , user + ":" + passwd + "@tcp(" + host + ":" + port + ")/" + dbName + "?time_zone=%27Asia%2FShanghai%27&allowAllFiles=true" + params , retryConnCount )
@@ -303,6 +379,7 @@ func (t *tester) preProcess() {
303379 }
304380 for rows .Next () {
305381 rows .Scan (& dbName )
382+ fmt .Println ("Scanning database:" , dbName )
306383 t .originalSchemas [dbName ] = struct {}{}
307384 }
308385 }
@@ -313,13 +390,32 @@ func (t *tester) preProcess() {
313390 log .Fatalf ("Executing create db %s err[%v]" , dbName , err )
314391 }
315392 t .mdb = mdb
393+
316394 conn , err := initConn (mdb , user , passwd , host , dbName )
317395 if err != nil {
318396 log .Fatalf ("Open db err %v" , err )
319397 }
320398 t .conn [default_connection ] = conn
321399 t .curr = conn
322400 t .currConnName = default_connection
401+
402+ if downstream != "" {
403+ // create replication checkpoint database
404+ if _ , err := t .mdb .Exec (fmt .Sprintf ("create database if not exists `%s`" , t .replicationCheckpointDB )); err != nil {
405+ log .Fatalf ("Executing create db %s err[%v]" , t .replicationCheckpointDB , err )
406+ }
407+
408+ downStreamDB , downStreamHost , downStreamPort , downStreamUser , downStreamPassword = parseDownstream (downstream )
409+
410+ fmt .Println ("downStreamDB:" , downStreamDB )
411+ fmt .Println ("downStreamHost:" , downStreamHost )
412+ fmt .Println ("downStreamPort:" , downStreamPort )
413+ fmt .Println ("downStreamUser:" , downStreamUser )
414+ fmt .Println ("downStreamPassword:" , downStreamPassword )
415+
416+ t .addConnection ("downstream" , downStreamHost , downStreamPort , downStreamUser , downStreamPassword , downStreamDB )
417+ }
418+ t .switchConnection (default_connection )
323419}
324420
325421func (t * tester ) postProcess () {
@@ -329,6 +425,7 @@ func (t *tester) postProcess() {
329425 }
330426 t .mdb .Close ()
331427 }()
428+ t .switchConnection (default_connection )
332429 if ! reserveSchema {
333430 rows , err := t .mdb .Query ("show databases" )
334431 if err != nil {
@@ -339,6 +436,7 @@ func (t *tester) postProcess() {
339436 for rows .Next () {
340437 rows .Scan (& dbName )
341438 if _ , exists := t .originalSchemas [dbName ]; ! exists {
439+ fmt .Println ("Dropping database:" , dbName )
342440 _ , err := t .mdb .Exec (fmt .Sprintf ("drop database `%s`" , dbName ))
343441 if err != nil {
344442 log .Errorf ("failed to drop database: %s" , err .Error ())
@@ -421,6 +519,49 @@ func (t *tester) importTableStmt(path, target string) string {
421519 ` , target , path )
422520}
423521
522+ func (t * tester ) waitForReplicationCheckpoint () error {
523+ curr := t .currConnName
524+ defer t .switchConnection (curr )
525+
526+ if err := t .executeStmt (fmt .Sprintf ("use `%s`" , t .replicationCheckpointDB )); err != nil {
527+ return err
528+ }
529+
530+ markerTable := fmt .Sprintf ("marker_%d" , t .replicationCheckpointID )
531+ if err := t .executeStmt (fmt .Sprintf ("create table `%s`.`%s` (id int primary key)" , t .replicationCheckpointDB , markerTable )); err != nil {
532+ return err
533+ }
534+
535+ t .switchConnection ("downstream" )
536+
537+ checkInterval := 1 * time .Second
538+ queryTimeout := 10 * time .Second
539+
540+ // Keep querying until the table is found
541+ for {
542+ ctx , cancel := context .WithTimeout (context .Background (), queryTimeout )
543+ defer cancel ()
544+
545+ query := fmt .Sprintf ("select * from information_schema.tables where table_schema = '%s' and table_name = '%s';" , t .replicationCheckpointDB , markerTable )
546+ rows , err := t .mdb .QueryContext (ctx , query )
547+ if err != nil {
548+ log .Printf ("Error checking for table: %v" , err )
549+ return err
550+ }
551+
552+ if rows .Next () {
553+ fmt .Printf ("Table '%s' found!\n " , markerTable )
554+ break
555+ } else {
556+ fmt .Printf ("Table '%s' not found. Retrying in %v...\n " , markerTable , checkInterval )
557+ }
558+
559+ time .Sleep (checkInterval )
560+ }
561+
562+ return nil
563+ }
564+
424565func (t * tester ) Run () error {
425566 t .preProcess ()
426567 defer t .postProcess ()
@@ -543,7 +684,7 @@ func (t *tester) Run() error {
543684 for i := 0 ; i < 4 ; i ++ {
544685 args = append (args , "" )
545686 }
546- t .addConnection (args [0 ], args [1 ], args [2 ], args [3 ], args [4 ])
687+ t .addConnection (args [0 ], args [1 ], port , args [2 ], args [3 ], args [4 ])
547688 case Q_CONNECTION :
548689 q .Query = strings .TrimSpace (q .Query )
549690 if q .Query [len (q .Query )- 1 ] == ';' {
@@ -646,7 +787,10 @@ func (t *tester) Run() error {
646787 return err
647788 }
648789 log .WithFields (log.Fields {"stmt" : importStmt , "line" : q .Line }).Warn ("Restore end" )
649-
790+ case Q_REPLICATION_CHECKPOINT :
791+ if err := t .waitForReplicationCheckpoint (); err != nil {
792+ return err
793+ }
650794 default :
651795 log .WithFields (log.Fields {"command" : q .firstWord , "arguments" : q .Query , "line" : q .Line }).Warn ("command not implemented" )
652796 }
0 commit comments