From 2d38c50faa93ba0ac00f18cbe884873a863369c7 Mon Sep 17 00:00:00 2001 From: Martin Tournoij Date: Wed, 8 Apr 2026 18:40:50 +0100 Subject: [PATCH] Refactor context support Previously the various *Context() methods called the non-context methods together with a goroutine to watches the context and send a cancel query on the context cancellation. This changes it to make the *Context() methods the primary methods, and removing or stubbing out the non-context ones. This cleans up a bunch of old pre-Go 1.9 code, and is also required to use the context in more places, such as network timeouts. database/sql never uses Exec() or Query() if the *Context() variants are implemented. These can safely be removed outright. Similarly, conn.Prepare(), conn.Begin(), stmt.Exec(), and stmt.Query() are also never called if the context variants are implemented, but we need to keep them around to satisfy driver.Conn and driver.Stmt. Make them panic to ensure it's not accidentally called from pq code or tests. This shouldn't change any behaviour --- conn.go | 210 ++++++++++++++++++++++++++++++++---------- conn_go18.go | 226 ---------------------------------------------- connector_test.go | 4 +- copy.go | 4 +- deprecated.go | 7 ++ rows.go | 4 +- stmt.go | 28 ++++-- 7 files changed, 193 insertions(+), 290 deletions(-) delete mode 100644 conn_go18.go diff --git a/conn.go b/conn.go index 667c5fa9c..3ed7ecfa4 100644 --- a/conn.go +++ b/conn.go @@ -47,19 +47,28 @@ var ( // Compile time validation that our types implement the expected interfaces var ( - _ driver.Driver = Driver{} - _ driver.ConnBeginTx = (*conn)(nil) - _ driver.ConnPrepareContext = (*conn)(nil) - _ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x - _ driver.ExecerContext = (*conn)(nil) - _ driver.NamedValueChecker = (*conn)(nil) - _ driver.Pinger = (*conn)(nil) - _ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x - _ driver.QueryerContext = (*conn)(nil) - _ driver.SessionResetter = (*conn)(nil) - _ driver.Validator = (*conn)(nil) - _ driver.StmtExecContext = (*stmt)(nil) - _ driver.StmtQueryContext = (*stmt)(nil) + _ driver.Driver = Driver{} + //_ driver.DriverContext = Driver{} // TODO: https://github.com/lib/pq/pull/900 + _ driver.Connector = (*Connector)(nil) + _ driver.Conn = (*conn)(nil) + _ driver.ConnBeginTx = (*conn)(nil) + _ driver.ConnPrepareContext = (*conn)(nil) + _ driver.ExecerContext = (*conn)(nil) + _ driver.NamedValueChecker = (*conn)(nil) + _ driver.Pinger = (*conn)(nil) + _ driver.QueryerContext = (*conn)(nil) + _ driver.SessionResetter = (*conn)(nil) + _ driver.Validator = (*conn)(nil) + _ driver.Stmt = (*stmt)(nil) + _ driver.StmtExecContext = (*stmt)(nil) + _ driver.StmtQueryContext = (*stmt)(nil) + _ driver.Rows = (*rows)(nil) + _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil) + _ driver.RowsColumnTypeLength = (*rows)(nil) + //_ driver.RowsColumnTypeNullable = (*rows)(nil) // TODO + _ driver.RowsColumnTypePrecisionScale = (*rows)(nil) + _ driver.RowsColumnTypeScanType = (*rows)(nil) + _ driver.RowsNextResultSet = (*rows)(nil) ) func init() { @@ -479,11 +488,29 @@ func (cn *conn) checkIsInTransaction(intxn bool) error { return nil } -func (cn *conn) Begin() (_ driver.Tx, err error) { - return cn.begin("") -} +// Implement [driver.ConnBeginTx]. +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var mode string + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) + } + if opts.ReadOnly { + mode += " READ ONLY" + } else { + mode += " READ WRITE" + } -func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if err := cn.err.get(); err != nil { return nil, err } @@ -503,17 +530,17 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } - return cn, nil -} -func (cn *conn) closeTxn() { - if finish := cn.txnFinish; finish != nil { - finish() - } + cn.txnFinish = cn.watchCancel(ctx, false) + return cn, nil } func (cn *conn) Commit() error { - defer cn.closeTxn() + defer func() { + if cn.txnFinish != nil { + cn.txnFinish() + } + }() if err := cn.err.get(); err != nil { return err } @@ -549,16 +576,17 @@ func (cn *conn) Commit() error { } func (cn *conn) Rollback() error { - defer cn.closeTxn() + defer func() { + if cn.txnFinish != nil { + cn.txnFinish() + } + }() if err := cn.err.get(); err != nil { return err } err := cn.rollback() - if err != nil { - return cn.handleError(err) - } - return nil + return cn.handleError(err) } func (cn *conn) rollback() (err error) { @@ -799,7 +827,9 @@ func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) { return st, nil } -func (cn *conn) Prepare(q string) (driver.Stmt, error) { +// Implement [driver.ConnPrepareContext]. +func (cn *conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) { + defer cn.watchCancel(ctx, false)() if err := cn.err.get(); err != nil { return nil, err } @@ -829,14 +859,6 @@ func (cn *conn) Close() error { return cn.c.Close() } -func toNamedValue(v []driver.Value) []driver.NamedValue { - v2 := make([]driver.NamedValue, len(v)) - for i := range v { - v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]} - } - return v2 -} - // CheckNamedValue implements [driver.NamedValueChecker]. func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { if cn.cfg.BinaryParameters { @@ -884,9 +906,18 @@ func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { } } -// Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { - return cn.query(query, toNamedValue(args)) +// Implement [driver.QueryerContext]. +func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + finish := cn.watchCancel(ctx, false) + r, err := cn.query(query, args) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil } func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { @@ -947,8 +978,9 @@ func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { }, nil } -// Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { +// Implement [driver.ExecerContext]. +func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + defer cn.watchCancel(ctx, false)() if err := cn.err.get(); err != nil { return nil, err } @@ -956,16 +988,14 @@ func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { return nil, errQueryInProgress } - // Check to see if we can use the "simpleExec" interface, which is *much* - // faster than going through prepare/exec + // simpleExec is *much* faster than going through prepare/exec. if len(args) == 0 { - // ignore commandTag, our caller doesn't care - r, _, err := cn.simpleExec(query) + r, _, err := cn.simpleExec(query) // Ignore commandTag, our caller doesn't care. return r, cn.handleError(err, query) } if cn.cfg.BinaryParameters { - err := cn.sendBinaryModeQuery(query, toNamedValue(args)) + err := cn.sendBinaryModeQuery(query, args) if err != nil { return nil, cn.handleError(err, query) } @@ -996,13 +1026,23 @@ func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { if err != nil { return nil, cn.handleError(err, query) } - r, err := st.Exec(args) + r, err := st.ExecContext(ctx, args) if err != nil { return nil, cn.handleError(err, query) } return r, nil } +func (cn *conn) Ping(ctx context.Context) error { + defer cn.watchCancel(ctx, false)() + rows, err := cn.simpleQuery(";") + if err != nil { + return driver.ErrBadConn + } + _ = rows.Close() + return nil +} + type safeRetryError struct{ Err error } func (se *safeRetryError) Error() string { return se.Err.Error() } @@ -1179,6 +1219,78 @@ func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) { return t, r, nil } +// We need to let PostgreSQL know the query is cancelled: just dropping the +// connection won't stop the query. +// +// So create a goroutine which selects on ctx.Done() and a finish channel. +// Returns a function to send to this, which should be called after the query is +// finished. +func (cn *conn) watchCancel(ctx context.Context, fromStmt bool) func() { + if ctx.Done() == nil { // "may return nil if this context can never be canceled" + return func() {} + } + + finished := make(chan struct{}, 1) + go func() { + select { + case <-finished: // Query finished successfully. + case <-ctx.Done(): + select { + case finished <- struct{}{}: + default: // Raced with the finish func, let the next query handle this with the context. + return + } + if !fromStmt { + cn.err.set(ctx.Err()) // Set the connection state to bad so it does not get reused. + } + cn.sendCancelRequest() // TODO: maybe handle error, somehow? + } + }() + + return func() { + select { + case <-finished: + if !fromStmt { + cn.err.set(ctx.Err()) + cn.Close() + } + case finished <- struct{}{}: + } + } +} + +func (cn *conn) sendCancelRequest() error { + // Use a copy since a new connection is created here. This is necessary + // because cancel is called from a goroutine in watchCancel. + cfg := cn.cfg.Clone() + + // Can't pass in context from parent, as that one may be cancelled. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // TODO: use connect_timeout? + defer cancel() + + c, err := dial(ctx, cn.dialer, cfg) + if err != nil { + return err + } + defer c.Close() + + cn2 := conn{c: c} + if err := cn2.ssl(cfg, cfg.SSLMode); err != nil { + return err + } + w := cn2.writeBuf(0) + w.int32(proto.CancelRequestCode) + w.int32(cn.pid) + w.bytes(cn.secretKey) + if err := cn2.sendStartupPacket(w); err != nil { + return err + } + + // Read until EOF to ensure that the server received the cancel. + _, err = io.Copy(io.Discard, c) + return err +} + // Don't refer to Config.SSLMode here, as the mode in arguments may be different // in case of sslmode=allow or prefer. func (cn *conn) ssl(cfg Config, mode SSLMode) error { diff --git a/conn_go18.go b/conn_go18.go deleted file mode 100644 index 16de38ebe..000000000 --- a/conn_go18.go +++ /dev/null @@ -1,226 +0,0 @@ -package pq - -import ( - "context" - "database/sql" - "database/sql/driver" - "fmt" - "io" - "time" - - "github.com/lib/pq/internal/proto" -) - -const watchCancelDialContextTimeout = 10 * time.Second - -// Implement the "QueryerContext" interface -func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - finish := cn.watchCancel(ctx) - r, err := cn.query(query, args) - if err != nil { - if finish != nil { - finish() - } - return nil, err - } - r.finish = finish - return r, nil -} - -// Implement the "ExecerContext" interface -func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - - return cn.Exec(query, list) -} - -// Implement the "ConnPrepareContext" interface -func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - return cn.Prepare(query) -} - -// Implement the "ConnBeginTx" interface -func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - var mode string - switch sql.IsolationLevel(opts.Isolation) { - case sql.LevelDefault: - // Don't touch mode: use the server's default - case sql.LevelReadUncommitted: - mode = " ISOLATION LEVEL READ UNCOMMITTED" - case sql.LevelReadCommitted: - mode = " ISOLATION LEVEL READ COMMITTED" - case sql.LevelRepeatableRead: - mode = " ISOLATION LEVEL REPEATABLE READ" - case sql.LevelSerializable: - mode = " ISOLATION LEVEL SERIALIZABLE" - default: - return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) - } - if opts.ReadOnly { - mode += " READ ONLY" - } else { - mode += " READ WRITE" - } - - tx, err := cn.begin(mode) - if err != nil { - return nil, err - } - cn.txnFinish = cn.watchCancel(ctx) - return tx, nil -} - -func (cn *conn) Ping(ctx context.Context) error { - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - rows, err := cn.simpleQuery(";") - if err != nil { - return driver.ErrBadConn - } - _ = rows.Close() - return nil -} - -func (cn *conn) watchCancel(ctx context.Context) func() { - if done := ctx.Done(); done != nil { - finished := make(chan struct{}, 1) - go func() { - select { - case <-done: - select { - case finished <- struct{}{}: - default: - // We raced with the finish func, let the next query handle this with the - // context. - return - } - - // Set the connection state to bad so it does not get reused. - cn.err.set(ctx.Err()) - - // At this point the function level context is canceled, - // so it must not be used for the additional network - // request to cancel the query. - // Create a new context to pass into the dial. - ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) - defer cancel() - - _ = cn.cancel(ctxCancel) - case <-finished: - } - }() - return func() { - select { - case <-finished: - cn.err.set(ctx.Err()) - _ = cn.Close() - case finished <- struct{}{}: - } - } - } - return nil -} - -func (cn *conn) cancel(ctx context.Context) error { - // Use a copy since a new connection is created here. This is necessary - // because cancel is called from a goroutine in watchCancel. - cfg := cn.cfg.Clone() - - c, err := dial(ctx, cn.dialer, cfg) - if err != nil { - return err - } - defer func() { _ = c.Close() }() - - cn2 := conn{c: c} - err = cn2.ssl(cfg, cfg.SSLMode) - if err != nil { - return err - } - - w := cn2.writeBuf(0) - w.int32(proto.CancelRequestCode) - w.int32(cn.pid) - w.bytes(cn.secretKey) - if err := cn2.sendStartupPacket(w); err != nil { - return err - } - - // Read until EOF to ensure that the server received the cancel. - _, err = io.Copy(io.Discard, c) - return err -} - -// Implement the "StmtQueryContext" interface -func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - finish := st.watchCancel(ctx) - r, err := st.query(args) - if err != nil { - if finish != nil { - finish() - } - return nil, err - } - r.finish = finish - return r, nil -} - -// Implement the "StmtExecContext" interface -func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - if finish := st.watchCancel(ctx); finish != nil { - defer finish() - } - if err := st.cn.err.get(); err != nil { - return nil, err - } - - err := st.exec(args) - if err != nil { - return nil, st.cn.handleError(err) - } - res, _, err := st.cn.readExecuteResponse("simple query") - return res, st.cn.handleError(err) -} - -// watchCancel is implemented on stmt in order to not mark the parent conn as bad -func (st *stmt) watchCancel(ctx context.Context) func() { - if done := ctx.Done(); done != nil { - finished := make(chan struct{}) - go func() { - select { - case <-done: - // At this point the function level context is canceled, so it - // must not be used for the additional network request to cancel - // the query. Create a new context to pass into the dial. - ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) - defer cancel() - - _ = st.cancel(ctxCancel) - finished <- struct{}{} - case <-finished: - } - }() - return func() { - select { - case <-finished: - case finished <- struct{}{}: - } - } - } - return nil -} - -func (st *stmt) cancel(ctx context.Context) error { - return st.cn.cancel(ctx) -} diff --git a/connector_test.go b/connector_test.go index b9699b597..792e1a127 100644 --- a/connector_test.go +++ b/connector_test.go @@ -34,8 +34,8 @@ func TestNewConnector(t *testing.T) { t.Fatal(err) } tx.Rollback() - case driver.Conn: - tx, err := db.Begin() //lint:ignore SA1019 x + case driver.ConnBeginTx: + tx, err := db.BeginTx(context.Background(), driver.TxOptions{}) if err != nil { t.Fatal(err) } diff --git a/copy.go b/copy.go index a7c73e011..6a153b2eb 100644 --- a/copy.go +++ b/copy.go @@ -281,9 +281,7 @@ func (ci *copyin) CopyData(ctx context.Context, line string) (driver.Result, err if ci.closed { return nil, errCopyInClosed } - if finish := ci.cn.watchCancel(ctx); finish != nil { - defer finish() - } + defer ci.cn.watchCancel(ctx, false)() if err := ci.getBad(); err != nil { return nil, err } diff --git a/deprecated.go b/deprecated.go index d43934a0a..861076777 100644 --- a/deprecated.go +++ b/deprecated.go @@ -3,10 +3,17 @@ package pq import ( "bytes" "database/sql" + "database/sql/driver" "github.com/lib/pq/pqerror" ) +// Never called, but need to retain them for interface compatibility. +func (cn *conn) Prepare(q string) (driver.Stmt, error) { panic("conn.Prepare") } +func (cn *conn) Begin() (driver.Tx, error) { panic("conn.Begin") } +func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { panic("stmt.Query") } +func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { panic("stmt.Exec") } + // [pq.Error.Severity] values. // // Deprecated: use pqerror.Severity[..] values. diff --git a/rows.go b/rows.go index d87e6767e..10c1dbd7d 100644 --- a/rows.go +++ b/rows.go @@ -41,8 +41,8 @@ type ( ) func (rs *rows) Close() error { - if finish := rs.finish; finish != nil { - defer finish() + if rs.finish != nil { + defer rs.finish() } // no need to look at cn.bad as Next() will for { diff --git a/stmt.go b/stmt.go index ca6ecc896..2653efd35 100644 --- a/stmt.go +++ b/stmt.go @@ -62,27 +62,39 @@ func (st *stmt) Close() error { return nil } -func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - return st.query(toNamedValue(v)) -} - -func (st *stmt) query(v []driver.NamedValue) (*rows, error) { +// Implement [driver.StmtQueryContext]. +func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + finish := st.cn.watchCancel(ctx, true) if err := st.cn.err.get(); err != nil { return nil, err } - err := st.exec(v) + err := st.exec(args) if err != nil { + finish() return nil, st.cn.handleError(err) } + return &rows{ cn: st.cn, rowsHeader: st.rowsHeader, + finish: finish, }, nil } -func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { - return st.ExecContext(context.Background(), toNamedValue(v)) +// Implement [driver.StmtExecContext]. +func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + defer st.cn.watchCancel(ctx, true)() + if err := st.cn.err.get(); err != nil { + return nil, err + } + + err := st.exec(args) + if err != nil { + return nil, st.cn.handleError(err) + } + res, _, err := st.cn.readExecuteResponse("simple query") + return res, st.cn.handleError(err) } func (st *stmt) exec(v []driver.NamedValue) error {