Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 161 additions & 49 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,28 @@ var (

// Compile time validation that our types implement the expected interfaces
var (
_ driver.Driver = Driver{}
_ driver.ConnBeginTx = (*conn)(nil)
_ driver.ConnPrepareContext = (*conn)(nil)
_ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x
_ driver.ExecerContext = (*conn)(nil)
_ driver.NamedValueChecker = (*conn)(nil)
_ driver.Pinger = (*conn)(nil)
_ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x
_ driver.QueryerContext = (*conn)(nil)
_ driver.SessionResetter = (*conn)(nil)
_ driver.Validator = (*conn)(nil)
_ driver.StmtExecContext = (*stmt)(nil)
_ driver.StmtQueryContext = (*stmt)(nil)
_ driver.Driver = Driver{}
//_ driver.DriverContext = Driver{} // TODO: https://github.com/lib/pq/pull/900
_ driver.Connector = (*Connector)(nil)
_ driver.Conn = (*conn)(nil)
_ driver.ConnBeginTx = (*conn)(nil)
_ driver.ConnPrepareContext = (*conn)(nil)
_ driver.ExecerContext = (*conn)(nil)
_ driver.NamedValueChecker = (*conn)(nil)
_ driver.Pinger = (*conn)(nil)
_ driver.QueryerContext = (*conn)(nil)
_ driver.SessionResetter = (*conn)(nil)
_ driver.Validator = (*conn)(nil)
_ driver.Stmt = (*stmt)(nil)
_ driver.StmtExecContext = (*stmt)(nil)
_ driver.StmtQueryContext = (*stmt)(nil)
_ driver.Rows = (*rows)(nil)
_ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil)
_ driver.RowsColumnTypeLength = (*rows)(nil)
//_ driver.RowsColumnTypeNullable = (*rows)(nil) // TODO
_ driver.RowsColumnTypePrecisionScale = (*rows)(nil)
_ driver.RowsColumnTypeScanType = (*rows)(nil)
_ driver.RowsNextResultSet = (*rows)(nil)
)

func init() {
Expand Down Expand Up @@ -479,11 +488,29 @@ func (cn *conn) checkIsInTransaction(intxn bool) error {
return nil
}

func (cn *conn) Begin() (_ driver.Tx, err error) {
return cn.begin("")
}
// Implement [driver.ConnBeginTx].
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
var mode string
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
// Don't touch mode: use the server's default
case sql.LevelReadUncommitted:
mode = " ISOLATION LEVEL READ UNCOMMITTED"
case sql.LevelReadCommitted:
mode = " ISOLATION LEVEL READ COMMITTED"
case sql.LevelRepeatableRead:
mode = " ISOLATION LEVEL REPEATABLE READ"
case sql.LevelSerializable:
mode = " ISOLATION LEVEL SERIALIZABLE"
default:
return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
}
if opts.ReadOnly {
mode += " READ ONLY"
} else {
mode += " READ WRITE"
}

func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
if err := cn.err.get(); err != nil {
return nil, err
}
Expand All @@ -503,17 +530,17 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
cn.err.set(driver.ErrBadConn)
return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
}
return cn, nil
}

func (cn *conn) closeTxn() {
if finish := cn.txnFinish; finish != nil {
finish()
}
cn.txnFinish = cn.watchCancel(ctx, false)
return cn, nil
}

func (cn *conn) Commit() error {
defer cn.closeTxn()
defer func() {
if cn.txnFinish != nil {
cn.txnFinish()
}
}()
if err := cn.err.get(); err != nil {
return err
}
Expand Down Expand Up @@ -549,16 +576,17 @@ func (cn *conn) Commit() error {
}

func (cn *conn) Rollback() error {
defer cn.closeTxn()
defer func() {
if cn.txnFinish != nil {
cn.txnFinish()
}
}()
if err := cn.err.get(); err != nil {
return err
}

err := cn.rollback()
if err != nil {
return cn.handleError(err)
}
return nil
return cn.handleError(err)
}

func (cn *conn) rollback() (err error) {
Expand Down Expand Up @@ -799,7 +827,9 @@ func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) {
return st, nil
}

func (cn *conn) Prepare(q string) (driver.Stmt, error) {
// Implement [driver.ConnPrepareContext].
func (cn *conn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
defer cn.watchCancel(ctx, false)()
if err := cn.err.get(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -829,14 +859,6 @@ func (cn *conn) Close() error {
return cn.c.Close()
}

func toNamedValue(v []driver.Value) []driver.NamedValue {
v2 := make([]driver.NamedValue, len(v))
for i := range v {
v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]}
}
return v2
}

// CheckNamedValue implements [driver.NamedValueChecker].
func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error {
if cn.cfg.BinaryParameters {
Expand Down Expand Up @@ -884,9 +906,18 @@ func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error {
}
}

// Implement the "Queryer" interface
func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return cn.query(query, toNamedValue(args))
// Implement [driver.QueryerContext].
func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
finish := cn.watchCancel(ctx, false)
r, err := cn.query(query, args)
if err != nil {
if finish != nil {
finish()
}
return nil, err
}
r.finish = finish
return r, nil
}

func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) {
Expand Down Expand Up @@ -947,25 +978,24 @@ func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) {
}, nil
}

