From 9b38089b77928ec488b7d66a6b94f7dcc1bd7508 Mon Sep 17 00:00:00 2001 From: Lukas Jorg Date: Thu, 25 May 2023 07:45:50 +0200 Subject: [PATCH 1/5] Implemented periodic ping to keep connection of transaction alive --- expression_ext.go | 7 ++ main.go | 158 ++++++++++++++++++++++++++++++---------------- 2 files changed, 112 insertions(+), 53 deletions(-) diff --git a/expression_ext.go b/expression_ext.go index 9a0a140428..c11d892e82 100644 --- a/expression_ext.go +++ b/expression_ext.go @@ -601,6 +601,13 @@ func (db *DB) UpdateFields(fields ...string) *DB { return db.clone().Set("gorm:save_associations", false).Set("gorm:association_save_reference", false).Update(sets) } +// UpdateFieldsWithoutHooks updates the specified fields of the current model without calling any +// Update hooks and without touching the UpdatedAt column (if any exists). +// The specified fields have to be the names of the struct variables. +func (db *DB) UpdateFieldsWithoutHooks(fields ...string) *DB { + return db.clone().Set("gorm:update_column", true).UpdateFields(fields...) +} + func (db *DB) SelectFields(fields ...string) *DB { selects := strings.Join(fields, ", ") diff --git a/main.go b/main.go index d07b797837..388f2cfd35 100644 --- a/main.go +++ b/main.go @@ -32,15 +32,17 @@ type DB struct { // Open initialize a new db connection, need to import driver first, e.g: // -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } +// import _ "github.com/go-sql-driver/mysql" +// func main() { +// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") +// } +// // GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" +// +// import _ "github.com/jinzhu/gorm/dialects/mysql" +// // import _ "github.com/jinzhu/gorm/dialects/postgres" +// // import _ "github.com/jinzhu/gorm/dialects/sqlite" +// // import _ "github.com/jinzhu/gorm/dialects/mssql" func Open(dialect string, args ...interface{}) (db *DB, err error) { if len(args) == 0 { err = errors.New("invalid database source") @@ -121,7 +123,9 @@ func (s *DB) Dialect() Dialect { } // Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) +// +// db.Callback().Create().Register("update_created_at", updateCreated) +// // Refer https://jinzhu.github.io/gorm/development.html#callbacks func (s *DB) Callback() *Callback { s.parent.callbacks = s.parent.callbacks.clone() @@ -224,9 +228,10 @@ func (s *DB) Offset(offset interface{}) *DB { } // Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +// +// db.Order("name DESC") +// db.Order("name DESC", true) // reorder +// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression func (s *DB) Order(value interface{}, reorder ...bool) *DB { return s.clone().search.Order(value, reorder...).db } @@ -253,23 +258,26 @@ func (s *DB) Having(query interface{}, values ...interface{}) *DB { } // Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (s *DB) Joins(query interface{}, args ...interface{}) *DB { return s.clone().search.Joins(query, args...).db } // Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } // -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) // -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) // Refer https://jinzhu.github.io/gorm/crud.html#scopes func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { for _, f := range funcs { @@ -356,8 +364,9 @@ func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { } // Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) +// +// var ages []int64 +// db.Find(&users).Pluck("age", &ages) func (s *DB) Pluck(column string, value interface{}) *DB { return s.NewScope(s.Value).pluck(column, value).db } @@ -454,7 +463,8 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB { } // Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) +// +// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) func (s *DB) Raw(sql string, values ...interface{}) *DB { return s.clone().search.Raw(true).Where(sql, values...).db } @@ -469,10 +479,11 @@ func (s *DB) Exec(sql string, values ...interface{}) *DB { } // Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") func (s *DB) Model(value interface{}) *DB { c := s.clone() c.Value = value @@ -530,32 +541,70 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { if _, ok := s.db.(*sql.Tx); ok { // Already in a transaction return f(s) - } else { - // Lets start a new transaction - tx := s.Begin() - if err = tx.Error; err != nil { - return + } + + // Lets start a new transaction + tx := s.Begin() + if err = tx.Error; err != nil { + return err + } + + // Create a channel to stop the ping goroutine. + stopTxPing := make(chan bool) + // Start a goroutine that pings the database connection for a keep-alive. + go func() { + for { + select { + // Stop the goroutine when the stop channel receives a value .. + case <-stopTxPing: + return + // .. otherwise ping the database connection every 10 seconds. + case <-time.After(10 * time.Second): + err := tx.DB().Ping() + if err != nil { + tx.AddError( + fmt.Errorf( + "Could not ping database connection for transaction: %w", + err, + ), + ) + + return + } + } } - panicked := true - defer func() { - if panicked || err != nil { - rollbackErr := tx.Rollback().Error - if rollbackErr != nil { - if err == nil { - err = rollbackErr - } else { - err = fmt.Errorf("Transacton code and rollback failed: %s; %s", err, rollbackErr) - } + }() + + panicked := true + + defer func() { + if panicked || err != nil { + rollbackErr := tx.Rollback().Error + if rollbackErr != nil { + if err == nil { + err = rollbackErr + } else { + err = fmt.Errorf("Transacton code and rollback failed: %s; %s", err, rollbackErr) } } - }() - err = f(tx) - if err == nil { - err = tx.Commit().Error } - panicked = false - return + }() + + err = f(tx) + + // As soon as the inner stack has returned, stop the ping goroutine. As the transaction will be + // only committed after this point, the ping would fail and the goroutine will exit. + stopTxPing <- true + // Last but not least, close the stop ping channel. + close(stopTxPing) + + if err == nil { + err = tx.Commit().Error } + + panicked = false + + return err } // SkipAssocSave disables saving of associations @@ -674,7 +723,8 @@ func (s *DB) RemoveIndex(indexName string) *DB { } // AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") +// +// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { scope := s.NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) @@ -682,7 +732,8 @@ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate } // RemoveForeignKey Remove foreign key from the given scope, e.g: -// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") +// +// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") func (s *DB) RemoveForeignKey(field string, dest string) *DB { scope := s.clone().NewScope(s.Value) scope.removeForeignKey(field, dest) @@ -712,7 +763,8 @@ func (s *DB) Association(column string) *Association { } // Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +// +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (s *DB) Preload(column string, conditions ...interface{}) *DB { return s.clone().search.Preload(column, conditions...).db } From 806a552a9fa540a227ecaa471a511642006e4bf0 Mon Sep 17 00:00:00 2001 From: Lukas Jorg Date: Thu, 25 May 2023 09:28:55 +0200 Subject: [PATCH 2/5] wip --- main.go | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/main.go b/main.go index 388f2cfd35..046bb40688 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -551,6 +552,12 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { // Create a channel to stop the ping goroutine. stopTxPing := make(chan bool) + // Get the database connection for the transaction. + txConn, err := tx.DB().Conn(context.Background()) + if err != nil { + return fmt.Errorf("Could not get database connection for transaction: %w", err) + } + // Start a goroutine that pings the database connection for a keep-alive. go func() { for { @@ -560,16 +567,18 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { return // .. otherwise ping the database connection every 10 seconds. case <-time.After(10 * time.Second): - err := tx.DB().Ping() - if err != nil { - tx.AddError( - fmt.Errorf( - "Could not ping database connection for transaction: %w", - err, - ), - ) - - return + if txConn != nil { + err := txConn.PingContext(context.Background()) + if err != nil { + tx.AddError( + fmt.Errorf( + "Could not ping database connection for transaction: %w", + err, + ), + ) + + return + } } } } From d5fb9383dd529a36f47bd1bbd85498faca9547d2 Mon Sep 17 00:00:00 2001 From: Lukas Jorg Date: Thu, 25 May 2023 10:32:59 +0200 Subject: [PATCH 3/5] wip --- main.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/main.go b/main.go index 046bb40688..a4990d7ccb 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package gorm import ( - "context" "database/sql" "errors" "fmt" @@ -553,10 +552,15 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { // Create a channel to stop the ping goroutine. stopTxPing := make(chan bool) // Get the database connection for the transaction. - txConn, err := tx.DB().Conn(context.Background()) - if err != nil { - return fmt.Errorf("Could not get database connection for transaction: %w", err) - } + // txConn, err := tx.DB().Conn(context.Background()) + // if err != nil { + // return fmt.Errorf("Could not get database connection for transaction: %w", err) + // } + + // tx.DB().Ping() + + // txBlub := tx.db.(*sql.Tx) + // txBlub. // Start a goroutine that pings the database connection for a keep-alive. go func() { @@ -567,8 +571,8 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { return // .. otherwise ping the database connection every 10 seconds. case <-time.After(10 * time.Second): - if txConn != nil { - err := txConn.PingContext(context.Background()) + if s != nil && s.DB() != nil { + err := s.DB().Ping() if err != nil { tx.AddError( fmt.Errorf( From 4c674eb63315425473ccf17a1687843a71b3e146 Mon Sep 17 00:00:00 2001 From: Lukas Jorg Date: Thu, 25 May 2023 10:38:49 +0200 Subject: [PATCH 4/5] wip --- main.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/main.go b/main.go index a4990d7ccb..7e609bfd46 100644 --- a/main.go +++ b/main.go @@ -571,19 +571,19 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { return // .. otherwise ping the database connection every 10 seconds. case <-time.After(10 * time.Second): - if s != nil && s.DB() != nil { - err := s.DB().Ping() - if err != nil { - tx.AddError( - fmt.Errorf( - "Could not ping database connection for transaction: %w", - err, - ), - ) - - return - } + // if s != nil && s.DB() != nil { + err := s.DB().Ping() + if err != nil { + tx.AddError( + fmt.Errorf( + "Could not ping database connection for transaction: %w", + err, + ), + ) + + return } + // } } } }() From f0159e00da57e03d93631d09d9bdd9e439b8e066 Mon Sep 17 00:00:00 2001 From: Lukas Jorg Date: Thu, 25 May 2023 11:06:45 +0200 Subject: [PATCH 5/5] wip copyright DH+LJ --- main.go | 44 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/main.go b/main.go index 7e609bfd46..8b7bc9f1c5 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" @@ -516,6 +517,30 @@ func (s *DB) Begin() *DB { return c } +func (s *DB) BeginFancy() (*DB, *sql.Conn) { + var conn *sql.Conn + var err error + + c := s.clone() + + if db, ok := c.db.(sqlDb); ok && db != nil { + conn, err = c.DB().Conn(context.Background()) + if err != nil { + c.AddError(err) + + return c, nil + } + + tx, err := conn.BeginTx(context.Background(), nil) + c.db = interface{}(tx).(SQLCommon) + c.AddError(err) + } else { + c.AddError(ErrCantStartTransaction) + } + + return c, conn +} + // Commit commit a transaction func (s *DB) Commit() *DB { if db, ok := s.db.(sqlTx); ok && db != nil { @@ -543,8 +568,17 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { return f(s) } + // sConn, err := s.DB().Conn(context.Background()) + // if err != nil { + // return fmt.Errorf("Could not get database connection: %w", err) + // } + + // defer sConn.Close() + + // s.db + // Lets start a new transaction - tx := s.Begin() + tx, conn := s.BeginFancy() if err = tx.Error; err != nil { return err } @@ -572,7 +606,7 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { // .. otherwise ping the database connection every 10 seconds. case <-time.After(10 * time.Second): // if s != nil && s.DB() != nil { - err := s.DB().Ping() + err := conn.PingContext(context.Background()) if err != nil { tx.AddError( fmt.Errorf( @@ -597,10 +631,14 @@ func (s *DB) WrapInTx(f func(tx *DB) error) (err error) { if err == nil { err = rollbackErr } else { - err = fmt.Errorf("Transacton code and rollback failed: %s; %s", err, rollbackErr) + err = fmt.Errorf("Transaction code and rollback failed: %s; %s", err, rollbackErr) } } } + + if conn != nil { + err = conn.Close() + } }() err = f(tx)