diff --git a/processor.go b/processor.go index f9c8087..58a6457 100644 --- a/processor.go +++ b/processor.go @@ -4,8 +4,8 @@ import ( "context" "encoding/json" "fmt" + "os" "reflect" - "strings" "github.com/aws/aws-lambda-go/events" ) @@ -46,13 +46,70 @@ func (ml *MockLambda) start(h interface{}) Response { } else if inputCount == 2 && inputTypes[0] == "context.Context" && inputTypes[1] == "events.SQSEvent" { response = ml.sqs(h.(func(ctx context.Context, request events.SQSEvent) error)) } else { - response.Payload.Error = "no handler found for method signature func(" + strings.Join(inputTypes, ", ") + ")" + handler := reflect.ValueOf(h) + handlerType := reflect.TypeOf(h) + if handlerType.Kind() != reflect.Func { + response.Payload.Error = fmt.Sprintf("handler kind %s is not %s", handlerType.Kind(), reflect.Func) + return response + } + + takesContext, err := validateArguments(handlerType) + if err != nil { + response.Payload.Error = err.Error() + return response + } + + if err := validateReturns(handlerType); err != nil { + response.Payload.Error = err.Error() + return response + } + + response = func(ctx context.Context, payload string) Response { + var ( + args []reflect.Value + result Response + ) + + if takesContext { + args = append(args, reflect.ValueOf(ctx)) + } + if (handlerType.NumIn() == 1 && !takesContext) || handlerType.NumIn() == 2 { + eventType := handlerType.In(handlerType.NumIn() - 1) + event := reflect.New(eventType) + + if err := decode(payload, event.Interface()); err != nil { + result.Payload.Error = err.Error() + return result + } + + args = append(args, event.Elem()) + } + + callResult := handler.Call(args) + + // convert return values into (interface{}, error) + if len(callResult) > 0 { + if errVal, ok := callResult[len(callResult)-1].Interface().(error); ok { + result.Payload.Error = errVal.Error() + } + } + if len(callResult) > 1 { + result.Payload.Success = callResult[0].Interface() + } + + return result + }(context.TODO(), os.Getenv("LAMBDA_EVENT")) } return response } func Start(h interface{}) { + if h == nil { + fmt.Println("no handler found") + return + } + ml := MockLambda{api: api, token: token, request: request, sqs: sqs} response := ml.start(h) diff --git a/validation.go b/validation.go new file mode 100644 index 0000000..6448fcf --- /dev/null +++ b/validation.go @@ -0,0 +1,42 @@ +package lambda + +import ( + "context" + "fmt" + "reflect" +) + +func validateArguments(handler reflect.Type) (bool, error) { + handlerTakesContext := false + if handler.NumIn() > 2 { + return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn()) + } else if handler.NumIn() > 0 { + contextType := reflect.TypeOf((*context.Context)(nil)).Elem() + argumentType := handler.In(0) + handlerTakesContext = argumentType.Implements(contextType) + if handler.NumIn() > 1 && !handlerTakesContext { + return false, fmt.Errorf("handler takes two arguments, but the first is not Context. got %s", argumentType.Kind()) + } + } + + return handlerTakesContext, nil +} + +func validateReturns(handler reflect.Type) error { + errorType := reflect.TypeOf((*error)(nil)).Elem() + + switch n := handler.NumOut(); { + case n > 2: + return fmt.Errorf("handler may not return more than two values") + case n > 1: + if !handler.Out(1).Implements(errorType) { + return fmt.Errorf("handler returns two values, but the second does not implement error") + } + case n == 1: + if !handler.Out(0).Implements(errorType) { + return fmt.Errorf("handler returns a single value, but it does not implement error") + } + } + + return nil +}