diff --git a/conn.go b/conn.go index 667c5fa9..3ed7ecfa 100644 --- a/conn.go +++ b/conn.go @@ -47,19 +47,28 @@ var ( // Compile time validation that our types implement the expected interfaces var ( - _ driver.Driver = Driver{} - _ driver.ConnBeginTx = (*conn)(nil) - _ driver.ConnPrepareContext = (*conn)(nil) - _ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x - _ driver.ExecerContext = (*conn)(nil) - _ driver.NamedValueChecker = (*conn)(nil) - _ driver.Pinger = (*conn)(nil) - _ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x - _ driver.QueryerContext = (*conn)(nil) - _ driver.SessionResetter = (*conn)(nil) - _ driver.Validator = (*conn)(nil) - _ driver.StmtExecContext = (*stmt)(nil) - _ driver.StmtQueryContext = (*stmt)(nil) + _ driver.Driver = Driver{} + //_ driver.DriverContext = Driver{} // TODO: https://github.com/lib/pq/pull/900 + _ driver.Connector = (*Connector)(nil) + _ driver.Conn = (*conn)(nil) + _ driver.ConnBeginTx = (*conn)(nil) + _ driver.ConnPrepareContext = (*conn)(nil) + _ driver.ExecerContext = (*conn)(nil) + _ driver.NamedValueChecker = (*conn)(nil) + _ driver.Pinger = (*conn)(nil) + _ driver.QueryerContext = (*conn)(nil) + _ driver.SessionResetter = (*conn)(nil) + _ driver.Validator = (*conn)(nil) + _ driver.Stmt = (*stmt)(nil) + _ driver.StmtExecContext = (*stmt)(nil) + _ driver.StmtQueryContext = (*stmt)(nil) + _ driver.Rows = (*rows)(nil) + _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil) + _ driver.RowsColumnTypeLength = (*rows)(nil) + //_ driver.RowsColumnTypeNullable = (*rows)(nil) // TODO + _ driver.RowsColumnTypePrecisionScale = (*rows)(nil) + _ driver.RowsColumnTypeScanType = (*rows)(nil) + _ driver.RowsNextResultSet = (*rows)(nil) ) func init() { @@ -479,11 +488,29 @@ func (cn *conn) checkIsInTransaction(intxn bool) error { return nil } -func (cn *conn) Begin() (_ driver.Tx, err error) { - return cn.begin("") -} +// Implement [driver.ConnBeginTx]. +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var mode string + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) + } + if opts.ReadOnly { + mode += " READ ONLY" + } else { + mode += " READ WRITE" + } -func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if err := cn.err.get(); err != nil { return nil, err } @@ -503,17 +530,17 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } - return cn, nil -} -func (cn *conn) closeTxn() { - if finish := cn.txnFinish; finish != nil { - finish() - } + cn.txnFinish = cn.watchCancel(ctx, false) + return cn, nil } func (cn *conn) Commit() error { - defer cn.closeTxn() + defer func() { + if cn.txnFinish != nil { + cn.txnFinish() + } + }() if err := cn.err.get(); err != nil { return err } @@ -549,16 +576,17 @@ func (cn *conn) Commit() error { } func (cn *conn) Rollback() error { - defer cn.closeTxn() + defer func() { + if cn.txnFinish != nil { + cn.txnFinish() + } + }() if err := cn.err.get(); err != nil { return err } err := cn.rollback() - if err != nil { - return cn.handleError(err) - } - return nil + return cn.handleError(err) } func (cn *conn) rollback() (err error) { @@ -799,7 +827,9 @@ func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) { return st, nil } -func (cn *conn) Prepare(q string) (driver.Stmt, error) { +// Implement [driver.ConnPrepareContext]. +func (cn *conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) { + defer cn.watchCancel(ctx, false)() if err := cn.err.get(); err != nil { return nil, err } @@ -829,14 +859,6 @@ func (cn *conn) Close() error { return cn.c.Close() } -func toNamedValue(v []driver.Value) []driver.NamedValue { - v2 := make([]driver.NamedValue, len(v)) - for i := range v { - v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]} - } - return v2 -} - // CheckNamedValue implements [driver.NamedValueChecker]. func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { if cn.cfg.BinaryParameters { @@ -884,9 +906,18 @@ func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { } } -// Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { - return cn.query(query, toNamedValue(args)) +// Implement [driver.QueryerContext]. +func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + finish := cn.watchCancel(ctx, false) + r, err := cn.query(query, args) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil } func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { @@ -947,8 +978,9 @@ func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { }, nil } -// Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { +// Implement [driver.ExecerContext]. +func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + defer cn.watchCancel(ctx, false)() if err := cn.err.get(); err != nil { return nil, err } @@ -956,16 +988,14 @@ func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { return nil, errQueryInProgress } - // Check to see if we can use the "simpleExec" interface, which is *much* - // faster than going through prepare/exec + // simpleExec is *much* faster than going through prepare/exec. if len(args) == 0 { - // ignore commandTag, our caller doesn't care - r, _, err := cn.simpleExec(query) + r, _, err := cn.simpleExec(query) // Ignore commandTag, our caller doesn't care. return r, cn.handleError(err, query) } if cn.cfg.BinaryParameters { - err := cn.sendBinaryModeQuery(query, toNamedValue(args)) + err := cn.sendBinaryModeQuery(query, args) if err != nil { return nil, cn.handleError(err, query) } @@ -996,13 +1026,23 @@ func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { if err != nil { return nil, cn.handleError(err, query) } - r, err := st.Exec(args) + r, err := st.ExecContext(ctx, args) if err != nil { return nil, cn.handleError(err, query) } return r, nil } +func (cn *conn) Ping(ctx context.Context) error { + defer cn.watchCancel(ctx, false)() + rows, err := cn.simpleQuery(";") + if err != nil { + return driver.ErrBadConn + } + _ = rows.Close() + return nil +} + type safeRetryError struct{ Err error } func (se *safeRetryError) Error() string { return se.Err.Error() } @@ -1179,6 +1219,78 @@ func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) { return t, r, nil } +// We need to let PostgreSQL know the query is cancelled: just dropping the +// connection won't stop the query. +// +// So create a goroutine which selects on ctx.Done() and a finish channel. +// Returns a function to send to this, which should be called after the query is +// finished. +func (cn *conn) watchCancel(ctx context.Context, fromStmt bool) func() { + if ctx.Done() == nil { // "may return nil if this context can never be canceled" + return func() {} + } + + finished := make(chan struct{}, 1) + go func() { + select { + case <-finished: // Query finished successfully. + case <-ctx.Done(): + select { + case finished <- struct{}{}: + default: // Raced with the finish func, let the next query handle this with the context. + return + } + if !fromStmt { + cn.err.set(ctx.Err()) // Set the connection state to bad so it does not get reused. + } + cn.sendCancelRequest() // TODO: maybe handle error, somehow? + } + }() + + return func() { + select { + case <-finished: + if !fromStmt { + cn.err.set(ctx.Err()) + cn.Close() + } + case finished <- struct{}{}: + } + } +} + +func (cn *conn) sendCancelRequest() error { + // Use a copy since a new connection is created here. This is necessary + // because cancel is called from a goroutine in watchCancel. + cfg := cn.cfg.Clone() + + // Can't pass in context from parent, as that one may be cancelled. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // TODO: use connect_timeout? + defer cancel() + + c, err := dial(ctx, cn.dialer, cfg) + if err != nil { + return err + } + defer c.Close() + + cn2 := conn{c: c} + if err := cn2.ssl(cfg, cfg.SSLMode); err != nil { + return err + } + w := cn2.writeBuf(0) + w.int32(proto.CancelRequestCode) + w.int32(cn.pid) + w.bytes(cn.secretKey) + if err := cn2.sendStartupPacket(w); err != nil { + return err + } + + // Read until EOF to ensure that the server received the cancel. + _, err = io.Copy(io.Discard, c) + return err +} + // Don't refer to Config.SSLMode here, as the mode in arguments may be different // in case of sslmode=allow or prefer. func (cn *conn) ssl(cfg Config, mode SSLMode) error { diff --git a/conn_go18.go b/conn_go18.go deleted file mode 100644 index 16de38eb..00000000 --- a/conn_go18.go +++ /dev/null @@ -1,226 +0,0 @@ -package pq - -import ( - "context" - "database/sql" - "database/sql/driver" - "fmt" - "io" - "time" - - "github.com/lib/pq/internal/proto" -) - -const watchCancelDialContextTimeout = 10 * time.Second - -// Implement the "QueryerContext" interface -func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - finish := cn.watchCancel(ctx) - r, err := cn.query(query, args) - if err != nil { - if finish != nil { - finish() - } - return nil, err - } - r.finish = finish - return r, nil -} - -// Implement the "ExecerContext" interface -func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - - return cn.Exec(query, list) -} - -// Implement the "ConnPrepareContext" interface -func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - return cn.Prepare(query) -} - -// Implement the "ConnBeginTx" interface -func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - var mode string - switch sql.IsolationLevel(opts.Isolation) { - case sql.LevelDefault: - // Don't touch mode: use the server's default - case sql.LevelReadUncommitted: - mode = " ISOLATION LEVEL READ UNCOMMITTED" - case sql.LevelReadCommitted: - mode = " ISOLATION LEVEL READ COMMITTED" - case sql.LevelRepeatableRead: - mode = " ISOLATION LEVEL REPEATABLE READ" - case sql.LevelSerializable: - mode = " ISOLATION LEVEL SERIALIZABLE" - default: - return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) - } - if opts.ReadOnly { - mode += " READ ONLY" - } else { - mode += " READ WRITE" - } - - tx, err := cn.begin(mode) - if err != nil { - return nil, err - } - cn.txnFinish = cn.watchCancel(ctx) - return tx, nil -} - -func (cn *conn) Ping(ctx context.Context) error { - if finish := cn.watchCancel(ctx); finish != nil { - defer finish() - } - rows, err := cn.simpleQuery(";") - if err != nil { - return driver.ErrBadConn - } - _ = rows.Close() - return nil -} - -func (cn *conn) watchCancel(ctx context.Context) func() { - if done := ctx.Done(); done != nil { - finished := make(chan struct{}, 1) - go func() { - select { - case <-done: - select { - case finished <- struct{}{}: - default: - // We raced with the finish func, let the next query handle this with the - // context. - return - } - - // Set the connection state to bad so it does not get reused. - cn.err.set(ctx.Err()) - - // At this point the function level context is canceled, - // so it must not be used for the additional network - // request to cancel the query. - // Create a new context to pass into the dial. - ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) - defer cancel() - - _ = cn.cancel(ctxCancel) - case <-finished: - } - }() - return func() { - select { - case <-finished: - cn.err.set(ctx.Err()) - _ = cn.Close() - case finished <- struct{}{}: - } - } - } - return nil -} - -func (cn *conn) cancel(ctx context.Context) error { - // Use a copy since a new connection is created here. This is necessary - // because cancel is called from a goroutine in watchCancel. - cfg := cn.cfg.Clone() - - c, err := dial(ctx, cn.dialer, cfg) - if err != nil { - return err - } - defer func() { _ = c.Close() }() - - cn2 := conn{c: c} - err = cn2.ssl(cfg, cfg.SSLMode) - if err != nil { - return err - } - - w := cn2.writeBuf(0) - w.int32(proto.CancelRequestCode) - w.int32(cn.pid) - w.bytes(cn.secretKey) - if err := cn2.sendStartupPacket(w); err != nil { - return err - } - - // Read until EOF to ensure that the server received the cancel. - _, err = io.Copy(io.Discard, c) - return err -} - -// Implement the "StmtQueryContext" interface -func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - finish := st.watchCancel(ctx) - r, err := st.query(args) - if err != nil { - if finish != nil { - finish() - } - return nil, err - } - r.finish = finish - return r, nil -} - -// Implement the "StmtExecContext" interface -func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - if finish := st.watchCancel(ctx); finish != nil { - defer finish() - } - if err := st.cn.err.get(); err != nil { - return nil, err - } - - err := st.exec(args) - if err != nil { - return nil, st.cn.handleError(err) - } - res, _, err := st.cn.readExecuteResponse("simple query") - return res, st.cn.handleError(err) -} - -// watchCancel is implemented on stmt in order to not mark the parent conn as bad -func (st *stmt) watchCancel(ctx context.Context) func() { - if done := ctx.Done(); done != nil { - finished := make(chan struct{}) - go func() { - select { - case <-done: - // At this point the function level context is canceled, so it - // must not be used for the additional network request to cancel - // the query. Create a new context to pass into the dial. - ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) - defer cancel() - - _ = st.cancel(ctxCancel) - finished <- struct{}{} - case <-finished: - } - }() - return func() { - select { - case <-finished: - case finished <- struct{}{}: - } - } - } - return nil -} - -func (st *stmt) cancel(ctx context.Context) error { - return st.cn.cancel(ctx) -} diff --git a/connector_test.go b/connector_test.go index b9699b59..792e1a12 100644 --- a/connector_test.go +++ b/connector_test.go @@ -34,8 +34,8 @@ func TestNewConnector(t *testing.T) { t.Fatal(err) } tx.Rollback() - case driver.Conn: - tx, err := db.Begin() //lint:ignore SA1019 x + case driver.ConnBeginTx: + tx, err := db.BeginTx(context.Background(), driver.TxOptions{}) if err != nil { t.Fatal(err) } diff --git a/copy.go b/copy.go index a7c73e01..6a153b2e 100644 --- a/copy.go +++ b/copy.go @@ -281,9 +281,7 @@ func (ci *copyin) CopyData(ctx context.Context, line string) (driver.Result, err if ci.closed { return nil, errCopyInClosed } - if finish := ci.cn.watchCancel(ctx); finish != nil { - defer finish() - } + defer ci.cn.watchCancel(ctx, false)() if err := ci.getBad(); err != nil { return nil, err } diff --git a/deprecated.go b/deprecated.go index d43934a0..86107677 100644 --- a/deprecated.go +++ b/deprecated.go @@ -3,10 +3,17 @@ package pq import ( "bytes" "database/sql" + "database/sql/driver" "github.com/lib/pq/pqerror" ) +// Never called, but need to retain them for interface compatibility. +func (cn *conn) Prepare(q string) (driver.Stmt, error) { panic("conn.Prepare") } +func (cn *conn) Begin() (driver.Tx, error) { panic("conn.Begin") } +func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { panic("stmt.Query") } +func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { panic("stmt.Exec") } + // [pq.Error.Severity] values. // // Deprecated: use pqerror.Severity[..] values. diff --git a/rows.go b/rows.go index d87e6767..10c1dbd7 100644 --- a/rows.go +++ b/rows.go @@ -41,8 +41,8 @@ type ( ) func (rs *rows) Close() error { - if finish := rs.finish; finish != nil { - defer finish() + if rs.finish != nil { + defer rs.finish() } // no need to look at cn.bad as Next() will for { diff --git a/stmt.go b/stmt.go index ca6ecc89..2653efd3 100644 --- a/stmt.go +++ b/stmt.go @@ -62,27 +62,39 @@ func (st *stmt) Close() error { return nil } -func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - return st.query(toNamedValue(v)) -} - -func (st *stmt) query(v []driver.NamedValue) (*rows, error) { +// Implement [driver.StmtQueryContext]. +func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + finish := st.cn.watchCancel(ctx, true) if err := st.cn.err.get(); err != nil { return nil, err } - err := st.exec(v) + err := st.exec(args) if err != nil { + finish() return nil, st.cn.handleError(err) } + return &rows{ cn: st.cn, rowsHeader: st.rowsHeader, + finish: finish, }, nil } -func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { - return st.ExecContext(context.Background(), toNamedValue(v)) +// Implement [driver.StmtExecContext]. +func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + defer st.cn.watchCancel(ctx, true)() + if err := st.cn.err.get(); err != nil { + return nil, err + } + + err := st.exec(args) + if err != nil { + return nil, st.cn.handleError(err) + } + res, _, err := st.cn.readExecuteResponse("simple query") + return res, st.cn.handleError(err) } func (st *stmt) exec(v []driver.NamedValue) error {