Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 41 additions & 56 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,17 @@ func (c *Conn) close(cause error) error {
}

func (c *Conn) readMessages(ctx context.Context) {
var err error
for err == nil {
for {
var m anyMessage
err = c.stream.ReadObject(&m)
err := c.stream.ReadObject(&m)
if err != nil {
break
c.close(err)
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice I looked at the old code and have no idea why we structured it that way. Thanks

}

switch {
// TODO: handle the case where both request and response are nil.

case m.request != nil:
for _, onRecv := range c.onRecv {
onRecv(m.request, nil)
Expand All @@ -204,43 +206,36 @@ func (c *Conn) readMessages(ctx context.Context) {

case m.response != nil:
resp := m.response
if resp != nil {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested if is unnecessary, since the switch case just proved that the response is non-nil.

id := resp.ID
c.mu.Lock()
call := c.pending[id]
delete(c.pending, id)
c.mu.Unlock()
id := resp.ID
c.mu.Lock()
call := c.pending[id]
delete(c.pending, id)
c.mu.Unlock()

var req *Request
if call != nil {
call.response = resp
req = call.request
}

for _, onRecv := range c.onRecv {
onRecv(req, resp)
}

if call == nil {
c.logger.Printf("jsonrpc2: ignoring response #%s with no corresponding request\n", id)
continue
}

if call != nil {
call.response = resp
}

if len(c.onRecv) > 0 {
var req *Request
if call != nil {
req = call.request
}
for _, onRecv := range c.onRecv {
onRecv(req, resp)
}
}

switch {
case call == nil:
c.logger.Printf("jsonrpc2: ignoring response #%s with no corresponding request\n", id)

case resp.Error != nil:
call.done <- resp.Error
close(call.done)

default:
call.done <- nil
close(call.done)
}
var err error
if resp.Error != nil {
err = resp.Error
}

call.done <- err
close(call.done)
}
}
c.close(err)
}

func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err error) {
Expand Down Expand Up @@ -339,25 +334,20 @@ type Waiter struct {
// error is returned.
func (w Waiter) Wait(ctx context.Context, result interface{}) error {
select {
case <-ctx.Done():
return ctx.Err()

case err, ok := <-w.call.done:
if !ok {
err = ErrClosed
return ErrClosed
}
if err != nil {
if err != nil || result == nil {
return err
}
if result != nil {
if w.call.response.Result == nil {
w.call.response.Result = &jsonNull
}
if err := json.Unmarshal(*w.call.response.Result, result); err != nil {
return err
}
if w.call.response.Result == nil {
w.call.response.Result = &jsonNull
}
return nil

case <-ctx.Done():
return ctx.Err()
return json.Unmarshal(*w.call.response.Result, result)
}
}

Expand Down Expand Up @@ -423,12 +413,7 @@ func (m *anyMessage) UnmarshalJSON(data []byte) error {
return errors.New("jsonrpc2: invalid empty batch")
}
for i := range msgs {
if err := checkType(&msg{
ID: msgs[i].ID,
Method: msgs[i].Method,
Result: msgs[i].Result,
Error: msgs[i].Error,
}); err != nil {
Comment on lines -426 to -431
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to construct a new msg literal, since the result is identical to what is already in msgs[i].

if err := checkType(&msgs[i]); err != nil {
return err
}
}
Expand Down
4 changes: 1 addition & 3 deletions conn_opt.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ func LogMessages(logger Logger) ConnOpt {
OnRecv(func(req *Request, resp *Response) {
switch {
case resp != nil:
var method string
method := "(no matching request)"
if req != nil {
method = req.Method
} else {
method = "(no matching request)"
}
switch {
case resp.Result != nil:
Expand Down
20 changes: 8 additions & 12 deletions handler_with_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,16 @@ func (h *HandlerWithErrorConfigurer) Handle(ctx context.Context, conn *Conn, req
if err == nil {
err = resp.SetResult(result)
}
if err != nil {
if e, ok := err.(*Error); ok {
resp.Error = e
} else {
resp.Error = &Error{Message: err.Error()}
}

if e, ok := err.(*Error); ok {
resp.Error = e
} else if err != nil {
resp.Error = &Error{Message: err.Error()}
}

if !req.Notif {
if err := conn.SendResponse(ctx, resp); err != nil {
if err != ErrClosed || !h.suppressErrClosed {
conn.logger.Printf("jsonrpc2 handler: sending response %s: %v\n", resp.ID, err)
}
}
err = conn.SendResponse(ctx, resp)
if err != nil && (err != ErrClosed || !h.suppressErrClosed) {
conn.logger.Printf("jsonrpc2 handler: sending response %s: %v\n", resp.ID, err)
}
}

Expand Down
32 changes: 18 additions & 14 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ func (r Request) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.
func (r *Request) UnmarshalJSON(data []byte) error {
r2 := make(map[string]interface{})
pop := func(key string) interface{} {
defer delete(r2, key)
return r2[key]
}

// Detect if the "params" or "meta" fields are JSON "null" or just not
// present by seeing if the field gets overwritten to nil.
Expand All @@ -68,36 +72,37 @@ func (r *Request) UnmarshalJSON(data []byte) error {
if err := decoder.Decode(&r2); err != nil {
return err
}

var ok bool
r.Method, ok = r2["method"].(string)
r.Method, ok = pop("method").(string)
if !ok {
return errors.New("missing method field")
}
switch {
case r2["params"] == nil:
switch params := pop("params"); params {
case nil:
r.Params = &jsonNull
case r2["params"] == emptyParams:
case emptyParams:
r.Params = nil
default:
b, err := json.Marshal(r2["params"])
b, err := json.Marshal(params)
if err != nil {
return fmt.Errorf("failed to marshal params: %w", err)
}
r.Params = (*json.RawMessage)(&b)
}
switch {
case r2["meta"] == nil:
switch meta := pop("meta"); meta {
case nil:
r.Meta = &jsonNull
case r2["meta"] == emptyMeta:
case emptyMeta:
r.Meta = nil
default:
b, err := json.Marshal(r2["meta"])
b, err := json.Marshal(meta)
if err != nil {
return fmt.Errorf("failed to marshal Meta: %w", err)
}
r.Meta = (*json.RawMessage)(&b)
}
switch rawID := r2["id"].(type) {
switch rawID := pop("id").(type) {
case nil:
r.ID = ID{}
r.Notif = true
Expand All @@ -115,13 +120,12 @@ func (r *Request) UnmarshalJSON(data []byte) error {
return fmt.Errorf("unexpected ID type: %T", rawID)
}

// The jsonrpc field should not be added to ExtraFields.
delete(r2, "jsonrpc")

// Clear the extra fields before populating them again.
r.ExtraFields = nil
for name, value := range r2 {
switch name {
case "id", "jsonrpc", "meta", "method", "params":
continue
}
Comment on lines -121 to -124
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Through use of pop (and delete), none of these fields will be present.

r.ExtraFields = append(r.ExtraFields, RequestField{
Name: name,
Value: value,
Expand Down
Loading