diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 2b371ca..47620d4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,14 +13,15 @@ jobs: name: lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v3 + - uses: actions/setup-go@v6 with: go-version-file: .go-version - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v9 with: - version: latest + version: v2.9.0 + verify: false args: --config .golangci.yml --timeout 5m --verbose diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 35702d5..a85ee6e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -54,13 +54,13 @@ jobs: needs: [create_release_tag] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 with: fetch-depth: 0 ref: ${{ github.event.inputs.version }} - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v6 with: go-version-file: .go-version diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e2c2ad2..6f5b31b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,13 +19,14 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [1.18, 1.19, 1.19.2] + # go: [1.19, 1.22, 1.23, 1.24, 1.25] + go: [1.23, 1.24, 1.25] name: Go ${{ matrix.go }} tests steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v5 - name: Setup go - uses: actions/setup-go@v3 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go }} diff --git a/.gitignore b/.gitignore index c926f86..2298607 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /internal/thing /dist .DS_Store -/coverage/ \ No newline at end of file +/coverage/ +.claude diff --git a/.go-version b/.go-version index 8068c6e..7f6db7f 100644 --- a/.go-version +++ b/.go-version @@ -1 +1 @@ -1.19 \ No newline at end of file +1.25.5 \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index ba44d53..8d6ffdc 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,17 +1,15 @@ +version: "2" + linters: - disable-all: true + default: none enable: - errcheck - - gofmt - - goimports - gosec - - gosimple - govet - ineffassign - misspell - prealloc - staticcheck - - typecheck - unconvert - unused - asciicheck @@ -21,51 +19,55 @@ linters: - makezero - nonamedreturns - predeclared + settings: + depguard: + rules: + main: + deny: + - pkg: github.com/aws/aws-sdk-go/aws + desc: 'use v2 sdk instead' + gocritic: + disabled-checks: + - newDeref # it's wrong on generics + govet: + enable: + - shadow + exclusions: + rules: + - text: "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*\\.Exit|.*Flush|os\\.Remove(All)?|.*printf?|fmt\\.Fprintln|os\\.(Un)?Setenv|io\\.WriteString|io\\.Copy). is not checked" + linters: + - errcheck + - text: 'declaration of "err"' + linters: + - govet + # it's all fake + - text: "G101" # G101: potential hard coded creds + linters: + - gosec + # yea, it's a test proxy. we're way beyond security + - text: "G402" # TLS InsecureSkipVerify set true + linters: + - gosec + # it's used for a serial number + - text: "G404" # G404: Use of weak random number generator (math/rand instead of crypto/rand) + linters: + - gosec + - text: "QF1001" + linters: + - staticcheck + - text: "QF1003" + linters: + - staticcheck + paths: + - scripts -issues: - exclude: - - "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*\\.Exit|.*Flush|os\\.Remove(All)?|.*printf?|os\\.(Un)?Setenv|io\\.WriteString|io\\.Copy). is not checked" - - 'declaration of "err"' - exclude-rules: - - # it's all fake - - text: "G101" # G101: potential hard coded creds - linters: - - gosec - - # yea, it's a test proxy. we're way beyond security - - text: "G402" # TLS InsecureSkipVerify set true - linters: - - gosec - - # it's used for a serial number - - text: "G404" # G404: Use of weak random number generator (math/rand instead of crypto/rand) - linters: - - gosec - +formatters: + enable: + - gofmt + - goimports -# output configuration options output: - format: 'colored-line-number' - print-issued-lines: true - print-linter-name: true - -linters-settings: - depguard: - list-type: denylist - packages: - - github.com/aws/aws-sdk-go/aws - packages-with-error-message: - - github.com/aws/aws-sdk-go/aws: 'use v2 sdk instead' - - gocritic: - disabled-checks: - - newDeref # it's wrong on generics - - govet: - check-shadowing: true - # enable-all: true - -run: - skip-dirs: - - scripts \ No newline at end of file + formats: + text: + print-issued-lines: true + print-linter-name: true diff --git a/.revive.toml b/.revive.toml new file mode 100644 index 0000000..3420964 --- /dev/null +++ b/.revive.toml @@ -0,0 +1,29 @@ +ignoreGeneratedHeader = false +severity = "warning" +confidence = 0.8 +errorCode = 0 +warningCode = 0 + +[rule.blank-imports] +[rule.context-as-argument] +[rule.context-keys-type] +[rule.dot-imports] +[rule.empty-block] +[rule.error-naming] +[rule.error-return] +[rule.error-strings] +[rule.errorf] +# [rule.exported] +[rule.increment-decrement] +[rule.indent-error-flow] +[rule.package-comments] +[rule.range] +# [rule.receiver-naming] +[rule.redefines-builtin-id] +[rule.superfluous-else] +[rule.time-naming] +[rule.unexported-return] +[rule.unreachable-code] +[rule.unused-parameter] +[rule.var-declaration] +# [rule.var-naming] \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..6fa315e --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "go.testEnvVars": { + "AWSMOCKER_DEBUG": "1", + } +} \ No newline at end of file diff --git a/Makefile b/Makefile index 68c67ca..562d0b4 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,6 @@ generate: go generate ./... - .PHONY: tidy tidy: go mod verify @@ -15,7 +14,7 @@ tidy: .PHONY: lint-install lint-install: @echo "Installing golangci-lint" - go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.46.2 + go install github.com/golangci/golangci-lint/cmd/golangci-lint@v2.9.0 .PHONY: lint lint: @@ -25,6 +24,10 @@ lint: } @golangci-lint run && echo "All Good!" +.PHONY: outdated +outdated: + @go list -u -m -f '{{if not .Indirect}}{{if .Update}}{{.}}{{end}}{{end}}' all + .PHONY: test-release test-release: goreleaser release --skip-publish --rm-dist --snapshot diff --git a/README.md b/README.md index f0aecc1..031daef 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,11 @@ Easily create a proxy to allow easy testing of AWS API calls. **:warning: This is considered alpha quality right now. It might not work for all of AWS's APIs.** +> [!IMPORTANT] +> Version 1.0.0 has **BREAKING CHANGES**: +> * You must use the AWS config from `Config()` function returned by `Start()`. +> * The `Start` function has been modified to accept variable arguments for options setting + If you find problems, please create an Issue or make a PR. ## Installation @@ -16,30 +21,17 @@ go get -u github.com/webdestroya/awsmocker ``` ## Configuration -The default configuration when passing `nil` will setup a few mocks for STS. +The default configuration will setup a few mocks for STS. ```go -awsmocker.Start(t, nil) +m := awsmocker.Start(t) ``` For advanced usage and adding other mocks, you can use the following options: ```go -awsmocker.Start(t, &awsmocker.MockerOptions{ - // parameters -}) +m := awsmocker.Start(t, ...OPTIONS) ``` -| Option Key | Type | Description | -| ----------- | ---- | ------ | -| `Mocks` | `[]*MockedEndpoint` | A list of MockedEndpoints that will be matched against all incoming requests. | -| `Timeout` | `time.Duration` | If provided, then requests that run longer than this will be terminated. Generally you should not need to set this | -| `MockEc2Metadata` | `bool` | Set this to `true` and mocks for common EC2 IMDS endpoints will be added. These are not exhaustive, so if you have a special need you will have to add it. | -| `SkipDefaultMocks` | `bool` | Setting this to true will prevent mocks for STS being added. Note: any mocks you add will be evaluated before the default mocks, so this option is generally not necessary. | -| `ReturnAwsConfig` | `bool` | For many applications, the test suite will have the ability to pass a custom `aws.Config` value. If you have the ability to do this, you can bypass setting all the HTTP_PROXY environment variables. This makes your test cleaner. Setting this to true will add the `AwsConfig` value to the returned value of the Start call. | -| `DoNotProxy` | `string` | Optional list of hostname globs that will be added to the `NO_PROXY` environment variable. These hostnames will bypass the mocker. Use this if you are making actual HTTP requests elsewhere in your code that you want to allow through. | -| `DoNotFailUnhandledRequests` | `bool` | By default, if the mocker receives any request that does not have a matching mock, it will fail the test. This is usually desired as it prevents requests without error checking from allowing tests to pass. If you explicitly want a request to fail, you can define that. | -| `DoNotOverrideCreds` | `bool` | This will stop the test mocker from overriding the AWS environment variables with fake values. This means if you do not properly configure the mocker, you could end up making real requests to AWS. This is not recommended. | - ## Defining Mocks ```go @@ -47,6 +39,7 @@ awsmocker.Start(t, &awsmocker.MockerOptions{ Request: &awsmocker.MockedRequest{}, Response: &awsmocker.MockedResponse{}, } + ``` ### Mocking Requests @@ -92,9 +85,7 @@ awsmocker.Start(t, &awsmocker.MockerOptions{ ```go func TestSomethingThatCallsAws(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - // List out the mocks - Mocks: []*awsmocker.MockedEndpoint{ + m := awsmocker.Start(t, awsmocker.WithMocks( // Simple construction of a response awsmocker.NewSimpleMockedEndpoint("sts", "GetCallerIdentity", sts.GetCallerIdentityOutput{ Account: aws.String("123456789012"), @@ -103,7 +94,7 @@ func TestSomethingThatCallsAws(t *testing.T) { }), // advanced construction - { + &awsmocker.MockedEndpoint{ Request: &awsmocker.MockedRequest{ // specify the service/action to respond to Service: "ecs", @@ -111,8 +102,8 @@ func TestSomethingThatCallsAws(t *testing.T) { }, // provide the response to give Response: &awsmocker.MockedResponse{ - Body: map[string]interface{}{ - "services": []map[string]interface{}{ + Body: map[string]any{ + "services": []map[string]any{ { "serviceName": "someservice", }, @@ -120,12 +111,10 @@ func TestSomethingThatCallsAws(t *testing.T) { }, }, }, - }, - }) + ), + ) - cfg, _ := config.LoadDefaultConfig(context.TODO()) - - stsClient := sts.NewFromConfig(cfg) + stsClient := sts.NewFromConfig(m.Config()) resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) if err != nil { @@ -154,7 +143,7 @@ func Mock_Events_PutRule_Generic() *awsmocker.MockedEndpoint { name, _ := jmespath.Search("Name", rr.JsonPayload) - return util.Must(util.Jsonify(map[string]interface{}{ + return util.Must(util.Jsonify(map[string]any{ "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), })) }, @@ -177,5 +166,10 @@ To see the request/response traffic, you can use either of the following: * if you provide a response object, it will be encoded to JSON or XML based on the requesting content type. If you need a response in a special format, please provide the content type and a string for the body. * There is very little "error handling". If something goes wrong, it just panics. This might be less than ideal, but the only usecase for this library is within a test, which would make the test fail. This is the goal. +## Possible Issues + +**Receiving error: "not found, ResolveEndpointV2"**: +Upgrade aws modules: `go get -u github.com/aws/aws-sdk-go-v2/...` + ## See Also * Heavily influenced by [hashicorp's servicemocks](github.com/hashicorp/aws-sdk-go-base/v2/servicemocks) \ No newline at end of file diff --git a/awsconfig.go b/awsconfig.go index f223459..9179c39 100644 --- a/awsconfig.go +++ b/awsconfig.go @@ -1,39 +1,54 @@ package awsmocker import ( - "bytes" "context" "net/http" - "net/url" "time" "github.com/aws/aws-sdk-go-v2/aws" - awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + // awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" ) // If your application is setup to where you can provide an aws.Config object for your clients, // then using the one provided by this method will make testing much easier. -func (m *mocker) buildAwsConfig() aws.Config { - - httpClient := awshttp.NewBuildableClient().WithTimeout(10 * time.Second).WithTransportOptions(func(t *http.Transport) { - proxyUrl, _ := url.Parse(m.httpServer.URL) - t.Proxy = http.ProxyURL(proxyUrl) - - // remove the need for CA bundle? - // t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - }) - - cfg, err := config.LoadDefaultConfig(context.TODO(), - config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("XXfakekey", "XXfakesecret", "xxtoken")), - config.WithDefaultRegion(DefaultRegion), - config.WithHTTPClient(httpClient), - config.WithCustomCABundle(bytes.NewReader(caCert)), - config.WithRetryer(func() aws.Retryer { - return aws.NopRetryer{} - }), - ) +func (m *mocker) buildAwsConfig(opts ...AwsLoadOptionsFunc) aws.Config { + + // httpClient := awshttp.NewBuildableClient().WithTimeout(10 * time.Second).WithTransportOptions(func(t *http.Transport) { + // proxyUrl, _ := url.Parse(m.httpServer.URL) + // t.Proxy = http.ProxyURL(proxyUrl) + + // // remove the need for CA bundle? + // // t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + + // }) + // _ = httpClient + + c := &http.Client{ + Transport: m, + Timeout: 2 * time.Second, + } + + options := make([]AwsLoadOptionsFunc, 0, 15) + + options = append(options, config.WithDisableRequestCompression(aws.Bool(true))) + options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("XXfakekey", "XXfakesecret", "xxtoken"))) + options = append(options, config.WithDefaultRegion(DefaultRegion)) + // options = append(options, config.WithHTTPClient(httpClient)) + // options = append(options, config.WithHTTPClient(m)) + options = append(options, config.WithHTTPClient(c)) + // options = append(options, config.WithCustomCABundle(bytes.NewReader(caCert))) + options = append(options, config.WithRetryer(func() aws.Retryer { + return aws.NopRetryer{} + })) + + // apply the options the user wanted + options = append(options, opts...) + + options = append(options, addMiddlewareConfigOption(m)) + + cfg, err := config.LoadDefaultConfig(context.TODO(), options...) if err != nil { panic(err) } diff --git a/awsconfig_test.go b/awsconfig_test.go index 9890b79..03708f6 100644 --- a/awsconfig_test.go +++ b/awsconfig_test.go @@ -10,10 +10,8 @@ import ( ) func TestAwsConfigBuilder(t *testing.T) { - info := awsmocker.Start(t, &awsmocker.MockerOptions{ - ReturnAwsConfig: true, - }) - stsClient := sts.NewFromConfig(*info.AwsConfig) + info := awsmocker.Start(t) + stsClient := sts.NewFromConfig(info.Config()) resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) require.NoError(t, err) diff --git a/certs.go b/certs.go index c7fbc55..2554ae5 100644 --- a/certs.go +++ b/certs.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "crypto/x509" _ "embed" - "os" ) //go:embed cacert.pem @@ -37,9 +36,10 @@ func CACertPEM() []byte { // Returns the parsed X509 Certificate func CACert() *x509.Certificate { + // sync.OnceValue() return caKeyPair.Leaf } -func writeCABundle(filePath string) error { - return os.WriteFile(filePath, caCert, 0o600) -} +// func writeCABundle(filePath string) error { +// return os.WriteFile(filePath, caCert, 0o600) +// } diff --git a/certstore.go b/certstore.go index 41fcfb4..48fa8b9 100644 --- a/certstore.go +++ b/certstore.go @@ -14,13 +14,13 @@ import ( ) var ( - globalCertStore *CertStorage + globalCertStore *certStorage leafCertStart = time.Unix(time.Now().Unix()-2592000, 0) // 2592000 = 30 day leafCertEnd = time.Unix(time.Now().Unix()+31536000, 0) ) -type CertStorage struct { +type certStorage struct { certs sync.Map mu sync.Mutex @@ -29,7 +29,7 @@ type CertStorage struct { privateKey *rsa.PrivateKey } -func (tcs *CertStorage) Fetch(hostname string) *tls.Certificate { +func (tcs *certStorage) Fetch(hostname string) *tls.Certificate { icert, ok := tcs.certs.Load(hostname) if ok { @@ -40,7 +40,7 @@ func (tcs *CertStorage) Fetch(hostname string) *tls.Certificate { return tcs.generateCert(hostname) } -func (tcs *CertStorage) generateCert(hostname string) *tls.Certificate { +func (tcs *certStorage) generateCert(hostname string) *tls.Certificate { tcs.mu.Lock() defer tcs.mu.Unlock() @@ -92,7 +92,7 @@ func init() { privKey, _ := rsa.GenerateKey(rand.Reader, 2048) startSerial, _ := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(2, 40)))) - globalCertStore = &CertStorage{ + globalCertStore = &certStorage{ certs: sync.Map{}, privateKey: privKey, nextSerial: startSerial.Int64(), diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..1c82865 --- /dev/null +++ b/doc.go @@ -0,0 +1,48 @@ +// Package awsmocker allows easier mocking of AWS API responses. +// +// # Example Usage +// +// The following is a complete example using awsmocker in an example test: +// +// import ( +// "testing" +// "context" +// +// "github.com/aws/aws-sdk-go-v2/aws" +// "github.com/aws/aws-sdk-go-v2/config" +// "github.com/aws/aws-sdk-go-v2/service/ecs" +// "github.com/webdestroya/awsmocker" +// ) +// +// +// func TestEcsDescribeServices(t *testing.T) { +// m := awsmocker.Start(t, awsmocker.WithMocks(&awsmocker.MockedEndpoint{ +// Request: &awsmocker.MockedRequest{ +// Service: "ecs", +// Action: "DescribeServices", +// }, +// Response: &awsmocker.MockedResponse{ +// Body: map[string]any{ +// "services": []map[string]any{ +// { +// "serviceName": "someservice", +// }, +// }, +// }, +// }, +// })) +// +// client := ecs.NewFromConfig(m.Config()) +// +// resp, err := client.DescribeServices(context.TODO(), &ecs.DescribeServicesInput{ +// Services: []string{"someservice"}, +// Cluster: aws.String("testcluster"), +// }) +// if err != nil { +// t.Errorf(err) +// } +// if *resp.Services[0].ServiceName != "someservice" { +// t.Errorf("Service name was wrong") +// } +// } +package awsmocker diff --git a/error.go b/error_response.go similarity index 100% rename from error.go rename to error_response.go diff --git a/error_test.go b/error_response_test.go similarity index 100% rename from error_test.go rename to error_response_test.go diff --git a/go.mod b/go.mod index 2114d49..0a39365 100644 --- a/go.mod +++ b/go.mod @@ -1,31 +1,34 @@ module github.com/webdestroya/awsmocker -go 1.19 +go 1.23.0 + +toolchain go1.25.5 require ( - github.com/aws/aws-sdk-go-v2 v1.17.1 - github.com/aws/aws-sdk-go-v2/config v1.17.10 - github.com/aws/aws-sdk-go-v2/credentials v1.12.23 - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.19 - github.com/aws/aws-sdk-go-v2/service/ecs v1.18.26 - github.com/aws/aws-sdk-go-v2/service/eventbridge v1.16.17 - github.com/aws/aws-sdk-go-v2/service/sts v1.17.1 + github.com/aws/aws-sdk-go-v2 v1.36.3 + github.com/aws/aws-sdk-go-v2/config v1.29.14 + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.212.0 + github.com/aws/aws-sdk-go-v2/service/ecs v1.55.0 + github.com/aws/aws-sdk-go-v2/service/eventbridge v1.39.0 + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 + github.com/aws/smithy-go v1.22.3 github.com/clbanning/mxj v1.8.4 - github.com/google/uuid v1.3.0 + github.com/google/uuid v1.6.0 github.com/jmespath/go-jmespath v0.4.0 - github.com/stretchr/testify v1.8.1 - golang.org/x/exp v0.0.0-20221106115401-f9659909a136 + github.com/stretchr/testify v1.10.0 ) require ( - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.25 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.19 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.3.26 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.16 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.19 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.11.25 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.8 // indirect - github.com/aws/smithy-go v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index fa602a8..c33c991 100644 --- a/go.sum +++ b/go.sum @@ -1,42 +1,44 @@ -github.com/aws/aws-sdk-go-v2 v1.17.1 h1:02c72fDJr87N8RAC2s3Qu0YuvMRZKNZJ9F+lAehCazk= -github.com/aws/aws-sdk-go-v2 v1.17.1/go.mod h1:JLnGeGONAyi2lWXI1p0PCIOIy333JMVK1U7Hf0aRFLw= -github.com/aws/aws-sdk-go-v2/config v1.17.10 h1:zBy5QQ/mkvHElM1rygHPAzuH+sl8nsdSaxSWj0+rpdE= -github.com/aws/aws-sdk-go-v2/config v1.17.10/go.mod h1:/4np+UiJJKpWHN7Q+LZvqXYgyjgeXm5+lLfDI6TPZao= -github.com/aws/aws-sdk-go-v2/credentials v1.12.23 h1:LctvcJMIb8pxvk5hQhChpCu0WlU6oKQmcYb1HA4IZSA= -github.com/aws/aws-sdk-go-v2/credentials v1.12.23/go.mod h1:0awX9iRr/+UO7OwRQFpV1hNtXxOVuehpjVEzrIAYNcA= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.19 h1:E3PXZSI3F2bzyj6XxUXdTIfvp425HHhwKsFvmzBwHgs= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.19/go.mod h1:VihW95zQpeKQWVPGkwT+2+WJNQV8UXFfMTWdU6VErL8= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.25 h1:nBO/RFxeq/IS5G9Of+ZrgucRciie2qpLy++3UGZ+q2E= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.25/go.mod h1:Zb29PYkf42vVYQY6pvSyJCJcFHlPIiY+YKdPtwnvMkY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.19 h1:oRHDrwCTVT8ZXi4sr9Ld+EXk7N/KGssOr2ygNeojEhw= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.19/go.mod h1:6Q0546uHDp421okhmmGfbxzq2hBqbXFNpi4k+Q1JnQA= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.26 h1:Mza+vlnZr+fPKFKRq/lKGVvM6B/8ZZmNdEopOwSQLms= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.26/go.mod h1:Y2OJ+P+MC1u1VKnavT+PshiEuGPyh/7DqxoDNij4/bg= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.16 h1:2EXB7dtGwRYIN3XQ9qwIW504DVbKIw3r89xQnonGdsQ= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.16/go.mod h1:XH+3h395e3WVdd6T2Z3mPxuI+x/HVtdqVOREkTiyubs= -github.com/aws/aws-sdk-go-v2/service/ecs v1.18.26 h1:EHJAYkUnlFJ/KwuFMvUs/bPbb0DaqAI+gTfXxffTPZ0= -github.com/aws/aws-sdk-go-v2/service/ecs v1.18.26/go.mod h1:NpR78BP2STxvF/R1GXLDM4gAEfjz68W/h0nC5b6Jk3s= -github.com/aws/aws-sdk-go-v2/service/eventbridge v1.16.17 h1:MSUSEjlL0+WOhFzYmDp7S2M09AzVC3bjLQke6+yc54g= -github.com/aws/aws-sdk-go-v2/service/eventbridge v1.16.17/go.mod h1:8g5GmQrg6Q44ap2NIxBb6eCZojS70QhJiv0qsgHVSKo= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.19 h1:GE25AWCdNUPh9AOJzI9KIJnja7IwUc1WyUqz/JTyJ/I= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.19/go.mod h1:02CP6iuYP+IVnBX5HULVdSAku/85eHB2Y9EsFhrkEwU= -github.com/aws/aws-sdk-go-v2/service/sso v1.11.25 h1:GFZitO48N/7EsFDt8fMa5iYdmWqkUDDB3Eje6z3kbG0= -github.com/aws/aws-sdk-go-v2/service/sso v1.11.25/go.mod h1:IARHuzTXmj1C0KS35vboR0FeJ89OkEy1M9mWbK2ifCI= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.8 h1:jcw6kKZrtNfBPJkaHrscDOZoe5gvi9wjudnxvozYFJo= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.8/go.mod h1:er2JHN+kBY6FcMfcBBKNGCT3CarImmdFzishsqBmSRI= -github.com/aws/aws-sdk-go-v2/service/sts v1.17.1 h1:KRAix/KHvjGODaHAMXnxRk9t0D+4IJVUuS/uwXxngXk= -github.com/aws/aws-sdk-go-v2/service/sts v1.17.1/go.mod h1:bXcN3koeVYiJcdDU89n3kCYILob7Y34AeLopUbZgLT4= -github.com/aws/smithy-go v1.13.4 h1:/RN2z1txIJWeXeOkzX+Hk/4Uuvv7dWtCjbmVJcrskyk= -github.com/aws/smithy-go v1.13.4/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 h1:ZNTqv4nIdE/DiBfUUfXcLZ/Spcuz+RjeziUtNJackkM= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34/go.mod h1:zf7Vcd1ViW7cPqYWEHLHJkS50X0JS2IKz9Cgaj6ugrs= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.212.0 h1:z5thR/zKUlw7gd1OT59xBHm4AKBf2kPXKHFvVzLMfBk= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.212.0/go.mod h1:ouvGEfHbLaIlWwpDpOVWPWR+YwO0HDv3vm5tYLq8ImY= +github.com/aws/aws-sdk-go-v2/service/ecs v1.55.0 h1:7rmrcEBkAK22a8VYfxJ+LeBlHMiYYbnXSGRTEQ20OzE= +github.com/aws/aws-sdk-go-v2/service/ecs v1.55.0/go.mod h1:wAtdeFanDuF9Re/ge4DRDaYe3Wy1OGrU7jG042UcuI4= +github.com/aws/aws-sdk-go-v2/service/eventbridge v1.39.0 h1:XfMLLbZdz57JwIuETa789jOgqeEemR9gzam7x37HGS4= +github.com/aws/aws-sdk-go-v2/service/eventbridge v1.39.0/go.mod h1:QiEUHcyXhCdsTzHAbfmgwlFEmW3WgfqL4L1bS+E9IlA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= +github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/clbanning/mxj v1.8.4 h1:HuhwZtbyvyOw+3Z1AowPkU87JkJUSv751ELWaiTpj8I= github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -44,18 +46,11 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/exp v0.0.0-20221106115401-f9659909a136 h1:Fq7F/w7MAa1KJ5bt2aJ62ihqp9HDcRuyILskkpIAurw= -golang.org/x/exp v0.0.0-20221106115401-f9659909a136/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hacks.go b/hacks.go deleted file mode 100644 index a065bc7..0000000 --- a/hacks.go +++ /dev/null @@ -1,17 +0,0 @@ -package awsmocker - -import ( - _ "unsafe" -) - -// GO actually caches proxy env vars which totally breaks our test flow -// so this hacks in a call to Go's internal method... This is pretty janky - -//go:linkname resetProxyConfig net/http.resetProxyConfig -func resetProxyConfig() - -// Force call it just to make sure it works -// if Go updates this, this will make it very obvious -func init() { - resetProxyConfig() -} diff --git a/http.go b/http.go index 5595243..e7db4c0 100644 --- a/http.go +++ b/http.go @@ -19,6 +19,6 @@ func (m *mocker) handleHttp(w http.ResponseWriter, r *http.Request) { } } -var handleNonProxyRequest = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +var handleNonProxyRequest = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "AWSMocker is meant to be used as a proxy server. Don't send requests directly to it.", http.StatusNotImplemented) }) diff --git a/http_test.go b/http_test.go index e327c14..bc2ad40 100644 --- a/http_test.go +++ b/http_test.go @@ -2,8 +2,6 @@ package awsmocker_test import ( "net/http" - "net/url" - "os" "testing" "github.com/stretchr/testify/require" @@ -11,13 +9,11 @@ import ( ) func TestProxyHttp(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - DoNotFailUnhandledRequests: true, - }) + info := awsmocker.Start(t, awsmocker.WithoutFailingUnhandledRequests()) - transport := http.Transport{} - proxyUrl, _ := url.Parse(os.Getenv("HTTP_PROXY")) - transport.Proxy = http.ProxyURL(proxyUrl) // set proxy + transport := http.Transport{ + Proxy: info.Proxy(), + } // transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //set ssl client := &http.Client{ diff --git a/imds_mocks.go b/imds_mocks.go index 928d1bc..c90392a 100644 --- a/imds_mocks.go +++ b/imds_mocks.go @@ -23,6 +23,8 @@ type IMDSMockOptions struct { InstanceProfileName string } +type IMDSMockOptionFunc = func(*IMDSMockOptions) + func getDefaultImdsIdentityDocument() imds.InstanceIdentityDocument { return imds.InstanceIdentityDocument{ Version: "2017-09-30", @@ -36,7 +38,7 @@ func getDefaultImdsIdentityDocument() imds.InstanceIdentityDocument { // Provides an array of mocks that will provide a decent replication of the // EC2 Instance Metadata Service -func Mock_IMDS_Common(optFns ...func(*IMDSMockOptions)) []*MockedEndpoint { +func Mock_IMDS_Common(optFns ...IMDSMockOptionFunc) []*MockedEndpoint { cfg := IMDSMockOptions{ IdentityDocument: getDefaultImdsIdentityDocument(), @@ -143,7 +145,7 @@ func Mock_IMDS_IAM_Info(profileName string) *MockedEndpoint { }, Response: &MockedResponse{ Encoding: ResponseEncodingJSON, - Body: map[string]interface{}{ + Body: map[string]any{ "Code": "Success", "LastUpdated": time.Now().UTC().Format(time.RFC3339), "InstanceProfileArn": fmt.Sprintf("arn:aws:iam::%s:instance-profile/%s", DefaultAccountId, profileName), @@ -175,7 +177,7 @@ func Mock_IMDS_IAM_Credentials(roleName string) *MockedEndpoint { }, Response: &MockedResponse{ Encoding: ResponseEncodingJSON, - Body: map[string]interface{}{ + Body: map[string]any{ "Code": "Success", "Type": "AWS-HMAC", "LastUpdated": time.Now().UTC().Format(time.RFC3339), diff --git a/imds_test.go b/imds_test.go index ee5b46f..b4380f0 100644 --- a/imds_test.go +++ b/imds_test.go @@ -12,11 +12,9 @@ import ( ) func TestEc2IMDS(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - MockEc2Metadata: true, - }) + m := awsmocker.Start(t, awsmocker.WithEC2Metadata()) - client := imds.NewFromConfig(testutil.GetAwsConfig()) + client := imds.NewFromConfig(m.Config()) ctx := context.TODO() @@ -42,7 +40,9 @@ func TestEc2IMDS(t *testing.T) { }) t.Run("iam creds", func(t *testing.T) { - provider := ec2rolecreds.New() + provider := ec2rolecreds.New(func(o *ec2rolecreds.Options) { + o.Client = m.IMDSClient() + }) creds, err := provider.Retrieve(ctx) require.NoError(t, err) diff --git a/internal/certgen/main.go b/internal/certgen/main.go index 74ac7b4..02ffbf2 100644 --- a/internal/certgen/main.go +++ b/internal/certgen/main.go @@ -30,10 +30,10 @@ func main() { CommonName: "AWSMocker Root CA", Country: []string{"US"}, Organization: []string{"webdestroya"}, - OrganizationalUnit: []string{"aws-mocker"}, + OrganizationalUnit: []string{"awsmocker"}, }, NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), + NotAfter: time.Now().AddDate(50, 0, 0), IsCA: true, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, diff --git a/internal/testutil/util.go b/internal/testutil/util.go index b362765..843e93c 100644 --- a/internal/testutil/util.go +++ b/internal/testutil/util.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" ) +// Deprecated: DONT USE THIS func GetAwsConfig() aws.Config { cfg, err := config.LoadDefaultConfig(context.TODO(), // add creds just in case something happens diff --git a/jmespath_test.go b/jmespath_test.go index 4ce2c31..5b21806 100644 --- a/jmespath_test.go +++ b/jmespath_test.go @@ -10,14 +10,14 @@ import ( func TestJmesPathMatching(t *testing.T) { var j = []byte(`{"foo": {"bar": {"baz": [0, 1, 2, 3, 4, 4.5, true, false, null, "hello"]}}}`) - var d interface{} + var d any err := json.Unmarshal(j, &d) require.NoError(t, err) t.Run("test generic expressions", func(t *testing.T) { tables := []struct { expr string - val interface{} + val any }{ {"foo.bar.baz[99]", nil}, {"foo.bar.baz[2]", int(2)}, @@ -45,7 +45,7 @@ func TestJmesPathMatching(t *testing.T) { t.Run("bad expected values", func(t *testing.T) { tables := []struct { expr string - val interface{} + val any }{ {"some goofy expression", []int{2}}, diff --git a/jmesutil.go b/jmesutil.go index e71c3dd..417d180 100644 --- a/jmesutil.go +++ b/jmesutil.go @@ -27,7 +27,7 @@ func jmesValueNormalize(value any) any { // return v // } switch v := value.(type) { - case string, bool, nil, float64: + case string, bool, float64: return v case int: return float64(v) @@ -68,7 +68,7 @@ func JMESMatch(obj any, expression string, expected any) bool { resp, err := jmespath.Search(expression, obj) if err != nil { - panic(fmt.Errorf("Failed to parse expression: '%s': %w", expression, err)) + panic(fmt.Errorf("failed to parse expression: '%s': %w", expression, err)) } exp := jmesValueNormalize(expected) diff --git a/main.go b/main.go index 8144e25..7ece4ad 100644 --- a/main.go +++ b/main.go @@ -1,57 +1,21 @@ package awsmocker -import ( - "os" - "time" - - "github.com/aws/aws-sdk-go-v2/aws" -) - -// Returned when you start the server, provides you some information if needed -type MockerInfo struct { - // URL of the proxy server - ProxyURL string - - // Aws configuration to use - // This is only provided if you gave ReturnAwsConfig in the options - AwsConfig *aws.Config -} - -func Start(t TestingT, options *MockerOptions) *MockerInfo { +// Start the mocker +func Start(t TestingT, optFns ...MockerOptionFunc) MockerInfo { if h, ok := t.(tHelper); ok { h.Helper() } - if options == nil { - options = &MockerOptions{} - } - - if options.Timeout == 0 { - options.Timeout = 5 * time.Second - } + options := newOptions() - if !options.SkipDefaultMocks { - options.Mocks = append(options.Mocks, MockStsGetCallerIdentityValid) - } + for _, optFn := range optFns { - if options.MockEc2Metadata { - options.Mocks = append(options.Mocks, Mock_IMDS_Common()...) - } - - // proxy bypass configuration - if options.DoNotProxy != "" { - noProxyStr := os.Getenv("NO_PROXY") - if noProxyStr == "" { - noProxyStr = os.Getenv("no_proxy") - } - if noProxyStr != "" { - noProxyStr += "," + // makes transitioning somewhat easier + if optFn == nil { + continue } - noProxyStr += options.DoNotProxy - - t.Setenv("NO_PROXY", noProxyStr) - t.Setenv("no_proxy", noProxyStr) + optFn(options) } mocks := make([]*MockedEndpoint, 0, len(options.Mocks)) @@ -69,19 +33,13 @@ func Start(t TestingT, options *MockerOptions) *MockerInfo { debugTraffic: getDebugMode(), // options.DebugTraffic, doNotOverrideCreds: options.DoNotOverrideCreds, doNotFailUnhandled: options.DoNotFailUnhandledRequests, + noMiddleware: options.noMiddleware, mocks: mocks, - usingAwsConfig: options.ReturnAwsConfig, + // usingAwsConfig: true, } server.Start() - info := &MockerInfo{ - ProxyURL: server.httpServer.URL, - } - - if options.ReturnAwsConfig { - cfg := server.buildAwsConfig() - info.AwsConfig = &cfg - } + server.awsConfig = server.buildAwsConfig(options.AwsConfigOptions...) - return info + return server } diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..42efdc9 --- /dev/null +++ b/middleware.go @@ -0,0 +1,67 @@ +package awsmocker + +import ( + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +const ( + mwHeaderService = `X-Awsmocker-Service` + mwHeaderOperation = `X-Awsmocker-Operation` + mwHeaderParamType = `X-Awsmocker-Param-Type` + mwHeaderRequestId = `X-Awsmocker-Request-Id` + mwHeaderError = `X-Awsmocker-Error` + mwHeaderUseDB = `X-Awsmocker-Use-Db` +) + +type ( + mwCtxKeyReqId struct{} + mwCtxKeyParams struct{} + mwCtxKeyUseDB struct{} //nolint:unused +) + +type ( + mwRequest = smithyhttp.Request //nolint:unused + mwResponse = smithyhttp.Response +) + +type mockerMiddleware struct { + mocker *mocker +} + +func (mockerMiddleware) ID() string { + return "awsmocker" +} + +func addMiddlewareConfigOption(m *mocker) AwsLoadOptionsFunc { + + mockerMW := &mockerMiddleware{ + mocker: m, + } + + return config.WithAPIOptions([]func(*middleware.Stack) error{ + func(stack *middleware.Stack) error { + if err := stack.Initialize.Add(mockerMW, middleware.After); err != nil { + return err + } + + if err := stack.Serialize.Add(mockerMW, middleware.After); err != nil { + return err + } + + if err := stack.Build.Add(mockerMW, middleware.After); err != nil { + return err + } + + if err := stack.Deserialize.Add(mockerMW, middleware.Before); err != nil { + return err + } + + if err := stack.Finalize.Add(mockerMW, middleware.After); err != nil { + return err + } + return nil + }, + }) +} diff --git a/middleware_build.go b/middleware_build.go new file mode 100644 index 0000000..72416f3 --- /dev/null +++ b/middleware_build.go @@ -0,0 +1,13 @@ +package awsmocker + +import ( + "context" + + "github.com/aws/smithy-go/middleware" +) + +var _ middleware.BuildMiddleware = (*mockerMiddleware)(nil) + +func (mockerMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (middleware.BuildOutput, middleware.Metadata, error) { + return next.HandleBuild(ctx, in) +} diff --git a/middleware_dbentry.go b/middleware_dbentry.go new file mode 100644 index 0000000..6efdf3f --- /dev/null +++ b/middleware_dbentry.go @@ -0,0 +1,7 @@ +package awsmocker + +type mwDBEntry struct { + Parameters any + Response any + Error error +} diff --git a/middleware_deserialize.go b/middleware_deserialize.go new file mode 100644 index 0000000..0a98142 --- /dev/null +++ b/middleware_deserialize.go @@ -0,0 +1,34 @@ +package awsmocker + +import ( + "context" + "errors" + + "github.com/aws/smithy-go/middleware" +) + +var _ middleware.DeserializeMiddleware = (*mockerMiddleware)(nil) + +func (m *mockerMiddleware) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (middleware.DeserializeOutput, middleware.Metadata, error) { + out, meta, err := next.HandleDeserialize(ctx, in) + + if resp, ok := out.RawResponse.(*mwResponse); ok { + if resp.Header.Get(mwHeaderUseDB) == "true" { + if reqId, ok := middleware.GetStackValue(ctx, mwCtxKeyReqId{}).(uint64); ok { + + res, _ := m.mocker.requestLog.LoadAndDelete(reqId) + entry := res.(mwDBEntry) + + return middleware.DeserializeOutput{ + RawResponse: resp, + Result: entry.Response, + }, meta, entry.Error + + } else { + return out, meta, errors.New("invalid mocker result?") + } + } + } + + return out, meta, err +} diff --git a/middleware_finalize.go b/middleware_finalize.go new file mode 100644 index 0000000..4255b3c --- /dev/null +++ b/middleware_finalize.go @@ -0,0 +1,33 @@ +package awsmocker + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/aws/smithy-go/middleware" + "github.com/aws/smithy-go/transport/http" +) + +var _ middleware.FinalizeMiddleware = (*mockerMiddleware)(nil) + +func (mockerMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (middleware.FinalizeOutput, middleware.Metadata, error) { + + if req, ok := in.Request.(*http.Request); ok { + req.Header.Add(mwHeaderService, strings.ToLower(middleware.GetServiceID(ctx))) + req.Header.Add(mwHeaderOperation, middleware.GetOperationName(ctx)) + + if params := middleware.GetStackValue(ctx, mwCtxKeyParams{}); ok { + req.Header.Add(mwHeaderParamType, fmt.Sprintf("%T", params)) + } + + if reqId, ok := middleware.GetStackValue(ctx, mwCtxKeyReqId{}).(uint64); ok { + req.Header.Add(mwHeaderRequestId, strconv.FormatUint(reqId, 10)) + } + + in.Request = req + } + + return next.HandleFinalize(ctx, in) +} diff --git a/middleware_initialize.go b/middleware_initialize.go new file mode 100644 index 0000000..3b0b5d5 --- /dev/null +++ b/middleware_initialize.go @@ -0,0 +1,39 @@ +package awsmocker + +import ( + "context" + + "github.com/aws/smithy-go/middleware" +) + +var _ middleware.InitializeMiddleware = (*mockerMiddleware)(nil) + +func (m *mockerMiddleware) HandleInitialize(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) (middleware.InitializeOutput, middleware.Metadata, error) { + + // if _, ok := in.Parameters.(*ec2.DescribeSubnetsInput); ok { + // out := middleware.InitializeOutput{ + // Result: &ec2.DescribeSubnetsOutput{ + // Subnets: []ec2Types.Subnet{ + // { + // SubnetId: aws.String("subnet-aaaaaaaa"), + // VpcId: aws.String("vpc-11111111"), + // }, + // }, + // }, + // } + // return out, middleware.Metadata{}, nil + // } + + reqId := m.mocker.mwReqCounter.Add(1) + + ctx = middleware.WithStackValue(ctx, mwCtxKeyReqId{}, reqId) + ctx = middleware.WithStackValue(ctx, mwCtxKeyParams{}, in.Parameters) + + m.mocker.requestLog.Store(reqId, mwDBEntry{ + Parameters: in.Parameters, + }) + + defer m.mocker.requestLog.Delete(reqId) + + return next.HandleInitialize(ctx, in) +} diff --git a/middleware_serialize.go b/middleware_serialize.go new file mode 100644 index 0000000..6a05ad6 --- /dev/null +++ b/middleware_serialize.go @@ -0,0 +1,13 @@ +package awsmocker + +import ( + "context" + + "github.com/aws/smithy-go/middleware" +) + +var _ middleware.SerializeMiddleware = (*mockerMiddleware)(nil) + +func (mockerMiddleware) HandleSerialize(ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler) (middleware.SerializeOutput, middleware.Metadata, error) { + return next.HandleSerialize(ctx, in) +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..dfa50ee --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,90 @@ +package awsmocker + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2Types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/stretchr/testify/require" +) + +func TestSerializeMiddleware(t *testing.T) { + + m := Start(t, + WithMocks(&MockedEndpoint{ + Request: &MockedRequest{ + Service: "ecs", + Action: "DescribeServices", + }, + Response: &MockedResponse{ + Body: map[string]any{ + "services": []map[string]any{ + { + "serviceName": "someservice", + }, + }, + }, + }, + })) + + client := ecs.NewFromConfig(m.Config()) + _, _ = sts.NewFromConfig(m.Config()).GetCallerIdentity(context.TODO(), nil) + + resp, err := client.DescribeServices(context.TODO(), &ecs.DescribeServicesInput{ + Services: []string{"someservice"}, + Cluster: aws.String("testcluster"), + }) + require.NoError(t, err) + require.Equalf(t, "someservice", *resp.Services[0].ServiceName, "Service name was wrong") +} + +func TestEC2Middleware(t *testing.T) { + + m := Start(t, + WithMocks(&MockedEndpoint{ + Request: &MockedRequest{ + Service: "ec2", + Action: "DescribeSubnets", + }, + // Response: &MockedResponse{ + // DoNotWrap: true, + // Body: map[string]any{ + // "requestId": "43e9cb52-0e10-40fe-b457-988c8fbfea26", + // "subnetSet": map[string]any{ + // "item": []any{ + // map[string]any{ + // "subnetId": "subnet-633333333333", + // "vpcId": "vpc-123456789", + // }, + // map[string]any{ + // "subnetId": "subnet-644444444444", + // "vpcId": "vpc-123456789", + // }, + // }, + // }, + // }, + // }, + Response: &MockedResponse{ + Body: &ec2.DescribeSubnetsOutput{ + Subnets: []ec2Types.Subnet{ + { + VpcId: aws.String("vpc-123456789"), + }, + }, + }, + }, + })) + + client := ec2.NewFromConfig(m.Config()) + + resp, err := client.DescribeSubnets(context.TODO(), &ec2.DescribeSubnetsInput{}) + require.NoError(t, err) + require.NotNil(t, resp) + + require.GreaterOrEqual(t, len(resp.Subnets), 1) + require.Equal(t, "vpc-123456789", *resp.Subnets[0].VpcId) +} diff --git a/mock.go b/mock.go new file mode 100644 index 0000000..0b528fc --- /dev/null +++ b/mock.go @@ -0,0 +1 @@ +package awsmocker diff --git a/mocked_endpoint.go b/mocked_endpoint.go index 786b257..eace56f 100644 --- a/mocked_endpoint.go +++ b/mocked_endpoint.go @@ -21,6 +21,7 @@ func (m *MockedEndpoint) getResponse(rr *ReceivedRequest) *httpResponse { return m.Response.getResponse(rr) } +// Generates a simple [MockedEndpoint] for the Service:Action func NewSimpleMockedEndpoint(service, action string, responseObj any) *MockedEndpoint { return &MockedEndpoint{ Request: &MockedRequest{ diff --git a/mocked_endpoint_test.go b/mocked_endpoint_test.go new file mode 100644 index 0000000..2a04853 --- /dev/null +++ b/mocked_endpoint_test.go @@ -0,0 +1,19 @@ +package awsmocker_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/webdestroya/awsmocker" +) + +func TestNewSimpleMockedEndpoint(t *testing.T) { + me := awsmocker.NewSimpleMockedEndpoint( + "sts", + "GetCallerIdentity", + "SimpleBody", + ) + require.Equal(t, "sts", me.Request.Service) + require.Equal(t, "GetCallerIdentity", me.Request.Action) + require.Equal(t, "SimpleBody", me.Response.Body) +} diff --git a/mocked_request.go b/mocked_request.go index b6f2d20..9ba364f 100644 --- a/mocked_request.go +++ b/mocked_request.go @@ -7,8 +7,8 @@ import ( "strings" "sync" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" + "maps" + "slices" ) // Describes a request that should be matched @@ -173,7 +173,7 @@ func (m *MockedRequest) matchRequestLazy(rr *ReceivedRequest) bool { return false } - if m.JMESPathMatches != nil && len(m.JMESPathMatches) > 0 { + if len(m.JMESPathMatches) > 0 { if ret := m.matchJmespath(rr); !ret { return false } @@ -183,9 +183,9 @@ func (m *MockedRequest) matchRequestLazy(rr *ReceivedRequest) bool { return false } - if m.Params != nil && len(m.Params) > 0 { + if len(m.Params) > 0 { // if the request has no params, it cant match something with params... - if rr.HttpRequest.Form == nil || len(rr.HttpRequest.Form) == 0 { + if len(rr.HttpRequest.Form) == 0 { return false } @@ -206,7 +206,7 @@ func (m *MockedRequest) matchRequestLazy(rr *ReceivedRequest) bool { func (m *MockedRequest) matchJmespath(rr *ReceivedRequest) bool { // just bail out if there is nothing to match - if m.JMESPathMatches == nil || len(m.JMESPathMatches) == 0 { + if len(m.JMESPathMatches) == 0 { return true } diff --git a/mocked_response.go b/mocked_response.go index cddc8ae..2a1b7e1 100644 --- a/mocked_response.go +++ b/mocked_response.go @@ -4,8 +4,10 @@ import ( "encoding/xml" "net/http" "reflect" + "strconv" "strings" + "github.com/aws/smithy-go/document" "github.com/clbanning/mxj" ) @@ -17,8 +19,13 @@ const ( var ( byteArrayType = reflect.SliceOf(reflect.TypeOf((*byte)(nil)).Elem()) + rrType = reflect.TypeFor[*ReceivedRequest]() + errType = reflect.TypeFor[error]() ) +// type used to generate +type directTypeFunc = func(*ReceivedRequest) (any, error) + type MockedResponse struct { // modify the status code. default is 200 StatusCode int @@ -34,6 +41,9 @@ type MockedResponse struct { // func(*ReceivedRequest) (string) = string payload (with 200 OK, inferred content type) // func(*ReceivedRequest) (string, int) = string payload, status code (with inferred content type) // func(*ReceivedRequest) (string, int, string) = string payload, status code, content type + // func(*ReceivedRequest) (*service.ACTIONOutput, error) = return the result type directly, or error + // func(*ReceivedRequest, *service.ACTIONInput) (*service.ACTIONOutput, error) = return the result type directly, or error + // func(*service.ACTIONInput) (*service.ACTIONOutput, error) = return the result type directly, or error Body any // Do not wrap the xml response in ACTIONResponse>ACTIONResult @@ -78,6 +88,10 @@ func (m *MockedResponse) getResponse(rr *ReceivedRequest) *httpResponse { } } + if dir := m.processDirectRequest(rr); dir != nil { + return dir + } + actionName := m.action if actionName == "" { actionName = rr.Action @@ -179,26 +193,27 @@ func (m *MockedResponse) getResponse(rr *ReceivedRequest) *httpResponse { StatusCode: m.StatusCode, contentType: ContentTypeXML, } - } else { - resultName := "" + actionName + "Result" - wrappedMap := map[string]interface{}{ - resultName: m.Body, - "ResponseMetadata": map[string]string{ - "RequestId": "01234567-89ab-cdef-0123-456789abcdef", - }, - } + } - xmlout, err := mxj.AnyXmlIndent(wrappedMap, "", " ", ""+actionName+"Response") - if err != nil { - return generateErrorStruct(0, "BadMockBody", "Could not serialize body to XML: %s", err).getResponse(rr) - } + resultName := "" + actionName + "Result" + wrappedMap := map[string]any{ + resultName: m.Body, + "ResponseMetadata": map[string]string{ + "RequestId": "01234567-89ab-cdef-0123-456789abcdef", + }, + } - return &httpResponse{ - bodyRaw: xmlout, - StatusCode: m.StatusCode, - contentType: ContentTypeXML, - } + xmlout, err := mxj.AnyXmlIndent(wrappedMap, "", " ", ""+actionName+"Response") + if err != nil { + return generateErrorStruct(0, "BadMockBody", "Could not serialize body to XML: %s", err).getResponse(rr) + } + + return &httpResponse{ + bodyRaw: xmlout, + StatusCode: m.StatusCode, + contentType: ContentTypeXML, } + case bodyKind == reflect.Slice && rBody.Type() == byteArrayType: cType := m.ContentType @@ -215,3 +230,137 @@ func (m *MockedResponse) getResponse(rr *ReceivedRequest) *httpResponse { } return generateErrorStruct(0, "BadMockResponse", "Don't know how to encode a kind=%v using content type=%s", bodyKind, m.ContentType).getResponse(rr) } + +func (m *MockedResponse) processDirectRequest(rr *ReceivedRequest) *httpResponse { + + if m.Body == nil { + return nil + } + + body := m.Body + var err error + + if rr.HttpRequest == nil || len(rr.HttpRequest.Header) == 0 { + return nil + } + reqId, perr := strconv.ParseUint(rr.HttpRequest.Header.Get(mwHeaderRequestId), 10, 64) + if perr != nil { + return generateErrorStruct(0, "BadMockBody", "Failed to get direct mocker: %s", perr.Error()).getResponse(rr) + + } + + mkr := rr.mocker + if mkr == nil { + return generateErrorStruct(0, "BadMockBody", "Failed to get direct mocker").getResponse(rr) + } + + entry, ok := mkr.requestLog.Load(reqId) + if !ok { + return generateErrorStruct(0, "BadMockBody", "Failed to find mock in DB??").getResponse(rr) + } + + rec := entry.(mwDBEntry) + + if !document.IsNoSerde(body) { + + // check if fancy func + if fn, ok := body.(directTypeFunc); ok { + body, err = fn(rr) + } else { + + body, err = processDirectRequestFunc(rec, rr, reflect.Indirect(reflect.ValueOf(body))) + + } + + if body != nil && !document.IsNoSerde(body) { + return nil + } + } + + if body == nil && err == nil { + return nil + } + + if reflect.TypeOf(body).Kind() == reflect.Struct { + val := reflect.ValueOf(body) + vp := reflect.New(val.Type()) + vp.Elem().Set(val) + body = vp.Interface() + } + + rec.Error = err + rec.Response = body + + mkr.requestLog.Store(reqId, rec) + + return &httpResponse{ + StatusCode: http.StatusOK, + Body: "", + contentType: ContentTypeJSON, + extraHeaders: map[string]string{ + mwHeaderUseDB: "true", + }, + } +} + +func processDirectRequestFunc(entry mwDBEntry, rr *ReceivedRequest, fnv reflect.Value) (any, error) { + typ := fnv.Type() + + if typ.Kind() != reflect.Func { + return nil, nil + } + + params := entry.Parameters + paramT := reflect.TypeOf(params) + + inputs := make([]reflect.Value, 0, 2) + + if typ.NumIn() == 1 { + + in1 := typ.In(0) + if in1 == rrType { + inputs = append(inputs, reflect.ValueOf(rr)) + } else if in1 == paramT { + inputs = append(inputs, reflect.ValueOf(params)) + } else { + return nil, nil + } + + } else if typ.NumIn() == 2 { + if in1 := typ.In(0); in1 == rrType { + inputs = append(inputs, reflect.ValueOf(rr)) + } else { + return nil, nil + } + + if in2 := typ.In(1); in2 == paramT { + inputs = append(inputs, reflect.ValueOf(params)) + } else { + return nil, nil + } + } else { + // invalid signature + return nil, nil + } + + if typ.NumOut() != 2 { + return nil, nil + } + + if out2 := typ.Out(1); out2 != errType { + // 2nd return must be error + return nil, nil + } + + outputs := fnv.Call(inputs) + + ret := outputs[0].Interface() + + if typ.NumOut() == 2 { + if err := outputs[1].Interface(); err != nil { + return ret, err.(error) + } + } + + return ret, nil +} diff --git a/mocked_response_test.go b/mocked_response_test.go index 0f56c43..ebb58b1 100644 --- a/mocked_response_test.go +++ b/mocked_response_test.go @@ -1,10 +1,13 @@ package awsmocker import ( + "maps" + "reflect" "testing" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" ) func TestMockedResponse_getResponse(t *testing.T) { @@ -51,3 +54,53 @@ func TestMockedResponse_getResponse(t *testing.T) { }) } } + +func TestProcessDirectRequestFunc(t *testing.T) { + + entry := mwDBEntry{ + Parameters: &sts.GetCallerIdentityInput{}, + } + rr := &ReceivedRequest{} + + stsResponse := &sts.GetCallerIdentityOutput{ + Account: aws.String(DefaultAccountId), + Arn: aws.String("arn"), + UserId: aws.String("userid"), + } + + tables := []struct { + fn any + }{ + { + fn: func(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + return stsResponse, nil + }, + }, + + { + fn: func(_ *ReceivedRequest) (*sts.GetCallerIdentityOutput, error) { + return stsResponse, nil + }, + }, + + { + fn: func(_ *ReceivedRequest, _ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + return stsResponse, nil + }, + }, + + { + fn: func() *sts.GetCallerIdentityOutput { + return stsResponse + }, + }, + } + + for _, table := range tables { + t.Run("entry", func(t *testing.T) { + result, err := processDirectRequestFunc(entry, rr, reflect.ValueOf(table.fn)) + _ = result + _ = err + }) + } +} diff --git a/mocker.go b/mocker.go index 79d8596..ff8b133 100644 --- a/mocker.go +++ b/mocker.go @@ -5,29 +5,31 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" - "path" + "sync" + "sync/atomic" "time" -) -const ( - envAwsCaBundle = "AWS_CA_BUNDLE" - envAwsAccessKey = "AWS_ACCESS_KEY_ID" - envAwsSecretKey = "AWS_SECRET_ACCESS_KEY" - envAwsSessionToken = "AWS_SESSION_TOKEN" - envAwsEc2MetaDisable = "AWS_EC2_METADATA_DISABLED" - envAwsContCredUri = "AWS_CONTAINER_CREDENTIALS_FULL_URI" - envAwsContCredRelUri = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" - envAwsContAuthToken = "AWS_CONTAINER_AUTHORIZATION_TOKEN" - envAwsConfigFile = "AWS_CONFIG_FILE" - envAwsSharedCredFile = "AWS_SHARED_CREDENTIALS_FILE" - envAwsWebIdentTFile = "AWS_WEB_IDENTITY_TOKEN_FILE" - envAwsDefaultRegion = "AWS_DEFAULT_REGION" - - // AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE - // AWS_EC2_METADATA_SERVICE_ENDPOINT + "github.com/aws/aws-sdk-go-v2/aws" ) +// const ( +// envAwsCaBundle = "AWS_CA_BUNDLE" +// envAwsAccessKey = "AWS_ACCESS_KEY_ID" +// envAwsSecretKey = "AWS_SECRET_ACCESS_KEY" +// envAwsSessionToken = "AWS_SESSION_TOKEN" +// envAwsEc2MetaDisable = "AWS_EC2_METADATA_DISABLED" +// envAwsContCredUri = "AWS_CONTAINER_CREDENTIALS_FULL_URI" +// envAwsContCredRelUri = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" +// envAwsContAuthToken = "AWS_CONTAINER_AUTHORIZATION_TOKEN" +// envAwsConfigFile = "AWS_CONFIG_FILE" +// envAwsSharedCredFile = "AWS_SHARED_CREDENTIALS_FILE" +// envAwsWebIdentTFile = "AWS_WEB_IDENTITY_TOKEN_FILE" +// envAwsDefaultRegion = "AWS_DEFAULT_REGION" + +// // AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE +// // AWS_EC2_METADATA_SERVICE_ENDPOINT +// ) + type mocker struct { t TestingT timeout time.Duration @@ -36,58 +38,66 @@ type mocker struct { verbose bool debugTraffic bool - usingAwsConfig bool + // usingAwsConfig bool doNotOverrideCreds bool doNotFailUnhandled bool - originalEnv map[string]*string + // originalEnv map[string]*string mocks []*MockedEndpoint -} -func (m *mocker) init() { - m.originalEnv = make(map[string]*string, 10) -} + // counter used by the middleware to track requests + mwReqCounter *atomic.Uint64 + requestLog *sync.Map -// Overrides an environment variable and then adds it to the stack to undo later -func (m *mocker) setEnv(k string, v any) { - val, ok := os.LookupEnv(k) - if ok { - m.originalEnv[k] = &val - } else { - m.originalEnv[k] = nil - } + awsConfig aws.Config - switch nval := v.(type) { - case string: - err := os.Setenv(k, nval) - if err != nil { - m.t.Errorf("Unable to set env var '%s': %s", k, err) - } - case nil: - err := os.Unsetenv(k) - if err != nil { - m.t.Errorf("Unable to unset env var '%s': %s", k, err) - } - default: - panic("WRONG ENV VAR VALUE TYPE: must be nil or a string") - } + noMiddleware bool } -func (m *mocker) revertEnv() { - for k, v := range m.originalEnv { - if v == nil { - _ = os.Unsetenv(k) - } else { - _ = os.Setenv(k, *v) - } - } +func (m *mocker) init() { + // m.originalEnv = make(map[string]*string, 10) + m.requestLog = &sync.Map{} + m.mwReqCounter = &atomic.Uint64{} } -func (m *mocker) Start() { - // reset Go's proxy cache - resetProxyConfig() +// Overrides an environment variable and then adds it to the stack to undo later +// func (m *mocker) setEnv(k string, v any) { +// return +// val, ok := os.LookupEnv(k) +// if ok { +// m.originalEnv[k] = &val +// } else { +// m.originalEnv[k] = nil +// } + +// switch nval := v.(type) { +// case string: +// err := os.Setenv(k, nval) +// if err != nil { +// m.t.Errorf("Unable to set env var '%s': %s", k, err) +// } +// case nil: +// err := os.Unsetenv(k) +// if err != nil { +// m.t.Errorf("Unable to unset env var '%s': %s", k, err) +// } +// default: +// panic("WRONG ENV VAR VALUE TYPE: must be nil or a string") +// } +// } + +// func (m *mocker) revertEnv() { +// for k, v := range m.originalEnv { +// if v == nil { +// _ = os.Unsetenv(k) +// } else { +// _ = os.Setenv(k, *v) +// } +// } +// } +func (m *mocker) Start() { m.init() m.t.Cleanup(m.Shutdown) @@ -96,46 +106,36 @@ func (m *mocker) Start() { m.mocks[i].prep() } - // if we are using aws config, then we don't need this - if !m.usingAwsConfig { - caBundlePath := path.Join(m.t.TempDir(), "awsmockcabundle.pem") - err := writeCABundle(caBundlePath) - if err != nil { - m.t.Errorf("Failed to write CA Bundle: %s", err) - } - m.setEnv(envAwsCaBundle, caBundlePath) - } - - ts := httptest.NewServer(m) - m.httpServer = ts - - m.setEnv("HTTP_PROXY", ts.URL) - m.setEnv("http_proxy", ts.URL) - m.setEnv("HTTPS_PROXY", ts.URL) - m.setEnv("https_proxy", ts.URL) + // ts := httptest.NewServer(m) + // m.httpServer = ts // m.setEnv(envAwsEc2MetaDisable, "true") - m.setEnv(envAwsDefaultRegion, DefaultRegion) - - if !m.doNotOverrideCreds { - m.setEnv(envAwsAccessKey, "fakekey") - m.setEnv(envAwsSecretKey, "fakesecret") - m.setEnv(envAwsSessionToken, "faketoken") - m.setEnv(envAwsConfigFile, "fakeconffile") - m.setEnv(envAwsSharedCredFile, "fakesharedfile") - } + // m.setEnv(envAwsDefaultRegion, DefaultRegion) + + // if !m.doNotOverrideCreds { + // m.setEnv(envAwsAccessKey, "fakekey") + // m.setEnv(envAwsSecretKey, "fakesecret") + // m.setEnv(envAwsSessionToken, "faketoken") + // m.setEnv(envAwsConfigFile, "fakeconffile") + // m.setEnv(envAwsSharedCredFile, "fakesharedfile") + // } } func (m *mocker) Shutdown() { - m.httpServer.Close() + if m.httpServer != nil { + m.httpServer.Close() + } + m.requestLog.Clear() - m.revertEnv() + // m.revertEnv() +} - // reset Go's proxy cache - if !m.usingAwsConfig { - resetProxyConfig() +func (m *mocker) startServer() { + if m.httpServer != nil { + return } + m.httpServer = httptest.NewServer(m) } func (m *mocker) Logf(format string, args ...any) { @@ -152,8 +152,14 @@ func (m *mocker) printf(format string, args ...any) { m.t.Logf("[AWSMOCKER] "+format, args...) } +func (m *mocker) RoundTrip(req *http.Request) (*http.Response, error) { + _, resp := m.handleRequest(req) + return resp, nil +} + func (m *mocker) handleRequest(req *http.Request) (*http.Request, *http.Response) { recvReq := newReceivedRequest(req) + recvReq.mocker = m // if recvReq.invalid { // recvReq.DebugDump() diff --git a/mocker_info.go b/mocker_info.go new file mode 100644 index 0000000..28ee278 --- /dev/null +++ b/mocker_info.go @@ -0,0 +1,53 @@ +package awsmocker + +import ( + "net/http" + "net/url" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" +) + +// Returned when you start the server, provides you some information if needed +type MockerInfo interface { + // URL of the proxy server + ProxyURL() string + + // Returns a function that can be used in [http.Transport] + Proxy() func(*http.Request) (*url.URL, error) + + // Preconfigured IMDS client + IMDSClient() *imds.Client + + // Aws configuration to use + Config() aws.Config +} + +var _ MockerInfo = (*mocker)(nil) + +func (m mocker) Config() aws.Config { + return m.awsConfig +} + +// Use this for custom proxy configurations +func (m mocker) Proxy() func(*http.Request) (*url.URL, error) { + uri, err := url.Parse(m.ProxyURL()) + return func(_ *http.Request) (*url.URL, error) { + return uri, err + } +} + +func (m *mocker) ProxyURL() string { + m.startServer() + return m.httpServer.URL +} + +func (m mocker) IMDSClient() *imds.Client { + return imds.NewFromConfig(m.Config()) +} + +// returns a preconfigured HTTP client. This will automatically use the proper proxy. +func (m *mocker) HTTPClient() *http.Client { + m.startServer() + return m.httpServer.Client() +} diff --git a/mocker_test.go b/mocker_test.go index b85031d..f6fefd9 100644 --- a/mocker_test.go +++ b/mocker_test.go @@ -2,7 +2,6 @@ package awsmocker_test import ( "context" - "crypto/tls" "fmt" "net/http" "testing" @@ -16,32 +15,26 @@ import ( "github.com/jmespath/go-jmespath" "github.com/stretchr/testify/require" "github.com/webdestroya/awsmocker" - "github.com/webdestroya/awsmocker/internal/testutil" ) func TestEcsDescribeServices(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - SkipDefaultMocks: true, - Mocks: []*awsmocker.MockedEndpoint{ - { - Request: &awsmocker.MockedRequest{ - Service: "ecs", - Action: "DescribeServices", - }, - Response: &awsmocker.MockedResponse{ - Body: map[string]interface{}{ - "services": []map[string]interface{}{ - { - "serviceName": "someservice", - }, - }, + m := awsmocker.Start(t, awsmocker.WithoutDefaultMocks(), awsmocker.WithMocks(&awsmocker.MockedEndpoint{ + Request: &awsmocker.MockedRequest{ + Service: "ecs", + Action: "DescribeServices", + }, + Response: &awsmocker.MockedResponse{ + Body: map[string]any{ + "services": []map[string]any{ + { + "serviceName": "someservice", }, }, }, }, - }) + })) - client := ecs.NewFromConfig(testutil.GetAwsConfig()) + client := ecs.NewFromConfig(m.Config()) resp, err := client.DescribeServices(context.TODO(), &ecs.DescribeServicesInput{ Services: []string{"someservice"}, @@ -52,130 +45,213 @@ func TestEcsDescribeServices(t *testing.T) { } func TestStsGetCallerIdentity_WithObj(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - SkipDefaultMocks: true, - Mocks: []*awsmocker.MockedEndpoint{ - { - Request: &awsmocker.MockedRequest{ - Service: "sts", - Action: "GetCallerIdentity", - }, - Response: &awsmocker.MockedResponse{ - Body: sts.GetCallerIdentityOutput{ - Account: aws.String(awsmocker.DefaultAccountId), - Arn: aws.String(fmt.Sprintf("arn:aws:iam::%s:user/fakeuser", awsmocker.DefaultAccountId)), - UserId: aws.String("AKIAI44QH8DHBEXAMPLE"), - }, - }, + m := awsmocker.Start(t, awsmocker.WithoutDefaultMocks(), awsmocker.WithMocks(&awsmocker.MockedEndpoint{ + Request: &awsmocker.MockedRequest{ + Service: "sts", + Action: "GetCallerIdentity", + }, + Response: &awsmocker.MockedResponse{ + Body: sts.GetCallerIdentityOutput{ + Account: aws.String(awsmocker.DefaultAccountId), + Arn: aws.String(fmt.Sprintf("arn:aws:iam::%s:user/fakeuser", awsmocker.DefaultAccountId)), + UserId: aws.String("AKIAI44QH8DHBEXAMPLE"), }, }, - }) + }), + ) - stsClient := sts.NewFromConfig(testutil.GetAwsConfig()) + stsClient := sts.NewFromConfig(m.Config()) resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) require.NoError(t, err) require.Equalf(t, awsmocker.DefaultAccountId, *resp.Account, "AccountID Mismatch") } -func TestStsGetCallerIdentity_WithMap(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - SkipDefaultMocks: true, - Mocks: []*awsmocker.MockedEndpoint{ - { - Request: &awsmocker.MockedRequest{ - Service: "sts", - Action: "GetCallerIdentity", - }, - Response: &awsmocker.MockedResponse{ - Body: map[string]interface{}{ - "Account": awsmocker.DefaultAccountId, - "Arn": fmt.Sprintf("arn:aws:iam::%s:user/fakeuser", awsmocker.DefaultAccountId), - "UserId": "AKIAI44QH8DHBEXAMPLE", - }, - }, +func TestStsGetCallerIdentity_WithPointerObj(t *testing.T) { + m := awsmocker.Start(t, awsmocker.WithoutDefaultMocks(), awsmocker.WithMocks(&awsmocker.MockedEndpoint{ + Request: &awsmocker.MockedRequest{ + Service: "sts", + Action: "GetCallerIdentity", + }, + Response: &awsmocker.MockedResponse{ + Body: &sts.GetCallerIdentityOutput{ + Account: aws.String(awsmocker.DefaultAccountId), + Arn: aws.String(fmt.Sprintf("arn:aws:iam::%s:user/fakeuser", awsmocker.DefaultAccountId)), + UserId: aws.String("AKIAI44QH8DHBEXAMPLE"), }, }, - }) - stsClient := sts.NewFromConfig(testutil.GetAwsConfig()) + }), + ) + + stsClient := sts.NewFromConfig(m.Config()) resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) require.NoError(t, err) - require.EqualValuesf(t, awsmocker.DefaultAccountId, *resp.Account, "account id mismatch") + require.Equalf(t, awsmocker.DefaultAccountId, *resp.Account, "AccountID Mismatch") } -func TestDynamicMocker(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - Mocks: []*awsmocker.MockedEndpoint{ - { - Request: &awsmocker.MockedRequest{ - Service: "events", - Action: "PutRule", - MaxMatchCount: 1, +func TestStsGetCallerIdentity_WithFancyFuncs(t *testing.T) { + reqMock := func() *awsmocker.MockedRequest { + return &awsmocker.MockedRequest{ + Service: "sts", + Action: "GetCallerIdentity", + MaxMatchCount: 1, + } + } + + normalResp := &sts.GetCallerIdentityOutput{ + Account: aws.String(awsmocker.DefaultAccountId), + Arn: aws.String(fmt.Sprintf("arn:aws:iam::%s:user/fakeuser", awsmocker.DefaultAccountId)), + UserId: aws.String("AKIAI44QH8DHBEXAMPLE"), + } + + mocks := []*awsmocker.MockedEndpoint{ + { + Request: reqMock(), + Response: &awsmocker.MockedResponse{ + Body: func(_ *awsmocker.ReceivedRequest) (any, error) { + return *normalResp, nil }, - Response: &awsmocker.MockedResponse{ - Body: func(rr *awsmocker.ReceivedRequest) string { - name, _ := jmespath.Search("Name", rr.JsonPayload) - return awsmocker.EncodeAsJson(map[string]interface{}{ - "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), - }) - }, + }, + }, + { + Request: reqMock(), + Response: &awsmocker.MockedResponse{ + Body: func(_ *awsmocker.ReceivedRequest) (any, error) { + return normalResp, nil }, }, - { - Request: &awsmocker.MockedRequest{ - Service: "events", - Action: "PutRule", - MaxMatchCount: 1, + }, + { + Request: reqMock(), + Response: &awsmocker.MockedResponse{ + Body: func(_ *awsmocker.ReceivedRequest) (*sts.GetCallerIdentityOutput, error) { + return normalResp, nil }, - Response: &awsmocker.MockedResponse{ - Body: func(rr *awsmocker.ReceivedRequest) (string, int) { - name, _ := jmespath.Search("Name", rr.JsonPayload) - return awsmocker.EncodeAsJson(map[string]interface{}{ - "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/x%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), - }), 200 - }, + }, + }, + { + Request: reqMock(), + Response: &awsmocker.MockedResponse{ + Body: func(_ *awsmocker.ReceivedRequest, _ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + return normalResp, nil }, }, - { - Request: &awsmocker.MockedRequest{ - Service: "events", - Action: "PutRule", - MaxMatchCount: 1, + }, + { + Request: reqMock(), + Response: &awsmocker.MockedResponse{ + Body: func(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + return normalResp, nil }, - Response: &awsmocker.MockedResponse{ - Body: func(rr *awsmocker.ReceivedRequest) (string, int, string) { - name, _ := jmespath.Search("Name", rr.JsonPayload) - return awsmocker.EncodeAsJson(map[string]interface{}{ - "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/y%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), - }), 200, awsmocker.ContentTypeJSON - }, + }, + }, + } + + m := awsmocker.Start(t, awsmocker.WithoutDefaultMocks(), awsmocker.WithMocks(mocks...)) + stsClient := sts.NewFromConfig(m.Config()) + + for i := 0; i < len(mocks); i++ { + + resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) + require.NoError(t, err) + require.Equalf(t, awsmocker.DefaultAccountId, *resp.Account, "AccountID Mismatch") + } +} + +func TestStsGetCallerIdentity_WithMap(t *testing.T) { + m := awsmocker.Start(t, awsmocker.WithoutDefaultMocks(), awsmocker.WithMocks(&awsmocker.MockedEndpoint{ + Request: &awsmocker.MockedRequest{ + Service: "sts", + Action: "GetCallerIdentity", + }, + Response: &awsmocker.MockedResponse{ + Body: map[string]any{ + "Account": awsmocker.DefaultAccountId, + "Arn": fmt.Sprintf("arn:aws:iam::%s:user/fakeuser", awsmocker.DefaultAccountId), + "UserId": "AKIAI44QH8DHBEXAMPLE", + }, + }, + }), + ) + stsClient := sts.NewFromConfig(m.Config()) + + resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) + require.NoError(t, err) + require.EqualValuesf(t, awsmocker.DefaultAccountId, *resp.Account, "account id mismatch") +} + +func TestDynamicMocker(t *testing.T) { + m := awsmocker.Start(t, awsmocker.WithMocks([]*awsmocker.MockedEndpoint{ + { + Request: &awsmocker.MockedRequest{ + Service: "events", + Action: "PutRule", + MaxMatchCount: 1, + }, + Response: &awsmocker.MockedResponse{ + Body: func(rr *awsmocker.ReceivedRequest) string { + name, _ := jmespath.Search("Name", rr.JsonPayload) + return awsmocker.EncodeAsJson(map[string]any{ + "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), + }) }, }, - { - Request: &awsmocker.MockedRequest{ - Service: "events", - Action: "PutRule", - MaxMatchCount: 1, + }, + { + Request: &awsmocker.MockedRequest{ + Service: "events", + Action: "PutRule", + MaxMatchCount: 1, + }, + Response: &awsmocker.MockedResponse{ + Body: func(rr *awsmocker.ReceivedRequest) (string, int) { + name, _ := jmespath.Search("Name", rr.JsonPayload) + return awsmocker.EncodeAsJson(map[string]any{ + "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/x%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), + }), 200 }, - Response: &awsmocker.MockedResponse{ - Body: func(rr *awsmocker.ReceivedRequest) (string, int, string, string) { - name, _ := jmespath.Search("Name", rr.JsonPayload) - return awsmocker.EncodeAsJson(map[string]interface{}{ - "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/y%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), - }), 200, awsmocker.ContentTypeJSON, "wut" - }, + }, + }, + { + Request: &awsmocker.MockedRequest{ + Service: "events", + Action: "PutRule", + MaxMatchCount: 1, + }, + Response: &awsmocker.MockedResponse{ + Body: func(rr *awsmocker.ReceivedRequest) (string, int, string) { + name, _ := jmespath.Search("Name", rr.JsonPayload) + return awsmocker.EncodeAsJson(map[string]any{ + "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/y%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), + }), 200, awsmocker.ContentTypeJSON }, }, }, - }) + { + Request: &awsmocker.MockedRequest{ + Service: "events", + Action: "PutRule", + MaxMatchCount: 1, + }, + Response: &awsmocker.MockedResponse{ + Body: func(rr *awsmocker.ReceivedRequest) (string, int, string, string) { + name, _ := jmespath.Search("Name", rr.JsonPayload) + return awsmocker.EncodeAsJson(map[string]any{ + "RuleArn": fmt.Sprintf("arn:aws:events:%s:%s:rule/y%s", rr.Region, awsmocker.DefaultAccountId, name.(string)), + }), 200, awsmocker.ContentTypeJSON, "wut" + }, + }, + }, + }...), + ) - client := eventbridge.NewFromConfig(testutil.GetAwsConfig()) + client := eventbridge.NewFromConfig(m.Config()) tables := []struct { name string expectedArn string - errorContains interface{} + errorContains any }{ {"testrule", "arn:aws:events:us-east-1:555555555555:rule/testrule", nil}, {"testrule", "arn:aws:events:us-east-1:555555555555:rule/xtestrule", nil}, @@ -214,9 +290,9 @@ func TestStartMockServerForTest(t *testing.T) { // END REALLY TALKING TO AWS // start the test mocker server - awsmocker.Start(t, &awsmocker.MockerOptions{}) + m := awsmocker.Start(t) - stsClient := sts.NewFromConfig(testutil.GetAwsConfig()) + stsClient := sts.NewFromConfig(m.Config()) resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) require.NoError(t, err) @@ -224,65 +300,74 @@ func TestStartMockServerForTest(t *testing.T) { } func TestDefaultMocks(t *testing.T) { - awsmocker.Start(t, nil) + m := awsmocker.Start(t) - stsClient := sts.NewFromConfig(testutil.GetAwsConfig()) + stsClient := sts.NewFromConfig(m.Config()) resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) require.NoError(t, err) require.EqualValuesf(t, awsmocker.DefaultAccountId, *resp.Account, "account id mismatch") } -func TestBypass(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - DoNotProxy: "example.com", - }) +func TestWithoutDefaultMocks(t *testing.T) { + m := awsmocker.Start(t, awsmocker.WithoutDefaultMocks(), awsmocker.WithoutFailingUnhandledRequests()) + stsClient := sts.NewFromConfig(m.Config()) - httpresp, err := http.Head("http://example.com/") - require.NoError(t, err) - require.Equal(t, http.StatusOK, httpresp.StatusCode) + _, err := stsClient.GetCallerIdentity(context.TODO(), nil) + require.Error(t, err) + require.ErrorContains(t, err, "AccessDenied") +} - stsClient := sts.NewFromConfig(testutil.GetAwsConfig()) +// func TestBypass(t *testing.T) { +// awsmocker.Start(t, &awsmocker.MockerOptions{ +// DoNotProxy: "example.com", +// }) - resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) - require.NoError(t, err) - require.EqualValuesf(t, awsmocker.DefaultAccountId, *resp.Account, "account id mismatch") +// httpresp, err := http.Head("http://example.com/") +// require.NoError(t, err) +// require.Equal(t, http.StatusOK, httpresp.StatusCode) -} +// stsClient := sts.NewFromConfig(testutil.GetAwsConfig()) -func TestBypassReject(t *testing.T) { - awsmocker.Start(t, &awsmocker.MockerOptions{ - DoNotProxy: "example.com", - DoNotFailUnhandledRequests: true, - }) +// resp, err := stsClient.GetCallerIdentity(context.TODO(), nil) +// require.NoError(t, err) +// require.EqualValuesf(t, awsmocker.DefaultAccountId, *resp.Account, "account id mismatch") - client := &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - } +// } - resp, err := client.Head("https://example.org/") - require.NoError(t, err) - require.Equal(t, "webdestroya", resp.TLS.PeerCertificates[0].Subject.Organization[0]) - require.Equal(t, http.StatusNotImplemented, resp.StatusCode) - // require.ErrorContains(t, err, "Not Implemented") +// func TestBypassReject(t *testing.T) { +// awsmocker.Start(t, &awsmocker.MockerOptions{ +// DoNotProxy: "example.com", +// DoNotFailUnhandledRequests: true, +// }) - resp = nil +// client := &http.Client{ +// Transport: &http.Transport{ +// Proxy: http.ProxyFromEnvironment, +// TLSClientConfig: &tls.Config{ +// InsecureSkipVerify: true, +// }, +// }, +// } - resp, err = http.Get("http://example.org/") - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, http.StatusNotImplemented, resp.StatusCode) -} +// resp, err := client.Head("https://example.org/") +// require.NoError(t, err) +// require.Equal(t, "webdestroya", resp.TLS.PeerCertificates[0].Subject.Organization[0]) +// require.Equal(t, http.StatusNotImplemented, resp.StatusCode) +// // require.ErrorContains(t, err, "Not Implemented") + +// resp = nil + +// resp, err = http.Get("http://example.org/") +// require.NoError(t, err) +// defer resp.Body.Close() +// require.Equal(t, http.StatusNotImplemented, resp.StatusCode) +// } func TestSendingRegularRequestToProxy(t *testing.T) { info := awsmocker.Start(t, nil) - resp, err := http.Get(info.ProxyURL + "/testing") + resp, err := http.Get(info.ProxyURL() + "/testing") require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusNotImplemented, resp.StatusCode) diff --git a/mocks_common.go b/mocks_common.go index 60860e5..bd26af7 100644 --- a/mocks_common.go +++ b/mocks_common.go @@ -7,6 +7,7 @@ func Mock_Failure(service, action string) *MockedEndpoint { return Mock_Failure_WithCode(0, service, action, "AccessDenied", "This mock was requested to fail") } +// Mocks a specific Service:Action call to return an error func Mock_Failure_WithCode(statusCode int, service, action, errorCode, errorMessage string) *MockedEndpoint { return &MockedEndpoint{ Request: &MockedRequest{ diff --git a/mocks_common_test.go b/mocks_common_test.go index 5bd115e..0c6d0b6 100644 --- a/mocks_common_test.go +++ b/mocks_common_test.go @@ -14,9 +14,9 @@ import ( ) func TestMockResponse_Error(t *testing.T) { - info := awsmocker.Start(t, &awsmocker.MockerOptions{ - SkipDefaultMocks: true, - Mocks: []*awsmocker.MockedEndpoint{ + info := awsmocker.Start(t, + awsmocker.WithoutDefaultMocks(), + awsmocker.WithMocks([]*awsmocker.MockedEndpoint{ { Request: &awsmocker.MockedRequest{ Hostname: "test.com", @@ -38,13 +38,13 @@ func TestMockResponse_Error(t *testing.T) { }, Response: awsmocker.MockResponse_Error(0, "SomeCode0", "SomeMessage"), }, - }, - }) + }...), + ) client := &http.Client{ Transport: &http.Transport{ Proxy: func(r *http.Request) (*url.URL, error) { - return url.Parse(info.ProxyURL) + return url.Parse(info.ProxyURL()) }, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, @@ -86,16 +86,13 @@ func TestMockResponse_Error(t *testing.T) { } func TestMock_Failure(t *testing.T) { - info := awsmocker.Start(t, &awsmocker.MockerOptions{ - SkipDefaultMocks: true, - ReturnAwsConfig: true, - Mocks: []*awsmocker.MockedEndpoint{ - awsmocker.Mock_Failure("ecs", "ListClusters"), - awsmocker.Mock_Failure_WithCode(403, "ecs", "ListServices", "SomeCode", "SomeMessage"), - }, - }) + info := awsmocker.Start(t, + awsmocker.WithoutDefaultMocks(), + awsmocker.WithMocks(awsmocker.Mock_Failure("ecs", "ListClusters")), + awsmocker.WithMocks(awsmocker.Mock_Failure_WithCode(403, "ecs", "ListServices", "SomeCode", "SomeMessage")), + ) - ecsClient := ecs.NewFromConfig(*info.AwsConfig) + ecsClient := ecs.NewFromConfig(info.Config()) _, err := ecsClient.ListClusters(context.TODO(), &ecs.ListClustersInput{}) require.Error(t, err) diff --git a/mocks_sts.go b/mocks_sts.go index 48c104d..ffc1798 100644 --- a/mocks_sts.go +++ b/mocks_sts.go @@ -6,6 +6,7 @@ import ( ) var ( + // Default Mock for the sts:GetCallerIdentity request MockStsGetCallerIdentityValid = &MockedEndpoint{ Request: &MockedRequest{ Service: "sts", @@ -14,7 +15,7 @@ var ( Response: &MockedResponse{ StatusCode: http.StatusOK, Encoding: ResponseEncodingXML, - Body: map[string]interface{}{ + Body: map[string]any{ "Account": DefaultAccountId, "Arn": fmt.Sprintf("arn:aws:iam::%s:user/fakeuser", DefaultAccountId), "UserId": "AKIAI44QH8DHBEXAMPLE", diff --git a/options.go b/options.go index ba31e98..cb12300 100644 --- a/options.go +++ b/options.go @@ -1,10 +1,11 @@ package awsmocker import ( + "slices" "time" ) -type MockerOptions struct { +type mockerOptions struct { // Add extra logging. This is deprecated, you should just use the AWSMOCKER_DEBUG=1 env var and do a targeted test run Verbose bool @@ -14,24 +15,12 @@ type MockerOptions struct { // DoNotOverrideCreds bool - // if this is true, then default mocks for GetCallerIdentity and role assumptions will not be provided - SkipDefaultMocks bool - - // WARNING: Setting this to true assumes that you are able to use the config value returned - // If you do not use the provided config and set this true, then requests will not be routed properly. - ReturnAwsConfig bool - // Timeout for proxied requests. Timeout time.Duration // The mocks that will be responded to Mocks []*MockedEndpoint - // Comma separated list of hostname globs that should not be proxied - // if you are doing other HTTP/HTTPS requests within your test, you should - // add the hostnames used to this. - DoNotProxy string - // Add mocks for the EC2 Instance Metadata Service MockEc2Metadata bool @@ -39,4 +28,101 @@ type MockerOptions struct { // you can pass true to this if you do not want to fail your test when the mocker receives an // unmatched request DoNotFailUnhandledRequests bool + + AwsConfigOptions []AwsLoadOptionsFunc + + noMiddleware bool +} + +type MockerOptionFunc func(*mockerOptions) + +var defaultMocks = []*MockedEndpoint{ + MockStsGetCallerIdentityValid, +} + +func newOptions() *mockerOptions { + return &mockerOptions{ + Timeout: 5 * time.Second, + Mocks: slices.Clone(defaultMocks), + } +} + +// Default mocks for GetCallerIdentity and role assumptions will not be provided +func WithoutDefaultMocks() MockerOptionFunc { + return func(o *mockerOptions) { + o.Mocks = slices.DeleteFunc(o.Mocks, func(m *MockedEndpoint) bool { + return slices.Contains(defaultMocks, m) + }) + } +} + +// Disables setting credential environment variables +// This is dangerous, because if the proxy were to fail, then your requests may actually +// execute on AWS with real credentials. +// This means if you do not properly configure the mocker, you could end up making real requests to AWS. +// This is not recommended. +// Deprecated: You should really not be using this +func WithoutCredentialProtection() MockerOptionFunc { + return func(o *mockerOptions) { + o.DoNotOverrideCreds = true + } +} + +// By default, receiving an unmatched request will cause the test to be marked as failed +// you can pass true to this if you do not want to fail your test when the mocker receives an +// unmatched request +func WithoutFailingUnhandledRequests() MockerOptionFunc { + return func(o *mockerOptions) { + o.DoNotFailUnhandledRequests = true + } +} + +// Skip installation of middleware in AWS Options +// Use this if you have a very specific middleware need that the mocker interferes with. +// +// WARNING: This will severely impair the mocker if you use typed requests/responses +// you should probably fix the issue with using the middleware instead. +func WithoutMiddleware() MockerOptionFunc { + return func(o *mockerOptions) { + o.noMiddleware = true + } +} + +// Add mocks for the EC2 Instance Metadata Service +// These are not exhaustive, so if you have a special need you will have to add it. +func WithEC2Metadata(opts ...IMDSMockOptionFunc) MockerOptionFunc { + return WithMocks(Mock_IMDS_Common(opts...)...) +} + +// Additional AWS LoadConfig options to pass along to the Config builder +// Use this if you need to add custom middleware. +// Don't use this to set credentials, or HTTP client, as those will be overridden +func WithAWSConfigOptions(opts ...AwsLoadOptionsFunc) MockerOptionFunc { + return func(o *mockerOptions) { + o.AwsConfigOptions = append(o.AwsConfigOptions, opts...) + } +} + +// The mocks that will be responded to +func WithMocks(mocks ...*MockedEndpoint) MockerOptionFunc { + return func(o *mockerOptions) { + o.Mocks = append(o.Mocks, mocks...) + } +} + +// If provided, then requests that run longer than this will be terminated. +// Generally you should not need to set this +func WithTimeout(value time.Duration) MockerOptionFunc { + return func(mo *mockerOptions) { + mo.Timeout = value + } +} + +// Add extra logging. +// +// Deprecated: you should just use the AWSMOCKER_DEBUG=1 env var and do a targeted test run +func WithVerbosity(value bool) MockerOptionFunc { + return func(mo *mockerOptions) { + mo.Verbose = value + } } diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..89f1f07 --- /dev/null +++ b/options_test.go @@ -0,0 +1,43 @@ +package awsmocker + +import ( + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/stretchr/testify/require" +) + +func TestMockerOptions(t *testing.T) { + t.Run("WithVerbosity", func(t *testing.T) { + mo := newOptions() + WithVerbosity(true)(mo) + require.True(t, mo.Verbose) + + WithVerbosity(false)(mo) + require.False(t, mo.Verbose) + }) + + t.Run("WithTimeout", func(t *testing.T) { + mo := newOptions() + WithTimeout(10 * time.Minute)(mo) + require.Equal(t, 10*time.Minute, mo.Timeout) + }) + + t.Run("WithoutDefaultMocks", func(t *testing.T) { + mo := newOptions() + WithEC2Metadata()(mo) + require.Contains(t, mo.Mocks, MockStsGetCallerIdentityValid) + WithoutDefaultMocks()(mo) + require.NotContains(t, mo.Mocks, MockStsGetCallerIdentityValid) + }) + + t.Run("WithAWSConfigOptions", func(t *testing.T) { + m := Start(t, + WithAWSConfigOptions(config.WithRegion("blah-yar")), + ) + + require.Equal(t, "blah-yar", m.Config().Region) + + }) +} diff --git a/received_request.go b/received_request.go index be78bf0..579b167 100644 --- a/received_request.go +++ b/received_request.go @@ -37,6 +37,9 @@ type ReceivedRequest struct { // TBA: maybe in the future we'll add invalid request flagging, for now allow all types // invalid bool + + // internal reference to the mocker parent + mocker *mocker } func (rr *ReceivedRequest) Inspect() string { @@ -58,9 +61,14 @@ func newReceivedRequest(req *http.Request) *ReceivedRequest { _ = req.ParseForm() - bodyBytes, err := io.ReadAll(req.Body) - if err == nil { - recvreq.RawBody = bodyBytes + var bodyBytes []byte + + if req.Body != nil { + bb, err := io.ReadAll(req.Body) + bodyBytes = bb + if err == nil { + recvreq.RawBody = bodyBytes + } } reqContentType := req.Header.Get("content-type") @@ -76,8 +84,7 @@ func newReceivedRequest(req *http.Request) *ReceivedRequest { } - authHeader := req.Header.Get("authorization") - if authHeader != "" { + if authHeader := req.Header.Get("authorization"); authHeader != "" { matches := credExtractRegexp.FindStringSubmatch(authHeader) if len(matches) > 1 { // 0 1 2 3 4 @@ -105,6 +112,14 @@ func newReceivedRequest(req *http.Request) *ReceivedRequest { recvreq.AssumedResponseType = ContentTypeXML } + if mhService := req.Header.Get(mwHeaderService); mhService != "" && recvreq.Service == "" { + recvreq.Service = mhService + } + + if mhAction := req.Header.Get(mwHeaderOperation); mhAction != "" && recvreq.Action == "" { + recvreq.Action = mhAction + } + // if recvreq.Action == "" { // log.Println("WARN: Received a request with no action????") // recvreq.invalid = true @@ -124,8 +139,8 @@ func (r *ReceivedRequest) DebugDump() { fmt.Fprintf(buf, "Operation: %s (service=%s @ %s)\n", r.Action, r.Service, r.Region) } - fmt.Fprintf(buf, "%s %s\n", r.HttpRequest.Method, r.HttpRequest.RequestURI) - fmt.Fprintf(buf, "Host: %s\n", r.HttpRequest.Host) + fmt.Fprintf(buf, "%s %s\n", r.HttpRequest.Method, coalesceString(r.HttpRequest.RequestURI, r.Path)) + fmt.Fprintf(buf, "Host: %s\n", coalesceString(r.HttpRequest.Host, r.Hostname)) for k, vlist := range r.HttpRequest.Header { for _, v := range vlist { fmt.Fprintf(buf, "%s: %s\n", k, v) @@ -137,7 +152,7 @@ func (r *ReceivedRequest) DebugDump() { if len(r.RawBody) > 0 { fmt.Fprintln(buf, "BODY:") fmt.Fprintln(buf, string(r.RawBody)) - } else if r.HttpRequest.Form != nil && len(r.HttpRequest.Form) > 0 { + } else if len(r.HttpRequest.Form) > 0 { fmt.Fprintln(buf, "PARAMS:") for k, vlist := range r.HttpRequest.Form { for _, v := range vlist { diff --git a/received_request_test.go b/received_request_test.go index 804815c..a78c698 100644 --- a/received_request_test.go +++ b/received_request_test.go @@ -29,17 +29,13 @@ func TestReceivedRequest_DebugDump(t *testing.T) { // } - info := awsmocker.Start(t, &awsmocker.MockerOptions{ - SkipDefaultMocks: true, - ReturnAwsConfig: true, - // V - Mocks: []*awsmocker.MockedEndpoint{ - awsmocker.Mock_Failure("ecs", "ListClusters"), - awsmocker.Mock_Failure_WithCode(403, "ecs", "ListServices", "SomeCode", "SomeMessage"), - }, - }) + info := awsmocker.Start(t, + awsmocker.WithoutDefaultMocks(), + awsmocker.WithMocks(awsmocker.Mock_Failure("ecs", "ListClusters")), + awsmocker.WithMocks(awsmocker.Mock_Failure_WithCode(403, "ecs", "ListServices", "SomeCode", "SomeMessage")), + ) - ecsClient := ecs.NewFromConfig(*info.AwsConfig) + ecsClient := ecs.NewFromConfig(info.Config()) _, err := ecsClient.ListClusters(context.TODO(), &ecs.ListClustersInput{}) require.Error(t, err) @@ -51,6 +47,7 @@ func TestReceivedRequest_DebugDump(t *testing.T) { require.Contains(t, debugStr, "AWSMOCKER RESPONSE:") require.Contains(t, debugStr, "POST") require.Contains(t, debugStr, "ecs.us-east-1.amazonaws.com") + require.Contains(t, debugStr, "service=ecs") require.Contains(t, debugStr, "ListClusters") } diff --git a/response.go b/response.go index 88ef1cd..253eeb1 100644 --- a/response.go +++ b/response.go @@ -62,7 +62,7 @@ func (hr *httpResponse) toHttpResponse(req *http.Request) *http.Response { resp.Header.Add("Content-Type", hr.contentType) resp.Header.Add("Server", "AWSMocker") - if hr.extraHeaders != nil && len(hr.extraHeaders) > 0 { + if len(hr.extraHeaders) > 0 { for k, v := range hr.extraHeaders { resp.Header.Set(k, v) } diff --git a/response_test.go b/response_test.go index c51616e..9fc4e01 100644 --- a/response_test.go +++ b/response_test.go @@ -12,29 +12,25 @@ import ( ) func TestResponseDebugLogging(t *testing.T) { - info := awsmocker.Start(t, &awsmocker.MockerOptions{ - SkipDefaultMocks: true, - ReturnAwsConfig: true, - Mocks: []*awsmocker.MockedEndpoint{ - { - Request: &awsmocker.MockedRequest{ - Hostname: "httptest.com", - }, - Response: awsmocker.MockResponse_Error(400, "SomeCode_HTTP", "SomeMessage"), + info := awsmocker.Start(t, awsmocker.WithoutDefaultMocks(), + awsmocker.WithMocks(&awsmocker.MockedEndpoint{ + Request: &awsmocker.MockedRequest{ + Hostname: "httptest.com", }, - { - Request: &awsmocker.MockedRequest{ - Hostname: "httpstest.com", - }, - Response: awsmocker.MockResponse_Error(401, "SomeCode_HTTPS", "SomeMessage"), + Response: awsmocker.MockResponse_Error(400, "SomeCode_HTTP", "SomeMessage"), + }), + awsmocker.WithMocks(&awsmocker.MockedEndpoint{ + Request: &awsmocker.MockedRequest{ + Hostname: "httpstest.com", }, - }, - }) + Response: awsmocker.MockResponse_Error(401, "SomeCode_HTTPS", "SomeMessage"), + }), + ) client := &http.Client{ Transport: &http.Transport{ Proxy: func(r *http.Request) (*url.URL, error) { - return url.Parse(info.ProxyURL) + return url.Parse(info.ProxyURL()) }, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, diff --git a/sanity_test.go b/sanity_test.go new file mode 100644 index 0000000..11f98a0 --- /dev/null +++ b/sanity_test.go @@ -0,0 +1,66 @@ +package awsmocker_test + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/smithy-go/document" + "github.com/stretchr/testify/require" +) + +func TestSanity(t *testing.T) { + t.Run("NoSerde", func(t *testing.T) { + require.False(t, document.IsNoSerde(uint(123))) + require.False(t, document.IsNoSerde(uint64(123))) + require.False(t, document.IsNoSerde(nil)) + + require.True(t, document.IsNoSerde(ec2.DescribeSubnetsOutput{})) + require.True(t, document.IsNoSerde(&ec2.DescribeSubnetsOutput{})) + }) + +} + +// func TestReflection(t *testing.T) { +// thing := &ec2.DescribeAccountAttributesOutput{} +// thing2 := ec2.DescribeAccountAttributesOutput{} +// var thing3 any = any(ec2.DescribeAccountAttributesOutput{}) + +// f1 := func(_ *awsmocker.ReceivedRequest) (*ec2.DescribeSubnetsOutput, error) { +// return nil, nil +// } + +// f2 := func(_ *awsmocker.ReceivedRequest, _ *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) { +// return nil, nil +// } + +// f3 := func(_ *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) { +// return nil, nil +// } +// _ = f2 +// _ = f3 +// _ = thing2 + +// t.Logf("Thing1:Kind: %s", reflect.TypeOf(thing).Kind().String()) +// t.Logf("Thing2:Kind: %s", reflect.TypeOf(thing2).Kind().String()) +// t.Logf("Thing3:Kind: %s", reflect.TypeOf(thing3).Kind().String()) +// t.Logf("Thing3:Kind: %s", reflect.TypeOf(&thing3).Kind().String()) +// t.Logf("F3:Kind: %s", reflect.TypeOf(f3).Kind().String()) +// t.Logf("New: %s", reflect.TypeFor[*awsmocker.ReceivedRequest]().String()) + +// typ := reflect.TypeOf(thing) + +// require.True(t, document.IsNoSerde(thing)) + +// t.Logf("Kind=%s", reflect.Indirect(reflect.ValueOf(thing)).Kind().String()) +// t.Logf("TYPE=%s", typ.String()) +// t.Logf("TYPE=%s", typ.Elem().String()) +// t.Logf("PKG=%s", typ.Elem().PkgPath()) + +// f1type := reflect.TypeOf(f1) +// t.Logf("F1 TYPE=%s", f1type.String()) +// t.Logf("F1 NumIn=%d", f1type.NumIn()) +// t.Logf("F1 NumOut=%d", f1type.NumOut()) +// t.Logf("F1 In.0=%s", f1type.In(0).String()) +// t.Logf("F1 Out.0=%s [%s]", f1type.Out(0).String(), f1type.Out(0).Elem().PkgPath()) + +// } diff --git a/types.go b/types.go index bdaf5e1..c652232 100644 --- a/types.go +++ b/types.go @@ -3,6 +3,9 @@ package awsmocker import ( "net" "net/http" + "testing" + + "github.com/aws/aws-sdk-go-v2/config" ) const ( @@ -25,6 +28,11 @@ type TestingT interface { // Fatalf(format string, args ...any) } +var ( + _ TestingT = (*testing.T)(nil) + _ TestingT = (*testing.B)(nil) +) + type tHelper interface { Helper() } @@ -48,3 +56,8 @@ const ( ) type MockedRequestHandler = func(*ReceivedRequest) *http.Response + +// Come on Amazon... +// Can't use {config.LoadOptionsFunc} because that is not a param for LoadDefaultConfig +// Why would you define the type and then not even use it... +type AwsLoadOptionsFunc = func(*config.LoadOptions) error diff --git a/util.go b/util.go index 86263a6..966bdb4 100644 --- a/util.go +++ b/util.go @@ -20,6 +20,7 @@ func encodeAsXml(obj any) string { return string(out) } +// JSONifies the given object func EncodeAsJson(obj any) string { out, err := json.Marshal(obj) if err != nil { @@ -63,3 +64,12 @@ func isAwsHostname(hostname string) bool { return false } */ + +func coalesceString(vals ...string) string { + for _, v := range vals { + if v != "" { + return v + } + } + return "" +}