diff --git a/go.mod b/go.mod index 7c5a04311..0fe0c49af 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/h2non/filetype v1.1.3 github.com/jedib0t/go-pretty/v6 v6.7.10 github.com/manifoldco/promptui v0.9.0 - github.com/mattn/go-sqlite3 v1.14.42 + github.com/mattn/go-sqlite3 v1.14.44 github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 github.com/prometheus/client_golang v1.23.2 github.com/rivo/tview v0.42.0 @@ -69,7 +69,7 @@ require ( github.com/go-openapi/swag/typeutils v0.26.0 // indirect github.com/go-openapi/swag/yamlutils v0.26.0 // indirect github.com/go-openapi/validate v0.25.2 // indirect - github.com/go-sql-driver/mysql v1.9.3 // indirect + github.com/go-sql-driver/mysql v1.10.0 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/google/go-querystring v1.2.0 // indirect @@ -77,9 +77,9 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/lucasb-eyer/go-colorful v1.4.0 // indirect - github.com/mattn/go-isatty v0.0.21 // indirect + github.com/mattn/go-isatty v0.0.22 // indirect github.com/mattn/go-runewidth v0.0.23 // indirect - github.com/minio/sio v0.4.3 // indirect + github.com/minio/sio v0.5.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/oklog/ulid/v2 v2.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/go.sum b/go.sum index 46c6ea674..c45535bf4 100644 --- a/go.sum +++ b/go.sum @@ -86,8 +86,8 @@ github.com/go-openapi/testify/v2 v2.5.0 h1:UOCr63aAsMIDydZbZGqo5Ev01D4eydItRbekD github.com/go-openapi/testify/v2 v2.5.0/go.mod h1:SgsVHtfooshd0tublTtJ50FPKhujf47YRqauXXOUxfw= github.com/go-openapi/validate v0.25.2 h1:12NsfLAwGegqbGWr2CnvT65X/Q2USJipmJ9b7xDJZz0= github.com/go-openapi/validate v0.25.2/go.mod h1:Pgl1LpPPGFnZ+ys4/hTlDiRYQdI1ocKypgE+8Q8BLfY= -github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= -github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw= +github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk= github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= @@ -143,14 +143,14 @@ github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= -github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= -github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= +github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4= +github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/microsoft/go-mssqldb v1.7.2 h1:CHkFJiObW7ItKTJfHo1QX7QBBD1iV+mn1eOyRP3b/PA= github.com/microsoft/go-mssqldb v1.7.2/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA= -github.com/minio/sio v0.4.3 h1:JqyID1XM86KwBZox5RAdLD4MLPIDoCY2cke2CXCJCkg= -github.com/minio/sio v0.4.3/go.mod h1:4ANoe4CCXqnt1FCiLM0+vlBUhhWZzVOhYCz0069KtFc= +github.com/minio/sio v0.5.1 h1:sqtImrnCSHbDqO/lVy3tfbsctHzfelDv3NbXWEVcWT8= +github.com/minio/sio v0.5.1/go.mod h1:4ANoe4CCXqnt1FCiLM0+vlBUhhWZzVOhYCz0069KtFc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 h1:4kuARK6Y6FxaNu/BnU2OAaLF86eTVhP2hjTB6iMvItA= diff --git a/vendor/github.com/go-sql-driver/mysql/AUTHORS b/vendor/github.com/go-sql-driver/mysql/AUTHORS index ec346e203..42c7f02c0 100644 --- a/vendor/github.com/go-sql-driver/mysql/AUTHORS +++ b/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -18,8 +18,8 @@ Alex Snast Alexey Palazhchenko Andrew Reid Animesh Ray -Arne Hormann Ariel Mashraki +Arne Hormann Artur Melanchyk Asta Xie B Lamarche @@ -38,6 +38,7 @@ Daniel Montoya Daniel Nichter Daniƫl van Eeden Dave Protasowski +Demouth Diego Dupin Dirkjan Bussink DisposaBoy @@ -66,6 +67,7 @@ Jeff Hodges Jeffrey Charles Jennifer Purevsuren Jerome Meyer +Jiabin Zhang Jiajia Zhong Jian Zhen Joe Mann @@ -85,10 +87,12 @@ Linh Tran Tuan Lion Yang Luca Looz Lucas Liu -Lunny Xiao Luke Scott +Lunny Xiao Maciej Zimnoch Michael Woolnough +Minh Quang +Morgan Tocker Nao Yokotsuka Nathanial Murphy Nicola Peduzzi @@ -99,7 +103,6 @@ Paul Bonser Paulius Lozys Peter Schultz Phil Porada -Minh Quang Rebecca Chin Reed Allman Richard Wilkes @@ -134,6 +137,7 @@ Ziheng Lyu # Organizations Barracuda Networks, Inc. +Block, Inc. Counting Ltd. Defined Networking Inc. DigitalOcean Inc. diff --git a/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md b/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md index 75674b603..b24af9bed 100644 --- a/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md +++ b/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md @@ -1,13 +1,26 @@ # Changelog +## v1.10.0 (2026-04-28) + +* Fix `getSystemVar("max_allowed_packet")` potentially returned wrong value. (#1754) + This affects only when `maxAllowedPacket=0` is set. + +* Bump filippo.io/edwards25519 from 1.1.1 to 1.2.0. (#1756) + While older versions have reported CVEs, they do not affect go-mysql. + +* Update Go versions to 1.24-1.26. (#1763) + +* Enhance interpolateParams to correctly handle placeholders. (#1732) + The question mark (?) within strings and comments will no longer be treated as a placeholder. + + ## v1.9.3 (2025-06-13) * `tx.Commit()` and `tx.Rollback()` returned `ErrInvalidConn` always. Now they return cached real error if present. (#1690) -* Optimize reading small resultsets to fix performance regression - introduced by compression protocol support. (#1707) - +* Optimize reading small result sets to fix a performance regression + introduced by compression protocol support. (`#1707`) * Fix `db.Ping()` on compressed connection. (#1723) diff --git a/vendor/github.com/go-sql-driver/mysql/README.md b/vendor/github.com/go-sql-driver/mysql/README.md index da4593ccf..3da0538c7 100644 --- a/vendor/github.com/go-sql-driver/mysql/README.md +++ b/vendor/github.com/go-sql-driver/mysql/README.md @@ -1,5 +1,8 @@ # Go-MySQL-Driver +[![DeepWiki](https://img.shields.io/badge/DeepWiki-go--sql--driver%2Fmysql-blue.svg?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACwAAAAyCAYAAAAnWDnqAAAAAXNSR0IArs4c6QAAA05JREFUaEPtmUtyEzEQhtWTQyQLHNak2AB7ZnyXZMEjXMGeK/AIi+QuHrMnbChYY7MIh8g01fJoopFb0uhhEqqcbWTp06/uv1saEDv4O3n3dV60RfP947Mm9/SQc0ICFQgzfc4CYZoTPAswgSJCCUJUnAAoRHOAUOcATwbmVLWdGoH//PB8mnKqScAhsD0kYP3j/Yt5LPQe2KvcXmGvRHcDnpxfL2zOYJ1mFwrryWTz0advv1Ut4CJgf5uhDuDj5eUcAUoahrdY/56ebRWeraTjMt/00Sh3UDtjgHtQNHwcRGOC98BJEAEymycmYcWwOprTgcB6VZ5JK5TAJ+fXGLBm3FDAmn6oPPjR4rKCAoJCal2eAiQp2x0vxTPB3ALO2CRkwmDy5WohzBDwSEFKRwPbknEggCPB/imwrycgxX2NzoMCHhPkDwqYMr9tRcP5qNrMZHkVnOjRMWwLCcr8ohBVb1OMjxLwGCvjTikrsBOiA6fNyCrm8V1rP93iVPpwaE+gO0SsWmPiXB+jikdf6SizrT5qKasx5j8ABbHpFTx+vFXp9EnYQmLx02h1QTTrl6eDqxLnGjporxl3NL3agEvXdT0WmEost648sQOYAeJS9Q7bfUVoMGnjo4AZdUMQku50McDcMWcBPvr0SzbTAFDfvJqwLzgxwATnCgnp4wDl6Aa+Ax283gghmj+vj7feE2KBBRMW3FzOpLOADl0Isb5587h/U4gGvkt5v60Z1VLG8BhYjbzRwyQZemwAd6cCR5/XFWLYZRIMpX39AR0tjaGGiGzLVyhse5C9RKC6ai42ppWPKiBagOvaYk8lO7DajerabOZP46Lby5wKjw1HCRx7p9sVMOWGzb/vA1hwiWc6jm3MvQDTogQkiqIhJV0nBQBTU+3okKCFDy9WwferkHjtxib7t3xIUQtHxnIwtx4mpg26/HfwVNVDb4oI9RHmx5WGelRVlrtiw43zboCLaxv46AZeB3IlTkwouebTr1y2NjSpHz68WNFjHvupy3q8TFn3Hos2IAk4Ju5dCo8B3wP7VPr/FGaKiG+T+v+TQqIrOqMTL1VdWV1DdmcbO8KXBz6esmYWYKPwDL5b5FA1a0hwapHiom0r/cKaoqr+27/XcrS5UwSMbQAAAABJRU5ErkJggg==)](https://deepwiki.com/go-sql-driver/mysql) + + A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package ![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin") @@ -42,8 +45,8 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac ## Requirements -* Go 1.21 or higher. We aim to support the 3 latest versions of Go. -* MySQL (5.7+) and MariaDB (10.5+) are supported. +* Go 1.24 or higher. We aim to support the 3 latest versions of Go. +* MySQL (5.7+) and MariaDB (10.5+) are supported by maintainers. * [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP. * Do not ask questions about TiDB in our issue tracker or forum. * [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang) diff --git a/vendor/github.com/go-sql-driver/mysql/auth.go b/vendor/github.com/go-sql-driver/mysql/auth.go index 74e1bd03e..610044fc1 100644 --- a/vendor/github.com/go-sql-driver/mysql/auth.go +++ b/vendor/github.com/go-sql-driver/mysql/auth.go @@ -305,7 +305,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { if !mc.cfg.AllowNativePasswords { return nil, ErrNativePassword } - // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + // https://dev.mysql.com/doc/dev/mysql-server/8.4.5/page_protocol_connection_phase_authentication_methods_native_password_authentication.html // Native password authentication only need and will need 20-byte challenge. authResp := scramblePassword(authData[:20], mc.cfg.Passwd) return authResp, nil diff --git a/vendor/github.com/go-sql-driver/mysql/conncheck.go b/vendor/github.com/go-sql-driver/mysql/conncheck.go index 0ea721720..f9c5cb65c 100644 --- a/vendor/github.com/go-sql-driver/mysql/conncheck.go +++ b/vendor/github.com/go-sql-driver/mysql/conncheck.go @@ -7,7 +7,6 @@ // You can obtain one at http://mozilla.org/MPL/2.0/. //go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos -// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos package mysql diff --git a/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go b/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go index a56c138f2..0ebf05c21 100644 --- a/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go +++ b/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go @@ -7,7 +7,6 @@ // You can obtain one at http://mozilla.org/MPL/2.0/. //go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos -// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos package mysql diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go index 3e455a3ff..65204e2d2 100644 --- a/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/go-sql-driver/mysql/connection.go @@ -33,7 +33,8 @@ type mysqlConn struct { connector *connector maxAllowedPacket int maxWriteSize int - flags clientFlag + capabilities capabilityFlag + extCapabilities extendedCapabilityFlag status statusFlag sequence uint8 compressSequence uint8 @@ -171,7 +172,7 @@ func (mc *mysqlConn) close() { } // Closes the network connection and unsets internal variables. Do not call this -// function after successfully authentication, call Close instead. This function +// function after successful authentication, call Close instead. This function // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { @@ -223,13 +224,21 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { + if err = mc.skipColumns(stmt.paramCount); err != nil { return nil, err } } if columnCount > 0 { - err = mc.readUntilEOF() + if mc.extCapabilities&clientCacheMetadata != 0 { + if stmt.columns, err = mc.readColumns(int(columnCount), nil); err != nil { + return nil, err + } + } else { + if err = mc.skipColumns(int(columnCount)); err != nil { + return nil, err + } + } } } @@ -237,100 +246,184 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { - // Number of ? should be same to len(args) - if strings.Count(query, "?") != len(args) { - return "", driver.ErrSkip - } + noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0 + const ( + stateNormal = iota + stateString + stateEscape + stateEOLComment + stateSlashStarComment + stateBacktick + ) + + const ( + QUOTE_BYTE = byte('\'') + DBL_QUOTE_BYTE = byte('"') + BACKSLASH_BYTE = byte('\\') + QUESTION_MARK_BYTE = byte('?') + SLASH_BYTE = byte('/') + STAR_BYTE = byte('*') + HASH_BYTE = byte('#') + MINUS_BYTE = byte('-') + LINE_FEED_BYTE = byte('\n') + BACKTICK_BYTE = byte('`') + ) buf, err := mc.buf.takeCompleteBuffer() if err != nil { - // can not take the buffer. Something must be wrong with the connection mc.cleanup() - // interpolateParams would be called before sending any query. - // So its safe to retry. return "", driver.ErrBadConn } buf = buf[:0] + state := stateNormal + singleQuotes := false + lastChar := byte(0) argPos := 0 - - for i := 0; i < len(query); i++ { - q := strings.IndexByte(query[i:], '?') - if q == -1 { - buf = append(buf, query[i:]...) - break - } - buf = append(buf, query[i:i+q]...) - i += q - - arg := args[argPos] - argPos++ - - if arg == nil { - buf = append(buf, "NULL"...) + lenQuery := len(query) + lastIdx := 0 + + for i := range lenQuery { + currentChar := query[i] + if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) { + state = stateString + lastChar = currentChar continue } - - switch v := arg.(type) { - case int64: - buf = strconv.AppendInt(buf, v, 10) - case uint64: - // Handle uint64 explicitly because our custom ConvertValue emits unsigned values - buf = strconv.AppendUint(buf, v, 10) - case float64: - buf = strconv.AppendFloat(buf, v, 'g', -1, 64) - case bool: - if v { - buf = append(buf, '1') - } else { - buf = append(buf, '0') + switch currentChar { + case STAR_BYTE: + if state == stateNormal && lastChar == SLASH_BYTE { + state = stateSlashStarComment } - case time.Time: - if v.IsZero() { - buf = append(buf, "'0000-00-00'"...) - } else { - buf = append(buf, '\'') - buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) - if err != nil { - return "", err - } - buf = append(buf, '\'') + case SLASH_BYTE: + if state == stateSlashStarComment && lastChar == STAR_BYTE { + state = stateNormal + // Clear lastChar so the '/' that closed the comment isn't + // reused to start a new comment with a following '*'. + lastChar = 0 + continue } - case json.RawMessage: - buf = append(buf, '\'') - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeBytesBackslash(buf, v) - } else { - buf = escapeBytesQuotes(buf, v) + case HASH_BYTE: + if state == stateNormal { + state = stateEOLComment } - buf = append(buf, '\'') - case []byte: - if v == nil { - buf = append(buf, "NULL"...) - } else { - buf = append(buf, "_binary'"...) - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeBytesBackslash(buf, v) + case MINUS_BYTE: + if state == stateNormal && lastChar == MINUS_BYTE { + // -- only starts a comment if followed by whitespace or control char + if i+1 < lenQuery { + nextChar := query[i+1] + if nextChar == ' ' || nextChar == '\t' || nextChar == '\n' || nextChar == '\r' { + state = stateEOLComment + } } else { - buf = escapeBytesQuotes(buf, v) + state = stateEOLComment } - buf = append(buf, '\'') } - case string: - buf = append(buf, '\'') - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeStringBackslash(buf, v) - } else { - buf = escapeStringQuotes(buf, v) + case LINE_FEED_BYTE: + if state == stateEOLComment { + state = stateNormal } - buf = append(buf, '\'') - default: - return "", driver.ErrSkip - } + case DBL_QUOTE_BYTE: + if state == stateNormal { + state = stateString + singleQuotes = false + } else if state == stateString && !singleQuotes { + state = stateNormal + } else if state == stateEscape { + state = stateString + } + case QUOTE_BYTE: + if state == stateNormal { + state = stateString + singleQuotes = true + } else if state == stateString && singleQuotes { + state = stateNormal + } else if state == stateEscape { + state = stateString + } + case BACKSLASH_BYTE: + if state == stateString && !noBackslashEscapes { + state = stateEscape + } + case QUESTION_MARK_BYTE: + if state == stateNormal { + if argPos >= len(args) { + return "", driver.ErrSkip + } + buf = append(buf, query[lastIdx:i]...) + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + lastIdx = i + 1 + break + } - if len(buf)+4 > mc.maxAllowedPacket { - return "", driver.ErrSkip + switch v := arg.(type) { + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint64: + buf = strconv.AppendUint(buf, v, 10) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) + if err != nil { + return "", err + } + buf = append(buf, '\'') + } + case json.RawMessage: + if noBackslashEscapes { + buf = escapeBytesQuotes(buf, v, false) + } else { + buf = escapeBytesBackslash(buf, v, false) + } + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + if noBackslashEscapes { + buf = escapeBytesQuotes(buf, v, true) + } else { + buf = escapeBytesBackslash(buf, v, true) + } + } + case string: + if noBackslashEscapes { + buf = escapeStringQuotes(buf, v) + } else { + buf = escapeStringBackslash(buf, v) + } + default: + return "", driver.ErrSkip + } + + if len(buf)+4 > mc.maxAllowedPacket { + return "", driver.ErrSkip + } + lastIdx = i + 1 + } + case BACKTICK_BYTE: + if state == stateBacktick { + state = stateNormal + } else if state == stateNormal { + state = stateBacktick + } } + lastChar = currentChar } + buf = append(buf, query[lastIdx:]...) if argPos != len(args) { return "", driver.ErrSkip } @@ -370,19 +463,19 @@ func (mc *mysqlConn) exec(query string) error { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipRows(); err != nil { return err } } @@ -419,7 +512,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // Read Result var resLen int - resLen, err = handleOk.readResultSetHeaderPacket() + resLen, _, err = handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -439,21 +532,20 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) } // Columns - rows.rs.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen, nil) return rows, err } // Gets the value of the given MySQL System Variable -// The returned byte slice is only valid until the next read -func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { +func (mc *mysqlConn) getSystemVar(name string) (string, error) { // Send command handleOk := mc.clearResult() if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { - return nil, err + return "", err } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc @@ -461,17 +553,20 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if resLen > 0 { // Columns - if err := mc.readUntilEOF(); err != nil { - return nil, err + if err := mc.skipColumns(resLen); err != nil { + return "", err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { - return dest[0].([]byte), mc.readUntilEOF() + // Convert to string before skipRows, which may + // overwrite the read buffer that dest[0] points into. + val := string(dest[0].([]byte)) + return val, mc.skipRows() } } - return nil, err + return "", err } // cancel is called when the query has canceled. diff --git a/vendor/github.com/go-sql-driver/mysql/connector.go b/vendor/github.com/go-sql-driver/mysql/connector.go index bc1d46afc..3d3760477 100644 --- a/vendor/github.com/go-sql-driver/mysql/connector.go +++ b/vendor/github.com/go-sql-driver/mysql/connector.go @@ -42,7 +42,7 @@ func encodeConnectionAttributes(cfg *Config) string { } // user-defined connection attributes - for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") { + for connAttr := range strings.SplitSeq(cfg.ConnectionAttributes, ",") { k, v, found := strings.Cut(connAttr, ":") if !found { continue @@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer() // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtCapabilities, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err @@ -153,6 +153,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } + mc.initCapabilities(serverCapabilities, serverExtCapabilities, mc.cfg) if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err @@ -161,13 +162,14 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { // Handle response to auth packet, switch methods if possible if err = mc.handleAuthResult(authData, plugin); err != nil { // Authentication failed and MySQL has already closed the connection - // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // (https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html#sect_protocol_connection_phase_fast_path_fails). // Do not send COM_QUIT, just cleanup and return the error. mc.cleanup() return nil, err } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + // compression is enabled after auth, not right after sending handshake response. + if mc.capabilities&clientCompress > 0 { mc.compress = true mc.compIO = newCompIO(mc) } @@ -180,7 +182,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.Close() return nil, err } - n, err := strconv.Atoi(string(maxap)) + n, err := strconv.Atoi(maxap) if err != nil { mc.Close() return nil, fmt.Errorf("invalid max_allowed_packet value (%q): %w", maxap, err) diff --git a/vendor/github.com/go-sql-driver/mysql/const.go b/vendor/github.com/go-sql-driver/mysql/const.go index 4aadcd642..6f0cdf303 100644 --- a/vendor/github.com/go-sql-driver/mysql/const.go +++ b/vendor/github.com/go-sql-driver/mysql/const.go @@ -32,7 +32,7 @@ const ( ) // MySQL constants documentation: -// http://dev.mysql.com/doc/internals/en/client-server-protocol.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html const ( iOK byte = 0x00 @@ -42,11 +42,12 @@ const ( iERR byte = 0xff ) -// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags -type clientFlag uint32 +// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html +// https://mariadb.com/kb/en/connection/#capabilities +type capabilityFlag uint32 const ( - clientLongPassword clientFlag = 1 << iota + clientMySQL capabilityFlag = 1 << iota clientFoundRows clientLongFlag clientConnectWithDB @@ -73,6 +74,18 @@ const ( clientDeprecateEOF ) +// https://mariadb.com/kb/en/connection/#capabilities +type extendedCapabilityFlag uint32 + +const ( + progressIndicator extendedCapabilityFlag = 1 << iota + clientComMulti + clientStmtBulkOperations + clientExtendedMetadata + clientCacheMetadata + clientUnitBulkResult +) + const ( comQuit byte = iota + 1 comInitDB diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go index ecf62567a..491e10f37 100644 --- a/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -15,6 +15,7 @@ import ( "crypto/tls" "errors" "fmt" + "maps" "math/big" "net" "net/url" @@ -157,9 +158,7 @@ func (cfg *Config) Clone() *Config { } if len(cp.Params) > 0 { cp.Params = make(map[string]string, len(cfg.Params)) - for k, v := range cfg.Params { - cp.Params[k] = v - } + maps.Copy(cp.Params, cfg.Params) } if cfg.pubKey != nil { cp.pubKey = &rsa.PublicKey{ @@ -414,7 +413,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) { if dsn[j] == '@' { // username[:password] // Find the first ':' in dsn[:j] - for k = 0; k < j; k++ { + for k = 0; k < j; k++ { // We cannot use k = range j here, because we use dsn[:k] below if dsn[k] == ':' { cfg.Passwd = dsn[k+1 : j] break @@ -477,7 +476,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) { // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed func parseDSNParams(cfg *Config, params string) (err error) { - for _, v := range strings.Split(params, "&") { + for v := range strings.SplitSeq(params, "&") { key, value, found := strings.Cut(v, "=") if !found { continue diff --git a/vendor/github.com/go-sql-driver/mysql/fields.go b/vendor/github.com/go-sql-driver/mysql/fields.go index be5cd809a..ee9d96417 100644 --- a/vendor/github.com/go-sql-driver/mysql/fields.go +++ b/vendor/github.com/go-sql-driver/mysql/fields.go @@ -120,23 +120,24 @@ func (mf *mysqlField) typeDatabaseName() string { } var ( - scanTypeFloat32 = reflect.TypeOf(float32(0)) - scanTypeFloat64 = reflect.TypeOf(float64(0)) - scanTypeInt8 = reflect.TypeOf(int8(0)) - scanTypeInt16 = reflect.TypeOf(int16(0)) - scanTypeInt32 = reflect.TypeOf(int32(0)) - scanTypeInt64 = reflect.TypeOf(int64(0)) - scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) - scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) - scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) - scanTypeUint8 = reflect.TypeOf(uint8(0)) - scanTypeUint16 = reflect.TypeOf(uint16(0)) - scanTypeUint32 = reflect.TypeOf(uint32(0)) - scanTypeUint64 = reflect.TypeOf(uint64(0)) - scanTypeString = reflect.TypeOf("") - scanTypeNullString = reflect.TypeOf(sql.NullString{}) - scanTypeBytes = reflect.TypeOf([]byte{}) - scanTypeUnknown = reflect.TypeOf(new(any)) + scanTypeFloat32 = reflect.TypeFor[float32]() + scanTypeFloat64 = reflect.TypeFor[float64]() + scanTypeInt8 = reflect.TypeFor[int8]() + scanTypeInt16 = reflect.TypeFor[int16]() + scanTypeInt32 = reflect.TypeFor[int32]() + scanTypeInt64 = reflect.TypeFor[int64]() + scanTypeNullFloat = reflect.TypeFor[sql.NullFloat64]() + scanTypeNullInt = reflect.TypeFor[sql.NullInt64]() + scanTypeNullUint = reflect.TypeFor[sql.Null[uint64]]() + scanTypeNullTime = reflect.TypeFor[sql.NullTime]() + scanTypeUint8 = reflect.TypeFor[uint8]() + scanTypeUint16 = reflect.TypeFor[uint16]() + scanTypeUint32 = reflect.TypeFor[uint32]() + scanTypeUint64 = reflect.TypeFor[uint64]() + scanTypeString = reflect.TypeFor[string]() + scanTypeNullString = reflect.TypeFor[sql.NullString]() + scanTypeBytes = reflect.TypeFor[[]byte]() + scanTypeUnknown = reflect.TypeFor[*any]() ) type mysqlField struct { @@ -185,6 +186,9 @@ func (mf *mysqlField) scanType() reflect.Type { } return scanTypeInt64 } + if mf.flags&flagUnsigned != 0 { + return scanTypeNullUint + } return scanTypeNullInt case fieldTypeFloat: diff --git a/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/go-sql-driver/mysql/infile.go index 453ae091e..597b5e7f6 100644 --- a/vendor/github.com/go-sql-driver/mysql/infile.go +++ b/vendor/github.com/go-sql-driver/mysql/infile.go @@ -95,10 +95,7 @@ const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead a func (mc *okHandler) handleInFileRequest(name string) (err error) { var rdr io.Reader - packetSize := defaultPacketSize - if mc.maxWriteSize < packetSize { - packetSize = mc.maxWriteSize - } + packetSize := min(mc.maxWriteSize, defaultPacketSize) if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader // The server might return an an absolute path. See issue #355. diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index 831fca6ca..d0b21b06c 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -179,20 +179,22 @@ func (mc *mysqlConn) writePacket(data []byte) error { ******************************************************************************/ // Handshake Initialization Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +// https://mariadb.com/kb/en/connection/#initial-handshake-packet +func (mc *mysqlConn) readHandshakePacket() (data []byte, capabilities capabilityFlag, extendedCapabilities extendedCapabilityFlag, plugin string, err error) { data, err = mc.readPacket() if err != nil { return } if data[0] == iERR { - return nil, "", mc.handleErrorPacket(data) + err = mc.handleErrorPacket(data) + return } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, "", fmt.Errorf( + return nil, 0, 0, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -210,15 +212,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro pos += 8 + 1 // capability flags (lower 2 bytes) [2 bytes] - mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) - if mc.flags&clientProtocol41 == 0 { - return nil, "", ErrOldProtocol + capabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if capabilities&clientProtocol41 == 0 { + return nil, capabilities, 0, "", ErrOldProtocol } - if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { + if capabilities&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil } else { - return nil, "", ErrNoTLS + return nil, capabilities, 0, "", ErrNoTLS } } pos += 2 @@ -228,11 +230,16 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // status flags [2 bytes] pos += 3 // capability flags (upper 2 bytes) [2 bytes] - mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + capabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 pos += 2 // length of auth-plugin-data [1 byte] - // reserved (all [00]) [10 bytes] - pos += 11 + // reserved (all [00]) [6 bytes] + pos += 7 + if capabilities&clientMySQL == 0 { + // MariaDB server extended flag + extendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4])) + } + pos += 4 // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -260,82 +267,72 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], capabilities, extendedCapabilities, plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], capabilities, 0, plugin, nil } -// Client Authentication Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { - // Adjust client flags based on server support - clientFlags := clientProtocol41 | - clientSecureConn | - clientLongPassword | - clientTransactions | - clientLocalFiles | - clientPluginAuth | - clientMultiResults | - mc.flags&clientConnectAttrs | - mc.flags&clientLongFlag - - sendConnectAttrs := mc.flags&clientConnectAttrs != 0 - - if mc.cfg.ClientFoundRows { - clientFlags |= clientFoundRows +// initCapabilities initializes the capabilities based on server support and configuration +func (mc *mysqlConn) initCapabilities(serverCapabilities capabilityFlag, serverExtCapabilities extendedCapabilityFlag, cfg *Config) { + clientCapabilities := + clientMySQL | + clientLongFlag | + clientProtocol41 | + clientSecureConn | + clientTransactions | + clientPluginAuthLenEncClientData | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + clientConnectAttrs | + clientDeprecateEOF + + if cfg.ClientFoundRows { + clientCapabilities |= clientFoundRows } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { - clientFlags |= clientCompress + if cfg.compress { + clientCapabilities |= clientCompress } // To enable TLS / SSL if mc.cfg.TLS != nil { - clientFlags |= clientSSL + clientCapabilities |= clientSSL } if mc.cfg.MultiStatements { - clientFlags |= clientMultiStatements + clientCapabilities |= clientMultiStatements } - - // encode length of the auth plugin data - var authRespLEIBuf [9]byte - authRespLen := len(authResp) - authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) - if len(authRespLEI) > 1 { - // if the length can not be written in 1 byte, it must be written as a - // length encoded integer - clientFlags |= clientPluginAuthLenEncClientData + if n := len(cfg.DBName); n > 0 { + clientCapabilities |= clientConnectWithDB } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + // only keep client capabilities that server have + mc.capabilities = clientCapabilities & serverCapabilities - // To specify a db name - if n := len(mc.cfg.DBName); n > 0 { - clientFlags |= clientConnectWithDB - pktLen += n + 1 - } - - // encode length of the connection attributes - var connAttrsLEI []byte - if sendConnectAttrs { - var connAttrsLEIBuf [9]byte - connAttrsLen := len(mc.connector.encodedAttributes) - connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) - pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) - } + // set MariaDB extended clientCacheMetadata capability if server support it + mc.extCapabilities = clientCacheMetadata & serverExtCapabilities +} - // Calculate packet length and get buffer with that size - data, err := mc.buf.takeBuffer(pktLen + 4) +// Client Authentication Packet +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { + // packet header 4 + // capabilities 4 + // maxPacketSize 4 + // collation id 1 + // filler 23 + data, err := mc.buf.takeSmallBuffer(4*3 + 24) if err != nil { mc.cleanup() return err } + _ = data[4*3+23] // boundery check - // ClientFlags [32 bit] - binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags)) + // clientCapabilities [32 bit] + binary.LittleEndian.PutUint32(data[4:], uint32(mc.capabilities)) // MaxPacketSize [32 bit] (none) binary.LittleEndian.PutUint32(data[8:], 0) @@ -353,16 +350,26 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // Filler [23 bytes] (all 0x00) + // or filler 19bytes + mariadb extCapabilities pos := 13 - for ; pos < 13+23; pos++ { - data[pos] = 0 + if mc.capabilities&clientMySQL == 0 { + for ; pos < 13+19; pos++ { + data[pos] = 0 + } + // MariaDB Extended Capabilities + binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.extCapabilities)) + } else { + for ; pos < 13+23; pos++ { + data[pos] = 0 + } } // SSL Connection Request Packet - // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html + // https://mariadb.com/kb/en/connection/#sslrequest-packet if mc.cfg.TLS != nil { // Send TLS / SSL request packet - if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + if err := mc.writePacket(data); err != nil { return err } @@ -379,37 +386,35 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // User [null terminated string] if len(mc.cfg.User) > 0 { - pos += copy(data[pos:], mc.cfg.User) + data = append(data, mc.cfg.User...) } - data[pos] = 0x00 - pos++ + data = append(data, 0) // Auth Data [length encoded integer] - pos += copy(data[pos:], authRespLEI) - pos += copy(data[pos:], authResp) + data = appendLengthEncodedInteger(data, uint64(len(authResp))) + data = append(data, authResp...) - // Databasename [null terminated string] - if len(mc.cfg.DBName) > 0 { - pos += copy(data[pos:], mc.cfg.DBName) - data[pos] = 0x00 - pos++ + // Database name [null terminated string] + if mc.capabilities&clientConnectWithDB != 0 { + data = append(data, mc.cfg.DBName...) + data = append(data, 0) } - pos += copy(data[pos:], plugin) - data[pos] = 0x00 - pos++ + data = append(data, plugin...) + data = append(data, 0) // Connection Attributes - if sendConnectAttrs { - pos += copy(data[pos:], connAttrsLEI) - pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + if mc.capabilities&clientConnectAttrs != 0 { + connAttrsLen := len(mc.connector.encodedAttributes) + data = appendLengthEncodedInteger(data, uint64(connAttrsLen)) + data = append(data, mc.connector.encodedAttributes...) } // Send Auth packet - return mc.writePacket(data[:pos]) + return mc.writePacket(data) } -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_response.html func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) data, err := mc.buf.takeBuffer(pktLen) @@ -511,7 +516,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { case iEOF: if len(data) == 1 { - // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_old_auth_switch_request.html return nil, "mysql_old_password", nil } pluginEndIndex := bytes.IndexByte(data, 0x00) @@ -545,36 +550,41 @@ func (mc *okHandler) readResultOK() error { // Result Set Header Packet // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html -func (mc *okHandler) readResultSetHeaderPacket() (int, error) { +func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) { // handleOkPacket replaces both values; other cases leave the values unchanged. mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.insertIds = append(mc.result.insertIds, 0) data, err := mc.conn().readPacket() if err != nil { - return 0, err + return 0, false, err } switch data[0] { case iOK: - return 0, mc.handleOkPacket(data) + return 0, false, mc.handleOkPacket(data) case iERR: - return 0, mc.conn().handleErrorPacket(data) + return 0, false, mc.conn().handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, false, mc.handleInFileRequest(string(data[1:])) } // column count // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html - num, _, _ := readLengthEncodedInteger(data) + // https://mariadb.com/kb/en/result-set-packets/#column-count-packet + num, _, len := readLengthEncodedInteger(data) + + if mc.extCapabilities&clientCacheMetadata != 0 { + return int(num), data[len] == 0x01, nil + } // ignore remaining data in the packet. see #1478. - return int(num), nil + return int(num), true, nil } // Error Packet -// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html func (mc *mysqlConn) handleErrorPacket(data []byte) error { if data[0] != iERR { return ErrMalformPkt @@ -656,7 +666,7 @@ func (mc *mysqlConn) clearResult() *okHandler { } // Ok Packet -// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html func (mc *okHandler) handleOkPacket(data []byte) error { var n, m int var affectedRows, insertId uint64 @@ -690,24 +700,19 @@ func (mc *okHandler) handleOkPacket(data []byte) error { } // Read Packets as Field Packets until EOF-Packet or an Error appears -// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 -func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_column_definition.html#sect_protocol_com_query_response_text_resultset_column_definition_41 +func (mc *mysqlConn) readColumns(count int, old []mysqlField) ([]mysqlField, error) { columns := make([]mysqlField, count) + if len(old) != count { + old = nil + } - for i := 0; ; i++ { + for i := range count { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { - if i == count { - return columns, nil - } - return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) - } - // Catalog pos, err := skipLengthEncodedString(data) if err != nil { @@ -728,7 +733,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { return nil, err } pos += n - columns[i].tableName = string(tableName) + if old != nil && old[i].tableName == string(tableName) { + // avoid allocating new string + columns[i].tableName = old[i].tableName + } else { + columns[i].tableName = string(tableName) + } } else { n, err = skipLengthEncodedString(data[pos:]) if err != nil { @@ -749,7 +759,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } - columns[i].name = string(name) + if old != nil && old[i].name == string(name) { + // avoid allocating new string + columns[i].name = old[i].name + } else { + columns[i].name = string(name) + } pos += n // Original name [len coded string] @@ -780,17 +795,17 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Decimals [uint8] columns[i].decimals = data[pos] - //pos++ + } - // Default value [len coded binary] - //if pos < len(data) { - // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) - //} + // skip EOF packet if client does not support deprecateEOF + if err := mc.skipEof(); err != nil { + return nil, err } + return columns, nil } // Read Packets as Field Packets until EOF-Packet or an Error appears -// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_row.html func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc @@ -804,9 +819,20 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + // text row packets may starts with LengthEncodedString. + // In such case, 0xFE can mean string larger than 0xffffff. + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le + if data[0] == iEOF && len(data) <= 0xffffff { + if mc.capabilities&clientDeprecateEOF == 0 { + // Deprecated EOF packet + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_eof_packet.html + mc.status = readStatus(data[3:]) + } else { + // Ok Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) // affected_rows + _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id + mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -880,8 +906,34 @@ func (rows *textRows) readRow(dest []driver.Value) error { return nil } -// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() error { +func (mc *mysqlConn) skipPackets(n int) error { + for range n { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +// skips EOF packet after n * ColumnDefinition packets when clientDeprecateEOF is not set +func (mc *mysqlConn) skipEof() error { + if mc.capabilities&clientDeprecateEOF == 0 { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) skipColumns(n int) error { + if err := mc.skipPackets(n); err != nil { + return err + } + return mc.skipEof() +} + +// Reads Packets until EOF-Packet or an Error appears. +func (mc *mysqlConn) skipRows() error { for { data, err := mc.readPacket() if err != nil { @@ -892,10 +944,20 @@ func (mc *mysqlConn) readUntilEOF() error { case iERR: return mc.handleErrorPacket(data) case iEOF: - if len(data) == 5 { - mc.status = readStatus(data[3:]) + // text row packets may starts with LengthEncodedString. + // In such case, 0xFE can mean string larger than 0xffffff. + if len(data) <= 0xffffff { + if mc.capabilities&clientDeprecateEOF == 0 { + // EOF packet + mc.status = readStatus(data[3:]) + } else { + // OK packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) // affected_rows + _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id + mc.status = readStatus(data[1+n+m:]) + } + return nil } - return nil } } } @@ -905,7 +967,7 @@ func (mc *mysqlConn) readUntilEOF() error { ******************************************************************************/ // Prepare Result Packets -// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { data, err := stmt.mc.readPacket() if err == nil { @@ -932,7 +994,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { return 0, err } -// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_send_long_data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen @@ -979,7 +1041,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { } // Execute Prepared Statement -// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( @@ -993,10 +1055,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { mc := stmt.mc // Determine threshold dynamically to avoid packet size shortage. - longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) - if longDataSize < 64 { - longDataSize = 64 - } + longDataSize := max(mc.maxAllowedPacket/(stmt.paramCount+1), 64) // Reset packet-sequence mc.resetSequence() @@ -1185,17 +1244,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // mc.affectedRows and mc.insertIds. func (mc *okHandler) discardResults() error { for mc.status&statusMoreResultsExists != 0 { - resLen, err := mc.readResultSetHeaderPacket() + resLen, _, err := mc.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipColumns(resLen); err != nil { return err } // rows - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipRows(); err != nil { return err } } @@ -1203,7 +1262,7 @@ func (mc *okHandler) discardResults() error { return nil } -// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row func (rows *binaryRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() if err != nil { @@ -1212,9 +1271,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + // EOF/OK Packet + if data[0] == iEOF { + if rows.mc.capabilities&clientDeprecateEOF == 0 { + // EOF packet + rows.mc.status = readStatus(data[3:]) + } else { + // OK Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + rows.mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil diff --git a/vendor/github.com/go-sql-driver/mysql/result.go b/vendor/github.com/go-sql-driver/mysql/result.go index d51631468..82dc0f9b6 100644 --- a/vendor/github.com/go-sql-driver/mysql/result.go +++ b/vendor/github.com/go-sql-driver/mysql/result.go @@ -8,6 +8,8 @@ package mysql +import "slices" + import "database/sql/driver" // Result exposes data not available through *connection.Result. @@ -42,9 +44,9 @@ func (res *mysqlResult) RowsAffected() (int64, error) { } func (res *mysqlResult) AllLastInsertIds() []int64 { - return append([]int64{}, res.insertIds...) // defensive copy + return slices.Clone(res.insertIds) // defensive copy } func (res *mysqlResult) AllRowsAffected() []int64 { - return append([]int64{}, res.affectedRows...) // defensive copy + return slices.Clone(res.affectedRows) // defensive copy } diff --git a/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/go-sql-driver/mysql/rows.go index df98417b8..190e75f9b 100644 --- a/vendor/github.com/go-sql-driver/mysql/rows.go +++ b/vendor/github.com/go-sql-driver/mysql/rows.go @@ -113,7 +113,7 @@ func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.rs.done { - err = mc.readUntilEOF() + err = mc.skipRows() } if err == nil { handleOk := mc.clearResult() @@ -143,7 +143,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { // Remove unread packets from stream if !rows.rs.done { - if err := rows.mc.readUntilEOF(); err != nil { + if err := rows.mc.skipRows(); err != nil { return 0, err } rows.rs.done = true @@ -156,7 +156,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { rows.rs = resultSet{} // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to // nextResultSet. - resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() + resLen, _, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() if err != nil { // Clean up about multi-results flag rows.rs.done = true @@ -186,7 +186,7 @@ func (rows *binaryRows) NextResultSet() error { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(resLen, nil) return err } @@ -208,7 +208,7 @@ func (rows *textRows) NextResultSet() (err error) { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(resLen, nil) return err } diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go index 35df85457..0261903b9 100644 --- a/vendor/github.com/go-sql-driver/mysql/statement.go +++ b/vendor/github.com/go-sql-driver/mysql/statement.go @@ -20,6 +20,7 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int + columns []mysqlField } func (stmt *mysqlStmt) Close() error { @@ -64,19 +65,26 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { handleOk := stmt.mc.clearResult() // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { - return nil, err + if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 { + // we can not skip column metadata because next stmt.Query() may use it. + if stmt.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { + return nil, err + } + } else { + if err = mc.skipColumns(resLen); err != nil { + return nil, err + } } // Rows - if err := mc.readUntilEOF(); err != nil { + if err = mc.skipRows(); err != nil { return nil, err } } @@ -107,7 +115,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Read Result handleOk := stmt.mc.clearResult() - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -116,7 +124,17 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc - rows.rs.columns, err = mc.readColumns(resLen) + if metadataFollows { + if rows.rs.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { + return nil, err + } + stmt.columns = rows.rs.columns + } else { + if err = mc.skipEof(); err != nil { + return nil, err + } + rows.rs.columns = stmt.columns + } } else { rows.rs.done = true @@ -131,7 +149,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { return rows, err } -var jsonType = reflect.TypeOf(json.RawMessage{}) +var jsonType = reflect.TypeFor[json.RawMessage]() type converter struct{} @@ -193,7 +211,7 @@ func (c converter) ConvertValue(v any) (driver.Value, error) { return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } -var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var valuerReflectType = reflect.TypeFor[driver.Valuer]() // callValuerValue returns vr.Value(), with one exception: // If vr.Value is an auto-generated method on a pointer type and the diff --git a/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/go-sql-driver/mysql/utils.go index 8716c26c5..2dccb7d53 100644 --- a/vendor/github.com/go-sql-driver/mysql/utils.go +++ b/vendor/github.com/go-sql-driver/mysql/utils.go @@ -182,7 +182,7 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { func parseByteYear(b []byte) (int, error) { year, n := 0, 1000 - for i := 0; i < 4; i++ { + for i := range 4 { v, err := bToi(b[i]) if err != nil { return 0, err @@ -207,7 +207,7 @@ func parseByte2Digits(b1, b2 byte) (int, error) { func parseByteNanoSec(b []byte) (int, error) { ns, digit := 0, 100000 // max is 6-digits - for i := 0; i < len(b); i++ { + for i := range b { v, err := bToi(b[i]) if err != nil { return 0, err @@ -625,108 +625,80 @@ func reserveBuffer(buf []byte, appendSize int) []byte { return buf[:newSize] } -// escapeBytesBackslash escapes []byte with backslashes (\) -// This escapes the contents of a string (provided as []byte) by adding backslashes before special -// characters, and turning others into specific escape sequences, such as -// turning newlines into \n and null bytes into \0. -// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 -func escapeBytesBackslash(buf, v []byte) []byte { - pos := len(buf) - buf = reserveBuffer(buf, len(v)*2) +// Lookup table for backslash escapes (used for both string and bytes) +var backslashEscapeTable [256]byte - for _, c := range v { - switch c { - case '\x00': - buf[pos+1] = '0' - buf[pos] = '\\' - pos += 2 - case '\n': - buf[pos+1] = 'n' - buf[pos] = '\\' - pos += 2 - case '\r': - buf[pos+1] = 'r' - buf[pos] = '\\' - pos += 2 - case '\x1a': - buf[pos+1] = 'Z' - buf[pos] = '\\' - pos += 2 - case '\'': - buf[pos+1] = '\'' - buf[pos] = '\\' - pos += 2 - case '"': - buf[pos+1] = '"' - buf[pos] = '\\' - pos += 2 - case '\\': - buf[pos+1] = '\\' +func init() { + backslashEscapeTable['\x00'] = '0' + backslashEscapeTable['\n'] = 'n' + backslashEscapeTable['\r'] = 'r' + backslashEscapeTable['\x1a'] = 'Z' + backslashEscapeTable['\''] = '\'' + backslashEscapeTable['"'] = '"' + backslashEscapeTable['\\'] = '\\' +} + +// escapeStringBackslash is similar to escapeBytesBackslash but for string. +func escapeStringBackslash(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2+2) + buf[pos] = '\'' + pos++ + for i := 0; i < len(v); i++ { + c := v[i] + if esc := backslashEscapeTable[c]; esc != 0 { + buf[pos+1] = esc buf[pos] = '\\' pos += 2 - default: + } else { buf[pos] = c pos++ } } - + buf[pos] = '\'' + pos++ return buf[:pos] } -// escapeStringBackslash is similar to escapeBytesBackslash but for string. -func escapeStringBackslash(buf []byte, v string) []byte { +// escapeBytesBackslash appends _binary'...' or '...' with backslash escaping for bytes. +func escapeBytesBackslash(buf, v []byte, binary bool) []byte { pos := len(buf) - buf = reserveBuffer(buf, len(v)*2) - - for i := 0; i < len(v); i++ { - c := v[i] - switch c { - case '\x00': - buf[pos+1] = '0' - buf[pos] = '\\' - pos += 2 - case '\n': - buf[pos+1] = 'n' - buf[pos] = '\\' - pos += 2 - case '\r': - buf[pos+1] = 'r' - buf[pos] = '\\' - pos += 2 - case '\x1a': - buf[pos+1] = 'Z' - buf[pos] = '\\' - pos += 2 - case '\'': - buf[pos+1] = '\'' - buf[pos] = '\\' - pos += 2 - case '"': - buf[pos+1] = '"' - buf[pos] = '\\' - pos += 2 - case '\\': - buf[pos+1] = '\\' + if binary { + buf = reserveBuffer(buf, len(v)*2+9) + copy(buf[pos:], []byte("_binary'")) + pos += 8 + } else { + buf = reserveBuffer(buf, len(v)*2+2) + buf[pos] = '\'' + pos++ + } + for _, c := range v { + if esc := backslashEscapeTable[c]; esc != 0 { + buf[pos+1] = esc buf[pos] = '\\' pos += 2 - default: + } else { buf[pos] = c pos++ } } - + buf[pos] = '\'' + pos++ return buf[:pos] } -// escapeBytesQuotes escapes apostrophes in []byte by doubling them up. -// This escapes the contents of a string by doubling up any apostrophes that -// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in -// effect on the server. -// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 -func escapeBytesQuotes(buf, v []byte) []byte { +// escapeBytesQuotes appends _binary'...' or '...' with single-quote escaping for bytes. +func escapeBytesQuotes(buf, v []byte, binary bool) []byte { pos := len(buf) - buf = reserveBuffer(buf, len(v)*2) - + if binary { + buf = reserveBuffer(buf, len(v)*2+9) + copy(buf[pos:], []byte("_binary'")) + pos += 8 + } else { + buf = reserveBuffer(buf, len(v)*2+2) + buf[pos] = '\'' + pos++ + } for _, c := range v { if c == '\'' { buf[pos+1] = '\'' @@ -737,16 +709,18 @@ func escapeBytesQuotes(buf, v []byte) []byte { pos++ } } - + buf[pos] = '\'' + pos++ return buf[:pos] } // escapeStringQuotes is similar to escapeBytesQuotes but for string. func escapeStringQuotes(buf []byte, v string) []byte { pos := len(buf) - buf = reserveBuffer(buf, len(v)*2) - - for i := 0; i < len(v); i++ { + buf = reserveBuffer(buf, len(v)*2+2) + buf[pos] = '\'' + pos++ + for i := range len(v) { c := v[i] if c == '\'' { buf[pos+1] = '\'' @@ -757,7 +731,8 @@ func escapeStringQuotes(buf []byte, v string) []byte { pos++ } } - + buf[pos] = '\'' + pos++ return buf[:pos] } diff --git a/vendor/github.com/mattn/go-isatty/isatty_windows.go b/vendor/github.com/mattn/go-isatty/isatty_windows.go index 41edab076..5f29c11dd 100644 --- a/vendor/github.com/mattn/go-isatty/isatty_windows.go +++ b/vendor/github.com/mattn/go-isatty/isatty_windows.go @@ -47,9 +47,10 @@ func IsTerminal(fd uintptr) bool { // Check pipe name is used for cygwin/msys2 pty. // Cygwin/MSYS2 PTY has a name like: // \{cygwin,msys}-XXXXXXXXXXXXXXXX-ptyN-{from,to}-master +// On Windows 7 a trailing suffix (e.g. "-nat") may be appended. func isCygwinPipeName(name string) bool { token := strings.Split(name, "-") - if len(token) != 5 { + if len(token) < 5 { return false } @@ -76,6 +77,12 @@ func isCygwinPipeName(name string) bool { return false } + for _, t := range token[5:] { + if t == "" { + return false + } + } + return true } diff --git a/vendor/github.com/minio/sio/sio.go b/vendor/github.com/minio/sio/sio.go index 8f379dd51..6c3fe988f 100644 --- a/vendor/github.com/minio/sio/sio.go +++ b/vendor/github.com/minio/sio/sio.go @@ -21,11 +21,13 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "encoding/binary" "errors" "fmt" "io" "runtime" + "golang.org/x/crypto/chacha20" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/sys/cpu" ) @@ -38,8 +40,8 @@ const ( ) const ( - // AES_256_GCM specifies the cipher suite AES-GCM with 256 bit keys. - AES_256_GCM byte = iota + // AES_GCM specifies the cipher suite AES-GCM with 128 or 256 bit keys. + AES_GCM byte = iota // CHACHA20_POLY1305 specifies the cipher suite ChaCha20Poly1305 with 256 bit keys. CHACHA20_POLY1305 ) @@ -66,8 +68,6 @@ func detectAESSupport() bool { } const ( - keySize = 32 - headerSize = 16 maxPayloadSize = 1 << 16 tagSize = 16 @@ -77,17 +77,56 @@ const ( maxEncryptedSize = maxDecryptedSize + ((headerSize + tagSize) * 1 << 32) ) -var newAesGcm = func(key []byte) (cipher.AEAD, error) { - aes256, err := aes.NewCipher(key) +var newAES = func(key []byte) (cipher.AEAD, error) { + block, err := aes.NewCipher(key) if err != nil { return nil, err } - return cipher.NewGCM(aes256) + return cipher.NewGCM(block) +} + +var newChaCha20 = func(key []byte) (cipher.AEAD, error) { + // ChaCha20 requires 256 bit keys. Therefore, we expand the + // given 128 bit key into a 256 bit key. However, we don't + // use a simple algebraic key expansion to avoid any strutural + // relation/pattern in the resulting 256 bit key. Even though + // no effective attack against ChaCha20 is known that exploits + // related keys, we don't want to make any additional security + // assumption about ChaCha20. + // Therefore, we use the HChaCha20 KDF to derive a uniformly + // random 256 bit key from a 128 bit key as following: + // + // Input: K_i as 128 bit key + // Output: K_o as 256 bit key + // + // k = K_i | 0...0 # Expand K_i by 16 zero bytes to make it 256 bit long + // n = 0...0 | 1 # Set a 128 bit nonce value n to 1 + // K_o = HChaCha20(k, n) # Derive K_o as output of the HCHaCha20 KDF + // + // This construction ensures that the K_o is 256 bit long and is not strutural related + // to K_i since HChaCha20 is a secure PRF. As long as an attacker has no effective attack + // distinguishing HChaCha20 from randomness, the attacker's best attack is a key space + // search trying to find the right 128 bit key among 2^128 possible candidates - which + // is not feasible. + if len(key) == 128/8 { + var ( + k [32]byte + nonce [16]byte + err error + ) + copy(k[:], key) + binary.BigEndian.PutUint32(nonce[12:], 1) + + if key, err = chacha20.HChaCha20(k[:], nonce[:]); err != nil { + return nil, err + } + } + return chacha20poly1305.New(key) } var supportedCiphers = [...]func([]byte) (cipher.AEAD, error){ - AES_256_GCM: newAesGcm, - CHACHA20_POLY1305: chacha20poly1305.New, + AES_GCM: newAES, + CHACHA20_POLY1305: newChaCha20, } var ( @@ -336,9 +375,9 @@ func DecryptWriter(dst io.Writer, config Config) (io.WriteCloser, error) { func defaultCipherSuites() []byte { if supportsAES { - return []byte{AES_256_GCM, CHACHA20_POLY1305} + return []byte{AES_GCM, CHACHA20_POLY1305} } - return []byte{CHACHA20_POLY1305, AES_256_GCM} + return []byte{CHACHA20_POLY1305, AES_GCM} } func setConfigDefaults(config *Config) error { @@ -348,7 +387,7 @@ func setConfigDefaults(config *Config) error { if config.MaxVersion > Version20 { return errors.New("sio: unknown maximum version") } - if len(config.Key) != keySize { + if len(config.Key) != 128/8 && len(config.Key) != 256/8 { return errors.New("sio: invalid key size") } if len(config.CipherSuites) > 2 { diff --git a/vendor/modules.txt b/vendor/modules.txt index 020eea174..b43d43d3f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -163,8 +163,8 @@ github.com/go-openapi/swag/yamlutils # github.com/go-openapi/validate v0.25.2 ## explicit; go 1.24.0 github.com/go-openapi/validate -# github.com/go-sql-driver/mysql v1.9.3 -## explicit; go 1.21.0 +# github.com/go-sql-driver/mysql v1.10.0 +## explicit; go 1.24.0 github.com/go-sql-driver/mysql # github.com/go-viper/mapstructure/v2 v2.5.0 ## explicit; go 1.18 @@ -221,16 +221,16 @@ github.com/lucasb-eyer/go-colorful github.com/manifoldco/promptui github.com/manifoldco/promptui/list github.com/manifoldco/promptui/screenbuf -# github.com/mattn/go-isatty v0.0.21 +# github.com/mattn/go-isatty v0.0.22 ## explicit; go 1.21 github.com/mattn/go-isatty # github.com/mattn/go-runewidth v0.0.23 ## explicit; go 1.20 github.com/mattn/go-runewidth -# github.com/mattn/go-sqlite3 v1.14.42 => github.com/gabriel-samfira/go-sqlite3 v0.0.0-20251005121134-bc61ecf9b4c7 +# github.com/mattn/go-sqlite3 v1.14.44 => github.com/gabriel-samfira/go-sqlite3 v0.0.0-20251005121134-bc61ecf9b4c7 ## explicit; go 1.19 github.com/mattn/go-sqlite3 -# github.com/minio/sio v0.4.3 +# github.com/minio/sio v0.5.1 ## explicit; go 1.24.0 github.com/minio/sio # github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822