Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"reflect"
"strings"

"github.com/aws/aws-lambda-go/events"
)
Expand Down Expand Up @@ -46,13 +46,70 @@ func (ml *MockLambda) start(h interface{}) Response {
} else if inputCount == 2 && inputTypes[0] == "context.Context" && inputTypes[1] == "events.SQSEvent" {
response = ml.sqs(h.(func(ctx context.Context, request events.SQSEvent) error))
} else {
response.Payload.Error = "no handler found for method signature func(" + strings.Join(inputTypes, ", ") + ")"
handler := reflect.ValueOf(h)
handlerType := reflect.TypeOf(h)
if handlerType.Kind() != reflect.Func {
response.Payload.Error = fmt.Sprintf("handler kind %s is not %s", handlerType.Kind(), reflect.Func)
return response
}

takesContext, err := validateArguments(handlerType)
if err != nil {
response.Payload.Error = err.Error()
return response
}

if err := validateReturns(handlerType); err != nil {
response.Payload.Error = err.Error()
return response
}

response = func(ctx context.Context, payload string) Response {
var (
args []reflect.Value
result Response
)

if takesContext {
args = append(args, reflect.ValueOf(ctx))
}
if (handlerType.NumIn() == 1 && !takesContext) || handlerType.NumIn() == 2 {
eventType := handlerType.In(handlerType.NumIn() - 1)
event := reflect.New(eventType)

if err := decode(payload, event.Interface()); err != nil {
result.Payload.Error = err.Error()
return result
}

args = append(args, event.Elem())
}

callResult := handler.Call(args)

// convert return values into (interface{}, error)
if len(callResult) > 0 {
if errVal, ok := callResult[len(callResult)-1].Interface().(error); ok {
result.Payload.Error = errVal.Error()
}
}
if len(callResult) > 1 {
result.Payload.Success = callResult[0].Interface()
}

return result
}(context.TODO(), os.Getenv("LAMBDA_EVENT"))
}

return response
}

func Start(h interface{}) {
if h == nil {
fmt.Println("no handler found")
return
}

ml := MockLambda{api: api, token: token, request: request, sqs: sqs}
response := ml.start(h)

Expand Down
42 changes: 42 additions & 0 deletions validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package lambda

import (
"context"
"fmt"
"reflect"
)

func validateArguments(handler reflect.Type) (bool, error) {
handlerTakesContext := false
if handler.NumIn() > 2 {
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
} else if handler.NumIn() > 0 {
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
argumentType := handler.In(0)
handlerTakesContext = argumentType.Implements(contextType)
if handler.NumIn() > 1 && !handlerTakesContext {
return false, fmt.Errorf("handler takes two arguments, but the first is not Context. got %s", argumentType.Kind())
}
}

return handlerTakesContext, nil
}

func validateReturns(handler reflect.Type) error {
errorType := reflect.TypeOf((*error)(nil)).Elem()

switch n := handler.NumOut(); {
case n > 2:
return fmt.Errorf("handler may not return more than two values")
case n > 1:
if !handler.Out(1).Implements(errorType) {
return fmt.Errorf("handler returns two values, but the second does not implement error")
}
case n == 1:
if !handler.Out(0).Implements(errorType) {
return fmt.Errorf("handler returns a single value, but it does not implement error")
}
}

return nil
}