diff --git a/conn.go b/conn.go index 012c8c7c..59e3bd58 100644 --- a/conn.go +++ b/conn.go @@ -600,7 +600,7 @@ func (cn *conn) gname() string { func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { b := cn.writeBuf('Q') b.string(q) - cn.send(b) + cn.mustSendRetryable(b) for { t, r := cn.recv1() @@ -632,7 +632,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { b := cn.writeBuf('Q') b.string(q) - cn.send(b) + cn.mustSendRetryable(b) for { t, r := cn.recv1() @@ -765,7 +765,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt { b.string(st.name) b.next('S') - cn.send(b) + cn.mustSendRetryable(b) cn.readParseResponse() st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() @@ -882,13 +882,29 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err return r, err } -func (cn *conn) send(m *writeBuf) { - _, err := cn.c.Write(m.wrap()) +func (cn *conn) send(m *writeBuf) (int, error) { + return cn.c.Write(m.wrap()) +} + +func (cn *conn) mustSend(m *writeBuf) { + _, err := cn.send(m) if err != nil { panic(err) } } +func (cn *conn) mustSendRetryable(m *writeBuf) { + sentBytes, err := cn.send(m) + if err != nil { + if _, ok := err.(*net.OpError); ok { + if sentBytes == 0 { + err = &netErrorNoWrite{err} + } + } + panic(err) + } +} + func (cn *conn) sendStartupPacket(m *writeBuf) error { _, err := cn.c.Write((m.wrap())[1:]) return err @@ -1109,7 +1125,7 @@ func (cn *conn) auth(r *readBuf, o values) { case 3: w := cn.writeBuf('p') w.string(o["password"]) - cn.send(w) + cn.mustSend(w) t, r := cn.recv() if t != 'R' { @@ -1123,7 +1139,7 @@ func (cn *conn) auth(r *readBuf, o values) { s := string(r.next(4)) w := cn.writeBuf('p') w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) - cn.send(w) + cn.mustSend(w) t, r := cn.recv() if t != 'R' { @@ -1145,7 +1161,7 @@ func (cn *conn) auth(r *readBuf, o values) { w.string("SCRAM-SHA-256") w.int32(len(scOut)) w.bytes(scOut) - cn.send(w) + cn.mustSend(w) t, r := cn.recv() if t != 'R' { @@ -1165,7 +1181,7 @@ func (cn *conn) auth(r *readBuf, o values) { scOut = sc.Out() w = cn.writeBuf('p') w.bytes(scOut) - cn.send(w) + cn.mustSend(w) t, r = cn.recv() if t != 'R' { @@ -1219,9 +1235,9 @@ func (st *stmt) Close() (err error) { w := st.cn.writeBuf('C') w.byte('S') w.string(st.name) - st.cn.send(w) + st.cn.mustSend(w) - st.cn.send(st.cn.writeBuf('S')) + st.cn.mustSend(st.cn.writeBuf('S')) t, _ := st.cn.recv1() if t != '3' { @@ -1299,7 +1315,7 @@ func (st *stmt) exec(v []driver.Value) { w.int32(0) w.next('S') - cn.send(w) + cn.mustSend(w) cn.readBindResponse() cn.postExecuteWorkaround() @@ -1601,7 +1617,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { b.int32(0) b.next('S') - cn.send(b) + cn.mustSendRetryable(b) } func (cn *conn) processParameterStatus(r *readBuf) { diff --git a/error.go b/error.go index 96aae29c..a978460a 100644 --- a/error.go +++ b/error.go @@ -460,6 +460,18 @@ func errorf(s string, args ...interface{}) { panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } +// NetErrorNoWrite is a network error that occured before a message that +// indicates the operation to execute was transfered to the server. +// These operations are safe to retry. This error should be replaced with +// driver.ErrBadConn before it's passed to the caller. +type netErrorNoWrite struct { + Err error +} + +func (e *netErrorNoWrite) Error() string { + return "netErrorNoWrite: " + e.Err.Error() +} + // TODO(ainar-g) Rename to errorf after removing panics. func fmterrorf(s string, args ...interface{}) error { return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)) @@ -492,6 +504,9 @@ func (c *conn) errRecover(err *error) { } else { *err = v } + case *netErrorNoWrite: + c.bad = true + *err = driver.ErrBadConn case *net.OpError: c.bad = true *err = v