diff --git a/iq/application.go b/iq/application.go index 21a0dce..769ad97 100644 --- a/iq/application.go +++ b/iq/application.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -44,13 +45,12 @@ type Application struct { } `json:"applicationTags,omitempty"` } -// GetApplicationByPublicID returns details on the named IQ application -func GetApplicationByPublicID(iq IQ, applicationPublicID string) (*Application, error) { +func GetApplicationByPublicIDContext(ctx context.Context, iq IQ, applicationPublicID string) (*Application, error) { doError := func(err error) error { return fmt.Errorf("application '%s' not found: %v", applicationPublicID, err) } endpoint := fmt.Sprintf(restApplicationByPublic, applicationPublicID) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, doError(err) } @@ -67,8 +67,12 @@ func GetApplicationByPublicID(iq IQ, applicationPublicID string) (*Application, return &resp.Applications[0], nil } -// CreateApplication creates an application in IQ with the given name and identifier -func CreateApplication(iq IQ, name, id, organizationID string) (string, error) { +// GetApplicationByPublicID returns details on the named IQ application +func GetApplicationByPublicID(iq IQ, applicationPublicID string) (*Application, error) { + return GetApplicationByPublicIDContext(context.Background(), iq, applicationPublicID) +} + +func CreateApplicationContext(ctx context.Context, iq IQ, name, id, organizationID string) (string, error) { if name == "" || id == "" || organizationID == "" { return "", fmt.Errorf("cannot create application with empty values") } @@ -82,7 +86,7 @@ func CreateApplication(iq IQ, name, id, organizationID string) (string, error) { return doError(err) } - body, _, err := iq.Post(restApplication, bytes.NewBuffer(request)) + body, _, err := iq.Post(ctx, restApplication, bytes.NewBuffer(request)) if err != nil { return doError(err) } @@ -95,17 +99,25 @@ func CreateApplication(iq IQ, name, id, organizationID string) (string, error) { return resp.ID, nil } -// DeleteApplication deletes an application in IQ with the given id -func DeleteApplication(iq IQ, applicationID string) error { - if resp, err := iq.Del(fmt.Sprintf("%s/%s", restApplication, applicationID)); err != nil && resp.StatusCode != http.StatusNoContent { +// CreateApplication creates an application in IQ with the given name and identifier +func CreateApplication(iq IQ, name, id, organizationID string) (string, error) { + return CreateApplicationContext(context.Background(), iq, name, id, organizationID) +} + +func DeleteApplicationContext(ctx context.Context, iq IQ, applicationID string) error { + if resp, err := iq.Del(ctx, fmt.Sprintf("%s/%s", restApplication, applicationID)); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("application '%s' not deleted: %v", applicationID, err) } return nil } -// GetAllApplications returns a slice of all of the applications in an IQ instance -func GetAllApplications(iq IQ) ([]Application, error) { - body, _, err := iq.Get(restApplication) +// DeleteApplication deletes an application in IQ with the given id +func DeleteApplication(iq IQ, applicationID string) error { + return DeleteApplicationContext(context.Background(), iq, applicationID) +} + +func GetAllApplicationsContext(ctx context.Context, iq IQ) ([]Application, error) { + body, _, err := iq.Get(ctx, restApplication) if err != nil { return nil, fmt.Errorf("applications not found: %v", err) } @@ -118,14 +130,18 @@ func GetAllApplications(iq IQ) ([]Application, error) { return resp.Applications, nil } -// GetApplicationsByOrganization returns all applications under a given organization -func GetApplicationsByOrganization(iq IQ, organizationName string) ([]Application, error) { - org, err := GetOrganizationByName(iq, organizationName) +// GetAllApplications returns a slice of all of the applications in an IQ instance +func GetAllApplications(iq IQ) ([]Application, error) { + return GetAllApplicationsContext(context.Background(), iq) +} + +func GetApplicationsByOrganizationContext(ctx context.Context, iq IQ, organizationName string) ([]Application, error) { + org, err := GetOrganizationByNameContext(ctx, iq, organizationName) if err != nil { return nil, fmt.Errorf("organization not found: %v", err) } - apps, err := GetAllApplications(iq) + apps, err := GetAllApplicationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not get applications list: %v", err) } @@ -139,3 +155,8 @@ func GetApplicationsByOrganization(iq IQ, organizationName string) ([]Applicatio return orgApps, nil } + +// GetApplicationsByOrganization returns all applications under a given organization +func GetApplicationsByOrganization(iq IQ, organizationName string) ([]Application, error) { + return GetApplicationsByOrganizationContext(context.Background(), iq, organizationName) +} diff --git a/iq/application_test.go b/iq/application_test.go index 3f7ee71..b89c774 100644 --- a/iq/application_test.go +++ b/iq/application_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -109,7 +110,7 @@ func TestGetAllApplications(t *testing.T) { iq, mock := applicationTestIQ(t) defer mock.Close() - applications, err := GetAllApplications(iq) + applications, err := GetAllApplicationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -133,7 +134,7 @@ func TestGetApplicationByPublicID(t *testing.T) { dummyAppsIdx := 2 - got, err := GetApplicationByPublicID(iq, dummyApps[dummyAppsIdx].PublicID) + got, err := GetApplicationByPublicIDContext(context.Background(), iq, dummyApps[dummyAppsIdx].PublicID) if err != nil { t.Error(err) } @@ -197,7 +198,7 @@ func TestCreateApplication(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := CreateApplication(tt.args.iq, tt.args.name, tt.args.id, tt.args.organizationID) + got, err := CreateApplicationContext(context.Background(), tt.args.iq, tt.args.name, tt.args.id, tt.args.organizationID) if (err != nil) != tt.wantErr { t.Errorf("CreateApplication() error = %v, wantErr %v", err, tt.wantErr) return @@ -216,16 +217,16 @@ func TestDeleteApplication(t *testing.T) { deleteMeApp := Application{PublicID: "deleteMeApp", Name: "deleteMeApp", OrganizationID: "deleteMeAppOrgId"} var err error - deleteMeApp.ID, err = CreateApplication(iq, deleteMeApp.Name, deleteMeApp.PublicID, deleteMeApp.OrganizationID) + deleteMeApp.ID, err = CreateApplicationContext(context.Background(), iq, deleteMeApp.Name, deleteMeApp.PublicID, deleteMeApp.OrganizationID) if err != nil { t.Fatal(err) } - if err := DeleteApplication(iq, deleteMeApp.PublicID); err != nil { + if err := DeleteApplicationContext(context.Background(), iq, deleteMeApp.PublicID); err != nil { t.Fatal(err) } - if _, err := GetApplicationByPublicID(iq, deleteMeApp.PublicID); err == nil { + if _, err := GetApplicationByPublicIDContext(context.Background(), iq, deleteMeApp.PublicID); err == nil { t.Fatal("App was not deleted") } } @@ -254,7 +255,7 @@ func TestGetApplicationsByOrganization(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetApplicationsByOrganization(tt.args.iq, tt.args.organizationName) + got, err := GetApplicationsByOrganizationContext(context.Background(), tt.args.iq, tt.args.organizationName) if (err != nil) != tt.wantErr { t.Errorf("GetApplicationsByOrganization() error = %v, wantErr %v", err, tt.wantErr) return @@ -272,7 +273,7 @@ func ExampleGetAllApplications() { panic(err) } - applications, err := GetAllApplications(iq) + applications, err := GetAllApplicationsContext(context.Background(), iq) if err != nil { panic(err) } @@ -286,7 +287,7 @@ func ExampleCreateApplication() { panic(err) } - appID, err := CreateApplication(iq, "name", "id", "organization") + appID, err := CreateApplicationContext(context.Background(), iq, "name", "id", "organization") if err != nil { panic(err) } diff --git a/iq/componentDetails.go b/iq/componentDetails.go index 0885ed1..907947b 100644 --- a/iq/componentDetails.go +++ b/iq/componentDetails.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -42,17 +43,20 @@ type ComponentDetail struct { } `json:"securityData"` } -// GetComponent returns information on a named component -func GetComponent(iq IQ, component Component) (ComponentDetail, error) { - deets, err := GetComponents(iq, []Component{component}) +func GetComponentContext(ctx context.Context, iq IQ, component Component) (ComponentDetail, error) { + deets, err := GetComponentsContext(ctx, iq, []Component{component}) if deets == nil || len(deets) == 0 { return ComponentDetail{}, err } return deets[0], err } -// GetComponents returns information on the named components -func GetComponents(iq IQ, components []Component) ([]ComponentDetail, error) { +// GetComponent returns information on a named component +func GetComponent(iq IQ, component Component) (ComponentDetail, error) { + return GetComponentContext(context.Background(), iq, component) +} + +func GetComponentsContext(ctx context.Context, iq IQ, components []Component) ([]ComponentDetail, error) { reqComponents := detailsRequest{Components: make([]componentRequested, len(components))} for i, c := range components { reqComponents.Components[i] = componentRequestedFromComponent(c) @@ -63,7 +67,7 @@ func GetComponents(iq IQ, components []Component) ([]ComponentDetail, error) { return nil, fmt.Errorf("could not generate request: %v", err) } - body, _, err := iq.Post(restComponentDetails, bytes.NewBuffer(req)) + body, _, err := iq.Post(ctx, restComponentDetails, bytes.NewBuffer(req)) if err != nil { return nil, fmt.Errorf("could not find component details: %v", err) } @@ -76,13 +80,17 @@ func GetComponents(iq IQ, components []Component) ([]ComponentDetail, error) { return resp.ComponentDetails, nil } -// GetComponentsByApplication returns an array with all components along with their -func GetComponentsByApplication(iq IQ, appPublicID string) ([]ComponentDetail, error) { +// GetComponents returns information on the named components +func GetComponents(iq IQ, components []Component) ([]ComponentDetail, error) { + return GetComponentsContext(context.Background(), iq, components) +} + +func GetComponentsByApplicationContext(ctx context.Context, iq IQ, appPublicID string) ([]ComponentDetail, error) { componentHashes := make(map[string]struct{}) components := make([]Component, 0) stages := []Stage{StageBuild, StageStageRelease, StageRelease, StageOperate} for _, stage := range stages { - if report, err := GetRawReportByAppID(iq, appPublicID, string(stage)); err == nil { + if report, err := GetRawReportByAppIDContext(ctx, iq, appPublicID, string(stage)); err == nil { for _, c := range report.Components { if _, ok := componentHashes[c.Hash]; !ok { componentHashes[c.Hash] = struct{}{} @@ -92,12 +100,16 @@ func GetComponentsByApplication(iq IQ, appPublicID string) ([]ComponentDetail, e } } - return GetComponents(iq, components) + return GetComponentsContext(ctx, iq, components) } -// GetAllComponents returns an array with all components along with their -func GetAllComponents(iq IQ) ([]ComponentDetail, error) { - apps, err := GetAllApplications(iq) +// GetComponentsByApplication returns an array with all components along with their +func GetComponentsByApplication(iq IQ, appPublicID string) ([]ComponentDetail, error) { + return GetComponentsByApplicationContext(context.Background(), iq, appPublicID) +} + +func GetAllComponentsContext(ctx context.Context, iq IQ) ([]ComponentDetail, error) { + apps, err := GetAllApplicationsContext(ctx, iq) if err != nil { return nil, err } @@ -106,7 +118,7 @@ func GetAllComponents(iq IQ) ([]ComponentDetail, error) { components := make([]ComponentDetail, 0) for _, app := range apps { - appComponents, err := GetComponentsByApplication(iq, app.PublicID) + appComponents, err := GetComponentsByApplicationContext(ctx, iq, app.PublicID) // TODO: catcher if err != nil { return nil, err @@ -122,3 +134,8 @@ func GetAllComponents(iq IQ) ([]ComponentDetail, error) { return components, nil } + +// GetAllComponents returns an array with all components +func GetAllComponents(iq IQ) ([]ComponentDetail, error) { + return GetAllComponentsContext(context.Background(), iq) +} diff --git a/iq/componentDetails_test.go b/iq/componentDetails_test.go index e316fc6..8701509 100644 --- a/iq/componentDetails_test.go +++ b/iq/componentDetails_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -88,13 +89,13 @@ func TestGetComponent(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetComponent(tt.args.iq, tt.args.component) + got, err := GetComponentContext(context.Background(), tt.args.iq, tt.args.component) if (err != nil) != tt.wantErr { - t.Errorf("GetComponent() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("GetComponentContext() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetComponent() = %v, want %v", got, tt.want) + t.Errorf("GetComponentContext() = %v, want %v", got, tt.want) } }) } @@ -106,7 +107,7 @@ func TestGetComponents(t *testing.T) { expected := dummyComponentDetails[0] - details, err := GetComponents(iq, []Component{expected.Component}) + details, err := GetComponentsContext(context.Background(), iq, []Component{expected.Component}) if err != nil { t.Error(err) } @@ -145,13 +146,13 @@ func TestGetComponentsByApplication(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetComponentsByApplication(tt.args.iq, tt.args.appPublicID) + got, err := GetComponentsByApplicationContext(context.Background(), tt.args.iq, tt.args.appPublicID) if (err != nil) != tt.wantErr { - t.Errorf("GetComponentsByApplication() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("GetComponentsByApplicationContext() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetComponentsByApplication() = %v, want %v", got, tt.want) + t.Errorf("GetComponentsByApplicationContext() = %v, want %v", got, tt.want) } }) } @@ -180,13 +181,13 @@ func TestGetAllComponents(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetAllComponents(tt.args.iq) + got, err := GetAllComponentsContext(context.Background(), tt.args.iq) if (err != nil) != tt.wantErr { - t.Errorf("GetAllComponents() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("GetAllComponentsContext() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetAllComponents() = %v, want %v", got, tt.want) + t.Errorf("GetAllComponentsContext() = %v, want %v", got, tt.want) } }) } diff --git a/iq/componentLabels.go b/iq/componentLabels.go index bc9315e..8c3d7d3 100644 --- a/iq/componentLabels.go +++ b/iq/componentLabels.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -26,15 +27,14 @@ type IqComponentLabel struct { Color string `json:"color"` } -// ComponentLabelApply adds an existing label to a component for a given application -func ComponentLabelApply(iq IQ, comp Component, appID, label string) error { - app, err := GetApplicationByPublicID(iq, appID) +func ComponentLabelApplyContext(ctx context.Context, iq IQ, comp Component, appID, label string) error { + app, err := GetApplicationByPublicIDContext(ctx, iq, appID) if err != nil { return fmt.Errorf("could not retrieve application with ID %s: %v", appID, err) } endpoint := fmt.Sprintf(restLabelComponent, comp.Hash, url.PathEscape(label), app.ID) - _, resp, err := iq.Post(endpoint, nil) + _, resp, err := iq.Post(ctx, endpoint, nil) if err != nil { if resp == nil || resp.StatusCode != http.StatusNoContent { return fmt.Errorf("could not apply label: %v", err) @@ -44,15 +44,19 @@ func ComponentLabelApply(iq IQ, comp Component, appID, label string) error { return nil } -// ComponentLabelUnapply removes an existing association between a label and a component -func ComponentLabelUnapply(iq IQ, comp Component, appID, label string) error { - app, err := GetApplicationByPublicID(iq, appID) +// ComponentLabelApply adds an existing label to a component for a given application +func ComponentLabelApply(iq IQ, comp Component, appID, label string) error { + return ComponentLabelApplyContext(context.Background(), iq, comp, appID, label) +} + +func ComponentLabelUnapplyContext(ctx context.Context, iq IQ, comp Component, appID, label string) error { + app, err := GetApplicationByPublicIDContext(ctx, iq, appID) if err != nil { return fmt.Errorf("could not retrieve application with ID %s: %v", appID, err) } endpoint := fmt.Sprintf(restLabelComponent, comp.Hash, url.PathEscape(label), app.ID) - resp, err := iq.Del(endpoint) + resp, err := iq.Del(ctx, endpoint) if err != nil { if resp == nil || resp.StatusCode != http.StatusNoContent { return fmt.Errorf("could not unapply label: %v", err) @@ -62,8 +66,13 @@ func ComponentLabelUnapply(iq IQ, comp Component, appID, label string) error { return nil } -func getComponentLabels(iq IQ, endpoint string) ([]IqComponentLabel, error) { - body, _, err := iq.Get(endpoint) +// ComponentLabelUnapply removes an existing association between a label and a component +func ComponentLabelUnapply(iq IQ, comp Component, appID, label string) error { + return ComponentLabelUnapplyContext(context.Background(), iq, comp, appID, label) +} + +func getComponentLabels(ctx context.Context, iq IQ, endpoint string) ([]IqComponentLabel, error) { + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, err } @@ -77,26 +86,34 @@ func getComponentLabels(iq IQ, endpoint string) ([]IqComponentLabel, error) { return labels, nil } +func GetComponentLabelsByOrganizationContext(ctx context.Context, iq IQ, organization string) ([]IqComponentLabel, error) { + endpoint := fmt.Sprintf(restLabelComponentByOrg, organization) + return getComponentLabels(ctx, iq, endpoint) +} + // GetComponentLabelsByOrganization retrieves an array of an organization's component label func GetComponentLabelsByOrganization(iq IQ, organization string) ([]IqComponentLabel, error) { - endpoint := fmt.Sprintf(restLabelComponentByOrg, organization) - return getComponentLabels(iq, endpoint) + return GetComponentLabelsByOrganizationContext(context.Background(), iq, organization) +} + +func GetComponentLabelsByAppIDContext(ctx context.Context, iq IQ, appID string) ([]IqComponentLabel, error) { + endpoint := fmt.Sprintf(restLabelComponentByApp, appID) + return getComponentLabels(ctx, iq, endpoint) } // GetComponentLabelsByAppID retrieves an array of an organization's component label func GetComponentLabelsByAppID(iq IQ, appID string) ([]IqComponentLabel, error) { - endpoint := fmt.Sprintf(restLabelComponentByApp, appID) - return getComponentLabels(iq, endpoint) + return GetComponentLabelsByAppIDContext(context.Background(), iq, appID) } -func createLabel(iq IQ, endpoint, label, description, color string) (IqComponentLabel, error) { +func createLabel(ctx context.Context, iq IQ, endpoint, label, description, color string) (IqComponentLabel, error) { var labelResponse IqComponentLabel request, err := json.Marshal(IqComponentLabel{Label: label, Description: description, Color: color}) if err != nil { return labelResponse, fmt.Errorf("could not marshal label: %v", err) } - body, resp, err := iq.Post(endpoint, bytes.NewBuffer(request)) + body, resp, err := iq.Post(ctx, endpoint, bytes.NewBuffer(request)) if resp.StatusCode != http.StatusOK { return labelResponse, fmt.Errorf("did not succeeed in creating label: %v", err) } @@ -109,22 +126,29 @@ func createLabel(iq IQ, endpoint, label, description, color string) (IqComponent return labelResponse, nil } +func CreateComponentLabelForOrganizationContext(ctx context.Context, iq IQ, organization, label, description, color string) (IqComponentLabel, error) { + endpoint := fmt.Sprintf(restLabelComponentByOrg, organization) + return createLabel(ctx, iq, endpoint, label, description, color) +} + // CreateComponentLabelForOrganization creates a label for an organization func CreateComponentLabelForOrganization(iq IQ, organization, label, description, color string) (IqComponentLabel, error) { - endpoint := fmt.Sprintf(restLabelComponentByOrg, organization) - return createLabel(iq, endpoint, label, description, color) + return CreateComponentLabelForOrganizationContext(context.Background(), iq, organization, label, description, color) +} + +func CreateComponentLabelForApplicationContext(ctx context.Context, iq IQ, appID, label, description, color string) (IqComponentLabel, error) { + endpoint := fmt.Sprintf(restLabelComponentByApp, appID) + return createLabel(ctx, iq, endpoint, label, description, color) } // CreateComponentLabelForApplication creates a label for an application func CreateComponentLabelForApplication(iq IQ, appID, label, description, color string) (IqComponentLabel, error) { - endpoint := fmt.Sprintf(restLabelComponentByApp, appID) - return createLabel(iq, endpoint, label, description, color) + return CreateComponentLabelForApplicationContext(context.Background(), iq, appID, label, description, color) } -// DeleteComponentLabelForOrganization deletes a label from an organization -func DeleteComponentLabelForOrganization(iq IQ, organization, label string) error { +func DeleteComponentLabelForOrganizationContext(ctx context.Context, iq IQ, organization, label string) error { endpoint := fmt.Sprintf(restLabelComponentByOrgDel, organization, label) - resp, err := iq.Del(endpoint) + resp, err := iq.Del(ctx, endpoint) if resp.StatusCode != http.StatusOK { return fmt.Errorf("did not succeeed in deleting label: %v", err) } @@ -133,10 +157,14 @@ func DeleteComponentLabelForOrganization(iq IQ, organization, label string) erro return nil } -// DeleteComponentLabelForApplication deletes a label from an application -func DeleteComponentLabelForApplication(iq IQ, appID, label string) error { +// DeleteComponentLabelForOrganization deletes a label from an organization +func DeleteComponentLabelForOrganization(iq IQ, organization, label string) error { + return DeleteComponentLabelForOrganizationContext(context.Background(), iq, organization, label) +} + +func DeleteComponentLabelForApplicationContext(ctx context.Context, iq IQ, appID, label string) error { endpoint := fmt.Sprintf(restLabelComponentByAppDel, appID, label) - resp, err := iq.Del(endpoint) + resp, err := iq.Del(ctx, endpoint) if resp.StatusCode != http.StatusOK { return fmt.Errorf("did not succeeed in deleting label: %v", err) } @@ -144,3 +172,8 @@ func DeleteComponentLabelForApplication(iq IQ, appID, label string) error { return nil } + +// DeleteComponentLabelForApplication deletes a label from an application +func DeleteComponentLabelForApplication(iq IQ, appID, label string) error { + return DeleteComponentLabelForApplicationContext(context.Background(), iq, appID, label) +} diff --git a/iq/componentLabels_test.go b/iq/componentLabels_test.go index 14e7e2d..416de29 100644 --- a/iq/componentLabels_test.go +++ b/iq/componentLabels_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -55,7 +56,7 @@ func TestComponentLabelApply(t *testing.T) { label, component, appID := dummyLabels[0], dummyComponent, dummyApps[0].PublicID - if err := ComponentLabelApply(iq, component, appID, label); err != nil { + if err := ComponentLabelApplyContext(context.Background(), iq, component, appID, label); err != nil { t.Error(err) } } @@ -66,11 +67,11 @@ func TestComponentLabelUnapply(t *testing.T) { label, component, appID := dummyLabels[0], dummyComponent, dummyApps[0].PublicID - if err := ComponentLabelApply(iq, component, appID, label); err != nil { + if err := ComponentLabelApplyContext(context.Background(), iq, component, appID, label); err != nil { t.Fatal(err) } - if err := ComponentLabelUnapply(iq, component, appID, label); err != nil { + if err := ComponentLabelUnapplyContext(context.Background(), iq, component, appID, label); err != nil { t.Error(err) } } diff --git a/iq/componentVersions.go b/iq/componentVersions.go index 0021d01..08bc44e 100644 --- a/iq/componentVersions.go +++ b/iq/componentVersions.go @@ -2,20 +2,20 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) const restComponentVersions = "api/v2/components/versions" -// ComponentVersions returns all known versions of a given component -func ComponentVersions(iq IQ, comp Component) (versions []string, err error) { +func ComponentVersionsContext(ctx context.Context, iq IQ, comp Component) (versions []string, err error) { str, err := json.Marshal(comp) if err != nil { return nil, fmt.Errorf("could not process component: %v", err) } - body, _, err := iq.Post(restComponentVersions, bytes.NewBuffer(str)) + body, _, err := iq.Post(ctx, restComponentVersions, bytes.NewBuffer(str)) if err != nil { return nil, fmt.Errorf("could not request component: %v", err) } @@ -26,3 +26,8 @@ func ComponentVersions(iq IQ, comp Component) (versions []string, err error) { return } + +// ComponentVersions returns all known versions of a given component +func ComponentVersions(iq IQ, comp Component) (versions []string, err error) { + return ComponentVersionsContext(context.Background(), iq, comp) +} diff --git a/iq/componentVersions_test.go b/iq/componentVersions_test.go index b543fb6..1f86b71 100644 --- a/iq/componentVersions_test.go +++ b/iq/componentVersions_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -62,7 +63,7 @@ func TestComponentVersions(t *testing.T) { iq, mock := componentVersionsTestIQ(t) defer mock.Close() - versions, err := ComponentVersions(iq, dummyComponent) + versions, err := ComponentVersionsContext(context.Background(), iq, dummyComponent) if err != nil { t.Error(err) } diff --git a/iq/componentsRemediation.go b/iq/componentsRemediation.go index b43e4e7..ace12b7 100644 --- a/iq/componentsRemediation.go +++ b/iq/componentsRemediation.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "sync" @@ -60,13 +61,13 @@ func createRemediationEndpoint(base, id, stage string) string { return buf.String() } -func getRemediation(iq IQ, component Component, endpoint string) (Remediation, error) { +func getRemediation(ctx context.Context, iq IQ, component Component, endpoint string) (Remediation, error) { request, err := json.Marshal(component) if err != nil { return Remediation{}, fmt.Errorf("could not build the request: %v", err) } - body, _, err := iq.Post(endpoint, bytes.NewBuffer(request)) + body, _, err := iq.Post(ctx, endpoint, bytes.NewBuffer(request)) if err != nil { return Remediation{}, fmt.Errorf("could not get remediation: %v", err) } @@ -80,40 +81,47 @@ func getRemediation(iq IQ, component Component, endpoint string) (Remediation, e return results.Remediation, nil } -func getRemediationByAppInternalID(iq IQ, component Component, stage, appInternalID string) (Remediation, error) { - return getRemediation(iq, component, createRemediationEndpoint(restRemediationByApp, appInternalID, stage)) +func getRemediationByAppInternalID(ctx context.Context, iq IQ, component Component, stage, appInternalID string) (Remediation, error) { + return getRemediation(ctx, iq, component, createRemediationEndpoint(restRemediationByApp, appInternalID, stage)) } -// GetRemediationByApp retrieves the remediation information on a component based on an application's policies -func GetRemediationByApp(iq IQ, component Component, stage, applicationID string) (Remediation, error) { - app, err := GetApplicationByPublicID(iq, applicationID) +func GetRemediationByAppContext(ctx context.Context, iq IQ, component Component, stage, applicationID string) (Remediation, error) { + app, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return Remediation{}, fmt.Errorf("could not get application: %v", err) } - return getRemediationByAppInternalID(iq, component, stage, app.ID) + return getRemediationByAppInternalID(ctx, iq, component, stage, app.ID) } -// GetRemediationByOrg retrieves the remediation information on a component based on an organization's policies -func GetRemediationByOrg(iq IQ, component Component, stage, organizationName string) (Remediation, error) { - org, err := GetOrganizationByName(iq, organizationName) +// GetRemediationByApp retrieves the remediation information on a component based on an application's policies +func GetRemediationByApp(iq IQ, component Component, stage, applicationID string) (Remediation, error) { + return GetRemediationByAppContext(context.Background(), iq, component, stage, applicationID) +} + +func GetRemediationByOrgContext(ctx context.Context, iq IQ, component Component, stage, organizationName string) (Remediation, error) { + org, err := GetOrganizationByNameContext(ctx, iq, organizationName) if err != nil { return Remediation{}, fmt.Errorf("could not get organization: %v", err) } endpoint := createRemediationEndpoint(restRemediationByOrg, org.ID, stage) - return getRemediation(iq, component, endpoint) + return getRemediation(ctx, iq, component, endpoint) } -// GetRemediationsByAppReport retrieves the remediation information on each component of a report -func GetRemediationsByAppReport(iq IQ, applicationID, reportID string) (remediations []Remediation, err error) { - report, err := getRawReportByAppReportID(iq, applicationID, reportID) +// GetRemediationByOrg retrieves the remediation information on a component based on an organization's policies +func GetRemediationByOrg(iq IQ, component Component, stage, organizationName string) (Remediation, error) { + return GetRemediationByOrgContext(context.Background(), iq, component, stage, organizationName) +} + +func GetRemediationsByAppReportContext(ctx context.Context, iq IQ, applicationID, reportID string) (remediations []Remediation, err error) { + report, err := getRawReportByAppReportID(ctx, iq, applicationID, reportID) if err != nil { return nil, fmt.Errorf("could not get report %s for app %s: %v", reportID, applicationID, err) } - app, err := GetApplicationByPublicID(iq, applicationID) + app, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return nil, fmt.Errorf("could not get application: %v", err) } @@ -131,7 +139,7 @@ func GetRemediationsByAppReport(iq IQ, applicationID, reportID string) (remediat PackageURL: c.PackageURL, } var remediation Remediation - remediation, err = getRemediationByAppInternalID(iq, purl, report.ReportInfo.Stage, app.ID) + remediation, err = getRemediationByAppInternalID(ctx, iq, purl, report.ReportInfo.Stage, app.ID) if err != nil { err = fmt.Errorf("did not find remediation for '%v': %v", c, err) break @@ -157,3 +165,8 @@ func GetRemediationsByAppReport(iq IQ, applicationID, reportID string) (remediat return } + +// GetRemediationsByAppReport retrieves the remediation information on each component of a report +func GetRemediationsByAppReport(iq IQ, applicationID, reportID string) (remediations []Remediation, err error) { + return GetRemediationsByAppReportContext(context.Background(), iq, applicationID, reportID) +} diff --git a/iq/componentsRemediation_test.go b/iq/componentsRemediation_test.go index 6c7061e..83f36cb 100644 --- a/iq/componentsRemediation_test.go +++ b/iq/componentsRemediation_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -103,7 +104,7 @@ func TestRemediationByApp(t *testing.T) { id, stage := dummyApps[0].PublicID, "build" - remediation, err := GetRemediationByApp(iq, dummyComponent, stage, id) + remediation, err := GetRemediationByAppContext(context.Background(), iq, dummyComponent, stage, id) if err != nil { t.Error(err) } @@ -120,7 +121,7 @@ func TestRemediationByOrg(t *testing.T) { id, stage := dummyOrgs[0].Name, "build" - remediation, err := GetRemediationByOrg(iq, dummyComponent, stage, id) + remediation, err := GetRemediationByOrgContext(context.Background(), iq, dummyComponent, stage, id) if err != nil { t.Error(err) } @@ -138,7 +139,7 @@ func TestRemediationByAppReport(t *testing.T) { appIdx, reportID := 0, "0" - got, err := GetRemediationsByAppReport(iq, dummyApps[appIdx].PublicID, reportID) + got, err := GetRemediationsByAppReportContext(context.Background(), iq, dummyApps[appIdx].PublicID, reportID) if err != nil { t.Error(err) } diff --git a/iq/dataRetentionPolicies.go b/iq/dataRetentionPolicies.go index a7f38c6..fe6fa1c 100644 --- a/iq/dataRetentionPolicies.go +++ b/iq/dataRetentionPolicies.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -26,16 +27,15 @@ type DataRetentionPolicy struct { MaxAge string `json:"maxAge"` } -// GetRetentionPolicies returns the current retention policies -func GetRetentionPolicies(iq IQ, orgName string) (policies DataRetentionPolicies, err error) { - org, err := GetOrganizationByName(iq, orgName) +func GetRetentionPoliciesContext(ctx context.Context, iq IQ, orgName string) (policies DataRetentionPolicies, err error) { + org, err := GetOrganizationByNameContext(ctx, iq, orgName) if err != nil { return policies, fmt.Errorf("could not retrieve organization named %s: %v", orgName, err) } endpoint := fmt.Sprintf(restDataRetentionPolicies, org.ID) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return policies, fmt.Errorf("did not retrieve retention policies for organization %s: %v", orgName, err) } @@ -45,9 +45,13 @@ func GetRetentionPolicies(iq IQ, orgName string) (policies DataRetentionPolicies return } -// SetRetentionPolicies updates the retention policies -func SetRetentionPolicies(iq IQ, orgName string, policies DataRetentionPolicies) error { - org, err := GetOrganizationByName(iq, orgName) +// GetRetentionPolicies returns the current retention policies +func GetRetentionPolicies(iq IQ, orgName string) (policies DataRetentionPolicies, err error) { + return GetRetentionPoliciesContext(context.Background(), iq, orgName) +} + +func SetRetentionPoliciesContext(ctx context.Context, iq IQ, orgName string, policies DataRetentionPolicies) error { + org, err := GetOrganizationByNameContext(ctx, iq, orgName) if err != nil { return fmt.Errorf("could not retrieve organization named %s: %v", orgName, err) } @@ -59,10 +63,15 @@ func SetRetentionPolicies(iq IQ, orgName string, policies DataRetentionPolicies) endpoint := fmt.Sprintf(restDataRetentionPolicies, org.ID) - _, _, err = iq.Put(endpoint, bytes.NewBuffer(request)) + _, _, err = iq.Put(ctx, endpoint, bytes.NewBuffer(request)) if err != nil { return fmt.Errorf("did not set retention policies for organization %s: %v", orgName, err) } return nil } + +// SetRetentionPolicies updates the retention policies +func SetRetentionPolicies(iq IQ, orgName string, policies DataRetentionPolicies) error { + return SetRetentionPoliciesContext(context.Background(), iq, orgName, policies) +} diff --git a/iq/dataRetentionPolicies_test.go b/iq/dataRetentionPolicies_test.go index 60b15c0..b993119 100644 --- a/iq/dataRetentionPolicies_test.go +++ b/iq/dataRetentionPolicies_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -98,7 +99,7 @@ func TestGetRetentionPolicies(t *testing.T) { iq, mock := dataRetentionPoliciesTestIQ(t) defer mock.Close() - policies, err := GetRetentionPolicies(iq, dummyOrgs[0].Name) + policies, err := GetRetentionPoliciesContext(context.Background(), iq, dummyOrgs[0].Name) if err != nil { t.Error(err) } @@ -135,12 +136,12 @@ func TestSetRetentionPolicies(t *testing.T) { SuccessMetrics: expected.SuccessMetrics, } - err := SetRetentionPolicies(iq, dummyOrgs[0].Name, retentionRequest) + err := SetRetentionPoliciesContext(context.Background(), iq, dummyOrgs[0].Name, retentionRequest) if err != nil { t.Error(err) } - got, err := GetRetentionPolicies(iq, dummyOrgs[0].Name) + got, err := GetRetentionPoliciesContext(context.Background(), iq, dummyOrgs[0].Name) if err != nil { t.Error(err) } diff --git a/iq/evaluation.go b/iq/evaluation.go index 93a543c..f2df081 100644 --- a/iq/evaluation.go +++ b/iq/evaluation.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -230,15 +231,14 @@ type iqEvaluationRequest struct { Components []Component `json:"components"` } -// EvaluateComponents evaluates the list of components -func EvaluateComponents(iq IQ, components []Component, applicationID string) (*Evaluation, error) { +func EvaluateComponentsContext(ctx context.Context, iq IQ, components []Component, applicationID string) (*Evaluation, error) { request, err := json.Marshal(iqEvaluationRequest{Components: components}) if err != nil { return nil, fmt.Errorf("could not build the request: %v", err) } requestEndpoint := fmt.Sprintf(restEvaluation, applicationID) - body, _, err := iq.Post(requestEndpoint, bytes.NewBuffer(request)) + body, _, err := iq.Post(ctx, requestEndpoint, bytes.NewBuffer(request)) if err != nil { return nil, fmt.Errorf("components not evaluated: %v", err) } @@ -249,7 +249,7 @@ func EvaluateComponents(iq IQ, components []Component, applicationID string) (*E } getEvaluationResults := func() (*Evaluation, error) { - body, resp, e := iq.Get(results.ResultsURL) + body, resp, e := iq.Get(ctx, results.ResultsURL) if e != nil { if resp.StatusCode != http.StatusNotFound { return nil, fmt.Errorf("could not retrieve evaluation results: %v", err) @@ -280,3 +280,8 @@ func EvaluateComponents(iq IQ, components []Component, applicationID string) (*E } } } + +// EvaluateComponents evaluates the list of components +func EvaluateComponents(iq IQ, components []Component, applicationID string) (*Evaluation, error) { + return EvaluateComponentsContext(context.Background(), iq, components, applicationID) +} diff --git a/iq/evaluation_test.go b/iq/evaluation_test.go index 1bd65d2..d6ebbcb 100644 --- a/iq/evaluation_test.go +++ b/iq/evaluation_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -98,7 +99,7 @@ func TestEvaluateComponents(t *testing.T) { appID := "dummyAppId" - report, err := EvaluateComponents(iq, []Component{dummyComponent}, appID) + report, err := EvaluateComponentsContext(context.Background(), iq, []Component{dummyComponent}, appID) if err != nil { t.Error(err) } diff --git a/iq/organization.go b/iq/organization.go index 16de0f5..96c4abb 100644 --- a/iq/organization.go +++ b/iq/organization.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -35,9 +36,8 @@ type Organization struct { Tags []IQCategory `json:"tags,omitempty"` } -// GetOrganizationByName returns details on the named IQ organization -func GetOrganizationByName(iq IQ, organizationName string) (*Organization, error) { - orgs, err := GetAllOrganizations(iq) +func GetOrganizationByNameContext(ctx context.Context, iq IQ, organizationName string) (*Organization, error) { + orgs, err := GetAllOrganizationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("organization '%s' not found: %v", organizationName, err) } @@ -50,8 +50,12 @@ func GetOrganizationByName(iq IQ, organizationName string) (*Organization, error return nil, fmt.Errorf("organization '%s' not found", organizationName) } -// CreateOrganization creates an organization in IQ with the given name -func CreateOrganization(iq IQ, name string) (string, error) { +// GetOrganizationByName returns details on the named IQ organization +func GetOrganizationByName(iq IQ, organizationName string) (*Organization, error) { + return GetOrganizationByNameContext(context.Background(), iq, organizationName) +} + +func CreateOrganizationContext(ctx context.Context, iq IQ, name string) (string, error) { doError := func(err error) error { return fmt.Errorf("organization '%s' not created: %v", name, err) } @@ -61,7 +65,7 @@ func CreateOrganization(iq IQ, name string) (string, error) { return "", doError(err) } - body, _, err := iq.Post(restOrganization, bytes.NewBuffer(request)) + body, _, err := iq.Post(ctx, restOrganization, bytes.NewBuffer(request)) if err != nil { return "", doError(err) } @@ -74,13 +78,17 @@ func CreateOrganization(iq IQ, name string) (string, error) { return org.ID, nil } -// GetAllOrganizations returns a slice of all of the organizations in an IQ instance -func GetAllOrganizations(iq IQ) ([]Organization, error) { +// CreateOrganization creates an organization in IQ with the given name +func CreateOrganization(iq IQ, name string) (string, error) { + return CreateOrganizationContext(context.Background(), iq, name) +} + +func GetAllOrganizationsContext(ctx context.Context, iq IQ) ([]Organization, error) { doError := func(err error) error { return fmt.Errorf("organizations not found: %v", err) } - body, _, err := iq.Get(restOrganization) + body, _, err := iq.Get(ctx, restOrganization) if err != nil { return nil, doError(err) } @@ -92,3 +100,8 @@ func GetAllOrganizations(iq IQ) ([]Organization, error) { return resp.Organizations, nil } + +// GetAllOrganizations returns a slice of all of the organizations in an IQ instance +func GetAllOrganizations(iq IQ) ([]Organization, error) { + return GetAllOrganizationsContext(context.Background(), iq) +} diff --git a/iq/organization_test.go b/iq/organization_test.go index 1c5169d..2ce3ab8 100644 --- a/iq/organization_test.go +++ b/iq/organization_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -68,7 +69,7 @@ func TestGetOranizationByName(t *testing.T) { dummyOrgsIdx := 2 - org, err := GetOrganizationByName(iq, dummyOrgs[dummyOrgsIdx].Name) + org, err := GetOrganizationByNameContext(context.Background(), iq, dummyOrgs[dummyOrgsIdx].Name) if err != nil { t.Error(err) } @@ -90,12 +91,12 @@ func TestCreateOrganization(t *testing.T) { createdOrg := Organization{Name: "createdOrg"} var err error - createdOrg.ID, err = CreateOrganization(iq, createdOrg.Name) + createdOrg.ID, err = CreateOrganizationContext(context.Background(), iq, createdOrg.Name) if err != nil { t.Fatal(err) } - org, err := GetOrganizationByName(iq, createdOrg.Name) + org, err := GetOrganizationByNameContext(context.Background(), iq, createdOrg.Name) if err != nil { t.Fatal(err) } @@ -111,7 +112,7 @@ func TestGetAllOrganizations(t *testing.T) { iq, mock := organizationTestIQ(t) defer mock.Close() - organizations, err := GetAllOrganizations(iq) + organizations, err := GetAllOrganizationsContext(context.Background(), iq) if err != nil { panic(err) } @@ -125,7 +126,7 @@ func ExampleCreateOrganization() { panic(err) } - orgID, err := CreateOrganization(iq, "DatabaseTeam") + orgID, err := CreateOrganizationContext(context.Background(), iq, "DatabaseTeam") if err != nil { panic(err) } diff --git a/iq/policies.go b/iq/policies.go index 91821a1..364a1b7 100644 --- a/iq/policies.go +++ b/iq/policies.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" ) @@ -21,9 +22,8 @@ type policiesList struct { Policies []PolicyInfo `json:"policies"` } -// GetPolicies returns a list of all of the policies in IQ -func GetPolicies(iq IQ) ([]PolicyInfo, error) { - body, _, err := iq.Get(restPolicies) +func GetPoliciesContext(ctx context.Context, iq IQ) ([]PolicyInfo, error) { + body, _, err := iq.Get(ctx, restPolicies) if err != nil { return nil, fmt.Errorf("could not get list of policies: %v", err) } @@ -36,9 +36,13 @@ func GetPolicies(iq IQ) ([]PolicyInfo, error) { return resp.Policies, nil } -// GetPolicyInfoByName returns an information object for the named policy -func GetPolicyInfoByName(iq IQ, policyName string) (PolicyInfo, error) { - policies, err := GetPolicies(iq) +// GetPolicies returns a list of all of the policies in IQ +func GetPolicies(iq IQ) ([]PolicyInfo, error) { + return GetPoliciesContext(context.Background(), iq) +} + +func GetPolicyInfoByNameContext(ctx context.Context, iq IQ, policyName string) (PolicyInfo, error) { + policies, err := GetPoliciesContext(ctx, iq) if err != nil { return PolicyInfo{}, fmt.Errorf("did not find policy with name %s: %v", policyName, err) } @@ -51,3 +55,8 @@ func GetPolicyInfoByName(iq IQ, policyName string) (PolicyInfo, error) { return PolicyInfo{}, fmt.Errorf("did not find policy with name %s", policyName) } + +// GetPolicyInfoByName returns an information object for the named policy +func GetPolicyInfoByName(iq IQ, policyName string) (PolicyInfo, error) { + return GetPolicyInfoByNameContext(context.Background(), iq, policyName) +} diff --git a/iq/policies_test.go b/iq/policies_test.go index 7232eeb..1f7cece 100644 --- a/iq/policies_test.go +++ b/iq/policies_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -50,7 +51,7 @@ func TestGetPolicies(t *testing.T) { iq, mock := policiesTestIQ(t) defer mock.Close() - infos, err := GetPolicies(iq) + infos, err := GetPoliciesContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -72,7 +73,7 @@ func TestGetPolicyInfoByName(t *testing.T) { expected := dummyPolicyInfos[0] - info, err := GetPolicyInfoByName(iq, expected.Name) + info, err := GetPolicyInfoByNameContext(context.Background(), iq, expected.Name) if err != nil { t.Error(err) } diff --git a/iq/policyViolations.go b/iq/policyViolations.go index e8f68e3..86e3ae5 100644 --- a/iq/policyViolations.go +++ b/iq/policyViolations.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -18,9 +19,8 @@ type violationResponse struct { ApplicationViolations []ApplicationViolation `json:"applicationViolations"` } -// GetAllPolicyViolations returns all policy violations -func GetAllPolicyViolations(iq IQ) ([]ApplicationViolation, error) { - policyInfos, err := GetPolicies(iq) +func GetAllPolicyViolationsContext(ctx context.Context, iq IQ) ([]ApplicationViolation, error) { + policyInfos, err := GetPoliciesContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not get policies: %v", err) } @@ -33,7 +33,7 @@ func GetAllPolicyViolations(iq IQ) ([]ApplicationViolation, error) { endpoint.WriteString(i.ID) } - body, _, err := iq.Get(endpoint.String()) + body, _, err := iq.Get(ctx, endpoint.String()) if err != nil { return nil, fmt.Errorf("could not get policy violations: %v", err) } @@ -47,9 +47,13 @@ func GetAllPolicyViolations(iq IQ) ([]ApplicationViolation, error) { return resp.ApplicationViolations, nil } -// GetPolicyViolationsByName returns the policy violations by policy name -func GetPolicyViolationsByName(iq IQ, policyNames ...string) ([]ApplicationViolation, error) { - policies, err := GetPolicies(iq) +// GetAllPolicyViolations returns all policy violations +func GetAllPolicyViolations(iq IQ) ([]ApplicationViolation, error) { + return GetAllPolicyViolationsContext(context.Background(), iq) +} + +func GetPolicyViolationsByNameContext(ctx context.Context, iq IQ, policyNames ...string) ([]ApplicationViolation, error) { + policies, err := GetPoliciesContext(ctx, iq) if err != nil { return nil, fmt.Errorf("did not find policy: %v", err) } @@ -67,7 +71,7 @@ func GetPolicyViolationsByName(iq IQ, policyNames ...string) ([]ApplicationViola } } - body, _, err := iq.Get(endpoint.String()) + body, _, err := iq.Get(ctx, endpoint.String()) if err != nil { return nil, fmt.Errorf("could not get policy violations: %v", err) } @@ -80,3 +84,8 @@ func GetPolicyViolationsByName(iq IQ, policyNames ...string) ([]ApplicationViola return resp.ApplicationViolations, nil } + +// GetPolicyViolationsByName returns the policy violations by policy name +func GetPolicyViolationsByName(iq IQ, policyNames ...string) ([]ApplicationViolation, error) { + return GetPolicyViolationsByNameContext(context.Background(), iq, policyNames...) +} diff --git a/iq/policyViolations_test.go b/iq/policyViolations_test.go index 9fe65a8..4cc53b2 100644 --- a/iq/policyViolations_test.go +++ b/iq/policyViolations_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -80,7 +81,7 @@ func TestGetAllPolicyViolations(t *testing.T) { iq, mock := policyViolationsTestIQ(t) defer mock.Close() - violations, err := GetAllPolicyViolations(iq) + violations, err := GetAllPolicyViolationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -102,7 +103,7 @@ func TestGetPolicyViolationsByName(t *testing.T) { expected := dummyPolicyViolations[0] - violations, err := GetPolicyViolationsByName(iq, expected.PolicyViolations[0].PolicyName) + violations, err := GetPolicyViolationsByNameContext(context.Background(), iq, expected.PolicyViolations[0].PolicyName) if err != nil { t.Error(err) } diff --git a/iq/reportMetrics.go b/iq/reportMetrics.go index b671d29..009b3ec 100644 --- a/iq/reportMetrics.go +++ b/iq/reportMetrics.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -126,7 +127,7 @@ func (b *MetricsRequestBuilder) WithOrganization(v string) *MetricsRequestBuilde return b } -func (b *MetricsRequestBuilder) build(iq IQ) (req metricRequest, err error) { +func (b *MetricsRequestBuilder) build(ctx context.Context, iq IQ) (req metricRequest, err error) { // If timePeriod is MONTH - an ISO 8601 year and month without timezone. // If timePeriod is WEEK - an ISO 8601 week year and week (e.g. week of 29 December 2008 is "2009-W01") formatTime := func(t time.Time) string { @@ -163,7 +164,7 @@ func (b *MetricsRequestBuilder) build(iq IQ) (req metricRequest, err error) { if b.apps != nil { req.ApplicationIDS = make([]string, len(b.apps)) for i, a := range b.apps { - app, er := GetApplicationByPublicID(iq, a) + app, er := GetApplicationByPublicIDContext(ctx, iq, a) if er != nil { return req, fmt.Errorf("could not find application with public id %s: %v", a, er) } @@ -174,7 +175,7 @@ func (b *MetricsRequestBuilder) build(iq IQ) (req metricRequest, err error) { if b.orgs != nil { req.OrganizationIDS = make([]string, len(b.orgs)) for i, o := range b.orgs { - org, er := GetOrganizationByName(iq, o) + org, er := GetOrganizationByNameContext(ctx, iq, o) if er != nil { return req, fmt.Errorf("could not find organization with name %s: %v", o, er) } @@ -190,11 +191,10 @@ func NewMetricsRequestBuilder() *MetricsRequestBuilder { return new(MetricsRequestBuilder) } -// GenerateMetrics creates metrics from the given qualifiers -func GenerateMetrics(iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { +func GenerateMetricsContext(ctx context.Context, iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { // TODO: Accept header: application/json or text/csv - req, err := builder.build(iq) + req, err := builder.build(ctx, iq) if err != nil { return nil, fmt.Errorf("could not build request: %v", err) } @@ -204,7 +204,7 @@ func GenerateMetrics(iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { return nil, fmt.Errorf("could not marshal request: %v", err) } - body, _, err := iq.Post(restMetrics, bytes.NewBuffer(buf)) + body, _, err := iq.Post(ctx, restMetrics, bytes.NewBuffer(buf)) if err != nil { return nil, fmt.Errorf("could not issue request to IQ: %v", err) } @@ -217,3 +217,8 @@ func GenerateMetrics(iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { return metrics, nil } + +// GenerateMetrics creates metrics from the given qualifiers +func GenerateMetrics(iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { + return GenerateMetricsContext(context.Background(), iq, builder) +} diff --git a/iq/reportMetrics_test.go b/iq/reportMetrics_test.go index c932304..61b8a1e 100644 --- a/iq/reportMetrics_test.go +++ b/iq/reportMetrics_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -197,7 +198,7 @@ func TestMetricsRequestBuilder(t *testing.T) { defer mock.Close() for _, test := range tests { - got, err := test.input.build(iq) + got, err := test.input.build(context.Background(), iq) if err != nil { t.Errorf("Unexpected error building metrics request: %v", err) t.Error("input", test.input) @@ -230,7 +231,7 @@ func TestGenerateMetrics(t *testing.T) { } for _, test := range tests { - got, err := GenerateMetrics(iq, test.input) + got, err := GenerateMetricsContext(context.Background(), iq, test.input) if err != nil { t.Error(err) } @@ -251,7 +252,7 @@ func ExampleGenerateMetrics() { reqLastYear := NewMetricsRequestBuilder().Monthly().StartingOn(time.Now().Add(-(24 * time.Hour) * 365)).WithApplication("WebGoat") - metrics, err := GenerateMetrics(iq, reqLastYear) + metrics, err := GenerateMetricsContext(context.Background(), iq, reqLastYear) if err != nil { panic(err) } diff --git a/iq/reports.go b/iq/reports.go index 277db76..f9dac15 100644 --- a/iq/reports.go +++ b/iq/reports.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "log" @@ -124,9 +125,8 @@ type Report struct { Raw ReportRaw `json:"rawReport"` } -// GetAllReportInfos returns all report infos -func GetAllReportInfos(iq IQ) ([]ReportInfo, error) { - body, _, err := iq.Get(restReports) +func GetAllReportInfosContext(ctx context.Context, iq IQ) ([]ReportInfo, error) { + body, _, err := iq.Get(ctx, restReports) if err != nil { return nil, fmt.Errorf("could not get report info: %v", err) } @@ -137,9 +137,13 @@ func GetAllReportInfos(iq IQ) ([]ReportInfo, error) { return infos, err } -// GetAllReports returns all policy and raw reports -func GetAllReports(iq IQ) ([]Report, error) { - infos, err := GetAllReportInfos(iq) +// GetAllReportInfos returns all report infos +func GetAllReportInfos(iq IQ) ([]ReportInfo, error) { + return GetAllReportInfosContext(context.Background(), iq) +} + +func GetAllReportsContext(ctx context.Context, iq IQ) ([]Report, error) { + infos, err := GetAllReportInfosContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not get report infos: %v", err) } @@ -147,8 +151,8 @@ func GetAllReports(iq IQ) ([]Report, error) { reports := make([]Report, 0) for _, info := range infos { - raw, _ := getRawReportByURL(iq, info.ReportDataURL) - policy, _ := getPolicyReportByURL(iq, strings.Replace(info.ReportDataURL, "/raw", "/policy", 1)) + raw, _ := getRawReportByURL(ctx, iq, info.ReportDataURL) + policy, _ := getPolicyReportByURL(ctx, iq, strings.Replace(info.ReportDataURL, "/raw", "/policy", 1)) raw.ReportInfo = info policy.ReportInfo = info @@ -163,15 +167,19 @@ func GetAllReports(iq IQ) ([]Report, error) { return reports, err } -// GetReportInfosByAppID returns report information by application public ID -func GetReportInfosByAppID(iq IQ, appID string) (infos []ReportInfo, err error) { - app, err := GetApplicationByPublicID(iq, appID) +// GetAllReports returns all policy and raw reports +func GetAllReports(iq IQ) ([]Report, error) { + return GetAllReportsContext(context.Background(), iq) +} + +func GetReportInfosByAppIDContext(ctx context.Context, iq IQ, appID string) (infos []ReportInfo, err error) { + app, err := GetApplicationByPublicIDContext(ctx, iq, appID) if err != nil { return nil, fmt.Errorf("could not get info for application: %v", err) } endpoint := fmt.Sprintf("%s/%s", restReports, app.ID) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, fmt.Errorf("could not get report infos: %v", err) } @@ -184,9 +192,13 @@ func GetReportInfosByAppID(iq IQ, appID string) (infos []ReportInfo, err error) return } -// GetReportInfoByAppIDStage returns report information by application public ID and stage -func GetReportInfoByAppIDStage(iq IQ, appID, stage string) (ReportInfo, error) { - if infos, err := GetReportInfosByAppID(iq, appID); err == nil { +// GetReportInfosByAppID returns report information by application public ID +func GetReportInfosByAppID(iq IQ, appID string) (infos []ReportInfo, err error) { + return GetReportInfosByAppIDContext(context.Background(), iq, appID) +} + +func GetReportInfoByAppIDStageContext(ctx context.Context, iq IQ, appID, stage string) (ReportInfo, error) { + if infos, err := GetReportInfosByAppIDContext(ctx, iq, appID); err == nil { for _, info := range infos { if info.Stage == stage { return info, nil @@ -197,8 +209,13 @@ func GetReportInfoByAppIDStage(iq IQ, appID, stage string) (ReportInfo, error) { return ReportInfo{}, fmt.Errorf("did not find report for '%s'", appID) } -func getRawReportByURL(iq IQ, URL string) (ReportRaw, error) { - body, resp, err := iq.Get(URL) +// GetReportInfoByAppIDStage returns report information by application public ID and stage +func GetReportInfoByAppIDStage(iq IQ, appID, stage string) (ReportInfo, error) { + return GetReportInfoByAppIDStageContext(context.Background(), iq, appID, stage) +} + +func getRawReportByURL(ctx context.Context, iq IQ, URL string) (ReportRaw, error) { + body, resp, err := iq.Get(ctx, URL) if err != nil { log.Printf("error: could not retrieve raw report: %v\n", err) dump, _ := httputil.DumpRequest(resp.Request, true) @@ -213,20 +230,19 @@ func getRawReportByURL(iq IQ, URL string) (ReportRaw, error) { return report, nil } -func getRawReportByAppReportID(iq IQ, appID, reportID string) (ReportRaw, error) { - return getRawReportByURL(iq, fmt.Sprintf(restReportsRaw, appID, reportID)) +func getRawReportByAppReportID(ctx context.Context, iq IQ, appID, reportID string) (ReportRaw, error) { + return getRawReportByURL(ctx, iq, fmt.Sprintf(restReportsRaw, appID, reportID)) } -// GetRawReportByAppID returns report information by application public ID -func GetRawReportByAppID(iq IQ, appID, stage string) (ReportRaw, error) { - infos, err := GetReportInfosByAppID(iq, appID) +func GetRawReportByAppIDContext(ctx context.Context, iq IQ, appID, stage string) (ReportRaw, error) { + infos, err := GetReportInfosByAppIDContext(ctx, iq, appID) if err != nil { return ReportRaw{}, fmt.Errorf("could not get report info for app '%s': %v", appID, err) } for _, info := range infos { if info.Stage == stage { - report, err := getRawReportByURL(iq, info.ReportDataURL) + report, err := getRawReportByURL(ctx, iq, info.ReportDataURL) report.ReportInfo = info return report, err } @@ -235,8 +251,13 @@ func GetRawReportByAppID(iq IQ, appID, stage string) (ReportRaw, error) { return ReportRaw{}, fmt.Errorf("could not find raw report for stage %s", stage) } -func getPolicyReportByURL(iq IQ, URL string) (ReportPolicy, error) { - body, _, err := iq.Get(URL) +// GetRawReportByAppID returns report information by application public ID +func GetRawReportByAppID(iq IQ, appID, stage string) (ReportRaw, error) { + return GetRawReportByAppIDContext(context.Background(), iq, appID, stage) +} + +func getPolicyReportByURL(ctx context.Context, iq IQ, URL string) (ReportPolicy, error) { + body, _, err := iq.Get(ctx, URL) if err != nil { return ReportPolicy{}, fmt.Errorf("could not get policy report at URL %s: %v", URL, err) } @@ -248,16 +269,15 @@ func getPolicyReportByURL(iq IQ, URL string) (ReportPolicy, error) { return report, nil } -// GetPolicyReportByAppID returns report information by application public ID -func GetPolicyReportByAppID(iq IQ, appID, stage string) (ReportPolicy, error) { - infos, err := GetReportInfosByAppID(iq, appID) +func GetPolicyReportByAppIDContext(ctx context.Context, iq IQ, appID, stage string) (ReportPolicy, error) { + infos, err := GetReportInfosByAppIDContext(ctx, iq, appID) if err != nil { return ReportPolicy{}, fmt.Errorf("could not get report info for app '%s': %v", appID, err) } for _, info := range infos { if info.Stage == stage { - report, err := getPolicyReportByURL(iq, strings.Replace(infos[0].ReportDataURL, "/raw", "/policy", 1)) + report, err := getPolicyReportByURL(ctx, iq, strings.Replace(infos[0].ReportDataURL, "/raw", "/policy", 1)) report.ReportInfo = info return report, err } @@ -266,14 +286,18 @@ func GetPolicyReportByAppID(iq IQ, appID, stage string) (ReportPolicy, error) { return ReportPolicy{}, fmt.Errorf("could not find policy report for stage %s", stage) } -// GetReportByAppID returns report information by application public ID -func GetReportByAppID(iq IQ, appID, stage string) (report Report, err error) { - report.Policy, err = GetPolicyReportByAppID(iq, appID, stage) +// GetPolicyReportByAppID returns report information by application public ID +func GetPolicyReportByAppID(iq IQ, appID, stage string) (ReportPolicy, error) { + return GetPolicyReportByAppIDContext(context.Background(), iq, appID, stage) +} + +func GetReportByAppIDContext(ctx context.Context, iq IQ, appID, stage string) (report Report, err error) { + report.Policy, err = GetPolicyReportByAppIDContext(ctx, iq, appID, stage) if err != nil { return report, fmt.Errorf("could not retrieve policy report: %v", err) } - report.Raw, err = GetRawReportByAppID(iq, appID, stage) + report.Raw, err = GetRawReportByAppIDContext(ctx, iq, appID, stage) if err != nil { return report, fmt.Errorf("could not retrieve raw report: %v", err) } @@ -281,19 +305,23 @@ func GetReportByAppID(iq IQ, appID, stage string) (report Report, err error) { return report, nil } -// GetReportByAppReportID returns raw and policy report information for a given report ID -func GetReportByAppReportID(iq IQ, appID, reportID string) (report Report, err error) { - report.Policy, err = getPolicyReportByURL(iq, fmt.Sprintf(restReportsPolicy, appID, reportID)) +// GetReportByAppID returns report information by application public ID +func GetReportByAppID(iq IQ, appID, stage string) (report Report, err error) { + return GetReportByAppIDContext(context.Background(), iq, appID, stage) +} + +func GetReportByAppReportIDContext(ctx context.Context, iq IQ, appID, reportID string) (report Report, err error) { + report.Policy, err = getPolicyReportByURL(ctx, iq, fmt.Sprintf(restReportsPolicy, appID, reportID)) if err != nil { return report, fmt.Errorf("could not retrieve policy report: %v", err) } - report.Raw, err = getRawReportByURL(iq, fmt.Sprintf(restReportsRaw, appID, reportID)) + report.Raw, err = getRawReportByURL(ctx, iq, fmt.Sprintf(restReportsRaw, appID, reportID)) if err != nil { return report, fmt.Errorf("could not retrieve raw report: %v", err) } - infos, err := GetReportInfosByAppID(iq, appID) + infos, err := GetReportInfosByAppIDContext(ctx, iq, appID) if err != nil { return report, fmt.Errorf("could not retrieve report infos: %v", err) } @@ -307,16 +335,20 @@ func GetReportByAppReportID(iq IQ, appID, reportID string) (report Report, err e return report, nil } -// GetReportInfosByOrganization returns report information by organization name -func GetReportInfosByOrganization(iq IQ, organizationName string) (infos []ReportInfo, err error) { - apps, err := GetApplicationsByOrganization(iq, organizationName) +// GetReportByAppReportID returns raw and policy report information for a given report ID +func GetReportByAppReportID(iq IQ, appID, reportID string) (report Report, err error) { + return GetReportByAppReportIDContext(context.Background(), iq, appID, reportID) +} + +func GetReportInfosByOrganizationContext(ctx context.Context, iq IQ, organizationName string) (infos []ReportInfo, err error) { + apps, err := GetApplicationsByOrganizationContext(ctx, iq, organizationName) if err != nil { return nil, fmt.Errorf("could not get applications for organization '%s': %v", organizationName, err) } infos = make([]ReportInfo, 0) for _, app := range apps { - if appInfos, err := GetReportInfosByAppID(iq, app.PublicID); err == nil { + if appInfos, err := GetReportInfosByAppIDContext(ctx, iq, app.PublicID); err == nil { infos = append(infos, appInfos...) } } @@ -324,9 +356,13 @@ func GetReportInfosByOrganization(iq IQ, organizationName string) (infos []Repor return infos, nil } -// GetReportsByOrganization returns all reports for an given organization -func GetReportsByOrganization(iq IQ, organizationName string) (reports []Report, err error) { - apps, err := GetApplicationsByOrganization(iq, organizationName) +// GetReportInfosByOrganization returns report information by organization name +func GetReportInfosByOrganization(iq IQ, organizationName string) (infos []ReportInfo, err error) { + return GetReportInfosByOrganizationContext(context.Background(), iq, organizationName) +} + +func GetReportsByOrganizationContext(ctx context.Context, iq IQ, organizationName string) (reports []Report, err error) { + apps, err := GetApplicationsByOrganizationContext(ctx, iq, organizationName) if err != nil { return nil, fmt.Errorf("could not get applications for organization '%s': %v", organizationName, err) } @@ -336,7 +372,7 @@ func GetReportsByOrganization(iq IQ, organizationName string) (reports []Report, reports = make([]Report, 0) for _, app := range apps { for _, s := range stages { - if appReport, err := GetReportByAppID(iq, app.PublicID, string(s)); err == nil { + if appReport, err := GetReportByAppIDContext(ctx, iq, app.PublicID, string(s)); err == nil { reports = append(reports, appReport) } } @@ -345,6 +381,11 @@ func GetReportsByOrganization(iq IQ, organizationName string) (reports []Report, return reports, nil } +// GetReportsByOrganization returns all reports for an given organization +func GetReportsByOrganization(iq IQ, organizationName string) (reports []Report, err error) { + return GetReportsByOrganizationContext(context.Background(), iq, organizationName) +} + // ReportDiff encapsulates the differences between reports type ReportDiff struct { Reports []Report `json:"reports"` @@ -352,16 +393,15 @@ type ReportDiff struct { Fixed []PolicyReportComponent `json:"fixed,omitempty"` } -// ReportsDiff returns a structure describing various differences between two reports -func ReportsDiff(iq IQ, appID, report1ID, report2ID string) (ReportDiff, error) { +func ReportsDiffContext(ctx context.Context, iq IQ, appID, report1ID, report2ID string) (ReportDiff, error) { var ( report1, report2 Report err error ) - report1, err = GetReportByAppReportID(iq, appID, report1ID) + report1, err = GetReportByAppReportIDContext(ctx, iq, appID, report1ID) if err == nil { - report2, err = GetReportByAppReportID(iq, appID, report2ID) + report2, err = GetReportByAppReportIDContext(ctx, iq, appID, report2ID) } if err != nil { return ReportDiff{}, fmt.Errorf("could not retrieve raw reports: %v", err) @@ -414,3 +454,8 @@ func ReportsDiff(iq IQ, appID, report1ID, report2ID string) (ReportDiff, error) return diff(iq, report2, report1) } + +// ReportsDiff returns a structure describing various differences between two reports +func ReportsDiff(iq IQ, appID, report1ID, report2ID string) (ReportDiff, error) { + return ReportsDiffContext(context.Background(), iq, appID, report1ID, report2ID) +} diff --git a/iq/reports_test.go b/iq/reports_test.go index be1841f..00b81c0 100644 --- a/iq/reports_test.go +++ b/iq/reports_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -162,7 +163,7 @@ func TestGetAllReportInfos(t *testing.T) { iq, mock := reportsTestIQ(t) defer mock.Close() - infos, err := GetAllReportInfos(iq) + infos, err := GetAllReportInfosContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -184,7 +185,7 @@ func TestGetReportInfosByAppID(t *testing.T) { testIdx := 0 - infos, err := GetReportInfosByAppID(iq, dummyApps[testIdx].PublicID) + infos, err := GetReportInfosByAppIDContext(context.Background(), iq, dummyApps[testIdx].PublicID) if err != nil { t.Error(err) } @@ -204,7 +205,7 @@ func Test_getRawReportByAppReportID(t *testing.T) { testIdx := 0 - report, err := getRawReportByAppReportID(iq, dummyApps[testIdx].PublicID, fmt.Sprintf("%d", testIdx)) + report, err := getRawReportByAppReportID(context.Background(), iq, dummyApps[testIdx].PublicID, fmt.Sprintf("%d", testIdx)) if err != nil { t.Fatal(err) } @@ -221,7 +222,7 @@ func TestGetRawReportByAppID(t *testing.T) { testIdx := 0 - report, err := GetRawReportByAppID(iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) + report, err := GetRawReportByAppIDContext(context.Background(), iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) if err != nil { t.Fatal(err) } @@ -238,7 +239,7 @@ func TestGetPolicyReportByAppID(t *testing.T) { testIdx := 0 - report, err := GetPolicyReportByAppID(iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) + report, err := GetPolicyReportByAppIDContext(context.Background(), iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) if err != nil { t.Fatal(err) } @@ -255,7 +256,7 @@ func TestGetReportByAppID(t *testing.T) { testIdx := 0 - report, err := GetReportByAppID(iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) + report, err := GetReportByAppIDContext(context.Background(), iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) if err != nil { t.Fatal(err) } @@ -294,7 +295,7 @@ func TestGetReportInfosByOrganization(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotInfos, err := GetReportInfosByOrganization(tt.args.iq, tt.args.organizationName) + gotInfos, err := GetReportInfosByOrganizationContext(context.Background(), tt.args.iq, tt.args.organizationName) if (err != nil) != tt.wantErr { t.Errorf("GetReportInfosByOrganization() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/iq/roleMemberships.go b/iq/roleMemberships.go index c6c76d1..d2c485c 100644 --- a/iq/roleMemberships.go +++ b/iq/roleMemberships.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -53,9 +54,9 @@ type Member struct { UserOrGroupName string `json:"userOrGroupName"` } -func hasRev70API(iq IQ) bool { +func hasRev70API(ctx context.Context, iq IQ) bool { api := fmt.Sprintf(restRoleMembersOrgGet, RootOrganization) - request, _ := iq.NewRequest("HEAD", api, nil) + request, _ := iq.NewRequest(ctx, "HEAD", api, nil) _, resp, _ := iq.Do(request) return resp.StatusCode != http.StatusNotFound } @@ -78,15 +79,15 @@ func newMappings(roleID, memberType, memberName string) memberMappings { } } -func organizationAuthorizationsByID(iq IQ, orgID string) ([]MemberMapping, error) { +func organizationAuthorizationsByID(ctx context.Context, iq IQ, orgID string) ([]MemberMapping, error) { var endpoint string - if hasRev70API(iq) { + if hasRev70API(ctx, iq) { endpoint = fmt.Sprintf(restRoleMembersOrgGet, orgID) } else { endpoint = fmt.Sprintf(restRoleMembersOrgDeprecated, orgID) } - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, fmt.Errorf("could not retrieve role mapping for organization %s: %v", orgID, err) } @@ -97,15 +98,15 @@ func organizationAuthorizationsByID(iq IQ, orgID string) ([]MemberMapping, error return mappings.MemberMappings, err } -func organizationAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { - orgs, err := GetAllOrganizations(iq) +func organizationAuthorizationsByRoleID(ctx context.Context, iq IQ, roleID string) ([]MemberMapping, error) { + orgs, err := GetAllOrganizationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not find organizations: %v", err) } mappings := make([]MemberMapping, 0) for _, org := range orgs { - orgMaps, _ := organizationAuthorizationsByID(iq, org.ID) + orgMaps, _ := organizationAuthorizationsByID(ctx, iq, org.ID) for _, m := range orgMaps { if m.RoleID == roleID { mappings = append(mappings, m) @@ -116,40 +117,48 @@ func organizationAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, return mappings, nil } -// OrganizationAuthorizations returns the member mappings of an organization -func OrganizationAuthorizations(iq IQ, name string) ([]MemberMapping, error) { - org, err := GetOrganizationByName(iq, name) +func OrganizationAuthorizationsContext(ctx context.Context, iq IQ, name string) ([]MemberMapping, error) { + org, err := GetOrganizationByNameContext(ctx, iq, name) if err != nil { return nil, fmt.Errorf("could not find organization with name %s: %v", name, err) } - return organizationAuthorizationsByID(iq, org.ID) + return organizationAuthorizationsByID(ctx, iq, org.ID) } -// OrganizationAuthorizationsByRole returns the member mappings of all organizations which match the given role -func OrganizationAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { - role, err := RoleByName(iq, roleName) +// OrganizationAuthorizations returns the member mappings of an organization +func OrganizationAuthorizations(iq IQ, name string) ([]MemberMapping, error) { + return OrganizationAuthorizationsContext(context.Background(), iq, name) +} + +func OrganizationAuthorizationsByRoleContext(ctx context.Context, iq IQ, roleName string) ([]MemberMapping, error) { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return nil, fmt.Errorf("could not find role with name %s: %v", roleName, err) } - return organizationAuthorizationsByRoleID(iq, role.ID) + return organizationAuthorizationsByRoleID(ctx, iq, role.ID) } -func setOrganizationAuth(iq IQ, name, roleName, member, memberType string) error { - org, err := GetOrganizationByName(iq, name) +// OrganizationAuthorizationsByRole returns the member mappings of all organizations which match the given role +func OrganizationAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { + return OrganizationAuthorizationsByRoleContext(context.Background(), iq, roleName) +} + +func setOrganizationAuth(ctx context.Context, iq IQ, name, roleName, member, memberType string) error { + org, err := GetOrganizationByNameContext(ctx, iq, name) if err != nil { return fmt.Errorf("could not find organization with name %s: %v", name, err) } - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } var endpoint string var payload io.Reader - if hasRev70API(iq) { + if hasRev70API(ctx, iq) { switch memberType { case MemberTypeUser: endpoint = fmt.Sprintf(restRoleMembersOrgUser, org.ID, role.ID, member) @@ -158,7 +167,7 @@ func setOrganizationAuth(iq IQ, name, roleName, member, memberType string) error } } else { endpoint = fmt.Sprintf(restRoleMembersOrgDeprecated, org.ID) - current, err := OrganizationAuthorizations(iq, name) + current, err := OrganizationAuthorizationsContext(ctx, iq, name) if err != nil && current == nil { current = make([]MemberMapping, 0) } @@ -171,7 +180,7 @@ func setOrganizationAuth(iq IQ, name, roleName, member, memberType string) error payload = bytes.NewBuffer(buf) } - _, _, err = iq.Put(endpoint, payload) + _, _, err = iq.Put(ctx, endpoint, payload) if err != nil { return fmt.Errorf("could not update organization role mapping: %v", err) } @@ -179,25 +188,33 @@ func setOrganizationAuth(iq IQ, name, roleName, member, memberType string) error return nil } +func SetOrganizationUserContext(ctx context.Context, iq IQ, name, roleName, user string) error { + return setOrganizationAuth(ctx, iq, name, roleName, user, MemberTypeUser) +} + // SetOrganizationUser sets the role and user that can have access to an organization func SetOrganizationUser(iq IQ, name, roleName, user string) error { - return setOrganizationAuth(iq, name, roleName, user, MemberTypeUser) + return SetOrganizationUserContext(context.Background(), iq, name, roleName, user) +} + +func SetOrganizationGroupContext(ctx context.Context, iq IQ, name, roleName, group string) error { + return setOrganizationAuth(ctx, iq, name, roleName, group, MemberTypeGroup) } // SetOrganizationGroup sets the role and group that can have access to an organization func SetOrganizationGroup(iq IQ, name, roleName, group string) error { - return setOrganizationAuth(iq, name, roleName, group, MemberTypeGroup) + return SetOrganizationGroupContext(context.Background(), iq, name, roleName, group) } -func applicationAuthorizationsByID(iq IQ, appID string) ([]MemberMapping, error) { +func applicationAuthorizationsByID(ctx context.Context, iq IQ, appID string) ([]MemberMapping, error) { var endpoint string - if hasRev70API(iq) { + if hasRev70API(ctx, iq) { endpoint = fmt.Sprintf(restRoleMembersAppGet, appID) } else { endpoint = fmt.Sprintf(restRoleMembersAppDeprecated, appID) } - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, fmt.Errorf("could not retrieve role mapping for application %s: %v", appID, err) } @@ -208,15 +225,15 @@ func applicationAuthorizationsByID(iq IQ, appID string) ([]MemberMapping, error) return mappings.MemberMappings, err } -func applicationAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { - apps, err := GetAllApplications(iq) +func applicationAuthorizationsByRoleID(ctx context.Context, iq IQ, roleID string) ([]MemberMapping, error) { + apps, err := GetAllApplicationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not find applications: %v", err) } mappings := make([]MemberMapping, 0) for _, app := range apps { - appMaps, _ := applicationAuthorizationsByID(iq, app.ID) + appMaps, _ := applicationAuthorizationsByID(ctx, iq, app.ID) for _, m := range appMaps { if m.RoleID == roleID { mappings = append(mappings, m) @@ -227,40 +244,48 @@ func applicationAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, e return mappings, nil } -// ApplicationAuthorizations returns the member mappings of an application -func ApplicationAuthorizations(iq IQ, name string) ([]MemberMapping, error) { - app, err := GetApplicationByPublicID(iq, name) +func ApplicationAuthorizationsContext(ctx context.Context, iq IQ, name string) ([]MemberMapping, error) { + app, err := GetApplicationByPublicIDContext(ctx, iq, name) if err != nil { return nil, fmt.Errorf("could not find application with name %s: %v", name, err) } - return applicationAuthorizationsByID(iq, app.ID) + return applicationAuthorizationsByID(ctx, iq, app.ID) } -// ApplicationAuthorizationsByRole returns the member mappings of all applications which match the given role -func ApplicationAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { - role, err := RoleByName(iq, roleName) +// ApplicationAuthorizations returns the member mappings of an application +func ApplicationAuthorizations(iq IQ, name string) ([]MemberMapping, error) { + return ApplicationAuthorizationsContext(context.Background(), iq, name) +} + +func ApplicationAuthorizationsByRoleContext(ctx context.Context, iq IQ, roleName string) ([]MemberMapping, error) { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return nil, fmt.Errorf("could not find role with name %s: %v", roleName, err) } - return applicationAuthorizationsByRoleID(iq, role.ID) + return applicationAuthorizationsByRoleID(ctx, iq, role.ID) } -func setApplicationAuth(iq IQ, name, roleName, member, memberType string) error { - app, err := GetApplicationByPublicID(iq, name) +// ApplicationAuthorizationsByRole returns the member mappings of all applications which match the given role +func ApplicationAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { + return ApplicationAuthorizationsByRoleContext(context.Background(), iq, roleName) +} + +func setApplicationAuth(ctx context.Context, iq IQ, name, roleName, member, memberType string) error { + app, err := GetApplicationByPublicIDContext(ctx, iq, name) if err != nil { return fmt.Errorf("could not find application with name %s: %v", name, err) } - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } var endpoint string var payload io.Reader - if hasRev70API(iq) { + if hasRev70API(ctx, iq) { switch memberType { case MemberTypeUser: endpoint = fmt.Sprintf(restRoleMembersAppUser, app.ID, role.ID, member) @@ -269,7 +294,7 @@ func setApplicationAuth(iq IQ, name, roleName, member, memberType string) error } } else { endpoint = fmt.Sprintf(restRoleMembersAppDeprecated, app.ID) - current, err := ApplicationAuthorizations(iq, name) + current, err := ApplicationAuthorizationsContext(ctx, iq, name) if err != nil && current == nil { current = make([]MemberMapping, 0) } @@ -282,7 +307,7 @@ func setApplicationAuth(iq IQ, name, roleName, member, memberType string) error payload = bytes.NewBuffer(buf) } - _, _, err = iq.Put(endpoint, payload) + _, _, err = iq.Put(ctx, endpoint, payload) if err != nil { return fmt.Errorf("could not update organization role mapping: %v", err) } @@ -290,19 +315,27 @@ func setApplicationAuth(iq IQ, name, roleName, member, memberType string) error return nil } +func SetApplicationUserContext(ctx context.Context, iq IQ, name, roleName, user string) error { + return setApplicationAuth(ctx, iq, name, roleName, user, MemberTypeUser) +} + // SetApplicationUser sets the role and user that can have access to an application func SetApplicationUser(iq IQ, name, roleName, user string) error { - return setApplicationAuth(iq, name, roleName, user, MemberTypeUser) + return SetApplicationUserContext(context.Background(), iq, name, roleName, user) +} + +func SetApplicationGroupContext(ctx context.Context, iq IQ, name, roleName, group string) error { + return setApplicationAuth(ctx, iq, name, roleName, group, MemberTypeGroup) } // SetApplicationGroup sets the role and group that can have access to an application func SetApplicationGroup(iq IQ, name, roleName, group string) error { - return setApplicationAuth(iq, name, roleName, group, MemberTypeGroup) + return SetApplicationGroupContext(context.Background(), iq, name, roleName, group) } -func revokeLT70(iq IQ, authType, authName, roleName, memberType, memberName string) error { +func revokeLT70(ctx context.Context, iq IQ, authType, authName, roleName, memberType, memberName string) error { var err error - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } @@ -313,18 +346,18 @@ func revokeLT70(iq IQ, authType, authName, roleName, memberType, memberName stri ) switch authType { case "organization": - org, err := GetOrganizationByName(iq, authName) + org, err := GetOrganizationByNameContext(ctx, iq, authName) if err == nil { authID = org.ID baseEndpoint = restRoleMembersOrgDeprecated - mapping, err = OrganizationAuthorizations(iq, authName) + mapping, err = OrganizationAuthorizationsContext(ctx, iq, authName) } case "application": - app, err := GetApplicationByPublicID(iq, authName) + app, err := GetApplicationByPublicIDContext(ctx, iq, authName) if err == nil { authID = app.ID baseEndpoint = restRoleMembersAppDeprecated - mapping, err = ApplicationAuthorizations(iq, authName) + mapping, err = ApplicationAuthorizationsContext(ctx, iq, authName) } } if err != nil && mapping != nil { @@ -349,7 +382,7 @@ func revokeLT70(iq IQ, authType, authName, roleName, memberType, memberName stri } endpoint := fmt.Sprintf(baseEndpoint, authID) - _, _, err = iq.Put(endpoint, bytes.NewBuffer(buf)) + _, _, err = iq.Put(ctx, endpoint, bytes.NewBuffer(buf)) if err != nil { return fmt.Errorf("could not remove role mapping: %v", err) } @@ -357,8 +390,8 @@ func revokeLT70(iq IQ, authType, authName, roleName, memberType, memberName stri return nil } -func revoke(iq IQ, authType, authName, roleName, memberType, memberName string) error { - role, err := RoleByName(iq, roleName) +func revoke(ctx context.Context, iq IQ, authType, authName, roleName, memberType, memberName string) error { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } @@ -368,7 +401,7 @@ func revoke(iq IQ, authType, authName, roleName, memberType, memberName string) ) switch authType { case "organization": - org, err := GetOrganizationByName(iq, authName) + org, err := GetOrganizationByNameContext(ctx, iq, authName) if err == nil { authID = org.ID switch memberType { @@ -379,7 +412,7 @@ func revoke(iq IQ, authType, authName, roleName, memberType, memberName string) } } case "application": - app, err := GetApplicationByPublicID(iq, authName) + app, err := GetApplicationByPublicIDContext(ctx, iq, authName) if err == nil { authID = app.ID switch memberType { @@ -392,48 +425,64 @@ func revoke(iq IQ, authType, authName, roleName, memberType, memberName string) } endpoint := fmt.Sprintf(baseEndpoint, authID, role.ID, memberName) - _, err = iq.Del(endpoint) + _, err = iq.Del(ctx, endpoint) return err } -// RevokeOrganizationUser removes a user and role from the named organization +func RevokeOrganizationUserContext(ctx context.Context, iq IQ, name, roleName, user string) error { + if !hasRev70API(ctx, iq) { + return revokeLT70(ctx, iq, "organization", name, roleName, MemberTypeUser, user) + } + return revoke(ctx, iq, "organization", name, roleName, MemberTypeUser, user) +} + func RevokeOrganizationUser(iq IQ, name, roleName, user string) error { - if !hasRev70API(iq) { - return revokeLT70(iq, "organization", name, roleName, MemberTypeUser, user) + return RevokeOrganizationUserContext(context.Background(), iq, name, roleName, user) +} + +func RevokeOrganizationGroupContext(ctx context.Context, iq IQ, name, roleName, group string) error { + if !hasRev70API(ctx, iq) { + return revokeLT70(ctx, iq, "organization", name, roleName, MemberTypeGroup, group) } - return revoke(iq, "organization", name, roleName, MemberTypeUser, user) + return revoke(ctx, iq, "organization", name, roleName, MemberTypeGroup, group) } // RevokeOrganizationGroup removes a group and role from the named organization func RevokeOrganizationGroup(iq IQ, name, roleName, group string) error { - if !hasRev70API(iq) { - return revokeLT70(iq, "organization", name, roleName, MemberTypeGroup, group) + return RevokeOrganizationGroupContext(context.Background(), iq, name, roleName, group) +} + +func RevokeApplicationUserContext(ctx context.Context, iq IQ, name, roleName, user string) error { + if !hasRev70API(ctx, iq) { + return revokeLT70(ctx, iq, "application", name, roleName, MemberTypeUser, user) } - return revoke(iq, "organization", name, roleName, MemberTypeGroup, group) + return revoke(ctx, iq, "application", name, roleName, MemberTypeUser, user) } // RevokeApplicationUser removes a user and role from the named application func RevokeApplicationUser(iq IQ, name, roleName, user string) error { - if !hasRev70API(iq) { - return revokeLT70(iq, "application", name, roleName, MemberTypeUser, user) + return RevokeApplicationUserContext(context.Background(), iq, name, roleName, user) + +} + +func RevokeApplicationGroupContext(ctx context.Context, iq IQ, name, roleName, group string) error { + if !hasRev70API(ctx, iq) { + return revokeLT70(ctx, iq, "application", name, roleName, MemberTypeGroup, group) } - return revoke(iq, "application", name, roleName, MemberTypeUser, user) + return revoke(ctx, iq, "application", name, roleName, MemberTypeGroup, group) } // RevokeApplicationGroup removes a group and role from the named application func RevokeApplicationGroup(iq IQ, name, roleName, group string) error { - if !hasRev70API(iq) { - return revokeLT70(iq, "application", name, roleName, MemberTypeGroup, group) - } - return revoke(iq, "application", name, roleName, MemberTypeGroup, group) + return RevokeApplicationGroupContext(context.Background(), iq, name, roleName, group) } -func repositoriesAuth(iq IQ, method, roleName, memberType, member string) error { - if !hasRev70API(iq) { +func repositoriesAuth(ctx context.Context, iq IQ, method, roleName, memberType, member string) error { + if !hasRev70API(ctx, iq) { return fmt.Errorf("did not find revision 70 API") } - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } @@ -448,9 +497,9 @@ func repositoriesAuth(iq IQ, method, roleName, memberType, member string) error switch method { case http.MethodPut: - _, _, err = iq.Put(endpoint, nil) + _, _, err = iq.Put(ctx, endpoint, nil) case http.MethodDelete: - _, err = iq.Del(endpoint) + _, err = iq.Del(ctx, endpoint) } if err != nil { return fmt.Errorf("could not affect repositories role mapping: %v", err) @@ -459,8 +508,8 @@ func repositoriesAuth(iq IQ, method, roleName, memberType, member string) error return nil } -func repositoriesAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { - auths, err := RepositoriesAuthorizations(iq) +func repositoriesAuthorizationsByRoleID(ctx context.Context, iq IQ, roleID string) ([]MemberMapping, error) { + auths, err := RepositoriesAuthorizationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not find authorization mappings for repositories: %v", err) } @@ -475,9 +524,8 @@ func repositoriesAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, return mappings, nil } -// RepositoriesAuthorizations returns the member mappings of all repositories -func RepositoriesAuthorizations(iq IQ) ([]MemberMapping, error) { - body, _, err := iq.Get(restRoleMembersReposGet) +func RepositoriesAuthorizationsContext(ctx context.Context, iq IQ) ([]MemberMapping, error) { + body, _, err := iq.Get(ctx, restRoleMembersReposGet) if err != nil { return nil, fmt.Errorf("could not get repositories mappings: %v", err) } @@ -491,49 +539,74 @@ func RepositoriesAuthorizations(iq IQ) ([]MemberMapping, error) { return mappings.MemberMappings, nil } -// RepositoriesAuthorizationsByRole returns the member mappings of all repositories which match the given role -func RepositoriesAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { - role, err := RoleByName(iq, roleName) +// RepositoriesAuthorizations returns the member mappings of all repositories +func RepositoriesAuthorizations(iq IQ) ([]MemberMapping, error) { + return RepositoriesAuthorizationsContext(context.Background(), iq) +} + +func RepositoriesAuthorizationsByRoleContext(ctx context.Context, iq IQ, roleName string) ([]MemberMapping, error) { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return nil, fmt.Errorf("could not find role with name %s: %v", roleName, err) } - return repositoriesAuthorizationsByRoleID(iq, role.ID) + return repositoriesAuthorizationsByRoleID(ctx, iq, role.ID) +} + +// RepositoriesAuthorizationsByRole returns the member mappings of all repositories which match the given role +func RepositoriesAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { + return RepositoriesAuthorizationsByRoleContext(context.Background(), iq, roleName) +} + +func SetRepositoriesUserContext(ctx context.Context, iq IQ, roleName, user string) error { + return repositoriesAuth(ctx, iq, http.MethodPut, roleName, MemberTypeUser, user) } // SetRepositoriesUser sets the role and user that can have access to the repositories func SetRepositoriesUser(iq IQ, roleName, user string) error { - return repositoriesAuth(iq, http.MethodPut, roleName, MemberTypeUser, user) + return SetRepositoriesUserContext(context.Background(), iq, roleName, user) +} + +func SetRepositoriesGroupContext(ctx context.Context, iq IQ, roleName, group string) error { + return repositoriesAuth(ctx, iq, http.MethodPut, roleName, MemberTypeGroup, group) } // SetRepositoriesGroup sets the role and group that can have access to the repositories func SetRepositoriesGroup(iq IQ, roleName, group string) error { - return repositoriesAuth(iq, http.MethodPut, roleName, MemberTypeGroup, group) + return SetRepositoriesGroupContext(context.Background(), iq, roleName, group) +} + +func RevokeRepositoriesUserContext(ctx context.Context, iq IQ, roleName, user string) error { + return repositoriesAuth(ctx, iq, http.MethodDelete, roleName, MemberTypeUser, user) } // RevokeRepositoriesUser revoke the role and user that can have access to the repositories func RevokeRepositoriesUser(iq IQ, roleName, user string) error { - return repositoriesAuth(iq, http.MethodDelete, roleName, MemberTypeUser, user) + return RevokeRepositoriesUserContext(context.Background(), iq, roleName, user) +} + +func RevokeRepositoriesGroupContext(ctx context.Context, iq IQ, roleName, group string) error { + return repositoriesAuth(ctx, iq, http.MethodDelete, roleName, MemberTypeGroup, group) } // RevokeRepositoriesGroup revoke the role and group that can have access to the repositories func RevokeRepositoriesGroup(iq IQ, roleName, group string) error { - return repositoriesAuth(iq, http.MethodDelete, roleName, MemberTypeGroup, group) + return RevokeRepositoriesGroupContext(context.Background(), iq, roleName, group) } -func membersByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { +func membersByRoleID(ctx context.Context, iq IQ, roleID string) ([]MemberMapping, error) { members := make([]MemberMapping, 0) - if m, err := organizationAuthorizationsByRoleID(iq, roleID); err == nil && len(m) > 0 { + if m, err := organizationAuthorizationsByRoleID(ctx, iq, roleID); err == nil && len(m) > 0 { members = append(members, m...) } - if m, err := applicationAuthorizationsByRoleID(iq, roleID); err == nil && len(m) > 0 { + if m, err := applicationAuthorizationsByRoleID(ctx, iq, roleID); err == nil && len(m) > 0 { members = append(members, m...) } - if hasRev70API(iq) { - if m, err := repositoriesAuthorizationsByRoleID(iq, roleID); err == nil && len(m) > 0 { + if hasRev70API(ctx, iq) { + if m, err := repositoriesAuthorizationsByRoleID(ctx, iq, roleID); err == nil && len(m) > 0 { members = append(members, m...) } } @@ -541,18 +614,21 @@ func membersByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { return members, nil } -// MembersByRole returns all users and groups by role name -func MembersByRole(iq IQ, roleName string) ([]MemberMapping, error) { - role, err := RoleByName(iq, roleName) +func MembersByRoleContext(ctx context.Context, iq IQ, roleName string) ([]MemberMapping, error) { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return nil, fmt.Errorf("could not find role with name %s: %v", roleName, err) } - return membersByRoleID(iq, role.ID) + return membersByRoleID(ctx, iq, role.ID) } -// GlobalAuthorizations returns all of the users and roles who have the administrator role across all of IQ -func GlobalAuthorizations(iq IQ) ([]MemberMapping, error) { - body, _, err := iq.Get(restRoleMembersGlobalGet) +// MembersByRole returns all users and groups by role name +func MembersByRole(iq IQ, roleName string) ([]MemberMapping, error) { + return MembersByRoleContext(context.Background(), iq, roleName) +} + +func GlobalAuthorizationsContext(ctx context.Context, iq IQ) ([]MemberMapping, error) { + body, _, err := iq.Get(ctx, restRoleMembersGlobalGet) if err != nil { return nil, fmt.Errorf("could not get global members: %v", err) } @@ -566,12 +642,17 @@ func GlobalAuthorizations(iq IQ) ([]MemberMapping, error) { return mappings.MemberMappings, nil } -func globalAuth(iq IQ, method, roleName, memberType, member string) error { - if !hasRev70API(iq) { +// GlobalAuthorizations returns all of the users and roles who have the administrator role across all of IQ +func GlobalAuthorizations(iq IQ) ([]MemberMapping, error) { + return GlobalAuthorizationsContext(context.Background(), iq) +} + +func globalAuth(ctx context.Context, iq IQ, method, roleName, memberType, member string) error { + if !hasRev70API(ctx, iq) { return fmt.Errorf("did not find revision 70 API") } - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } @@ -586,9 +667,9 @@ func globalAuth(iq IQ, method, roleName, memberType, member string) error { switch method { case http.MethodPut: - _, _, err = iq.Put(endpoint, nil) + _, _, err = iq.Put(ctx, endpoint, nil) case http.MethodDelete: - _, err = iq.Del(endpoint) + _, err = iq.Del(ctx, endpoint) } if err != nil { return fmt.Errorf("could not affect global role mapping: %v", err) @@ -597,22 +678,38 @@ func globalAuth(iq IQ, method, roleName, memberType, member string) error { return nil } +func SetGlobalUserContext(ctx context.Context, iq IQ, roleName, user string) error { + return globalAuth(ctx, iq, http.MethodPut, roleName, MemberTypeUser, user) +} + // SetGlobalUser sets the role and user that can have access to the repositories func SetGlobalUser(iq IQ, roleName, user string) error { - return globalAuth(iq, http.MethodPut, roleName, MemberTypeUser, user) + return SetGlobalUserContext(context.Background(), iq, roleName, user) +} + +func SetGlobalGroupContext(ctx context.Context, iq IQ, roleName, group string) error { + return globalAuth(ctx, iq, http.MethodPut, roleName, MemberTypeGroup, group) } // SetGlobalGroup sets the role and group that can have access to the global func SetGlobalGroup(iq IQ, roleName, group string) error { - return globalAuth(iq, http.MethodPut, roleName, MemberTypeGroup, group) + return SetGlobalGroupContext(context.Background(), iq, roleName, group) +} + +func RevokeGlobalUserContext(ctx context.Context, iq IQ, roleName, user string) error { + return globalAuth(ctx, iq, http.MethodDelete, roleName, MemberTypeUser, user) } // RevokeGlobalUser revoke the role and user that can have access to the global func RevokeGlobalUser(iq IQ, roleName, user string) error { - return globalAuth(iq, http.MethodDelete, roleName, MemberTypeUser, user) + return RevokeGlobalUserContext(context.Background(), iq, roleName, user) +} + +func RevokeGlobalGroupContext(ctx context.Context, iq IQ, roleName, group string) error { + return globalAuth(ctx, iq, http.MethodDelete, roleName, MemberTypeGroup, group) } // RevokeGlobalGroup revoke the role and group that can have access to the global func RevokeGlobalGroup(iq IQ, roleName, group string) error { - return globalAuth(iq, http.MethodDelete, roleName, MemberTypeGroup, group) + return RevokeGlobalGroupContext(context.Background(), iq, roleName, group) } diff --git a/iq/roleMemberships_test.go b/iq/roleMemberships_test.go index 7bf8a36..89a0814 100644 --- a/iq/roleMemberships_test.go +++ b/iq/roleMemberships_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -397,7 +398,7 @@ func testGetOrganizationAuthorizations(t *testing.T, iq IQ) { t.Helper() dummyIdx := 0 - got, err := OrganizationAuthorizations(iq, dummyOrgs[dummyIdx].Name) + got, err := OrganizationAuthorizationsContext(context.Background(), iq, dummyOrgs[dummyIdx].Name) if err != nil { t.Error(err) } @@ -424,7 +425,7 @@ func testGetOrganizationAuthorizationsByRole(t *testing.T, iq IQ) { } } - got, err := OrganizationAuthorizationsByRole(iq, role.Name) + got, err := OrganizationAuthorizationsByRoleContext(context.Background(), iq, role.Name) if err != nil { t.Error(err) } @@ -448,7 +449,7 @@ func testGetApplicationAuthorizations(t *testing.T, iq IQ) { t.Helper() dummyIdx := 0 - got, err := ApplicationAuthorizations(iq, dummyApps[dummyIdx].PublicID) + got, err := ApplicationAuthorizationsContext(context.Background(), iq, dummyApps[dummyIdx].PublicID) if err != nil { t.Error(err) } @@ -474,7 +475,7 @@ func testGetApplicationAuthorizationsByRole(t *testing.T, iq IQ) { } } - got, err := ApplicationAuthorizationsByRole(iq, role.Name) + got, err := ApplicationAuthorizationsByRoleContext(context.Background(), iq, role.Name) if err != nil { t.Error(err) } @@ -517,22 +518,22 @@ func testSetAuth(t *testing.T, iq IQ, authTarget string, memberType string) { case "organization": switch memberType { case MemberTypeUser: - err = SetOrganizationUser(iq, dummyOrgs[dummyIdx].Name, dummyRoles[role].Name, memberName) + err = SetOrganizationUserContext(context.Background(), iq, dummyOrgs[dummyIdx].Name, dummyRoles[role].Name, memberName) case MemberTypeGroup: - err = SetOrganizationGroup(iq, dummyOrgs[dummyIdx].Name, dummyRoles[role].Name, memberName) + err = SetOrganizationGroupContext(context.Background(), iq, dummyOrgs[dummyIdx].Name, dummyRoles[role].Name, memberName) } if err == nil { - got, err = OrganizationAuthorizations(iq, dummyOrgs[dummyIdx].Name) + got, err = OrganizationAuthorizationsContext(context.Background(), iq, dummyOrgs[dummyIdx].Name) } case "application": switch memberType { case MemberTypeUser: - err = SetApplicationUser(iq, dummyApps[dummyIdx].PublicID, dummyRoles[role].Name, memberName) + err = SetApplicationUserContext(context.Background(), iq, dummyApps[dummyIdx].PublicID, dummyRoles[role].Name, memberName) case MemberTypeGroup: - err = SetApplicationGroup(iq, dummyApps[dummyIdx].PublicID, dummyRoles[role].Name, memberName) + err = SetApplicationGroupContext(context.Background(), iq, dummyApps[dummyIdx].PublicID, dummyRoles[role].Name, memberName) } if err == nil { - got, err = ApplicationAuthorizations(iq, dummyApps[dummyIdx].PublicID) + got, err = ApplicationAuthorizationsContext(context.Background(), iq, dummyApps[dummyIdx].PublicID) } } if err != nil { @@ -613,68 +614,68 @@ func testRevoke(t *testing.T, iq IQ, authType, memberType string) { dummyOrgName := dummyOrgs[0].Name switch memberType { case MemberTypeUser: - err = SetOrganizationUser(iq, dummyOrgName, role.Name, name) + err = SetOrganizationUserContext(context.Background(), iq, dummyOrgName, role.Name, name) if err == nil { - err = RevokeOrganizationUser(iq, dummyOrgName, role.Name, name) + err = RevokeOrganizationUserContext(context.Background(), iq, dummyOrgName, role.Name, name) } case MemberTypeGroup: - err = SetOrganizationGroup(iq, dummyOrgName, role.Name, name) + err = SetOrganizationGroupContext(context.Background(), iq, dummyOrgName, role.Name, name) if err == nil { t.Log("HERE1") - err = RevokeOrganizationGroup(iq, dummyOrgName, role.Name, name) + err = RevokeOrganizationGroupContext(context.Background(), iq, dummyOrgName, role.Name, name) } } if err == nil { - mappings, err = OrganizationAuthorizations(iq, dummyOrgName) + mappings, err = OrganizationAuthorizationsContext(context.Background(), iq, dummyOrgName) } case "application": dummyAppName := dummyApps[0].PublicID switch memberType { case MemberTypeUser: - err = SetApplicationUser(iq, dummyAppName, role.Name, name) + err = SetApplicationUserContext(context.Background(), iq, dummyAppName, role.Name, name) if err == nil { - err = RevokeApplicationUser(iq, dummyAppName, role.Name, name) + err = RevokeApplicationUserContext(context.Background(), iq, dummyAppName, role.Name, name) } case MemberTypeGroup: - err = SetApplicationGroup(iq, dummyAppName, role.Name, name) + err = SetApplicationGroupContext(context.Background(), iq, dummyAppName, role.Name, name) if err == nil { - err = RevokeApplicationGroup(iq, dummyAppName, role.Name, name) + err = RevokeApplicationGroupContext(context.Background(), iq, dummyAppName, role.Name, name) } } if err == nil { - mappings, err = ApplicationAuthorizations(iq, dummyAppName) + mappings, err = ApplicationAuthorizationsContext(context.Background(), iq, dummyAppName) } case "repository_container": switch memberType { case MemberTypeUser: - err = SetRepositoriesUser(iq, role.Name, name) + err = SetRepositoriesUserContext(context.Background(), iq, role.Name, name) if err == nil { - err = RevokeRepositoriesUser(iq, role.Name, name) + err = RevokeRepositoriesUserContext(context.Background(), iq, role.Name, name) } case MemberTypeGroup: - err = SetRepositoriesGroup(iq, role.Name, name) + err = SetRepositoriesGroupContext(context.Background(), iq, role.Name, name) if err == nil { - err = RevokeRepositoriesGroup(iq, role.Name, name) + err = RevokeRepositoriesGroupContext(context.Background(), iq, role.Name, name) } } if err == nil { - mappings, err = RepositoriesAuthorizations(iq) + mappings, err = RepositoriesAuthorizationsContext(context.Background(), iq) } case "global": switch memberType { case MemberTypeUser: - err = SetGlobalUser(iq, role.Name, name) + err = SetGlobalUserContext(context.Background(), iq, role.Name, name) if err == nil { - err = RevokeGlobalUser(iq, role.Name, name) + err = RevokeGlobalUserContext(context.Background(), iq, role.Name, name) } case MemberTypeGroup: - err = SetGlobalGroup(iq, role.Name, name) + err = SetGlobalGroupContext(context.Background(), iq, role.Name, name) if err == nil { - err = RevokeGlobalGroup(iq, role.Name, name) + err = RevokeGlobalGroupContext(context.Background(), iq, role.Name, name) } } if err == nil { - mappings, err = GlobalAuthorizations(iq) + mappings, err = GlobalAuthorizationsContext(context.Background(), iq) } } if err != nil { @@ -732,7 +733,7 @@ func TestRepositoriesAuthorizations(t *testing.T) { iq, mock := roleMembershipsTestIQ(t, false) defer mock.Close() - got, err := RepositoriesAuthorizations(iq) + got, err := RepositoriesAuthorizationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -758,7 +759,7 @@ func TestGetApplicationAuthorizationsByRole(t *testing.T) { } } - got, err := RepositoriesAuthorizationsByRole(iq, role.Name) + got, err := RepositoriesAuthorizationsByRoleContext(context.Background(), iq, role.Name) if err != nil { t.Error(err) } @@ -791,15 +792,15 @@ func testSetRepositories(t *testing.T, memberType string) { var err error switch memberType { case MemberTypeUser: - err = SetRepositoriesUser(iq, role.Name, memberName) + err = SetRepositoriesUserContext(context.Background(), iq, role.Name, memberName) case MemberTypeGroup: - err = SetRepositoriesGroup(iq, role.Name, memberName) + err = SetRepositoriesGroupContext(context.Background(), iq, role.Name, memberName) } if err != nil { t.Error(err) } - got, err := RepositoriesAuthorizations(iq) + got, err := RepositoriesAuthorizationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -860,7 +861,7 @@ func testMembersByRole(t *testing.T, iq IQ) { } } } - if hasRev70API(iq) { + if hasRev70API(context.Background(), iq) { for _, m := range dummyRoleMappingsRepos { if m.RoleID == role.ID { want = append(want, m) @@ -868,7 +869,7 @@ func testMembersByRole(t *testing.T, iq IQ) { } } - got, err := MembersByRole(iq, role.Name) + got, err := MembersByRoleContext(context.Background(), iq, role.Name) if err != nil { t.Error(err) } @@ -888,7 +889,7 @@ func TestGlobalAuthorizations(t *testing.T) { iq, mock := roleMembershipsTestIQ(t, false) defer mock.Close() - got, err := GlobalAuthorizations(iq) + got, err := GlobalAuthorizationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -924,15 +925,15 @@ func testSetGlobal(t *testing.T, memberType string) { var err error switch memberType { case MemberTypeUser: - err = SetGlobalUser(iq, role.Name, memberName) + err = SetGlobalUserContext(context.Background(), iq, role.Name, memberName) case MemberTypeGroup: - err = SetGlobalGroup(iq, role.Name, memberName) + err = SetGlobalGroupContext(context.Background(), iq, role.Name, memberName) } if err != nil { t.Error(err) } - got, err := GlobalAuthorizations(iq) + got, err := GlobalAuthorizationsContext(context.Background(), iq) if err != nil { t.Error(err) } diff --git a/iq/roles.go b/iq/roles.go index 830bddd..2fc9ff2 100644 --- a/iq/roles.go +++ b/iq/roles.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -22,11 +23,10 @@ type Role struct { Description string `json:"description"` } -// Roles returns a slice of all the roles in the IQ instance -func Roles(iq IQ) ([]Role, error) { - body, resp, err := iq.Get(restRoles) +func RolesContext(ctx context.Context, iq IQ) ([]Role, error) { + body, resp, err := iq.Get(ctx, restRoles) if resp.StatusCode == http.StatusNotFound { - body, _, err = iq.Get(restRolesDeprecated) + body, _, err = iq.Get(ctx, restRolesDeprecated) } if err != nil { return nil, fmt.Errorf("could not retrieve roles: %v", err) @@ -40,9 +40,13 @@ func Roles(iq IQ) ([]Role, error) { return list.Roles, nil } -// RoleByName returns the named role -func RoleByName(iq IQ, name string) (Role, error) { - roles, err := Roles(iq) +// Roles returns a slice of all the roles in the IQ instance +func Roles(iq IQ) ([]Role, error) { + return RolesContext(context.Background(), iq) +} + +func RoleByNameContext(ctx context.Context, iq IQ, name string) (Role, error) { + roles, err := RolesContext(ctx, iq) if err != nil { return Role{}, fmt.Errorf("did not find role with name %s: %v", name, err) } @@ -56,12 +60,21 @@ func RoleByName(iq IQ, name string) (Role, error) { return Role{}, fmt.Errorf("did not find role with name %s", name) } -// GetSystemAdminID returns the identifier of the System Administrator role -func GetSystemAdminID(iq IQ) (string, error) { - role, err := RoleByName(iq, "System Administrator") +// RoleByName returns the named role +func RoleByName(iq IQ, name string) (Role, error) { + return RoleByNameContext(context.Background(), iq, name) +} + +func GetSystemAdminIDContext(ctx context.Context, iq IQ) (string, error) { + role, err := RoleByNameContext(ctx, iq, "System Administrator") if err != nil { return "", fmt.Errorf("did not get admin role: %v", err) } return role.ID, nil } + +// GetSystemAdminID returns the identifier of the System Administrator role +func GetSystemAdminID(iq IQ) (string, error) { + return GetSystemAdminIDContext(context.Background(), iq) +} diff --git a/iq/roles_test.go b/iq/roles_test.go index 73bdaa2..0d524d2 100644 --- a/iq/roles_test.go +++ b/iq/roles_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -80,7 +81,7 @@ func TestRoles(t *testing.T) { iq, mock := rolesTestIQ(t) defer mock.Close() - got, err := Roles(iq) + got, err := RolesContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -96,7 +97,7 @@ func TestRoleByName(t *testing.T) { want := dummyRoles[0] - got, err := RoleByName(iq, want.Name) + got, err := RoleByNameContext(context.Background(), iq, want.Name) if err != nil { t.Error(err) } diff --git a/iq/search.go b/iq/search.go index df65041..2757857 100644 --- a/iq/search.go +++ b/iq/search.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -117,10 +118,9 @@ func NewSearchQueryBuilder() *SearchQueryBuilder { return b } -// SearchComponents allows searching the indicated IQ instance for specific components -func SearchComponents(iq IQ, query nexus.SearchQueryBuilder) ([]SearchResult, error) { +func SearchComponentsContext(ctx context.Context, iq IQ, query nexus.SearchQueryBuilder) ([]SearchResult, error) { endpoint := restSearchComponent + "?" + query.Build() - body, resp, err := iq.Get(endpoint) + body, resp, err := iq.Get(ctx, endpoint) if err != nil || resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("could not find component: %v", err) } @@ -132,3 +132,8 @@ func SearchComponents(iq IQ, query nexus.SearchQueryBuilder) ([]SearchResult, er return searchResp.Results, nil } + +// SearchComponents allows searching the indicated IQ instance for specific components +func SearchComponents(iq IQ, query nexus.SearchQueryBuilder) ([]SearchResult, error) { + return SearchComponentsContext(context.Background(), iq, query) +} diff --git a/iq/search_test.go b/iq/search_test.go index 9747b00..7c5b339 100644 --- a/iq/search_test.go +++ b/iq/search_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -57,7 +58,7 @@ func TestSearchComponent(t *testing.T) { defer mock.Close() query := NewSearchQueryBuilder().Coordinates(dummyComponent.ComponentID.Coordinates) - components, err := SearchComponents(iq, query) + components, err := SearchComponentsContext(context.Background(), iq, query) if err != nil { t.Fatalf("Did not complete search: %v", err) } @@ -88,7 +89,7 @@ func ExampleSearchComponents() { query = query.Stage(StageBuild) query = query.PackageURL("pkg:maven/commons-collections/commons-collections@3.2") - components, err := SearchComponents(iq, query) + components, err := SearchComponentsContext(context.Background(), iq, query) if err != nil { panic(fmt.Sprintf("Did not complete search: %v", err)) } diff --git a/iq/sourceControl.go b/iq/sourceControl.go index 4514246..5c01bd4 100644 --- a/iq/sourceControl.go +++ b/iq/sourceControl.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -20,10 +21,10 @@ type SourceControlEntry struct { Token string `json:"token"` } -func getSourceControlEntryByInternalID(iq IQ, applicationID string) (entry SourceControlEntry, err error) { +func getSourceControlEntryByInternalID(ctx context.Context, iq IQ, applicationID string) (entry SourceControlEntry, err error) { endpoint := fmt.Sprintf(restSourceControl, applicationID) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return } @@ -33,26 +34,29 @@ func getSourceControlEntryByInternalID(iq IQ, applicationID string) (entry Sourc return } -// GetSourceControlEntry lists of all of the Source Control entries for the given application -func GetSourceControlEntry(iq IQ, applicationID string) (SourceControlEntry, error) { - appInfo, err := GetApplicationByPublicID(iq, applicationID) +func GetSourceControlEntryContext(ctx context.Context, iq IQ, applicationID string) (SourceControlEntry, error) { + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return SourceControlEntry{}, fmt.Errorf("no source control entry for '%s': %v", applicationID, err) } - return getSourceControlEntryByInternalID(iq, appInfo.ID) + return getSourceControlEntryByInternalID(ctx, iq, appInfo.ID) } -// GetAllSourceControlEntries lists of all of the Source Control entries in the IQ instance -func GetAllSourceControlEntries(iq IQ) ([]SourceControlEntry, error) { - apps, err := GetAllApplications(iq) +// GetSourceControlEntry lists of all of the Source Control entries for the given application +func GetSourceControlEntry(iq IQ, applicationID string) (SourceControlEntry, error) { + return GetSourceControlEntryContext(context.Background(), iq, applicationID) +} + +func GetAllSourceControlEntriesContext(ctx context.Context, iq IQ) ([]SourceControlEntry, error) { + apps, err := GetAllApplicationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("no source control entries: %v", err) } entries := make([]SourceControlEntry, 0) for _, app := range apps { - if entry, err := getSourceControlEntryByInternalID(iq, app.ID); err == nil { + if entry, err := getSourceControlEntryByInternalID(ctx, iq, app.ID); err == nil { entries = append(entries, entry) } } @@ -60,13 +64,17 @@ func GetAllSourceControlEntries(iq IQ) ([]SourceControlEntry, error) { return entries, nil } -// CreateSourceControlEntry creates a source control entry in IQ -func CreateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) error { +// GetAllSourceControlEntries lists of all of the Source Control entries in the IQ instance +func GetAllSourceControlEntries(iq IQ) ([]SourceControlEntry, error) { + return GetAllSourceControlEntriesContext(context.Background(), iq) +} + +func CreateSourceControlEntryContext(ctx context.Context, iq IQ, applicationID, repositoryURL, token string) error { doError := func(err error) error { return fmt.Errorf("source control entry not created for '%s': %v", applicationID, err) } - appInfo, err := GetApplicationByPublicID(iq, applicationID) + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return doError(err) } @@ -77,20 +85,24 @@ func CreateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) } endpoint := fmt.Sprintf(restSourceControl, appInfo.ID) - if _, _, err = iq.Post(endpoint, bytes.NewBuffer(request)); err != nil { + if _, _, err = iq.Post(ctx, endpoint, bytes.NewBuffer(request)); err != nil { return doError(err) } return nil } -// UpdateSourceControlEntry updates a source control entry in IQ -func UpdateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) error { +// CreateSourceControlEntry creates a source control entry in IQ +func CreateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) error { + return CreateSourceControlEntryContext(context.Background(), iq, applicationID, repositoryURL, token) +} + +func UpdateSourceControlEntryContext(ctx context.Context, iq IQ, applicationID, repositoryURL, token string) error { doError := func(err error) error { return fmt.Errorf("source control entry not updated for '%s': %v", applicationID, err) } - appInfo, err := GetApplicationByPublicID(iq, applicationID) + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return doError(err) } @@ -101,17 +113,22 @@ func UpdateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) } endpoint := fmt.Sprintf(restSourceControl, appInfo.ID) - if _, _, err = iq.Put(endpoint, bytes.NewBuffer(request)); err != nil { + if _, _, err = iq.Put(ctx, endpoint, bytes.NewBuffer(request)); err != nil { return doError(err) } return nil } -func deleteSourceControlEntry(iq IQ, appInternalID, sourceControlID string) error { +// UpdateSourceControlEntry updates a source control entry in IQ +func UpdateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) error { + return UpdateSourceControlEntryContext(context.Background(), iq, applicationID, repositoryURL, token) +} + +func deleteSourceControlEntry(ctx context.Context, iq IQ, appInternalID, sourceControlID string) error { endpoint := fmt.Sprintf(restSourceControlDelete, appInternalID, sourceControlID) - resp, err := iq.Del(endpoint) + resp, err := iq.Del(ctx, endpoint) if err != nil && resp.StatusCode != http.StatusNoContent { return err } @@ -119,33 +136,41 @@ func deleteSourceControlEntry(iq IQ, appInternalID, sourceControlID string) erro return nil } -// DeleteSourceControlEntry deletes a source control entry in IQ -func DeleteSourceControlEntry(iq IQ, applicationID, sourceControlID string) error { - appInfo, err := GetApplicationByPublicID(iq, applicationID) +func DeleteSourceControlEntryContext(ctx context.Context, iq IQ, applicationID, sourceControlID string) error { + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return fmt.Errorf("source control entry not deleted from '%s': %v", applicationID, err) } - return deleteSourceControlEntry(iq, appInfo.ID, sourceControlID) + return deleteSourceControlEntry(ctx, iq, appInfo.ID, sourceControlID) } -// DeleteSourceControlEntryByApp deletes a source control entry in IQ for the given application -func DeleteSourceControlEntryByApp(iq IQ, applicationID string) error { +// DeleteSourceControlEntry deletes a source control entry in IQ +func DeleteSourceControlEntry(iq IQ, applicationID, sourceControlID string) error { + return DeleteSourceControlEntryContext(context.Background(), iq, applicationID, sourceControlID) +} + +func DeleteSourceControlEntryByAppContext(ctx context.Context, iq IQ, applicationID string) error { doError := func(err error) error { return fmt.Errorf("source control entry not deleted from '%s': %v", applicationID, err) } - appInfo, err := GetApplicationByPublicID(iq, applicationID) + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return doError(err) } - entry, err := getSourceControlEntryByInternalID(iq, appInfo.ID) + entry, err := getSourceControlEntryByInternalID(ctx, iq, appInfo.ID) if err != nil { return doError(err) } - return deleteSourceControlEntry(iq, appInfo.ID, entry.ID) + return deleteSourceControlEntry(ctx, iq, appInfo.ID, entry.ID) +} + +// DeleteSourceControlEntryByApp deletes a source control entry in IQ for the given application +func DeleteSourceControlEntryByApp(iq IQ, applicationID string) error { + return DeleteSourceControlEntryByAppContext(context.Background(), iq, applicationID) } // DeleteSourceControlEntryByEntry deletes a source control entry in IQ for the given entry ID diff --git a/iq/sourceControl_test.go b/iq/sourceControl_test.go index 7af2c93..a96c6e5 100644 --- a/iq/sourceControl_test.go +++ b/iq/sourceControl_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -115,7 +116,7 @@ func TestGetSourceControlEntryByInternalID(t *testing.T) { dummyEntryIdx := 2 - entry, err := getSourceControlEntryByInternalID(iq, dummyEntries[dummyEntryIdx].ApplicationID) + entry, err := getSourceControlEntryByInternalID(context.Background(), iq, dummyEntries[dummyEntryIdx].ApplicationID) if err != nil { t.Error(err) } @@ -131,7 +132,7 @@ func TestGetAllSourceControlEntries(t *testing.T) { iq, mock := sourceControlTestIQ(t) defer mock.Close() - entries, err := GetAllSourceControlEntries(iq) + entries, err := GetAllSourceControlEntriesContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -161,7 +162,7 @@ func TestGetSourceControlEntry(t *testing.T) { dummyEntryIdx := 0 - entry, err := GetSourceControlEntry(iq, dummyApps[dummyEntryIdx].PublicID) + entry, err := GetSourceControlEntryContext(context.Background(), iq, dummyApps[dummyEntryIdx].PublicID) if err != nil { t.Error(err) } @@ -179,12 +180,12 @@ func TestCreateSourceControlEntry(t *testing.T) { createdEntry := SourceControlEntry{newEntryID, dummyApps[len(dummyApps)-1].ID, "createdEntryURL", "createEntryToken"} - err := CreateSourceControlEntry(iq, dummyApps[len(dummyApps)-1].PublicID, createdEntry.RepositoryURL, createdEntry.Token) + err := CreateSourceControlEntryContext(context.Background(), iq, dummyApps[len(dummyApps)-1].PublicID, createdEntry.RepositoryURL, createdEntry.Token) if err != nil { t.Error(err) } - entry, err := GetSourceControlEntry(iq, dummyApps[len(dummyApps)-1].PublicID) + entry, err := GetSourceControlEntryContext(context.Background(), iq, dummyApps[len(dummyApps)-1].PublicID) if err != nil { t.Error(err) } @@ -202,12 +203,12 @@ func TestUpdateSourceControlEntry(t *testing.T) { updatedEntryRepositoryURL := "updatedRepoURL" updatedEntryToken := "updatedToken" - err := UpdateSourceControlEntry(iq, dummyApps[len(dummyApps)-2].PublicID, updatedEntryRepositoryURL, updatedEntryToken) + err := UpdateSourceControlEntryContext(context.Background(), iq, dummyApps[len(dummyApps)-2].PublicID, updatedEntryRepositoryURL, updatedEntryToken) if err != nil { t.Error(err) } - entry, err := GetSourceControlEntry(iq, dummyApps[len(dummyApps)-2].PublicID) + entry, err := GetSourceControlEntryContext(context.Background(), iq, dummyApps[len(dummyApps)-2].PublicID) if err != nil { t.Error(err) } @@ -229,15 +230,15 @@ func TestDeleteSourceControlEntry(t *testing.T) { app := dummyApps[len(dummyApps)-1] deleteMe := SourceControlEntry{newEntryID, app.ID, "deleteMeURL", "deleteMeToken"} - if err := CreateSourceControlEntry(iq, app.PublicID, deleteMe.RepositoryURL, deleteMe.Token); err != nil { + if err := CreateSourceControlEntryContext(context.Background(), iq, app.PublicID, deleteMe.RepositoryURL, deleteMe.Token); err != nil { t.Error(err) } - if err := DeleteSourceControlEntry(iq, app.PublicID, newEntryID); err != nil { + if err := DeleteSourceControlEntryContext(context.Background(), iq, app.PublicID, newEntryID); err != nil { t.Error(err) } - if _, err := GetSourceControlEntry(iq, app.PublicID); err == nil { + if _, err := GetSourceControlEntryContext(context.Background(), iq, app.PublicID); err == nil { t.Error("Unexpectedly found entry which should have been deleted") } } @@ -249,15 +250,15 @@ func TestDeleteSourceControlEntryByApp(t *testing.T) { app := dummyApps[len(dummyApps)-1] deleteMe := SourceControlEntry{newEntryID, app.ID, "deleteMeURL", "deleteMeToken"} - if err := CreateSourceControlEntry(iq, app.PublicID, deleteMe.RepositoryURL, deleteMe.Token); err != nil { + if err := CreateSourceControlEntryContext(context.Background(), iq, app.PublicID, deleteMe.RepositoryURL, deleteMe.Token); err != nil { t.Error(err) } - if err := DeleteSourceControlEntryByApp(iq, app.PublicID); err != nil { + if err := DeleteSourceControlEntryByAppContext(context.Background(), iq, app.PublicID); err != nil { t.Error(err) } - if _, err := GetSourceControlEntry(iq, app.PublicID); err == nil { + if _, err := GetSourceControlEntryContext(context.Background(), iq, app.PublicID); err == nil { t.Error("Unexpectedly found entry which should have been deleted") } } diff --git a/iq/users.go b/iq/users.go index 332d813..2550710 100644 --- a/iq/users.go +++ b/iq/users.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -21,10 +22,9 @@ type User struct { Password string `json:"password,omitempty"` } -// GetUser returns user details for the given name -func GetUser(iq IQ, username string) (user User, err error) { +func GetUserContext(ctx context.Context, iq IQ, username string) (user User, err error) { endpoint := fmt.Sprintf(restUsers, username) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return user, fmt.Errorf("could not retrieve details on username %s: %v", username, err) } @@ -34,32 +34,45 @@ func GetUser(iq IQ, username string) (user User, err error) { return user, err } -// SetUser creates a new user -func SetUser(iq IQ, user User) (err error) { +// GetUser returns user details for the given name +func GetUser(iq IQ, username string) (user User, err error) { + return GetUserContext(context.Background(), iq, username) +} + +func SetUserContext(ctx context.Context, iq IQ, user User) (err error) { buf, err := json.Marshal(user) if err != nil { return fmt.Errorf("could not read user details: %v", err) } str := bytes.NewBuffer(buf) - if _, er := GetUser(iq, user.Username); er != nil { - _, resp, er := iq.Post(restUsersPost, str) + if _, er := GetUserContext(ctx, iq, user.Username); er != nil { + _, resp, er := iq.Post(ctx, restUsersPost, str) if er != nil && resp.StatusCode != http.StatusNoContent { return er } } else { endpoint := fmt.Sprintf(restUsers, user.Username) - _, _, err = iq.Put(endpoint, str) + _, _, err = iq.Put(ctx, endpoint, str) } return err } -// DeleteUser removes the named user -func DeleteUser(iq IQ, username string) error { +// SetUser creates a new user +func SetUser(iq IQ, user User) (err error) { + return SetUserContext(context.Background(), iq, user) +} + +func DeleteUserContext(ctx context.Context, iq IQ, username string) error { endpoint := fmt.Sprintf(restUsers, username) - if resp, err := iq.Del(endpoint); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := iq.Del(ctx, endpoint); err != nil && resp.StatusCode != http.StatusNoContent { return err } return nil } + +// DeleteUser removes the named user +func DeleteUser(iq IQ, username string) error { + return DeleteUserContext(context.Background(), iq, username) +} diff --git a/iq/users_test.go b/iq/users_test.go index a644797..35c9f8f 100644 --- a/iq/users_test.go +++ b/iq/users_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -145,7 +146,7 @@ func usersTestIQ(t *testing.T, useDeprecated bool) (IQ, *httptest.Server) { func checkExists(t *testing.T, iq IQ, want User) { t.Helper() - got, err := GetUser(iq, want.Username) + got, err := GetUserContext(context.Background(), iq, want.Username) if err != nil { t.Error(err) } @@ -165,7 +166,7 @@ func TestGetUser(t *testing.T) { } func setUser(t *testing.T, iq IQ, want User) { - err := SetUser(iq, want) + err := SetUserContext(context.Background(), iq, want) if err != nil { t.Error(err) } @@ -234,12 +235,12 @@ func TestDeleteUser(t *testing.T) { // Create new dummy user setUser(t, iq, want) - err := DeleteUser(iq, want.Username) + err := DeleteUserContext(context.Background(), iq, want.Username) if err != nil { t.Error(err) } - if _, err := GetUser(iq, want.Username); err == nil { + if _, err := GetUserContext(context.Background(), iq, want.Username); err == nil { t.Error("Found user which I tried to delete") } } diff --git a/nexus.go b/nexus.go index 7ab00f9..9059d2a 100644 --- a/nexus.go +++ b/nexus.go @@ -1,6 +1,7 @@ package nexus import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -20,12 +21,12 @@ type ServerInfo struct { // Client is the interface which allows interacting with an IQ server type Client interface { - NewRequest(method, endpoint string, payload io.Reader) (*http.Request, error) + NewRequest(ctx context.Context, method, endpoint string, payload io.Reader) (*http.Request, error) Do(request *http.Request) ([]byte, *http.Response, error) - Get(endpoint string) ([]byte, *http.Response, error) - Post(endpoint string, payload io.Reader) ([]byte, *http.Response, error) - Put(endpoint string, payload io.Reader) ([]byte, *http.Response, error) - Del(endpoint string) (*http.Response, error) + Get(ctx context.Context, endpoint string) ([]byte, *http.Response, error) + Post(ctx context.Context, endpoint string, payload io.Reader) ([]byte, *http.Response, error) + Put(ctx context.Context, endpoint string, payload io.Reader) ([]byte, *http.Response, error) + Del(ctx context.Context, endpoint string) (*http.Response, error) Info() ServerInfo SetDebug(enable bool) SetCertFile(certFile string) @@ -38,9 +39,9 @@ type DefaultClient struct { } // NewRequest created an http.Request object based on an endpoint and fills in basic auth -func (s *DefaultClient) NewRequest(method, endpoint string, payload io.Reader) (request *http.Request, err error) { +func (s *DefaultClient) NewRequest(ctx context.Context, method, endpoint string, payload io.Reader) (request *http.Request, err error) { url := fmt.Sprintf("%s/%s", s.Host, endpoint) - request, err = http.NewRequest(method, url, payload) + request, err = http.NewRequestWithContext(ctx, method, url, payload) if err != nil { return } @@ -104,8 +105,8 @@ func (s *DefaultClient) Do(request *http.Request) (body []byte, resp *http.Respo return } -func (s *DefaultClient) http(method, endpoint string, payload io.Reader) ([]byte, *http.Response, error) { - request, err := s.NewRequest(method, endpoint, payload) +func (s *DefaultClient) http(ctx context.Context, method, endpoint string, payload io.Reader) ([]byte, *http.Response, error) { + request, err := s.NewRequest(ctx, method, endpoint, payload) if err != nil { return nil, nil, err } @@ -114,23 +115,23 @@ func (s *DefaultClient) http(method, endpoint string, payload io.Reader) ([]byte } // Get performs an HTTP GET against the indicated endpoint -func (s *DefaultClient) Get(endpoint string) ([]byte, *http.Response, error) { - return s.http(http.MethodGet, endpoint, nil) +func (s *DefaultClient) Get(ctx context.Context, endpoint string) ([]byte, *http.Response, error) { + return s.http(ctx, http.MethodGet, endpoint, nil) } // Post performs an HTTP POST against the indicated endpoint -func (s *DefaultClient) Post(endpoint string, payload io.Reader) ([]byte, *http.Response, error) { - return s.http(http.MethodPost, endpoint, payload) +func (s *DefaultClient) Post(ctx context.Context, endpoint string, payload io.Reader) ([]byte, *http.Response, error) { + return s.http(ctx, http.MethodPost, endpoint, payload) } // Put performs an HTTP PUT against the indicated endpoint -func (s *DefaultClient) Put(endpoint string, payload io.Reader) ([]byte, *http.Response, error) { - return s.http(http.MethodPut, endpoint, payload) +func (s *DefaultClient) Put(ctx context.Context, endpoint string, payload io.Reader) ([]byte, *http.Response, error) { + return s.http(ctx, http.MethodPut, endpoint, payload) } // Del performs an HTTP DELETE against the indicated endpoint -func (s *DefaultClient) Del(endpoint string) (resp *http.Response, err error) { - _, resp, err = s.http(http.MethodDelete, endpoint, nil) +func (s *DefaultClient) Del(ctx context.Context, endpoint string) (resp *http.Response, err error) { + _, resp, err = s.http(ctx, http.MethodDelete, endpoint, nil) return } diff --git a/rm/anonymous.go b/rm/anonymous.go index 36e6a0d..29bd9e9 100644 --- a/rm/anonymous.go +++ b/rm/anonymous.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -15,10 +16,10 @@ type SettingsAnonAccess struct { RealmName string `json:"realmName"` } -func GetAnonAccess(rm RM) (SettingsAnonAccess, error) { +func GetAnonAccessContext(ctx context.Context, rm RM) (SettingsAnonAccess, error) { var settings SettingsAnonAccess - body, resp, err := rm.Get(restAnonymous) + body, resp, err := rm.Get(ctx, restAnonymous) if err != nil && resp.StatusCode != http.StatusNoContent { return SettingsAnonAccess{}, fmt.Errorf("anonymous access settings can't getting: %v", err) } @@ -30,15 +31,23 @@ func GetAnonAccess(rm RM) (SettingsAnonAccess, error) { return settings, nil } -func SetAnonAccess(rm RM, settings SettingsAnonAccess) error { +func GetAnonAccess(rm RM) (SettingsAnonAccess, error) { + return GetAnonAccessContext(context.Background(), rm) +} + +func SetAnonAccessContext(ctx context.Context, rm RM, settings SettingsAnonAccess) error { json, err := json.Marshal(settings) if err != nil { return err } - if _, resp, err := rm.Put(restAnonymous, bytes.NewBuffer(json)); err != nil && resp.StatusCode != http.StatusNoContent { + if _, resp, err := rm.Put(ctx, restAnonymous, bytes.NewBuffer(json)); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("email config not set: %v", err) } return nil } + +func SetAnonAccess(rm RM, settings SettingsAnonAccess) error { + return SetAnonAccessContext(context.Background(), rm, settings) +} diff --git a/rm/assets.go b/rm/assets.go index 25b3cdb..fd9c16f 100644 --- a/rm/assets.go +++ b/rm/assets.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -31,8 +32,7 @@ type listAssetsResponse struct { ContinuationToken string `json:"continuationToken"` } -// GetAssets returns a list of assets in the indicated repository -func GetAssets(rm RM, repo string) (items []RepositoryItemAsset, err error) { +func GetAssetsContext(ctx context.Context, rm RM, repo string) (items []RepositoryItemAsset, err error) { continuation := "" get := func() (listResp listAssetsResponse, err error) { @@ -42,7 +42,7 @@ func GetAssets(rm RM, repo string) (items []RepositoryItemAsset, err error) { url += "&continuationToken=" + continuation } - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return } @@ -71,8 +71,12 @@ func GetAssets(rm RM, repo string) (items []RepositoryItemAsset, err error) { return items, nil } -// GetAssetByID returns an asset by ID -func GetAssetByID(rm RM, id string) (items RepositoryItemAsset, err error) { +// GetAssets returns a list of assets in the indicated repository +func GetAssets(rm RM, repo string) (items []RepositoryItemAsset, err error) { + return GetAssetsContext(context.Background(), rm, repo) +} + +func GetAssetByIDContext(ctx context.Context, rm RM, id string) (items RepositoryItemAsset, err error) { doError := func(err error) error { return fmt.Errorf("no asset with id '%s': %v", id, err) } @@ -80,7 +84,7 @@ func GetAssetByID(rm RM, id string) (items RepositoryItemAsset, err error) { var item RepositoryItemAsset url := fmt.Sprintf("%s/%s", restAssets, id) - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return item, doError(err) } @@ -92,13 +96,22 @@ func GetAssetByID(rm RM, id string) (items RepositoryItemAsset, err error) { return item, nil } -// DeleteAssetByID deletes the asset indicated by ID -func DeleteAssetByID(rm RM, id string) error { +// GetAssetByID returns an asset by ID +func GetAssetByID(rm RM, id string) (items RepositoryItemAsset, err error) { + return GetAssetByIDContext(context.Background(), rm, id) +} + +func DeleteAssetByIDContext(ctx context.Context, rm RM, id string) error { url := fmt.Sprintf("%s/%s", restAssets, id) - if resp, err := rm.Del(url); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := rm.Del(ctx, url); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("asset not deleted '%s': %v", id, err) } return nil } + +// DeleteAssetByID deletes the asset indicated by ID +func DeleteAssetByID(rm RM, id string) error { + return DeleteAssetByIDContext(context.Background(), rm, id) +} diff --git a/rm/assets_test.go b/rm/assets_test.go index cedcfc9..1c26ec2 100644 --- a/rm/assets_test.go +++ b/rm/assets_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -130,7 +131,7 @@ func getAssetsTester(t *testing.T, repo string) { rm, mock := assetsTestRM(t) defer mock.Close() - assets, err := GetAssets(rm, repo) + assets, err := GetAssetsContext(context.Background(), rm, repo) if err != nil { panic(err) } @@ -162,7 +163,7 @@ func TestGetAssetByID(t *testing.T) { expectedAsset := dummyAssets["repo-maven"][0] - asset, err := GetAssetByID(rm, expectedAsset.ID) + asset, err := GetAssetByIDContext(context.Background(), rm, expectedAsset.ID) if err != nil { t.Error(err) } @@ -189,15 +190,15 @@ func TestDeleteAssetByID(t *testing.T) { dummyAssets[deleteMe.Repository] = append(dummyAssets[deleteMe.Repository], deleteMe) - if _, err := GetAssetByID(rm, deleteMe.ID); err != nil { + if _, err := GetAssetByIDContext(context.Background(), rm, deleteMe.ID); err != nil { t.Errorf("Error getting component: %v\n", err) } - if err := DeleteAssetByID(rm, deleteMe.ID); err != nil { + if err := DeleteAssetByIDContext(context.Background(), rm, deleteMe.ID); err != nil { t.Fatal(err) } - if _, err := GetAssetByID(rm, deleteMe.ID); err == nil { + if _, err := GetAssetByIDContext(context.Background(), rm, deleteMe.ID); err == nil { t.Errorf("Asset not deleted: %v\n", err) } } diff --git a/rm/components.go b/rm/components.go index 65f41b1..5d5e9eb 100644 --- a/rm/components.go +++ b/rm/components.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "io" @@ -286,8 +287,7 @@ func (a UploadComponentApt) write(w *multipart.Writer) error { return nil } -// GetComponents returns a list of components in the indicated repository -func GetComponents(rm RM, repo string) ([]RepositoryItem, error) { +func GetComponentsContext(ctx context.Context, rm RM, repo string) ([]RepositoryItem, error) { continuation := "" getComponents := func() (listResp listComponentsResponse, err error) { @@ -297,7 +297,7 @@ func GetComponents(rm RM, repo string) ([]RepositoryItem, error) { url += "&continuationToken=" + continuation } - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return } @@ -326,8 +326,12 @@ func GetComponents(rm RM, repo string) ([]RepositoryItem, error) { return items, nil } -// GetComponentByID returns a component by ID -func GetComponentByID(rm RM, id string) (RepositoryItem, error) { +// GetComponents returns a list of components in the indicated repository +func GetComponents(rm RM, repo string) ([]RepositoryItem, error) { + return GetComponentsContext(context.Background(), rm, repo) +} + +func GetComponentByIDContext(ctx context.Context, rm RM, id string) (RepositoryItem, error) { doError := func(err error) error { return fmt.Errorf("no component with id '%s': %v", id, err) } @@ -335,7 +339,7 @@ func GetComponentByID(rm RM, id string) (RepositoryItem, error) { var item RepositoryItem url := fmt.Sprintf("%s/%s", restComponents, id) - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return item, doError(err) } @@ -347,20 +351,28 @@ func GetComponentByID(rm RM, id string) (RepositoryItem, error) { return item, nil } -// DeleteComponentByID deletes the indicated component -func DeleteComponentByID(rm RM, id string) error { +// GetComponentByID returns a component by ID +func GetComponentByID(rm RM, id string) (RepositoryItem, error) { + return GetComponentByIDContext(context.Background(), rm, id) +} + +func DeleteComponentByIDContext(ctx context.Context, rm RM, id string) error { url := fmt.Sprintf("%s/%s", restComponents, id) - if resp, err := rm.Del(url); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := rm.Del(ctx, url); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("component not deleted '%s': %v", id, err) } return nil } -// UploadComponent uploads a component to repository manager -func UploadComponent(rm RM, repo string, component UploadComponentWriter) error { - if _, err := GetRepositoryByName(rm, repo); err != nil { +// DeleteComponentByID deletes the indicated component +func DeleteComponentByID(rm RM, id string) error { + return DeleteComponentByIDContext(context.Background(), rm, id) +} + +func UploadComponentContext(ctx context.Context, rm RM, repo string, component UploadComponentWriter) error { + if _, err := GetRepositoryByNameContext(ctx, rm, repo); err != nil { return fmt.Errorf("could not find repository: %v", err) } @@ -379,7 +391,7 @@ func UploadComponent(rm RM, repo string, component UploadComponentWriter) error }() url := fmt.Sprintf(restListComponentsByRepo, repo) - req, err := rm.NewRequest("POST", url, b) + req, err := rm.NewRequest(ctx, "POST", url, b) req.Header.Set("Content-Type", m.FormDataContentType()) if err != nil { return doError(err) @@ -391,3 +403,8 @@ func UploadComponent(rm RM, repo string, component UploadComponentWriter) error return nil } + +// UploadComponent uploads a component to repository manager +func UploadComponent(rm RM, repo string, component UploadComponentWriter) error { + return UploadComponentContext(context.Background(), rm, repo, component) +} diff --git a/rm/components_test.go b/rm/components_test.go index 7fca5c8..58b045c 100644 --- a/rm/components_test.go +++ b/rm/components_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -168,7 +169,7 @@ func getComponentsTester(t *testing.T, repo string) { rm, mock := componentsTestRM(t) defer mock.Close() - components, err := GetComponents(rm, repo) + components, err := GetComponentsContext(context.Background(), rm, repo) if err != nil { panic(err) } @@ -200,7 +201,7 @@ func TestGetComponentByID(t *testing.T) { expectedComponent := dummyComponents["repo-maven"][0] - component, err := GetComponentByID(rm, expectedComponent.ID) + component, err := GetComponentByIDContext(context.Background(), rm, expectedComponent.ID) if err != nil { t.Error(err) } @@ -218,13 +219,13 @@ func componentUploader(t *testing.T, expected RepositoryItem, upload UploadCompo defer mock.Close() // if err := UploadComponent(rm, expected.Repository, coordinate, file); err != nil { - if err := UploadComponent(rm, expected.Repository, upload); err != nil { + if err := UploadComponentContext(context.Background(), rm, expected.Repository, upload); err != nil { t.Error(err) } expected.ID = dummyNewComponentID - component, err := GetComponentByID(rm, expected.ID) + component, err := GetComponentByIDContext(context.Background(), rm, expected.ID) if err != nil { t.Error(err) } @@ -314,17 +315,17 @@ func TestDeleteComponentByID(t *testing.T) { } // if err = UploadComponent(rm, deleteMe.Repository, coord, nil); err != nil { - if err = UploadComponent(rm, deleteMe.Repository, upload); err != nil { + if err = UploadComponentContext(context.Background(), rm, deleteMe.Repository, upload); err != nil { t.Error(err) } deleteMe.ID = dummyNewComponentID - if err = DeleteComponentByID(rm, deleteMe.ID); err != nil { + if err = DeleteComponentByIDContext(context.Background(), rm, deleteMe.ID); err != nil { t.Fatal(err) } - if _, err := GetComponentByID(rm, deleteMe.ID); err == nil { + if _, err := GetComponentByIDContext(context.Background(), rm, deleteMe.ID); err == nil { t.Errorf("Component not deleted: %v\n", err) } } @@ -335,7 +336,7 @@ func ExampleGetComponents() { panic(err) } - items, err := GetComponents(rm, "maven-central") + items, err := GetComponentsContext(context.Background(), rm, "maven-central") if err != nil { panic(err) } diff --git a/rm/email.go b/rm/email.go index fc7622b..0992bbb 100644 --- a/rm/email.go +++ b/rm/email.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -24,24 +25,27 @@ type EmailConfig struct { NexusTrustStoreEnabled bool `json:"nexusTrustStoreEnabled"` } -func SetEmailConfig(rm RM, config EmailConfig) error { - +func SetEmailConfigContext(ctx context.Context, rm RM, config EmailConfig) error { json, err := json.Marshal(config) if err != nil { return err } - if _, resp, err := rm.Put(restEmail, bytes.NewBuffer(json)); err != nil && resp.StatusCode != http.StatusNoContent { + if _, resp, err := rm.Put(ctx, restEmail, bytes.NewBuffer(json)); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("email config not set: %v", err) } return nil } -func GetEmailConfig(rm RM) (EmailConfig, error) { +func SetEmailConfig(rm RM, config EmailConfig) error { + return SetEmailConfigContext(context.Background(), rm, config) +} + +func GetEmailConfigContext(ctx context.Context, rm RM) (EmailConfig, error) { var config EmailConfig - body, resp, err := rm.Get(restEmail) + body, resp, err := rm.Get(ctx, restEmail) if err != nil && resp.StatusCode != http.StatusNoContent { return EmailConfig{}, fmt.Errorf("email config can't getting: %v", err) } @@ -53,11 +57,18 @@ func GetEmailConfig(rm RM) (EmailConfig, error) { return config, nil } -func DeleteEmailConfig(rm RM) error { +func GetEmailConfig(rm RM) (EmailConfig, error) { + return GetEmailConfigContext(context.Background(), rm) +} - if resp, err := rm.Del(restEmail); err != nil && resp.StatusCode != http.StatusNoContent { +func DeleteEmailConfigContext(ctx context.Context, rm RM) error { + if resp, err := rm.Del(ctx, restEmail); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("email config not deleted: %v", err) } return nil } + +func DeleteEmailConfig(rm RM) error { + return DeleteEmailConfigContext(context.Background(), rm) +} diff --git a/rm/groovyBlobStore.go b/rm/groovyBlobStore.go index 3ddc1ef..a41649e 100644 --- a/rm/groovyBlobStore.go +++ b/rm/groovyBlobStore.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "fmt" "text/template" ) @@ -54,13 +55,12 @@ func DeleteBlobStore(rm RM, name string) error { return err } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(rm, newAnonGroovyScript(buf.String()), nil) return err } */ -// CreateFileBlobStore creates a blobstore -func CreateFileBlobStore(rm RM, name, path string) error { +func CreateFileBlobStoreContext(ctx context.Context, rm RM, name, path string) error { tmpl, err := template.New("fbs").Parse(groovyCreateFileBlobStore) if err != nil { return fmt.Errorf("could not parse template: %v", err) @@ -72,12 +72,16 @@ func CreateFileBlobStore(rm RM, name, path string) error { return fmt.Errorf("could not create file blobstore from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create file blobstore: %v", err) } -// CreateBlobStoreGroup creates a blobstore -func CreateBlobStoreGroup(rm RM, name string, blobStores []string) error { +// CreateFileBlobStore creates a blobstore +func CreateFileBlobStore(rm RM, name, path string) error { + return CreateFileBlobStoreContext(context.Background(), rm, name, path) +} + +func CreateBlobStoreGroupContext(ctx context.Context, rm RM, name string, blobStores []string) error { tmpl, err := template.New("group").Parse(groovyCreateBlobStoreGroup) if err != nil { return fmt.Errorf("could not parse template: %v", err) @@ -89,6 +93,11 @@ func CreateBlobStoreGroup(rm RM, name string, blobStores []string) error { return fmt.Errorf("could not create group blobstore from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create group blobstore: %v", err) } + +// CreateBlobStoreGroup creates a blobstore group +func CreateBlobStoreGroup(rm RM, name string, blobStores []string) error { + return CreateBlobStoreGroupContext(context.Background(), rm, name, blobStores) +} diff --git a/rm/groovyBlobStore_test.go b/rm/groovyBlobStore_test.go index 22415df..da959e3 100644 --- a/rm/groovyBlobStore_test.go +++ b/rm/groovyBlobStore_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "testing" ) @@ -9,7 +10,7 @@ func TestCreateFileBlobStore(t *testing.T) { rm, mock := repositoriesTestRM(t) defer mock.Close() - err := CreateFileBlobStore(rm, "testname", "testpath") + err := CreateFileBlobStoreContext(context.Background(), rm, "testname", "testpath") if err != nil { t.Error(err) } @@ -22,11 +23,11 @@ func TestCreateBlobStoreGroup(t *testing.T) { rm, mock := repositoriesTestRM(t) defer mock.Close() - CreateFileBlobStore(rm, "f1", "pathf1") - CreateFileBlobStore(rm, "f2", "pathf2") - CreateFileBlobStore(rm, "f3", "pathf3") + CreateFileBlobStoreContext(context.Background(), rm, "f1", "pathf1") + CreateFileBlobStoreContext(context.Background(), rm, "f2", "pathf2") + CreateFileBlobStoreContext(context.Background(), rm, "f3", "pathf3") - err := CreateBlobStoreGroup(rm, "grpname", []string{"f1", "f2", "f3"}) + err := CreateBlobStoreGroupContext(context.Background(), rm, "grpname", []string{"f1", "f2", "f3"}) if err != nil { t.Error(err) } diff --git a/rm/groovyRepository.go b/rm/groovyRepository.go index 50b6d94..af08934 100644 --- a/rm/groovyRepository.go +++ b/rm/groovyRepository.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "fmt" "text/template" ) @@ -71,8 +72,7 @@ type repositoryGroup struct { Members []string } -// CreateHostedRepository creates a hosted repository of the indicated format -func CreateHostedRepository(rm RM, format repositoryFormat, config repositoryHosted) error { +func CreateHostedRepositoryContext(ctx context.Context, rm RM, format repositoryFormat, config repositoryHosted) error { var groovyTmpl string switch format { case Maven: @@ -112,12 +112,16 @@ func CreateHostedRepository(rm RM, format repositoryFormat, config repositoryHos return fmt.Errorf("could not create hosted repository from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create hosted repository: %v", err) } -// CreateProxyRepository creates a proxy repository of the indicated format -func CreateProxyRepository(rm RM, format repositoryFormat, config repositoryProxy) error { +// CreateHostedRepository creates a hosted repository of the indicated format +func CreateHostedRepository(rm RM, format repositoryFormat, config repositoryHosted) error { + return CreateHostedRepositoryContext(context.Background(), rm, format, config) +} + +func CreateProxyRepositoryContext(ctx context.Context, rm RM, format repositoryFormat, config repositoryProxy) error { var groovyTmpl string switch format { case Maven: @@ -157,12 +161,16 @@ func CreateProxyRepository(rm RM, format repositoryFormat, config repositoryProx return fmt.Errorf("could not create proxy repository from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create proxy repository: %v", err) } -// CreateGroupRepository creates a group repository of the indicated format -func CreateGroupRepository(rm RM, format repositoryFormat, config repositoryGroup) error { +// CreateProxyRepository creates a proxy repository of the indicated format +func CreateProxyRepository(rm RM, format repositoryFormat, config repositoryProxy) error { + return CreateProxyRepositoryContext(context.Background(), rm, format, config) +} + +func CreateGroupRepositoryContext(ctx context.Context, rm RM, format repositoryFormat, config repositoryGroup) error { var groovyTmpl string switch format { case Maven: @@ -202,6 +210,11 @@ func CreateGroupRepository(rm RM, format repositoryFormat, config repositoryGrou return fmt.Errorf("could not create group repository from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create group repository: %v", err) } + +// CreateGroupRepository creates a group repository of the indicated format +func CreateGroupRepository(rm RM, format repositoryFormat, config repositoryGroup) error { + return CreateGroupRepositoryContext(context.Background(), rm, format, config) +} diff --git a/rm/groovyRepository_test.go b/rm/groovyRepository_test.go index c817be5..74233b8 100644 --- a/rm/groovyRepository_test.go +++ b/rm/groovyRepository_test.go @@ -1,16 +1,12 @@ package nexusrm -import ( -// "testing" -) - /* func TestCreateFileBlobStore(t *testing.T) { t.Skip("Needs new framework") rm, mock := repositoriesTestRM(t) defer mock.Close() - err := CreateFileBlobStore(rm, "testname", "testpath") + err := CreateFileBlobStoreContext(rm, "testname", "testpath") if err != nil { t.Error(err) } diff --git a/rm/maintenance.go b/rm/maintenance.go index 105ee94..0f751c7 100644 --- a/rm/maintenance.go +++ b/rm/maintenance.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -22,8 +23,7 @@ type DatabaseState struct { IndexErrors int `json:"indexErrors"` } -// CheckDatabase returns the state of the named database -func CheckDatabase(rm RM, dbName string) (DatabaseState, error) { +func CheckDatabaseContext(ctx context.Context, rm RM, dbName string) (DatabaseState, error) { doError := func(err error) error { return fmt.Errorf("error checking status of database '%s': %v", dbName, err) } @@ -31,7 +31,7 @@ func CheckDatabase(rm RM, dbName string) (DatabaseState, error) { var state DatabaseState url := fmt.Sprintf(restMaintenanceDBCheck, dbName) - body, resp, err := rm.Put(url, nil) + body, resp, err := rm.Put(ctx, url, nil) if err != nil || resp.StatusCode != http.StatusOK { return state, doError(err) } @@ -43,8 +43,12 @@ func CheckDatabase(rm RM, dbName string) (DatabaseState, error) { return state, nil } -// CheckAllDatabases returns state on all of the databases -func CheckAllDatabases(rm RM) (states map[string]DatabaseState, err error) { +// CheckDatabase returns the state of the named database +func CheckDatabase(rm RM, dbName string) (DatabaseState, error) { + return CheckDatabaseContext(context.Background(), rm, dbName) +} + +func CheckAllDatabasesContext(ctx context.Context, rm RM) (states map[string]DatabaseState, err error) { states = make(map[string]DatabaseState) check := func(dbName string) { @@ -52,7 +56,7 @@ func CheckAllDatabases(rm RM) (states map[string]DatabaseState, err error) { return } - if state, er := CheckDatabase(rm, dbName); er != nil { + if state, er := CheckDatabaseContext(ctx, rm, dbName); er != nil { err = fmt.Errorf("error with '%s' database when all states: %v", dbName, er) } else { states[dbName] = state @@ -66,3 +70,8 @@ func CheckAllDatabases(rm RM) (states map[string]DatabaseState, err error) { return } + +// CheckAllDatabases returns state on all of the databases +func CheckAllDatabases(rm RM) (states map[string]DatabaseState, err error) { + return CheckAllDatabasesContext(context.Background(), rm) +} diff --git a/rm/maintenance_test.go b/rm/maintenance_test.go index c125193..9420de9 100644 --- a/rm/maintenance_test.go +++ b/rm/maintenance_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -50,7 +51,7 @@ func TestCheckDatabase(t *testing.T) { db := ComponentDB - state, err := CheckDatabase(rm, db) + state, err := CheckDatabaseContext(context.Background(), rm, db) if err != nil { panic(err) } @@ -68,7 +69,7 @@ func TestCheckAllDatabases(t *testing.T) { rm, mock := maintenanceTestRM(t) defer mock.Close() - states, err := CheckAllDatabases(rm) + states, err := CheckAllDatabasesContext(context.Background(), rm) if err != nil { panic(err) } diff --git a/rm/readOnly.go b/rm/readOnly.go index 103d057..49035c3 100644 --- a/rm/readOnly.go +++ b/rm/readOnly.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -39,9 +40,8 @@ func (s ReadOnlyState) String() string { return buf.String() } -// GetReadOnlyState returns the read-only state of the RM instance -func GetReadOnlyState(rm RM) (state ReadOnlyState, err error) { - body, resp, err := rm.Get(restReadOnly) +func GetReadOnlyStateContext(ctx context.Context, rm RM) (state ReadOnlyState, err error) { + body, resp, err := rm.Get(ctx, restReadOnly) if err != nil { return state, fmt.Errorf("could not get read-only state: %v", err) } @@ -55,9 +55,13 @@ func GetReadOnlyState(rm RM) (state ReadOnlyState, err error) { return } -// ReadOnlyEnable enables read-only mode for the RM instance -func ReadOnlyEnable(rm RM) (state ReadOnlyState, err error) { - body, resp, err := rm.Post(restReadOnlyFreeze, nil) +// GetReadOnlyState returns the read-only state of the RM instance +func GetReadOnlyState(rm RM) (state ReadOnlyState, err error) { + return GetReadOnlyStateContext(context.Background(), rm) +} + +func ReadOnlyEnableContext(ctx context.Context, rm RM) (state ReadOnlyState, err error) { + body, resp, err := rm.Post(ctx, restReadOnlyFreeze, nil) if err != nil && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound { return } @@ -67,14 +71,18 @@ func ReadOnlyEnable(rm RM) (state ReadOnlyState, err error) { return } -// ReadOnlyRelease disables read-only mode for the RM instance -func ReadOnlyRelease(rm RM, force bool) (state ReadOnlyState, err error) { +// ReadOnlyEnable enables read-only mode for the RM instance +func ReadOnlyEnable(rm RM) (state ReadOnlyState, err error) { + return ReadOnlyEnableContext(context.Background(), rm) +} + +func ReadOnlyReleaseContext(ctx context.Context, rm RM, force bool) (state ReadOnlyState, err error) { endpoint := restReadOnlyRelease if force { endpoint = restReadOnlyForceRelease } - body, resp, err := rm.Post(endpoint, nil) + body, resp, err := rm.Post(ctx, endpoint, nil) if err != nil && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound { return } @@ -83,3 +91,8 @@ func ReadOnlyRelease(rm RM, force bool) (state ReadOnlyState, err error) { return } + +// ReadOnlyRelease disables read-only mode for the RM instance +func ReadOnlyRelease(rm RM, force bool) (state ReadOnlyState, err error) { + return ReadOnlyReleaseContext(context.Background(), rm, force) +} diff --git a/rm/repositories.go b/rm/repositories.go index ff0678f..353ff45 100644 --- a/rm/repositories.go +++ b/rm/repositories.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -161,7 +162,7 @@ type RepositoryRawHosted struct { Raw AttributesRaw `json:"raw"` } -func CreateRepositoryHosted(rm RM, format repositoryFormat, r interface{}) error { +func CreateRepositoryHostedContext(ctx context.Context, rm RM, format repositoryFormat, r interface{}) error { buf, err := json.Marshal(r) if err != nil { return fmt.Errorf("could not marshal: %v", err) @@ -196,7 +197,7 @@ func CreateRepositoryHosted(rm RM, format repositoryFormat, r interface{}) error case Yum: restEndpointRepository = restRepositoriesHostedYum } - _, resp, err := rm.Post(restEndpointRepository, bytes.NewBuffer(buf)) + _, resp, err := rm.Post(ctx, restEndpointRepository, bytes.NewBuffer(buf)) if err != nil && resp == nil { return fmt.Errorf("could not create repository: %v", err) } @@ -204,7 +205,11 @@ func CreateRepositoryHosted(rm RM, format repositoryFormat, r interface{}) error return nil } -func CreateRepositoryProxy(rm RM, format repositoryFormat, r interface{}) error { +func CreateRepositoryHosted(rm RM, format repositoryFormat, r interface{}) error { + return CreateRepositoryHostedContext(context.Background(), rm, format, r) +} + +func CreateRepositoryProxyContext(ctx context.Context, rm RM, format repositoryFormat, r interface{}) error { buf, err := json.Marshal(r) if err != nil { return fmt.Errorf("could not marshal: %v", err) @@ -247,7 +252,7 @@ func CreateRepositoryProxy(rm RM, format repositoryFormat, r interface{}) error case Yum: restEndpointRepository = restRepositoriesProxyYum } - _, resp, err := rm.Post(restEndpointRepository, bytes.NewBuffer(buf)) + _, resp, err := rm.Post(ctx, restEndpointRepository, bytes.NewBuffer(buf)) if err != nil && resp == nil { return fmt.Errorf("could not create repository: %v", err) } @@ -255,23 +260,30 @@ func CreateRepositoryProxy(rm RM, format repositoryFormat, r interface{}) error return nil } -func DeleteRepositoryByName(rm RM, name string) error { +func CreateRepositoryProxy(rm RM, format repositoryFormat, r interface{}) error { + return CreateRepositoryProxyContext(context.Background(), rm, format, r) +} + +func DeleteRepositoryByNameContext(ctx context.Context, rm RM, name string) error { url := fmt.Sprintf("%s/%s", restRepositories, name) - if resp, err := rm.Del(url); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := rm.Del(ctx, url); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("repository not deleted '%s': %v", name, err) } return nil } -// GetRepositories returns a list of components in the indicated repository -func GetRepositories(rm RM) ([]Repository, error) { +func DeleteRepositoryByName(rm RM, name string) error { + return DeleteRepositoryByNameContext(context.Background(), rm, name) +} + +func GetRepositoriesContext(ctx context.Context, rm RM) ([]Repository, error) { doError := func(err error) error { return fmt.Errorf("could not find repositories: %v", err) } - body, resp, err := rm.Get(restRepositories) + body, resp, err := rm.Get(ctx, restRepositories) if err != nil || resp.StatusCode != http.StatusOK { return nil, doError(err) } @@ -284,9 +296,13 @@ func GetRepositories(rm RM) ([]Repository, error) { return repos, nil } -// GetRepositoryByName returns information on a named repository -func GetRepositoryByName(rm RM, name string) (repo Repository, err error) { - repos, err := GetRepositories(rm) +// GetRepositories returns a list of components in the indicated repository +func GetRepositories(rm RM) ([]Repository, error) { + return GetRepositoriesContext(context.Background(), rm) +} + +func GetRepositoryByNameContext(ctx context.Context, rm RM, name string) (repo Repository, err error) { + repos, err := GetRepositoriesContext(ctx, rm) if err != nil { return repo, fmt.Errorf("could not get list of repositories: %v", err) } @@ -299,3 +315,8 @@ func GetRepositoryByName(rm RM, name string) (repo Repository, err error) { return repo, fmt.Errorf("did not find repository '%s': %v", name, err) } + +// GetRepositoryByName returns information on a named repository +func GetRepositoryByName(rm RM, name string) (repo Repository, err error) { + return GetRepositoryByNameContext(context.Background(), rm, name) +} diff --git a/rm/repositories_test.go b/rm/repositories_test.go index ad97184..0f1dea0 100644 --- a/rm/repositories_test.go +++ b/rm/repositories_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -38,7 +39,7 @@ func TestGetRepositories(t *testing.T) { rm, mock := repositoriesTestRM(t) defer mock.Close() - repos, err := GetRepositories(rm) + repos, err := GetRepositoriesContext(context.Background(), rm) if err != nil { t.Error(err) } @@ -57,7 +58,7 @@ func TestGetRepositoryByName(t *testing.T) { dummyRepoIdx := 0 - repo, err := GetRepositoryByName(rm, dummyRepos[dummyRepoIdx].Name) + repo, err := GetRepositoryByNameContext(context.Background(), rm, dummyRepos[dummyRepoIdx].Name) if err != nil { t.Error(err) } diff --git a/rm/roles.go b/rm/roles.go index 47b8dd6..8ecc650 100644 --- a/rm/roles.go +++ b/rm/roles.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -17,13 +18,13 @@ type Role struct { Roles []string `json:"roles"` } -func CreateRole(rm RM, role Role) error { +func CreateRoleContext(ctx context.Context, rm RM, role Role) error { json, err := json.Marshal(role) if err != nil { return err } - _, resp, err := rm.Post(restRole, bytes.NewBuffer(json)) + _, resp, err := rm.Post(ctx, restRole, bytes.NewBuffer(json)) if err != nil && resp.StatusCode != http.StatusNoContent { return err } @@ -31,12 +32,20 @@ func CreateRole(rm RM, role Role) error { return nil } -func DeleteRoleById(rm RM, id string) error { +func CreateRole(rm RM, role Role) error { + return CreateRoleContext(context.Background(), rm, role) +} + +func DeleteRoleByIdContext(ctx context.Context, rm RM, id string) error { url := fmt.Sprintf("%s/%s", restRole, id) - if resp, err := rm.Del(url); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := rm.Del(ctx, url); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("role not deleted '%s': %v", id, err) } return nil } + +func DeleteRoleById(rm RM, id string) error { + return DeleteRoleByIdContext(context.Background(), rm, id) +} diff --git a/rm/scripts.go b/rm/scripts.go index b685176..a041c95 100644 --- a/rm/scripts.go +++ b/rm/scripts.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -24,13 +25,12 @@ type runResponse struct { Result string `json:"result"` } -// ScriptList lists all of the uploaded scripts in Repository Manager -func ScriptList(rm RM) ([]Script, error) { +func ScriptListContext(ctx context.Context, rm RM) ([]Script, error) { doError := func(err error) error { return fmt.Errorf("could not list scripts: %v", err) } - body, _, err := rm.Get(restScript) + body, _, err := rm.Get(ctx, restScript) if err != nil { return nil, doError(err) } @@ -43,8 +43,12 @@ func ScriptList(rm RM) ([]Script, error) { return scripts, nil } -// ScriptGet returns the named script -func ScriptGet(rm RM, name string) (Script, error) { +// ScriptList lists all of the uploaded scripts in Repository Manager +func ScriptList(rm RM) ([]Script, error) { + return ScriptListContext(context.Background(), rm) +} + +func ScriptGetContext(ctx context.Context, rm RM, name string) (Script, error) { doError := func(err error) error { return fmt.Errorf("could not find script '%s': %v", name, err) } @@ -52,7 +56,7 @@ func ScriptGet(rm RM, name string) (Script, error) { var script Script endpoint := fmt.Sprintf("%s/%s", restScript, name) - body, _, err := rm.Get(endpoint) + body, _, err := rm.Get(ctx, endpoint) if err != nil { return script, doError(err) } @@ -64,8 +68,12 @@ func ScriptGet(rm RM, name string) (Script, error) { return script, nil } -// ScriptUpload uploads the given Script to Repository Manager -func ScriptUpload(rm RM, script Script) error { +// ScriptGet returns the named script +func ScriptGet(rm RM, name string) (Script, error) { + return ScriptGetContext(context.Background(), rm, name) +} + +func ScriptUploadContext(ctx context.Context, rm RM, script Script) error { doError := func(err error) error { return fmt.Errorf("could not upload script '%s': %v", script.Name, err) } @@ -75,7 +83,7 @@ func ScriptUpload(rm RM, script Script) error { return doError(err) } - _, resp, err := rm.Post(restScript, bytes.NewBuffer(json)) + _, resp, err := rm.Post(ctx, restScript, bytes.NewBuffer(json)) if err != nil && resp.StatusCode != http.StatusNoContent { return doError(err) } @@ -83,8 +91,12 @@ func ScriptUpload(rm RM, script Script) error { return nil } -// ScriptUpdate update the contents of the given script -func ScriptUpdate(rm RM, script Script) error { +// ScriptUpload uploads the given Script to Repository Manager +func ScriptUpload(rm RM, script Script) error { + return ScriptUploadContext(context.Background(), rm, script) +} + +func ScriptUpdateContext(ctx context.Context, rm RM, script Script) error { doError := func(err error) error { return fmt.Errorf("could not update script '%s': %v", script.Name, err) } @@ -95,7 +107,7 @@ func ScriptUpdate(rm RM, script Script) error { } endpoint := fmt.Sprintf("%s/%s", restScript, script.Name) - _, resp, err := rm.Put(endpoint, bytes.NewBuffer(json)) + _, resp, err := rm.Put(ctx, endpoint, bytes.NewBuffer(json)) if err != nil && resp.StatusCode != http.StatusNoContent { return doError(err) } @@ -103,14 +115,18 @@ func ScriptUpdate(rm RM, script Script) error { return nil } -// ScriptRun executes the named Script -func ScriptRun(rm RM, name string, arguments []byte) (string, error) { +// ScriptUpdate update the contents of the given script +func ScriptUpdate(rm RM, script Script) error { + return ScriptUpdateContext(context.Background(), rm, script) +} + +func ScriptRunContext(ctx context.Context, rm RM, name string, arguments []byte) (string, error) { doError := func(err error) error { return fmt.Errorf("could not run script '%s': %v", name, err) } endpoint := fmt.Sprintf(restScriptRun, name) - body, _, err := rm.Post(endpoint, bytes.NewBuffer(arguments)) // TODO: Better response handling + body, _, err := rm.Post(ctx, endpoint, bytes.NewBuffer(arguments)) // TODO: Better response handling if err != nil { return "", doError(err) } @@ -123,22 +139,35 @@ func ScriptRun(rm RM, name string, arguments []byte) (string, error) { return resp.Result, nil } -// ScriptRunOnce takes the given Script, uploads it, executes it, and deletes it -func ScriptRunOnce(rm RM, script Script, arguments []byte) (string, error) { - if err := ScriptUpload(rm, script); err != nil { +// ScriptRun executes the named Script +func ScriptRun(rm RM, name string, arguments []byte) (string, error) { + return ScriptRunContext(context.Background(), rm, name, arguments) +} + +func ScriptRunOnceContext(ctx context.Context, rm RM, script Script, arguments []byte) (string, error) { + if err := ScriptUploadContext(ctx, rm, script); err != nil { return "", err } - defer ScriptDelete(rm, script.Name) + defer ScriptDeleteContext(ctx, rm, script.Name) - return ScriptRun(rm, script.Name, arguments) + return ScriptRunContext(ctx, rm, script.Name, arguments) } -// ScriptDelete removes the name, uploaded script -func ScriptDelete(rm RM, name string) error { +// ScriptRunOnce takes the given Script, uploads it, executes it, and deletes it +func ScriptRunOnce(rm RM, script Script, arguments []byte) (string, error) { + return ScriptRunOnceContext(context.Background(), rm, script, arguments) +} + +func ScriptDeleteContext(ctx context.Context, rm RM, name string) error { endpoint := fmt.Sprintf("%s/%s", restScript, name) - resp, err := rm.Del(endpoint) + resp, err := rm.Del(ctx, endpoint) if err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("could not delete '%s': %v", name, err) } return nil } + +// ScriptDelete removes the name, uploaded script +func ScriptDelete(rm RM, name string) error { + return ScriptDeleteContext(context.Background(), rm, name) +} diff --git a/rm/scripts_test.go b/rm/scripts_test.go index f02ffea..896ba89 100644 --- a/rm/scripts_test.go +++ b/rm/scripts_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -128,7 +129,7 @@ func TestScriptList(t *testing.T) { rm, mock := scriptsTestRM(t) defer mock.Close() - scripts, err := ScriptList(rm) + scripts, err := ScriptListContext(context.Background(), rm) if err != nil { t.Error(err) } @@ -151,7 +152,7 @@ func TestScriptGet(t *testing.T) { dummyScriptsIdx := 1 - script, err := ScriptGet(rm, dummyScripts[dummyScriptsIdx].Name) + script, err := ScriptGetContext(context.Background(), rm, dummyScripts[dummyScriptsIdx].Name) if err != nil { t.Error(err) } @@ -168,11 +169,11 @@ func TestScriptUpload(t *testing.T) { newScript := Script{Name: "newScript", Content: "log.info('I am new!')", Type: "groovy"} - if err := ScriptUpload(rm, newScript); err != nil { + if err := ScriptUploadContext(context.Background(), rm, newScript); err != nil { t.Error(err) } - script, err := ScriptGet(rm, newScript.Name) + script, err := ScriptGetContext(context.Background(), rm, newScript.Name) if err != nil { t.Error(err) } @@ -197,11 +198,11 @@ func TestScriptUpdate(t *testing.T) { t.Fatal("I am an idiot") } - if err := ScriptUpdate(rm, updatedScript); err != nil { + if err := ScriptUpdateContext(context.Background(), rm, updatedScript); err != nil { t.Error(err) } - script, err := ScriptGet(rm, updatedScript.Name) + script, err := ScriptGetContext(context.Background(), rm, updatedScript.Name) if err != nil { t.Error(err) } @@ -218,15 +219,15 @@ func TestScriptDelete(t *testing.T) { deleteMe := Script{Name: "deleteMe", Content: "log.info('Existence is pain!')", Type: "groovy"} - if err := ScriptUpload(rm, deleteMe); err != nil { + if err := ScriptUploadContext(context.Background(), rm, deleteMe); err != nil { t.Error(err) } - if err := ScriptDelete(rm, deleteMe.Name); err != nil { + if err := ScriptDeleteContext(context.Background(), rm, deleteMe.Name); err != nil { t.Error(err) } - if _, err := ScriptGet(rm, deleteMe.Name); err == nil { + if _, err := ScriptGetContext(context.Background(), rm, deleteMe.Name); err == nil { t.Error("Found script which should have been deleted") } } @@ -238,11 +239,11 @@ func TestScriptRun(t *testing.T) { script := Script{Name: "scriptArgsTest", Content: "return args", Type: "groovy"} input := "this is a test" - if err := ScriptUpload(rm, script); err != nil { + if err := ScriptUploadContext(context.Background(), rm, script); err != nil { t.Error(err) } - ret, err := ScriptRun(rm, script.Name, []byte(input)) + ret, err := ScriptRunContext(context.Background(), rm, script.Name, []byte(input)) if err != nil { t.Error(err) } @@ -251,7 +252,7 @@ func TestScriptRun(t *testing.T) { t.Errorf("Did not get expected script output: %s\n", ret) } - if err = ScriptDelete(rm, script.Name); err != nil { + if err = ScriptDeleteContext(context.Background(), rm, script.Name); err != nil { t.Error(err) } } @@ -263,7 +264,7 @@ func TestScriptRunOnce(t *testing.T) { script := Script{Name: "scriptArgsTest", Content: "return args", Type: "groovy"} input := "this is a test" - ret, err := ScriptRunOnce(rm, script, []byte(input)) + ret, err := ScriptRunOnceContext(context.Background(), rm, script, []byte(input)) if err != nil { t.Error(err) } @@ -272,7 +273,7 @@ func TestScriptRunOnce(t *testing.T) { t.Errorf("Did not get expected script output: %s\n", ret) } - if _, err = ScriptGet(rm, script.Name); err == nil { + if _, err = ScriptGetContext(context.Background(), rm, script.Name); err == nil { t.Error("Found script which should have been deleted") } } diff --git a/rm/search.go b/rm/search.go index c8b7c94..da67166 100644 --- a/rm/search.go +++ b/rm/search.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -104,7 +105,7 @@ func NewSearchQueryBuilder() *SearchQueryBuilder { return b } -func search(rm RM, endpoint string, queryBuilder nexus.SearchQueryBuilder, responseHandler func([]byte) (string, error)) error { +func search(ctx context.Context, rm RM, endpoint string, queryBuilder nexus.SearchQueryBuilder, responseHandler func([]byte) (string, error)) error { continuation := "" queryEndpoint := fmt.Sprintf("%s?%s", endpoint, queryBuilder.Build()) @@ -115,7 +116,7 @@ func search(rm RM, endpoint string, queryBuilder nexus.SearchQueryBuilder, respo url += "&continuationToken=" + continuation } - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return } @@ -142,11 +143,10 @@ func search(rm RM, endpoint string, queryBuilder nexus.SearchQueryBuilder, respo return nil } -// SearchComponents allows searching the indicated RM instance for specific components -func SearchComponents(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItem, error) { +func SearchComponentsContext(ctx context.Context, rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItem, error) { items := make([]RepositoryItem, 0) - err := search(rm, restSearchComponents, query, func(body []byte) (string, error) { + err := search(ctx, rm, restSearchComponents, query, func(body []byte) (string, error) { var resp searchComponentsResponse if er := json.Unmarshal(body, &resp); er != nil { return "", er @@ -160,11 +160,15 @@ func SearchComponents(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItem, return items, err } -// SearchAssets allows searching the indicated RM instance for specific assets -func SearchAssets(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItemAsset, error) { +// SearchComponents allows searching the indicated RM instance for specific components +func SearchComponents(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItem, error) { + return SearchComponentsContext(context.Background(), rm, query) +} + +func SearchAssetsContext(ctx context.Context, rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItemAsset, error) { items := make([]RepositoryItemAsset, 0) - err := search(rm, restSearchAssets, query, func(body []byte) (string, error) { + err := search(ctx, rm, restSearchAssets, query, func(body []byte) (string, error) { var resp searchAssetsResponse if er := json.Unmarshal(body, &resp); er != nil { return "", er @@ -177,3 +181,8 @@ func SearchAssets(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItemAsset, return items, err } + +// SearchAssets allows searching the indicated RM instance for specific assets +func SearchAssets(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItemAsset, error) { + return SearchAssetsContext(context.Background(), rm, query) +} diff --git a/rm/search_test.go b/rm/search_test.go index a2cd3de..9613827 100644 --- a/rm/search_test.go +++ b/rm/search_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -78,7 +79,7 @@ func TestSearchComponents(t *testing.T) { repo := "repo-maven" query := NewSearchQueryBuilder().Repository(repo) - components, err := SearchComponents(rm, query) + components, err := SearchComponentsContext(context.Background(), rm, query) if err != nil { t.Fatalf("Did not complete search: %v", err) } @@ -103,7 +104,7 @@ func TestSearchAssets(t *testing.T) { repo := "repo-maven" query := NewSearchQueryBuilder().Repository(repo) - assets, err := SearchAssets(rm, query) + assets, err := SearchAssetsContext(context.Background(), rm, query) if err != nil { t.Error(err) } @@ -128,7 +129,7 @@ func ExampleSearchComponents() { } query := NewSearchQueryBuilder().Repository("maven-releases") - components, err := SearchComponents(rm, query) + components, err := SearchComponentsContext(context.Background(), rm, query) if err != nil { panic(err) } diff --git a/rm/staging.go b/rm/staging.go index 50eef0c..a1c237c 100644 --- a/rm/staging.go +++ b/rm/staging.go @@ -1,6 +1,9 @@ package nexusrm -import "fmt" +import ( + "context" + "fmt" +) // service/rest/v1/staging/move/{repository} const ( @@ -38,19 +41,27 @@ type componentsDeleted struct { Version string `json:"version"` } -// StagingMove promotes components which match a set of criteria -func StagingMove(rm RM, query QueryBuilder) error { +func StagingMoveContext(ctx context.Context, rm RM, query QueryBuilder) error { endpoint := fmt.Sprintf("%s?%s", restStaging, query.Build()) // TODO: handle response - _, _, err := rm.Post(endpoint, nil) + _, _, err := rm.Post(ctx, endpoint, nil) return err } -// StagingDelete removes components which have been staged -func StagingDelete(rm RM, query QueryBuilder) error { +// StagingMove promotes components which match a set of criteria +func StagingMove(rm RM, query QueryBuilder) error { + return StagingMoveContext(context.Background(), rm, query) +} + +func StagingDeleteContext(ctx context.Context, rm RM, query QueryBuilder) error { endpoint := fmt.Sprintf("%s?%s", restStaging, query.Build()) - _, err := rm.Del(endpoint) + _, err := rm.Del(ctx, endpoint) return err } + +// StagingDelete removes components which have been staged +func StagingDelete(rm RM, query QueryBuilder) error { + return StagingDeleteContext(context.Background(), rm, query) +} diff --git a/rm/status.go b/rm/status.go index 6db09c3..c8a7306 100644 --- a/rm/status.go +++ b/rm/status.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "net/http" ) @@ -9,14 +10,22 @@ const ( restStatusWritable = "service/rest/v1/status/writable" ) +func StatusReadableContext(ctx context.Context, rm RM) (_ bool) { + _, resp, err := rm.Get(ctx, restStatusReadable) + return err == nil && resp.StatusCode == http.StatusOK +} + // StatusReadable returns true if the RM instance can serve read requests func StatusReadable(rm RM) (_ bool) { - _, resp, err := rm.Get(restStatusReadable) + return StatusReadableContext(context.Background(), rm) +} + +func StatusWritableContext(ctx context.Context, rm RM) (_ bool) { + _, resp, err := rm.Get(ctx, restStatusWritable) return err == nil && resp.StatusCode == http.StatusOK } // StatusWritable returns true if the RM instance can serve read requests func StatusWritable(rm RM) (_ bool) { - _, resp, err := rm.Get(restStatusWritable) - return err == nil && resp.StatusCode == http.StatusOK + return StatusWritableContext(context.Background(), rm) } diff --git a/rm/support.go b/rm/support.go index 88e6b3e..355a977 100644 --- a/rm/support.go +++ b/rm/support.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "mime" @@ -41,14 +42,13 @@ func NewSupportZipOptions() (o SupportZipOptions) { return } -// GetSupportZip generates a support zip with the given options -func GetSupportZip(rm RM, options SupportZipOptions) ([]byte, string, error) { +func GetSupportZipContext(ctx context.Context, rm RM, options SupportZipOptions) ([]byte, string, error) { request, err := json.Marshal(options) if err != nil { return nil, "", fmt.Errorf("error retrieving support zip: %v", err) } - body, resp, err := rm.Post(restSupportZip, bytes.NewBuffer(request)) + body, resp, err := rm.Post(ctx, restSupportZip, bytes.NewBuffer(request)) if err != nil { return nil, "", fmt.Errorf("error retrieving support zip: %v", err) } @@ -63,3 +63,8 @@ func GetSupportZip(rm RM, options SupportZipOptions) ([]byte, string, error) { return body, params["filename"], nil } + +// GetSupportZip generates a support zip with the given options +func GetSupportZip(rm RM, options SupportZipOptions) ([]byte, string, error) { + return GetSupportZipContext(context.Background(), rm, options) +} diff --git a/rm/tagging.go b/rm/tagging.go index dc53353..65f4945 100644 --- a/rm/tagging.go +++ b/rm/tagging.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -35,8 +36,7 @@ type componentsAssociated struct { Version string `json:"version"` } -// TagsList returns a list of tags in the given RM instance -func TagsList(rm RM) ([]Tag, error) { +func TagsListContext(ctx context.Context, rm RM) ([]Tag, error) { continuation := "" tags := make([]Tag, 0) @@ -47,7 +47,7 @@ func TagsList(rm RM) ([]Tag, error) { url += "&continuationToken=" + continuation } - body, _, err := rm.Get(url) + body, _, err := rm.Get(ctx, url) if err != nil { return fmt.Errorf("could not get list of tags: %v", err) } @@ -76,8 +76,12 @@ func TagsList(rm RM) ([]Tag, error) { return tags, nil } -// AddTag adds a tag to the given instance -func AddTag(rm RM, tagName string, attributes map[string]string) (Tag, error) { +// TagsList returns a list of tags in the given RM instance +func TagsList(rm RM) ([]Tag, error) { + return TagsListContext(context.Background(), rm) +} + +func AddTagContext(ctx context.Context, rm RM, tagName string, attributes map[string]string) (Tag, error) { tag := Tag{Name: tagName} //TODO: attributes @@ -86,7 +90,7 @@ func AddTag(rm RM, tagName string, attributes map[string]string) (Tag, error) { return Tag{}, fmt.Errorf("could not marshal tag: %v", err) } - body, _, err := rm.Post(restTagging, bytes.NewBuffer(buf)) + body, _, err := rm.Post(ctx, restTagging, bytes.NewBuffer(buf)) if err != nil { return Tag{}, fmt.Errorf("could not create tag %s: %v", tagName, err) } @@ -99,11 +103,15 @@ func AddTag(rm RM, tagName string, attributes map[string]string) (Tag, error) { return createdTag, nil } -// GetTag retrieve the named tag -func GetTag(rm RM, tagName string) (Tag, error) { +// AddTag adds a tag to the given instance +func AddTag(rm RM, tagName string, attributes map[string]string) (Tag, error) { + return AddTagContext(context.Background(), rm, tagName, attributes) +} + +func GetTagContext(ctx context.Context, rm RM, tagName string) (Tag, error) { endpoint := fmt.Sprintf("%s/%s", restTagging, tagName) - body, _, err := rm.Get(endpoint) + body, _, err := rm.Get(ctx, endpoint) if err != nil { return Tag{}, fmt.Errorf("could not find tag %s: %v", tagName, err) } @@ -116,19 +124,32 @@ func GetTag(rm RM, tagName string) (Tag, error) { return tag, nil } -// AssociateTag associates a tag to any component which matches the search criteria -func AssociateTag(rm RM, query QueryBuilder) error { +// GetTag retrieve the named tag +func GetTag(rm RM, tagName string) (Tag, error) { + return GetTagContext(context.Background(), rm, tagName) +} + +func AssociateTagContext(ctx context.Context, rm RM, query QueryBuilder) error { endpoint := fmt.Sprintf("%s?%s", restTagging, query.Build()) // TODO: handle response - _, _, err := rm.Post(endpoint, nil) + _, _, err := rm.Post(ctx, endpoint, nil) return err } -// DisassociateTag associates a tag to any component which matches the search criteria -func DisassociateTag(rm RM, query QueryBuilder) error { +// AssociateTag associates a tag to any component which matches the search criteria +func AssociateTag(rm RM, query QueryBuilder) error { + return AssociateTagContext(context.Background(), rm, query) +} + +func DisassociateTagContext(ctx context.Context, rm RM, query QueryBuilder) error { endpoint := fmt.Sprintf("%s?%s", restTagging, query.Build()) - _, err := rm.Del(endpoint) + _, err := rm.Del(ctx, endpoint) return err } + +// DisassociateTag associates a tag to any component which matches the search criteria +func DisassociateTag(rm RM, query QueryBuilder) error { + return DisassociateTagContext(context.Background(), rm, query) +} diff --git a/rm/tagging_test.go b/rm/tagging_test.go index 6c3fd47..40d0ce1 100644 --- a/rm/tagging_test.go +++ b/rm/tagging_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -99,7 +100,7 @@ func TestTagsList(t *testing.T) { rm, mock := taggingTestRM(t) defer mock.Close() - tags, err := TagsList(rm) + tags, err := TagsListContext(context.Background(), rm) if err != nil { t.Error(err) } @@ -121,7 +122,7 @@ func TestGetTag(t *testing.T) { want := dummyTags[0] - got, err := GetTag(rm, want.Name) + got, err := GetTagContext(context.Background(), rm, want.Name) if err != nil { t.Error(err) } @@ -139,7 +140,7 @@ func TestAddTag(t *testing.T) { newName := "newTestTag" - got, err := AddTag(rm, newName, nil) + got, err := AddTagContext(context.Background(), rm, newName, nil) if err != nil { t.Error(err) } @@ -148,7 +149,7 @@ func TestAddTag(t *testing.T) { t.Error("Did not get tag with expected name") } - gotAgain, err := GetTag(rm, newName) + gotAgain, err := GetTagContext(context.Background(), rm, newName) if err != nil { t.Error(err) }