From e1b0f4d4e943aaa5e70aaba27135471388c9b042 Mon Sep 17 00:00:00 2001 From: ramnes Date: Sat, 6 Dec 2025 03:23:47 +0100 Subject: [PATCH 1/6] Store prepared statements field definitions --- client/stmt.go | 37 ++++++++++++++++++++++++++++++++----- server/command.go | 13 +++++++++++++ server/stmt.go | 16 ++++++++++++++-- 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/client/stmt.go b/client/stmt.go index 106e176de..3fb1454c4 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -19,6 +19,10 @@ type Stmt struct { params int columns int warnings int + + // Field definitions from the PREPARE response (for proxy passthrough) + ParamFields []*mysql.Field + ColumnFields []*mysql.Field } func (s *Stmt) ParamNum() int { @@ -33,6 +37,18 @@ func (s *Stmt) WarningsNum() int { return s.warnings } +// GetParamFields returns the parameter field definitions from the PREPARE response. +// Implements server.StmtFieldsProvider for proxy passthrough. +func (s *Stmt) GetParamFields() []*mysql.Field { + return s.ParamFields +} + +// GetColumnFields returns the column field definitions from the PREPARE response. +// Implements server.StmtFieldsProvider for proxy passthrough. +func (s *Stmt) GetColumnFields() []*mysql.Field { + return s.ColumnFields +} + func (s *Stmt) Execute(args ...interface{}) (*mysql.Result, error) { if err := s.write(args...); err != nil { return nil, errors.Trace(err) @@ -275,8 +291,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { } if s.params > 0 { - for range s.params { - if _, err := s.conn.ReadPacket(); err != nil { + s.ParamFields = make([]*mysql.Field, s.params) + for i := range s.params { + data, err := s.conn.ReadPacket() + if err != nil { + return nil, errors.Trace(err) + } + s.ParamFields[i] = &mysql.Field{} + if err := s.ParamFields[i].Parse(data); err != nil { return nil, errors.Trace(err) } } @@ -290,9 +312,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { } if s.columns > 0 { - // TODO process when CLIENT_CACHE_METADATA enabled - for range s.columns { - if _, err := s.conn.ReadPacket(); err != nil { + s.ColumnFields = make([]*mysql.Field, s.columns) + for i := range s.columns { + data, err := s.conn.ReadPacket() + if err != nil { + return nil, errors.Trace(err) + } + s.ColumnFields[i] = &mysql.Field{} + if err := s.ColumnFields[i].Parse(data); err != nil { return nil, errors.Trace(err) } } diff --git a/server/command.go b/server/command.go index e23244bbd..badcfbdf0 100644 --- a/server/command.go +++ b/server/command.go @@ -10,6 +10,13 @@ import ( "github.com/go-mysql-org/go-mysql/utils" ) +// StmtFieldsProvider is an optional interface that prepared statement contexts can implement +// to provide field definitions for proxy passthrough scenarios. +type StmtFieldsProvider interface { + GetParamFields() []*mysql.Field + GetColumnFields() []*mysql.Field +} + // Handler is what a server needs to implement the client-server protocol type Handler interface { // handle COM_INIT_DB command, you can check whether the dbName is valid, or other. @@ -112,6 +119,12 @@ func (c *Conn) dispatch(data []byte) interface{} { if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil { return err } else { + // If context provides field definitions (e.g., from a backend prepared statement), + // use them for accurate metadata passthrough in proxy scenarios. + if provider, ok := st.Context.(StmtFieldsProvider); ok { + st.ParamFields = provider.GetParamFields() + st.ColumnFields = provider.GetColumnFields() + } st.ResetParams() c.stmts[c.stmtID] = st return st diff --git a/server/stmt.go b/server/stmt.go index ca9eae796..dbd0b1ab5 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -25,6 +25,10 @@ type Stmt struct { Args []interface{} Context interface{} + + // Field definitions for proxy passthrough (optional, uses dummy fields if nil) + ParamFields []*mysql.Field + ColumnFields []*mysql.Field } func (s *Stmt) Rest(params int, columns int, context interface{}) { @@ -61,7 +65,11 @@ func (c *Conn) writePrepare(s *Stmt) error { if s.Params > 0 { for i := 0; i < s.Params; i++ { data = data[0:4] - data = append(data, paramFieldData...) + if s.ParamFields != nil && i < len(s.ParamFields) { + data = append(data, s.ParamFields[i].Dump()...) + } else { + data = append(data, paramFieldData...) + } if err := c.WritePacket(data); err != nil { return errors.Trace(err) @@ -76,7 +84,11 @@ func (c *Conn) writePrepare(s *Stmt) error { if s.Columns > 0 { for i := 0; i < s.Columns; i++ { data = data[0:4] - data = append(data, columnFieldData...) + if s.ColumnFields != nil && i < len(s.ColumnFields) { + data = append(data, s.ColumnFields[i].Dump()...) + } else { + data = append(data, columnFieldData...) + } if err := c.WritePacket(data); err != nil { return errors.Trace(err) From 877760517dece78d513cdf78ff951618696ace4a Mon Sep 17 00:00:00 2001 From: ramnes Date: Sat, 6 Dec 2025 03:54:57 +0100 Subject: [PATCH 2/6] Add a test --- server/stmt_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/server/stmt_test.go b/server/stmt_test.go index bf9142f54..962bf054b 100644 --- a/server/stmt_test.go +++ b/server/stmt_test.go @@ -3,6 +3,7 @@ package server import ( "testing" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/stretchr/testify/require" ) @@ -46,3 +47,52 @@ func TestHandleStmtExecute(t *testing.T) { } } } + +type mockPrepareHandler struct { + EmptyHandler + context any + paramCount, columnCount int +} + +func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, error) { + return h.paramCount, h.columnCount, h.context, nil +} + +func TestStmtPrepareWithoutFieldsProvider(t *testing.T) { + c := &Conn{ + h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1}, + stmts: make(map[uint32]*Stmt), + } + + result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...)) + + stmt := result.(*Stmt) + require.Nil(t, stmt.ParamFields) + require.Nil(t, stmt.ColumnFields) +} + +type mockFieldsProvider struct { + paramFields, columnFields []*mysql.Field +} + +func (m *mockFieldsProvider) GetParamFields() []*mysql.Field { return m.paramFields } +func (m *mockFieldsProvider) GetColumnFields() []*mysql.Field { return m.columnFields } + +func TestStmtPrepareWithFieldsProvider(t *testing.T) { + provider := &mockFieldsProvider{ + paramFields: []*mysql.Field{{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}}, + columnFields: []*mysql.Field{{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}}, + } + c := &Conn{ + h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1}, + stmts: make(map[uint32]*Stmt), + } + + result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...)) + + stmt := result.(*Stmt) + require.NotNil(t, stmt.ParamFields) + require.NotNil(t, stmt.ColumnFields) + require.Equal(t, mysql.MYSQL_TYPE_LONG, stmt.ParamFields[0].Type) + require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, stmt.ColumnFields[0].Type) +} From 52cb9a215bdfeda45e11ae9a21926033fd87de10 Mon Sep 17 00:00:00 2001 From: ramnes Date: Sat, 6 Dec 2025 12:59:49 +0100 Subject: [PATCH 3/6] Make `paramFields` and `columnFields` private --- client/stmt.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/client/stmt.go b/client/stmt.go index 3fb1454c4..baa89bc6d 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -21,8 +21,8 @@ type Stmt struct { warnings int // Field definitions from the PREPARE response (for proxy passthrough) - ParamFields []*mysql.Field - ColumnFields []*mysql.Field + paramFields []*mysql.Field + columnFields []*mysql.Field } func (s *Stmt) ParamNum() int { @@ -40,13 +40,13 @@ func (s *Stmt) WarningsNum() int { // GetParamFields returns the parameter field definitions from the PREPARE response. // Implements server.StmtFieldsProvider for proxy passthrough. func (s *Stmt) GetParamFields() []*mysql.Field { - return s.ParamFields + return s.paramFields } // GetColumnFields returns the column field definitions from the PREPARE response. // Implements server.StmtFieldsProvider for proxy passthrough. func (s *Stmt) GetColumnFields() []*mysql.Field { - return s.ColumnFields + return s.columnFields } func (s *Stmt) Execute(args ...interface{}) (*mysql.Result, error) { @@ -291,14 +291,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { } if s.params > 0 { - s.ParamFields = make([]*mysql.Field, s.params) + s.paramFields = make([]*mysql.Field, s.params) for i := range s.params { data, err := s.conn.ReadPacket() if err != nil { return nil, errors.Trace(err) } - s.ParamFields[i] = &mysql.Field{} - if err := s.ParamFields[i].Parse(data); err != nil { + s.paramFields[i] = &mysql.Field{} + if err := s.paramFields[i].Parse(data); err != nil { return nil, errors.Trace(err) } } @@ -312,14 +312,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { } if s.columns > 0 { - s.ColumnFields = make([]*mysql.Field, s.columns) + s.columnFields = make([]*mysql.Field, s.columns) for i := range s.columns { data, err := s.conn.ReadPacket() if err != nil { return nil, errors.Trace(err) } - s.ColumnFields[i] = &mysql.Field{} - if err := s.ColumnFields[i].Parse(data); err != nil { + s.columnFields[i] = &mysql.Field{} + if err := s.columnFields[i].Parse(data); err != nil { return nil, errors.Trace(err) } } From e4cb58725f525d39b7e869fae286ee8f7a4b4cde Mon Sep 17 00:00:00 2001 From: ramnes Date: Sat, 6 Dec 2025 14:39:53 +0100 Subject: [PATCH 4/6] Add comments --- client/stmt.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/stmt.go b/client/stmt.go index baa89bc6d..26d878d21 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -39,12 +39,14 @@ func (s *Stmt) WarningsNum() int { // GetParamFields returns the parameter field definitions from the PREPARE response. // Implements server.StmtFieldsProvider for proxy passthrough. +// The caller should not modify the returned slice. func (s *Stmt) GetParamFields() []*mysql.Field { return s.paramFields } // GetColumnFields returns the column field definitions from the PREPARE response. // Implements server.StmtFieldsProvider for proxy passthrough. +// The caller should not modify the returned slice. func (s *Stmt) GetColumnFields() []*mysql.Field { return s.columnFields } From a37e0bb71802b7f59a99a296489ddec6195cb492 Mon Sep 17 00:00:00 2001 From: ramnes Date: Fri, 12 Dec 2025 16:49:20 +0100 Subject: [PATCH 5/6] Move statements into a new `stmt` package --- client/stmt.go | 66 +++++++++++++++------------------------------ server/command.go | 16 +++-------- server/stmt.go | 21 ++++++--------- server/stmt_test.go | 35 +++++++++++------------- stmt/stmt.go | 43 +++++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 89 deletions(-) create mode 100644 stmt/stmt.go diff --git a/client/stmt.go b/client/stmt.go index 26d878d21..cd64f3524 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -8,49 +8,31 @@ import ( "runtime" "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/stmt" "github.com/go-mysql-org/go-mysql/utils" "github.com/pingcap/errors" ) type Stmt struct { - conn *Conn - id uint32 - - params int - columns int + conn *Conn warnings int - // Field definitions from the PREPARE response (for proxy passthrough) - paramFields []*mysql.Field - columnFields []*mysql.Field + // PreparedStmt contains common fields shared with server.Stmt for proxy passthrough + stmt.PreparedStmt } func (s *Stmt) ParamNum() int { - return s.params + return s.Params } func (s *Stmt) ColumnNum() int { - return s.columns + return s.Columns } func (s *Stmt) WarningsNum() int { return s.warnings } -// GetParamFields returns the parameter field definitions from the PREPARE response. -// Implements server.StmtFieldsProvider for proxy passthrough. -// The caller should not modify the returned slice. -func (s *Stmt) GetParamFields() []*mysql.Field { - return s.paramFields -} - -// GetColumnFields returns the column field definitions from the PREPARE response. -// Implements server.StmtFieldsProvider for proxy passthrough. -// The caller should not modify the returned slice. -func (s *Stmt) GetColumnFields() []*mysql.Field { - return s.columnFields -} - func (s *Stmt) Execute(args ...interface{}) (*mysql.Result, error) { if err := s.write(args...); err != nil { return nil, errors.Trace(err) @@ -68,7 +50,7 @@ func (s *Stmt) ExecuteSelectStreaming(result *mysql.Result, perRowCb SelectPerRo } func (s *Stmt) Close() error { - if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.id); err != nil { + if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.ID); err != nil { return errors.Trace(err) } @@ -78,10 +60,10 @@ func (s *Stmt) Close() error { // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html func (s *Stmt) write(args ...interface{}) error { defer clear(s.conn.queryAttributes) - paramsNum := s.params + paramsNum := s.Params if len(args) != paramsNum { - return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) + return fmt.Errorf("argument mismatch, need %d but got %d", s.Params, len(args)) } if (s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) { @@ -205,7 +187,7 @@ func (s *Stmt) write(args ...interface{}) error { data.Write([]byte{0, 0, 0, 0}) data.WriteByte(mysql.COM_STMT_EXECUTE) - data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) + data.Write([]byte{byte(s.ID), byte(s.ID >> 8), byte(s.ID >> 16), byte(s.ID >> 24)}) flags := mysql.CURSOR_TYPE_NO_CURSOR if paramsNum > 0 { @@ -272,15 +254,15 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { pos := 1 // for statement id - s.id = binary.LittleEndian.Uint32(data[pos:]) + s.ID = binary.LittleEndian.Uint32(data[pos:]) pos += 4 // number columns - s.columns = int(binary.LittleEndian.Uint16(data[pos:])) + s.Columns = int(binary.LittleEndian.Uint16(data[pos:])) pos += 2 // number params - s.params = int(binary.LittleEndian.Uint16(data[pos:])) + s.Params = int(binary.LittleEndian.Uint16(data[pos:])) pos += 2 // reserved @@ -292,17 +274,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { // pos += 2 } - if s.params > 0 { - s.paramFields = make([]*mysql.Field, s.params) - for i := range s.params { + if s.Params > 0 { + s.RawParamFields = make([][]byte, s.Params) + for i := range s.Params { data, err := s.conn.ReadPacket() if err != nil { return nil, errors.Trace(err) } - s.paramFields[i] = &mysql.Field{} - if err := s.paramFields[i].Parse(data); err != nil { - return nil, errors.Trace(err) - } + s.RawParamFields[i] = data } if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 { if packet, err := s.conn.ReadPacket(); err != nil { @@ -313,17 +292,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { } } - if s.columns > 0 { - s.columnFields = make([]*mysql.Field, s.columns) - for i := range s.columns { + if s.Columns > 0 { + s.RawColumnFields = make([][]byte, s.Columns) + for i := range s.Columns { data, err := s.conn.ReadPacket() if err != nil { return nil, errors.Trace(err) } - s.columnFields[i] = &mysql.Field{} - if err := s.columnFields[i].Parse(data); err != nil { - return nil, errors.Trace(err) - } + s.RawColumnFields[i] = data } if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 { if packet, err := s.conn.ReadPacket(); err != nil { diff --git a/server/command.go b/server/command.go index badcfbdf0..a645798a9 100644 --- a/server/command.go +++ b/server/command.go @@ -7,16 +7,10 @@ import ( "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/replication" + "github.com/go-mysql-org/go-mysql/stmt" "github.com/go-mysql-org/go-mysql/utils" ) -// StmtFieldsProvider is an optional interface that prepared statement contexts can implement -// to provide field definitions for proxy passthrough scenarios. -type StmtFieldsProvider interface { - GetParamFields() []*mysql.Field - GetColumnFields() []*mysql.Field -} - // Handler is what a server needs to implement the client-server protocol type Handler interface { // handle COM_INIT_DB command, you can check whether the dbName is valid, or other. @@ -119,11 +113,9 @@ func (c *Conn) dispatch(data []byte) interface{} { if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil { return err } else { - // If context provides field definitions (e.g., from a backend prepared statement), - // use them for accurate metadata passthrough in proxy scenarios. - if provider, ok := st.Context.(StmtFieldsProvider); ok { - st.ParamFields = provider.GetParamFields() - st.ColumnFields = provider.GetColumnFields() + if provider, ok := st.Context.(*stmt.PreparedStmt); ok { + st.RawParamFields = provider.RawParamFields + st.RawColumnFields = provider.RawColumnFields } st.ResetParams() c.stmts[c.stmtID] = st diff --git a/server/stmt.go b/server/stmt.go index dbd0b1ab5..553c9e695 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -7,6 +7,7 @@ import ( "strconv" "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/stmt" "github.com/pingcap/errors" ) @@ -16,19 +17,13 @@ var ( ) type Stmt struct { - ID uint32 Query string - - Params int - Columns int - - Args []interface{} + Args []interface{} Context interface{} - // Field definitions for proxy passthrough (optional, uses dummy fields if nil) - ParamFields []*mysql.Field - ColumnFields []*mysql.Field + // PreparedStmt contains common fields shared with client.Stmt for proxy passthrough + stmt.PreparedStmt } func (s *Stmt) Rest(params int, columns int, context interface{}) { @@ -65,8 +60,8 @@ func (c *Conn) writePrepare(s *Stmt) error { if s.Params > 0 { for i := 0; i < s.Params; i++ { data = data[0:4] - if s.ParamFields != nil && i < len(s.ParamFields) { - data = append(data, s.ParamFields[i].Dump()...) + if s.RawParamFields != nil && i < len(s.RawParamFields) { + data = append(data, s.RawParamFields[i]...) } else { data = append(data, paramFieldData...) } @@ -84,8 +79,8 @@ func (c *Conn) writePrepare(s *Stmt) error { if s.Columns > 0 { for i := 0; i < s.Columns; i++ { data = data[0:4] - if s.ColumnFields != nil && i < len(s.ColumnFields) { - data = append(data, s.ColumnFields[i].Dump()...) + if s.RawColumnFields != nil && i < len(s.RawColumnFields) { + data = append(data, s.RawColumnFields[i]...) } else { data = append(data, columnFieldData...) } diff --git a/server/stmt_test.go b/server/stmt_test.go index 962bf054b..9c2f64907 100644 --- a/server/stmt_test.go +++ b/server/stmt_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/stmt" "github.com/stretchr/testify/require" ) @@ -58,7 +59,7 @@ func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, err return h.paramCount, h.columnCount, h.context, nil } -func TestStmtPrepareWithoutFieldsProvider(t *testing.T) { +func TestStmtPrepareWithoutPreparedStmt(t *testing.T) { c := &Conn{ h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1}, stmts: make(map[uint32]*Stmt), @@ -66,22 +67,18 @@ func TestStmtPrepareWithoutFieldsProvider(t *testing.T) { result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...)) - stmt := result.(*Stmt) - require.Nil(t, stmt.ParamFields) - require.Nil(t, stmt.ColumnFields) + st := result.(*Stmt) + require.Nil(t, st.RawParamFields) + require.Nil(t, st.RawColumnFields) } -type mockFieldsProvider struct { - paramFields, columnFields []*mysql.Field -} - -func (m *mockFieldsProvider) GetParamFields() []*mysql.Field { return m.paramFields } -func (m *mockFieldsProvider) GetColumnFields() []*mysql.Field { return m.columnFields } +func TestStmtPrepareWithPreparedStmt(t *testing.T) { + paramField := &mysql.Field{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG} + columnField := &mysql.Field{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG} -func TestStmtPrepareWithFieldsProvider(t *testing.T) { - provider := &mockFieldsProvider{ - paramFields: []*mysql.Field{{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}}, - columnFields: []*mysql.Field{{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}}, + provider := &stmt.PreparedStmt{ + RawParamFields: [][]byte{paramField.Dump()}, + RawColumnFields: [][]byte{columnField.Dump()}, } c := &Conn{ h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1}, @@ -90,9 +87,9 @@ func TestStmtPrepareWithFieldsProvider(t *testing.T) { result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...)) - stmt := result.(*Stmt) - require.NotNil(t, stmt.ParamFields) - require.NotNil(t, stmt.ColumnFields) - require.Equal(t, mysql.MYSQL_TYPE_LONG, stmt.ParamFields[0].Type) - require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, stmt.ColumnFields[0].Type) + st := result.(*Stmt) + require.NotNil(t, st.RawParamFields) + require.NotNil(t, st.RawColumnFields) + require.Equal(t, mysql.MYSQL_TYPE_LONG, st.GetParamFields()[0].Type) + require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, st.GetColumnFields()[0].Type) } diff --git a/stmt/stmt.go b/stmt/stmt.go new file mode 100644 index 000000000..e5c3e3a09 --- /dev/null +++ b/stmt/stmt.go @@ -0,0 +1,43 @@ +package stmt + +import "github.com/go-mysql-org/go-mysql/mysql" + +type PreparedStmt struct { + ID uint32 + Params int + Columns int + + RawParamFields [][]byte + RawColumnFields [][]byte + + paramFields []*mysql.Field + columnFields []*mysql.Field +} + +func (s *PreparedStmt) GetParamFields() []*mysql.Field { + if s.RawParamFields == nil { + return nil + } + if s.paramFields == nil { + s.paramFields = make([]*mysql.Field, len(s.RawParamFields)) + for i, raw := range s.RawParamFields { + s.paramFields[i] = &mysql.Field{} + _ = s.paramFields[i].Parse(raw) + } + } + return s.paramFields +} + +func (s *PreparedStmt) GetColumnFields() []*mysql.Field { + if s.RawColumnFields == nil { + return nil + } + if s.columnFields == nil { + s.columnFields = make([]*mysql.Field, len(s.RawColumnFields)) + for i, raw := range s.RawColumnFields { + s.columnFields[i] = &mysql.Field{} + _ = s.columnFields[i].Parse(raw) + } + } + return s.columnFields +} From fcaff1fd19bba756953b43ad4efd59f5d2146ee8 Mon Sep 17 00:00:00 2001 From: ramnes Date: Fri, 12 Dec 2025 20:00:39 +0100 Subject: [PATCH 6/6] Return an error if we can't parse --- server/stmt_test.go | 8 ++++++-- stmt/stmt.go | 32 ++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/server/stmt_test.go b/server/stmt_test.go index 9c2f64907..935597f68 100644 --- a/server/stmt_test.go +++ b/server/stmt_test.go @@ -90,6 +90,10 @@ func TestStmtPrepareWithPreparedStmt(t *testing.T) { st := result.(*Stmt) require.NotNil(t, st.RawParamFields) require.NotNil(t, st.RawColumnFields) - require.Equal(t, mysql.MYSQL_TYPE_LONG, st.GetParamFields()[0].Type) - require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, st.GetColumnFields()[0].Type) + paramFields, err := st.GetParamFields() + require.NoError(t, err) + require.Equal(t, mysql.MYSQL_TYPE_LONG, paramFields[0].Type) + columnFields, err := st.GetColumnFields() + require.NoError(t, err) + require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, columnFields[0].Type) } diff --git a/stmt/stmt.go b/stmt/stmt.go index e5c3e3a09..0f75724c0 100644 --- a/stmt/stmt.go +++ b/stmt/stmt.go @@ -14,30 +14,38 @@ type PreparedStmt struct { columnFields []*mysql.Field } -func (s *PreparedStmt) GetParamFields() []*mysql.Field { +func (s *PreparedStmt) GetParamFields() ([]*mysql.Field, error) { if s.RawParamFields == nil { - return nil + return nil, nil } if s.paramFields == nil { - s.paramFields = make([]*mysql.Field, len(s.RawParamFields)) + fields := make([]*mysql.Field, len(s.RawParamFields)) for i, raw := range s.RawParamFields { - s.paramFields[i] = &mysql.Field{} - _ = s.paramFields[i].Parse(raw) + field := &mysql.Field{} + if err := field.Parse(raw); err != nil { + return nil, err + } + fields[i] = field } + s.paramFields = fields } - return s.paramFields + return s.paramFields, nil } -func (s *PreparedStmt) GetColumnFields() []*mysql.Field { +func (s *PreparedStmt) GetColumnFields() ([]*mysql.Field, error) { if s.RawColumnFields == nil { - return nil + return nil, nil } if s.columnFields == nil { - s.columnFields = make([]*mysql.Field, len(s.RawColumnFields)) + fields := make([]*mysql.Field, len(s.RawColumnFields)) for i, raw := range s.RawColumnFields { - s.columnFields[i] = &mysql.Field{} - _ = s.columnFields[i].Parse(raw) + field := &mysql.Field{} + if err := field.Parse(raw); err != nil { + return nil, err + } + fields[i] = field } + s.columnFields = fields } - return s.columnFields + return s.columnFields, nil }