// Implement the optional "Execer" interface for one-shot queries
func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
// Implement [driver.ExecerContext].
func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
defer cn.watchCancel(ctx, false)()
if err := cn.err.get(); err != nil {
return nil, err
}
if !cn.inProgress.CompareAndSwap(false, true) {
return nil, errQueryInProgress
}

// Check to see if we can use the "simpleExec" interface, which is *much*
// faster than going through prepare/exec
// simpleExec is *much* faster than going through prepare/exec.
if len(args) == 0 {
// ignore commandTag, our caller doesn't care
r, _, err := cn.simpleExec(query)
r, _, err := cn.simpleExec(query) // Ignore commandTag, our caller doesn't care.
return r, cn.handleError(err, query)
}

if cn.cfg.BinaryParameters {
err := cn.sendBinaryModeQuery(query, toNamedValue(args))
err := cn.sendBinaryModeQuery(query, args)
if err != nil {
return nil, cn.handleError(err, query)
}
Expand Down Expand Up @@ -996,13 +1026,23 @@ func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
if err != nil {
return nil, cn.handleError(err, query)
}
r, err := st.Exec(args)
r, err := st.ExecContext(ctx, args)
if err != nil {
return nil, cn.handleError(err, query)
}
return r, nil
}

func (cn *conn) Ping(ctx context.Context) error {
defer cn.watchCancel(ctx, false)()
rows, err := cn.simpleQuery(";")
if err != nil {
return driver.ErrBadConn
}
_ = rows.Close()
return nil
}

type safeRetryError struct{ Err error }

func (se *safeRetryError) Error() string { return se.Err.Error() }
Expand Down Expand Up @@ -1179,6 +1219,78 @@ func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) {
return t, r, nil
}

// We need to let PostgreSQL know the query is cancelled: just dropping the
// connection won't stop the query.
//
// So create a goroutine which selects on ctx.Done() and a finish channel.
// Returns a function to send to this, which should be called after the query is
// finished.
func (cn *conn) watchCancel(ctx context.Context, fromStmt bool) func() {
if ctx.Done() == nil { // "may return nil if this context can never be canceled"
return func() {}
}

finished := make(chan struct{}, 1)
go func() {
select {
case <-finished: // Query finished successfully.
case <-ctx.Done():
select {
case finished <- struct{}{}:
default: // Raced with the finish func, let the next query handle this with the context.
return
}
if !fromStmt {
cn.err.set(ctx.Err()) // Set the connection state to bad so it does not get reused.
}
cn.sendCancelRequest() // TODO: maybe handle error, somehow?
}
}()

return func() {
select {
case <-finished:
if !fromStmt {
cn.err.set(ctx.Err())
cn.Close()
}
case finished <- struct{}{}:
}
}
}

func (cn *conn) sendCancelRequest() error {
// Use a copy since a new connection is created here. This is necessary
// because cancel is called from a goroutine in watchCancel.
cfg := cn.cfg.Clone()

// Can't pass in context from parent, as that one may be cancelled.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // TODO: use connect_timeout?
defer cancel()

c, err := dial(ctx, cn.dialer, cfg)
if err != nil {
return err
}
defer c.Close()

cn2 := conn{c: c}
if err := cn2.ssl(cfg, cfg.SSLMode); err != nil {
return err
}
w := cn2.writeBuf(0)
w.int32(proto.CancelRequestCode)
w.int32(cn.pid)
w.bytes(cn.secretKey)
if err := cn2.sendStartupPacket(w); err != nil {
return err
}

// Read until EOF to ensure that the server received the cancel.
_, err = io.Copy(io.Discard, c)
return err
}

// Don't refer to Config.SSLMode here, as the mode in arguments may be different
// in case of sslmode=allow or prefer.
func (cn *conn) ssl(cfg Config, mode SSLMode) error {
Expand Down
Loading