diff --git a/conn.go b/conn.go index 8ce19d1..18466cb 100644 --- a/conn.go +++ b/conn.go @@ -187,15 +187,17 @@ func (c *Conn) close(cause error) error { } func (c *Conn) readMessages(ctx context.Context) { - var err error - for err == nil { + for { var m anyMessage - err = c.stream.ReadObject(&m) + err := c.stream.ReadObject(&m) if err != nil { - break + c.close(err) + return } switch { + // TODO: handle the case where both request and response are nil. + case m.request != nil: for _, onRecv := range c.onRecv { onRecv(m.request, nil) @@ -204,43 +206,36 @@ func (c *Conn) readMessages(ctx context.Context) { case m.response != nil: resp := m.response - if resp != nil { - id := resp.ID - c.mu.Lock() - call := c.pending[id] - delete(c.pending, id) - c.mu.Unlock() + id := resp.ID + c.mu.Lock() + call := c.pending[id] + delete(c.pending, id) + c.mu.Unlock() + + var req *Request + if call != nil { + call.response = resp + req = call.request + } + + for _, onRecv := range c.onRecv { + onRecv(req, resp) + } + + if call == nil { + c.logger.Printf("jsonrpc2: ignoring response #%s with no corresponding request\n", id) + continue + } - if call != nil { - call.response = resp - } - - if len(c.onRecv) > 0 { - var req *Request - if call != nil { - req = call.request - } - for _, onRecv := range c.onRecv { - onRecv(req, resp) - } - } - - switch { - case call == nil: - c.logger.Printf("jsonrpc2: ignoring response #%s with no corresponding request\n", id) - - case resp.Error != nil: - call.done <- resp.Error - close(call.done) - - default: - call.done <- nil - close(call.done) - } + var err error + if resp.Error != nil { + err = resp.Error } + + call.done <- err + close(call.done) } } - c.close(err) } func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err error) { @@ -339,25 +334,20 @@ type Waiter struct { // error is returned. func (w Waiter) Wait(ctx context.Context, result interface{}) error { select { + case <-ctx.Done(): + return ctx.Err() + case err, ok := <-w.call.done: if !ok { - err = ErrClosed + return ErrClosed } - if err != nil { + if err != nil || result == nil { return err } - if result != nil { - if w.call.response.Result == nil { - w.call.response.Result = &jsonNull - } - if err := json.Unmarshal(*w.call.response.Result, result); err != nil { - return err - } + if w.call.response.Result == nil { + w.call.response.Result = &jsonNull } - return nil - - case <-ctx.Done(): - return ctx.Err() + return json.Unmarshal(*w.call.response.Result, result) } } @@ -423,12 +413,7 @@ func (m *anyMessage) UnmarshalJSON(data []byte) error { return errors.New("jsonrpc2: invalid empty batch") } for i := range msgs { - if err := checkType(&msg{ - ID: msgs[i].ID, - Method: msgs[i].Method, - Result: msgs[i].Result, - Error: msgs[i].Error, - }); err != nil { + if err := checkType(&msgs[i]); err != nil { return err } } diff --git a/conn_opt.go b/conn_opt.go index a83ccc3..8a29f80 100644 --- a/conn_opt.go +++ b/conn_opt.go @@ -44,11 +44,9 @@ func LogMessages(logger Logger) ConnOpt { OnRecv(func(req *Request, resp *Response) { switch { case resp != nil: - var method string + method := "(no matching request)" if req != nil { method = req.Method - } else { - method = "(no matching request)" } switch { case resp.Result != nil: diff --git a/handler_with_error.go b/handler_with_error.go index 2bd5c1d..d727237 100644 --- a/handler_with_error.go +++ b/handler_with_error.go @@ -30,20 +30,16 @@ func (h *HandlerWithErrorConfigurer) Handle(ctx context.Context, conn *Conn, req if err == nil { err = resp.SetResult(result) } - if err != nil { - if e, ok := err.(*Error); ok { - resp.Error = e - } else { - resp.Error = &Error{Message: err.Error()} - } + + if e, ok := err.(*Error); ok { + resp.Error = e + } else if err != nil { + resp.Error = &Error{Message: err.Error()} } - if !req.Notif { - if err := conn.SendResponse(ctx, resp); err != nil { - if err != ErrClosed || !h.suppressErrClosed { - conn.logger.Printf("jsonrpc2 handler: sending response %s: %v\n", resp.ID, err) - } - } + err = conn.SendResponse(ctx, resp) + if err != nil && (err != ErrClosed || !h.suppressErrClosed) { + conn.logger.Printf("jsonrpc2 handler: sending response %s: %v\n", resp.ID, err) } } diff --git a/request.go b/request.go index 372b3e7..b9cdde0 100644 --- a/request.go +++ b/request.go @@ -55,6 +55,10 @@ func (r Request) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (r *Request) UnmarshalJSON(data []byte) error { r2 := make(map[string]interface{}) + pop := func(key string) interface{} { + defer delete(r2, key) + return r2[key] + } // Detect if the "params" or "meta" fields are JSON "null" or just not // present by seeing if the field gets overwritten to nil. @@ -68,36 +72,37 @@ func (r *Request) UnmarshalJSON(data []byte) error { if err := decoder.Decode(&r2); err != nil { return err } + var ok bool - r.Method, ok = r2["method"].(string) + r.Method, ok = pop("method").(string) if !ok { return errors.New("missing method field") } - switch { - case r2["params"] == nil: + switch params := pop("params"); params { + case nil: r.Params = &jsonNull - case r2["params"] == emptyParams: + case emptyParams: r.Params = nil default: - b, err := json.Marshal(r2["params"]) + b, err := json.Marshal(params) if err != nil { return fmt.Errorf("failed to marshal params: %w", err) } r.Params = (*json.RawMessage)(&b) } - switch { - case r2["meta"] == nil: + switch meta := pop("meta"); meta { + case nil: r.Meta = &jsonNull - case r2["meta"] == emptyMeta: + case emptyMeta: r.Meta = nil default: - b, err := json.Marshal(r2["meta"]) + b, err := json.Marshal(meta) if err != nil { return fmt.Errorf("failed to marshal Meta: %w", err) } r.Meta = (*json.RawMessage)(&b) } - switch rawID := r2["id"].(type) { + switch rawID := pop("id").(type) { case nil: r.ID = ID{} r.Notif = true @@ -115,13 +120,12 @@ func (r *Request) UnmarshalJSON(data []byte) error { return fmt.Errorf("unexpected ID type: %T", rawID) } + // The jsonrpc field should not be added to ExtraFields. + delete(r2, "jsonrpc") + // Clear the extra fields before populating them again. r.ExtraFields = nil for name, value := range r2 { - switch name { - case "id", "jsonrpc", "meta", "method", "params": - continue - } r.ExtraFields = append(r.ExtraFields, RequestField{ Name: name, Value: value,