@@ -72,10 +72,15 @@ const (
7272type query struct {
7373 firstWord string
7474 Query string
75+ File string
7576 Line int
7677 tp int
7778}
7879
80+ func (q * query ) location () string {
81+ return fmt .Sprintf ("%s:%d" , q .File , q .Line )
82+ }
83+
7984type Conn struct {
8085 // DB might be a shared one by multiple Conn, if the connection information are the same.
8186 mdb * sql.DB
@@ -325,7 +330,7 @@ func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt
325330func (t * tester ) Run () error {
326331 t .preProcess ()
327332 defer t .postProcess ()
328- queries , err := t .loadQueries ()
333+ queries , err := t .loadQueries (t . testFileName () )
329334 if err != nil {
330335 err = errors .Trace (err )
331336 t .addFailure (& testSuite , & err , 0 )
@@ -338,17 +343,33 @@ func (t *tester) Run() error {
338343 return err
339344 }
340345
341- var s string
342346 defer func () {
343347 if t .resultFD != nil {
344348 t .resultFD .Close ()
345349 }
346350 }()
347351
348- testCnt := 0
349352 startTime := time .Now ()
353+ testCnt , err := t .runQueries (queries )
354+ if err != nil {
355+ return err
356+ }
357+
358+ fmt .Printf ("%s: ok! %d test cases passed, take time %v s\n " , t .testFileName (), testCnt , time .Since (startTime ).Seconds ())
359+
360+ if xmlPath != "" {
361+ t .addSuccess (& testSuite , & startTime , testCnt )
362+ }
363+
364+ return t .flushResult ()
365+ }
366+
367+ func (t * tester ) runQueries (queries []query ) (int , error ) {
368+ testCnt := 0
350369 var concurrentQueue []query
351370 var concurrentSize int
371+ var s string
372+ var err error
352373 for _ , q := range queries {
353374 s = q .Query
354375 switch q .tp {
@@ -379,15 +400,15 @@ func (t *tester) Run() error {
379400 if err != nil {
380401 err = errors .Annotate (err , "Atoi failed" )
381402 t .addFailure (& testSuite , & err , testCnt )
382- return err
403+ return testCnt , err
383404 }
384405 }
385406 case Q_END_CONCURRENT :
386407 t .enableConcurrent = false
387408 if err = t .concurrentRun (concurrentQueue , concurrentSize ); err != nil {
388409 err = errors .Annotate (err , fmt .Sprintf ("concurrent test failed in %v" , t .name ))
389410 t .addFailure (& testSuite , & err , testCnt )
390- return err
411+ return testCnt , err
391412 }
392413 t .expectedErrs = nil
393414 case Q_ERROR :
@@ -406,7 +427,7 @@ func (t *tester) Run() error {
406427 } else if err = t .execute (q ); err != nil {
407428 err = errors .Annotate (err , fmt .Sprintf ("sql:%v" , q .Query ))
408429 t .addFailure (& testSuite , & err , testCnt )
409- return err
430+ return testCnt , err
410431 }
411432
412433 testCnt ++
@@ -426,7 +447,7 @@ func (t *tester) Run() error {
426447 if err != nil {
427448 err = errors .Annotate (err , fmt .Sprintf ("Could not parse column in --replace_column: sql:%v" , q .Query ))
428449 t .addFailure (& testSuite , & err , testCnt )
429- return err
450+ return testCnt , err
430451 }
431452
432453 t .replaceColumn = append (t .replaceColumn , ReplaceColumn {col : colNr , replace : []byte (cols [i + 1 ])})
@@ -473,7 +494,7 @@ func (t *tester) Run() error {
473494 r , err := t .executeStmtString (s )
474495 if err != nil {
475496 log .WithFields (log.Fields {
476- "query" : s , "line" : q .Line },
497+ "query" : s , "line" : q .location () },
477498 ).Error ("failed to perform let query" )
478499 return ""
479500 }
@@ -484,27 +505,59 @@ func (t *tester) Run() error {
484505 case Q_REMOVE_FILE :
485506 err = os .Remove (strings .TrimSpace (q .Query ))
486507 if err != nil {
487- return errors .Annotate (err , "failed to remove file" )
508+ return testCnt , errors .Annotate (err , "failed to remove file" )
488509 }
489510 case Q_REPLACE_REGEX :
490511 t .replaceRegex = nil
491512 regex , err := ParseReplaceRegex (q .Query )
492513 if err != nil {
493- return errors .Annotate (err , fmt .Sprintf ("Could not parse regex in --replace_regex: line: %d sql:%v" , q .Line , q .Query ))
514+ return testCnt , errors .Annotate (
515+ err , fmt .Sprintf ("Could not parse regex in --replace_regex: line: %s sql:%v" ,
516+ q .location (), q .Query ))
494517 }
495518 t .replaceRegex = regex
496- default :
497- log .WithFields (log.Fields {"command" : q .firstWord , "arguments" : q .Query , "line" : q .Line }).Warn ("command not implemented" )
498- }
499- }
519+ case Q_SOURCE :
520+ fileName := strings .TrimSpace (q .Query )
521+ cwd , err := os .Getwd ()
522+ if err != nil {
523+ return testCnt , err
524+ }
500525
501- fmt .Printf ("%s: ok! %d test cases passed, take time %v s\n " , t .testFileName (), testCnt , time .Since (startTime ).Seconds ())
526+ // For security, don't allow to include files from other locations
527+ fullpath , err := filepath .Abs (fileName )
528+ if err != nil {
529+ return testCnt , err
530+ }
531+ if ! strings .HasPrefix (fullpath , cwd ) {
532+ return testCnt , errors .Errorf ("included file %s is not prefixed with %s" , fullpath , cwd )
533+ }
502534
503- if xmlPath != "" {
504- t .addSuccess (& testSuite , & startTime , testCnt )
505- }
535+ // Make sure we have a useful error message if the file can't be found or isn't a regular file
536+ s , err := os .Stat (fileName )
537+ if err != nil {
538+ return testCnt , errors .Annotate (err ,
539+ fmt .Sprintf ("file sourced with --source doesn't exist: line %s, file: %s" ,
540+ q .location (), fileName ))
541+ }
542+ if ! s .Mode ().IsRegular () {
543+ return testCnt , errors .Errorf ("file sourced with --source isn't a regular file: line %s, file: %s" ,
544+ q .location (), fileName )
545+ }
506546
507- return t .flushResult ()
547+ // Process the queries in the file
548+ includedQueries , err := t .loadQueries (fileName )
549+ if err != nil {
550+ return testCnt , errors .Annotate (err , fmt .Sprintf ("error loading queries from %s" , fileName ))
551+ }
552+ _ , err = t .runQueries (includedQueries )
553+ if err != nil {
554+ return testCnt , err
555+ }
556+ default :
557+ log .WithFields (log.Fields {"command" : q .firstWord , "arguments" : q .Query , "line" : q .location ()}).Warn ("command not implemented" )
558+ }
559+ }
560+ return testCnt , nil
508561}
509562
510563func (t * tester ) concurrentRun (concurrentQueue []query , concurrentSize int ) error {
@@ -606,8 +659,8 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure
606659 }
607660}
608661
609- func (t * tester ) loadQueries () ([]query , error ) {
610- data , err := os .ReadFile (t . testFileName () )
662+ func (t * tester ) loadQueries (fileName string ) ([]query , error ) {
663+ data , err := os .ReadFile (fileName )
611664 if err != nil {
612665 return nil , err
613666 }
@@ -623,18 +676,30 @@ func (t *tester) loadQueries() ([]query, error) {
623676 newStmt = true
624677 continue
625678 } else if strings .HasPrefix (s , "--" ) {
626- queries = append (queries , query {Query : s , Line : i + 1 })
679+ queries = append (queries , query {
680+ Query : s ,
681+ Line : i + 1 ,
682+ File : fileName ,
683+ })
627684 newStmt = true
628685 continue
629686 } else if len (s ) == 0 {
630687 continue
631688 }
632689
633690 if newStmt {
634- queries = append (queries , query {Query : s , Line : i + 1 })
691+ queries = append (queries , query {
692+ Query : s ,
693+ Line : i + 1 ,
694+ File : fileName ,
695+ })
635696 } else {
636697 lastQuery := queries [len (queries )- 1 ]
637- lastQuery = query {Query : fmt .Sprintf ("%s\n %s" , lastQuery .Query , s ), Line : lastQuery .Line }
698+ lastQuery = query {
699+ Query : fmt .Sprintf ("%s\n %s" , lastQuery .Query , s ),
700+ Line : lastQuery .Line ,
701+ File : fileName ,
702+ }
638703 queries [len (queries )- 1 ] = lastQuery
639704 }
640705
@@ -668,8 +733,8 @@ func (t *tester) checkExpectedError(q query, err error) error {
668733 }
669734 }
670735 if ! checkErr {
671- log .Warnf ("%s:%d query succeeded, but expected error(s)! (expected errors: %s) (query: %s)" ,
672- t . name , q . Line , strings .Join (t .expectedErrs , "," ), q .Query )
736+ log .Warnf ("%s query succeeded, but expected error(s)! (expected errors: %s) (query: %s)" ,
737+ q . location () , strings .Join (t .expectedErrs , "," ), q .Query )
673738 return nil
674739 }
675740 return errors .Errorf ("Statement succeeded, expected error(s) '%s'" , strings .Join (t .expectedErrs , "," ))
@@ -684,7 +749,7 @@ func (t *tester) checkExpectedError(q query, err error) error {
684749 errNo = int (innerErr .Number )
685750 }
686751 if errNo == 0 {
687- log .Warnf ("%s:%d Could not parse mysql error: %s" , t . name , q . Line , err .Error ())
752+ log .Warnf ("%s Could not parse mysql error: %s" , q . location () , err .Error ())
688753 return err
689754 }
690755 for _ , s := range t .expectedErrs {
@@ -696,9 +761,9 @@ func (t *tester) checkExpectedError(q query, err error) error {
696761 checkErrNo = i
697762 } else {
698763 if len (t .expectedErrs ) > 1 {
699- log .Warnf ("%s:%d Unknown named error %s in --error %s" , t . name , q . Line , s , strings .Join (t .expectedErrs , "," ))
764+ log .Warnf ("%s Unknown named error %s in --error %s" , q . location () , s , strings .Join (t .expectedErrs , "," ))
700765 } else {
701- log .Warnf ("%s:%d Unknown named --error %s" , t . name , q . Line , s )
766+ log .Warnf ("%s Unknown named --error %s" , q . location () , s )
702767 }
703768 continue
704769 }
@@ -726,11 +791,11 @@ func (t *tester) checkExpectedError(q query, err error) error {
726791 }
727792 }
728793 if len (t .expectedErrs ) > 1 {
729- log .Warnf ("%s:%d query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)" ,
730- t . name , q . Line , gotErrCode , strings .Join (t .expectedErrs , "," ), err .Error (), q .Query )
794+ log .Warnf ("%s query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)" ,
795+ q . location () , gotErrCode , strings .Join (t .expectedErrs , "," ), err .Error (), q .Query )
731796 } else {
732- log .Warnf ("%s:%d query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)" ,
733- t . name , q . Line , gotErrCode , t .expectedErrs [0 ], err .Error (), q .Query )
797+ log .Warnf ("%s query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)" ,
798+ q . location () , gotErrCode , t .expectedErrs [0 ], err .Error (), q .Query )
734799 }
735800 errStr := err .Error ()
736801 for _ , reg := range t .replaceRegex {
0 commit comments