diff --git a/cmd/cluster/changevolumetype.go b/cmd/cluster/changevolumetype.go new file mode 100644 index 000000000..c506b94c5 --- /dev/null +++ b/cmd/cluster/changevolumetype.go @@ -0,0 +1,480 @@ +package cluster + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "strings" + "time" + + cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" + machinev1 "github.com/openshift/api/machine/v1" + machinev1beta1 "github.com/openshift/api/machine/v1beta1" + hivev1 "github.com/openshift/hive/apis/hive/v1" + "github.com/openshift/osdctl/cmd/servicelog" + infraPkg "github.com/openshift/osdctl/pkg/infra" + "github.com/openshift/osdctl/pkg/k8s" + "github.com/openshift/osdctl/pkg/printer" + "github.com/openshift/osdctl/pkg/utils" + "github.com/spf13/cobra" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + changeVolumeTypeCPMSNamespace = "openshift-machine-api" + changeVolumeTypeCPMSName = "cluster" + + pollInterval = 30 * time.Second + rolloutPollTimeout = 45 * time.Minute +) + +var validVolumeTypes = []string{"gp3"} + +type changeVolumeTypeOptions struct { + clusterID string + cluster *cmv1.Cluster + reason string + targetType string + role string // "control-plane", "infra", or "" (both) + + client client.Client + clientAdmin client.Client + + hiveClient client.Client + hiveAdminClient client.Client +} + +func newCmdChangeVolumeType() *cobra.Command { + ops := &changeVolumeTypeOptions{} + cmd := &cobra.Command{ + Use: "change-ebs-volume-type", + Short: "Change EBS volume type for control plane and/or infra nodes by replacing machines", + Long: `Change the EBS volume type for control plane and/or infra nodes on a ROSA/OSD cluster. + +This command replaces machines to change volume types (not in-place modification). +For control plane nodes, it patches the ControlPlaneMachineSet (CPMS) which automatically +rolls nodes one at a time. For infra nodes, it uses the Hive MachinePool dance to safely +replace all infra nodes with new ones using the target volume type. + +Pre-flight checks are performed automatically before making changes.`, + Example: ` # Change both control plane and infra volumes to gp3 + osdctl cluster change-ebs-volume-type -C ${CLUSTER_ID} --type gp3 --reason "SREP-3811" + + # Change only control plane volumes to gp3 + osdctl cluster change-ebs-volume-type -C ${CLUSTER_ID} --type gp3 --role control-plane --reason "SREP-3811" + + # Change only infra volumes to gp3 + osdctl cluster change-ebs-volume-type -C ${CLUSTER_ID} --type gp3 --role infra --reason "SREP-3811"`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return ops.run(context.Background()) + }, + } + + cmd.Flags().StringVarP(&ops.clusterID, "cluster-id", "C", "", "The internal/external ID of the cluster") + cmd.Flags().StringVar(&ops.targetType, "type", "", "Target EBS volume type (gp3)") + cmd.Flags().StringVar(&ops.role, "role", "", "Node role to change: control-plane, infra (default: both)") + cmd.Flags().StringVar(&ops.reason, "reason", "", "Reason for elevation (OHSS/PD/JIRA ticket)") + + _ = cmd.MarkFlagRequired("cluster-id") + _ = cmd.MarkFlagRequired("type") + _ = cmd.MarkFlagRequired("reason") + + return cmd +} + +func (o *changeVolumeTypeOptions) validate() error { + if err := utils.IsValidClusterKey(o.clusterID); err != nil { + return err + } + + valid := false + for _, t := range validVolumeTypes { + if o.targetType == t { + valid = true + break + } + } + if !valid { + return fmt.Errorf("invalid volume type: %s (must be one of: %s)", o.targetType, strings.Join(validVolumeTypes, ", ")) + } + + if o.role != "" && o.role != "control-plane" && o.role != "infra" { + return fmt.Errorf("invalid role: %s (must be 'control-plane' or 'infra')", o.role) + } + + return nil +} + +func (o *changeVolumeTypeOptions) init() error { + connection, err := utils.CreateConnection() + if err != nil { + return err + } + defer connection.Close() + + cluster, err := utils.GetCluster(connection, o.clusterID) + if err != nil { + return err + } + o.cluster = cluster + o.clusterID = cluster.ID() + + if strings.ToLower(cluster.CloudProvider().ID()) != "aws" { + return fmt.Errorf("this command only supports AWS clusters (cluster is %s)", cluster.CloudProvider().ID()) + } + + if cluster.Hypershift().Enabled() { + return errors.New("this command does not support HCP clusters") + } + + scheme := runtime.NewScheme() + if err := machinev1.Install(scheme); err != nil { + return err + } + if err := machinev1beta1.Install(scheme); err != nil { + return err + } + if err := corev1.AddToScheme(scheme); err != nil { + return err + } + + c, err := k8s.New(o.clusterID, client.Options{Scheme: scheme}) + if err != nil { + return err + } + o.client = c + + cAdmin, err := k8s.NewAsBackplaneClusterAdmin(o.clusterID, client.Options{Scheme: scheme}, []string{ + o.reason, + fmt.Sprintf("Changing EBS volume type to %s for cluster %s", o.targetType, o.clusterID), + }...) + if err != nil { + return err + } + o.clientAdmin = cAdmin + + // Set up Hive clients for infra node replacement via machinepool dance + if o.role == "" || o.role == "infra" { + hiveScheme := runtime.NewScheme() + if err := hivev1.AddToScheme(hiveScheme); err != nil { + return err + } + if err := corev1.AddToScheme(hiveScheme); err != nil { + return err + } + + hive, err := utils.GetHiveCluster(o.clusterID) + if err != nil { + return fmt.Errorf("failed to get hive cluster: %v", err) + } + + hc, err := k8s.New(hive.ID(), client.Options{Scheme: hiveScheme}) + if err != nil { + return fmt.Errorf("failed to create hive client: %v", err) + } + o.hiveClient = hc + + hac, err := k8s.NewAsBackplaneClusterAdmin(hive.ID(), client.Options{Scheme: hiveScheme}, []string{ + o.reason, + fmt.Sprintf("Changing EBS volume type to %s for cluster %s", o.targetType, o.clusterID), + }...) + if err != nil { + return fmt.Errorf("failed to create hive admin client: %v", err) + } + o.hiveAdminClient = hac + } + + return nil +} + +func (o *changeVolumeTypeOptions) run(ctx context.Context) error { + if err := o.validate(); err != nil { + return err + } + + if err := o.init(); err != nil { + return err + } + + fmt.Printf("Cluster: %s (%s)\n", o.cluster.Name(), o.clusterID) + fmt.Printf("Target volume type: %s\n", o.targetType) + fmt.Printf("Role: %s\n", roleDisplay(o.role)) + fmt.Printf("Reason: %s\n\n", o.reason) + + // Pre-flight checks + if err := o.preFlightChecks(ctx); err != nil { + return fmt.Errorf("pre-flight checks failed: %v", err) + } + + doControlPlane := o.role == "" || o.role == "control-plane" + doInfra := o.role == "" || o.role == "infra" + + // Control plane + if doControlPlane { + if err := o.changeControlPlaneVolumeType(ctx); err != nil { + return fmt.Errorf("control plane volume type change failed: %v", err) + } + } + + // Infra + if doInfra { + if err := o.changeInfraVolumeType(ctx); err != nil { + return fmt.Errorf("infra volume type change failed: %v", err) + } + } + + printer.PrintlnGreen("\nVolume type change completed successfully!") + return nil +} + +// preFlightChecks verifies cluster health before making changes. +func (o *changeVolumeTypeOptions) preFlightChecks(ctx context.Context) error { + fmt.Println("Running pre-flight checks...") + + // Check 1: CPMS state (if changing control plane) + if o.role == "" || o.role == "control-plane" { + cpms := &machinev1.ControlPlaneMachineSet{} + if err := o.client.Get(ctx, client.ObjectKey{Namespace: changeVolumeTypeCPMSNamespace, Name: changeVolumeTypeCPMSName}, cpms); err != nil { + return fmt.Errorf("failed to get CPMS: %v", err) + } + + if cpms.Spec.State != machinev1.ControlPlaneMachineSetStateActive { + return fmt.Errorf("CPMS is not Active (state: %s). Cannot proceed with control plane changes", cpms.Spec.State) + } + + if cpms.Status.ReadyReplicas != 3 { + return fmt.Errorf("CPMS does not have 3 ready replicas (ready: %d)", cpms.Status.ReadyReplicas) + } + fmt.Printf(" CPMS: Active, %d/3 ready\n", cpms.Status.ReadyReplicas) + } + + // Check 2: Master nodes ready + masterNodes := &corev1.NodeList{} + if err := o.client.List(ctx, masterNodes, client.MatchingLabels{"node-role.kubernetes.io/master": ""}); err != nil { + return fmt.Errorf("failed to list master nodes: %v", err) + } + readyMasters := countReadyNodes(masterNodes) + if readyMasters != 3 { + return fmt.Errorf("expected 3 ready master nodes, found %d", readyMasters) + } + fmt.Printf(" Master nodes: %d/3 Ready\n", readyMasters) + + // Check 3: Infra nodes ready (if changing infra) + if o.role == "" || o.role == "infra" { + infraNodes := &corev1.NodeList{} + if err := o.client.List(ctx, infraNodes, client.MatchingLabels{"node-role.kubernetes.io/infra": ""}); err != nil { + return fmt.Errorf("failed to list infra nodes: %v", err) + } + readyInfra := countReadyNodes(infraNodes) + totalInfra := len(infraNodes.Items) + if totalInfra == 0 { + return fmt.Errorf("no infra nodes found") + } + if readyInfra != totalInfra { + return fmt.Errorf("not all infra nodes are ready (%d/%d)", readyInfra, totalInfra) + } + fmt.Printf(" Infra nodes: %d/%d Ready\n", readyInfra, totalInfra) + } + + // Check 4: etcd pods running + etcdPods := &corev1.PodList{} + if err := o.client.List(ctx, etcdPods, client.InNamespace("openshift-etcd"), client.MatchingLabels{"app": "etcd"}); err != nil { + return fmt.Errorf("failed to list etcd pods: %v", err) + } + runningEtcd := 0 + for _, pod := range etcdPods.Items { + if pod.Status.Phase == corev1.PodRunning { + runningEtcd++ + } + } + if runningEtcd != 3 { + return fmt.Errorf("expected 3 running etcd pods, found %d", runningEtcd) + } + fmt.Printf(" etcd: %d/3 Running\n", runningEtcd) + + printer.PrintlnGreen(" All pre-flight checks passed!") + fmt.Println() + return nil +} + +// changeControlPlaneVolumeType patches the CPMS to trigger a rolling replacement. +func (o *changeVolumeTypeOptions) changeControlPlaneVolumeType(ctx context.Context) error { + printer.PrintlnGreen("=== Changing control plane volume type ===") + + cpms := &machinev1.ControlPlaneMachineSet{} + if err := o.client.Get(ctx, client.ObjectKey{Namespace: changeVolumeTypeCPMSNamespace, Name: changeVolumeTypeCPMSName}, cpms); err != nil { + return fmt.Errorf("failed to get CPMS: %v", err) + } + + // Unmarshal the provider spec to read current blockDevices + awsSpec := &machinev1beta1.AWSMachineProviderConfig{} + if err := json.Unmarshal(cpms.Spec.Template.OpenShiftMachineV1Beta1Machine.Spec.ProviderSpec.Value.Raw, awsSpec); err != nil { + return fmt.Errorf("failed to unmarshal CPMS provider spec: %v", err) + } + + if len(awsSpec.BlockDevices) == 0 { + return fmt.Errorf("CPMS has no blockDevices configured") + } + + currentType := "" + if awsSpec.BlockDevices[0].EBS != nil && awsSpec.BlockDevices[0].EBS.VolumeType != nil { + currentType = *awsSpec.BlockDevices[0].EBS.VolumeType + } + + if currentType == o.targetType { + fmt.Printf("Control plane volumes are already %s - skipping\n", o.targetType) + return nil + } + + fmt.Printf("Current control plane volume type: %s\n", currentType) + fmt.Printf("Target volume type: %s\n", o.targetType) + + // Update volume type, preserving all other EBS settings (volumeSize, encrypted, kmsKey) + targetType := o.targetType + awsSpec.BlockDevices[0].EBS.VolumeType = &targetType + awsSpec.BlockDevices[0].EBS.Iops = nil + + // Confirm + fmt.Printf("\nThis will replace all 3 control plane nodes one at a time (~35-45 min).\n") + if !utils.ConfirmPrompt() { + return errors.New("aborted by user") + } + + // Marshal and patch + rawBytes, err := json.Marshal(awsSpec) + if err != nil { + return fmt.Errorf("failed to marshal updated provider spec: %v", err) + } + + patch := client.MergeFrom(cpms.DeepCopy()) + cpms.Spec.Template.OpenShiftMachineV1Beta1Machine.Spec.ProviderSpec.Value = &runtime.RawExtension{Raw: rawBytes} + + if err := o.clientAdmin.Patch(ctx, cpms, patch); err != nil { + return fmt.Errorf("failed to patch CPMS: %v", err) + } + + printer.PrintlnGreen("CPMS patched successfully. Rolling replacement in progress...") + fmt.Println("Monitoring rollout (this will take ~35-45 minutes)...") + + // Monitor the rollout + if err := o.monitorCPMSRollout(ctx); err != nil { + return err + } + + printer.PrintlnGreen("Control plane volume type change complete!") + return nil +} + +// monitorCPMSRollout polls the CPMS until all replicas are updated. +func (o *changeVolumeTypeOptions) monitorCPMSRollout(ctx context.Context) error { + pollCtx, cancel := context.WithTimeout(ctx, rolloutPollTimeout) + defer cancel() + return wait.PollUntilContextTimeout(pollCtx, pollInterval, rolloutPollTimeout, true, func(ctx context.Context) (bool, error) { + cpms := &machinev1.ControlPlaneMachineSet{} + if err := o.client.Get(ctx, client.ObjectKey{Namespace: changeVolumeTypeCPMSNamespace, Name: changeVolumeTypeCPMSName}, cpms); err != nil { + log.Printf("Error checking CPMS status: %v", err) + return false, nil + } + + updated := cpms.Status.UpdatedReplicas + ready := cpms.Status.ReadyReplicas + + log.Printf("[%s] CPMS: %d/3 updated, %d ready", time.Now().Format("15:04:05"), updated, ready) + + if updated == 3 && ready >= 3 { + return true, nil + } + return false, nil + }) +} + +const ( + volumeTypeChangedServiceLogTemplate = "https://raw.githubusercontent.com/openshift/managed-notifications/master/osd/infranode_volume_type_changed.json" +) + +// changeInfraVolumeType uses the Hive MachinePool dance from pkg/infra +// to replace infra nodes with new ones using the target volume type. +func (o *changeVolumeTypeOptions) changeInfraVolumeType(ctx context.Context) error { + printer.PrintlnGreen("\n=== Changing infra node volume type ===") + + targetType := o.targetType + previousType := "" + + originalMp, err := infraPkg.GetInfraMachinePool(ctx, o.hiveClient, o.clusterID) + if err != nil { + return err + } + + newMp, err := infraPkg.CloneMachinePool(originalMp, func(mp *hivev1.MachinePool) error { + if mp.Spec.Platform.AWS == nil { + return fmt.Errorf("infra MachinePool has no AWS platform configuration") + } + previousType = mp.Spec.Platform.AWS.Type + if previousType == targetType { + return fmt.Errorf("infra volumes are already %s", targetType) + } + fmt.Printf("Current infra volume type: %s\n", previousType) + fmt.Printf("Target volume type: %s\n", targetType) + mp.Spec.Platform.AWS.Type = targetType + mp.Spec.Platform.AWS.IOPS = 0 + return nil + }) + if err != nil { + return err + } + + clients := infraPkg.DanceClients{ + ClusterClient: o.client, + HiveClient: o.hiveClient, + HiveAdmin: o.hiveAdminClient, + } + + if err := infraPkg.RunMachinePoolDance(ctx, clients, originalMp, newMp, nil); err != nil { + return err + } + + // Post service log + postCmd := servicelog.PostCmdOptions{ + Template: volumeTypeChangedServiceLogTemplate, + ClusterId: o.clusterID, + TemplateParams: []string{ + fmt.Sprintf("PREVIOUS_VOLUME_TYPE=%s", previousType), + fmt.Sprintf("NEW_VOLUME_TYPE=%s", targetType), + fmt.Sprintf("REASON=%s", o.reason), + }, + } + if err := postCmd.Run(); err != nil { + fmt.Println("Failed to post service log. Please manually send a service log with:") + fmt.Printf("osdctl servicelog post %s -t %s -p %s\n", + o.clusterID, volumeTypeChangedServiceLogTemplate, strings.Join(postCmd.TemplateParams, " -p ")) + } + + printer.PrintlnGreen("Infra volume type change complete!") + return nil +} + +func countReadyNodes(nodes *corev1.NodeList) int { + ready := 0 + for _, node := range nodes.Items { + for _, cond := range node.Status.Conditions { + if cond.Type == corev1.NodeReady && cond.Status == corev1.ConditionTrue { + ready++ + } + } + } + return ready +} + +func roleDisplay(role string) string { + if role == "" { + return "control-plane + infra" + } + return role +} diff --git a/cmd/cluster/changevolumetype_test.go b/cmd/cluster/changevolumetype_test.go new file mode 100644 index 000000000..48d59e156 --- /dev/null +++ b/cmd/cluster/changevolumetype_test.go @@ -0,0 +1,120 @@ +package cluster + +import ( + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" +) + +func TestChangeVolumeType_ValidateTargetType(t *testing.T) { + tests := []struct { + name string + targetType string + wantErr bool + }{ + {"valid gp3", "gp3", false}, + {"invalid io1", "io1", true}, + {"invalid gp2", "gp2", true}, + {"invalid io2", "io2", true}, + {"invalid st1", "st1", true}, + {"invalid sc1", "sc1", true}, + {"invalid type", "invalid", true}, + {"empty type", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ops := &changeVolumeTypeOptions{ + clusterID: "test-cluster", + targetType: tt.targetType, + reason: "test", + } + err := ops.validate() + if tt.wantErr { + assert.Error(t, err) + } else { + // validate also checks clusterID which will fail for non-real IDs, + // so just verify the type validation logic + valid := false + for _, v := range validVolumeTypes { + if tt.targetType == v { + valid = true + break + } + } + assert.True(t, valid) + } + }) + } +} + +func TestChangeVolumeType_ValidateRole(t *testing.T) { + tests := []struct { + name string + role string + wantErr bool + }{ + {"empty role (both)", "", false}, + {"control-plane", "control-plane", false}, + {"infra", "infra", false}, + {"invalid worker", "worker", true}, + {"invalid master", "master", true}, + {"invalid random", "random", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ops := &changeVolumeTypeOptions{ + clusterID: "test-cluster", + targetType: "gp3", + reason: "test", + role: tt.role, + } + err := ops.validate() + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid role") + } + // For valid roles, validate would still fail on clusterID check, + // but the role validation should pass + if !tt.wantErr { + validRole := tt.role == "" || tt.role == "control-plane" || tt.role == "infra" + assert.True(t, validRole) + } + }) + } +} + +func TestChangeVolumeType_CommandCreation(t *testing.T) { + cmd := newCmdChangeVolumeType() + + assert.NotNil(t, cmd) + assert.Equal(t, "change-ebs-volume-type", cmd.Use) + assert.NotEmpty(t, cmd.Short) + assert.NotEmpty(t, cmd.Long) + assert.NotEmpty(t, cmd.Example) + + // Required flags + requiredFlags := []string{"cluster-id", "type", "reason"} + for _, flagName := range requiredFlags { + flag := cmd.Flag(flagName) + assert.NotNilf(t, flag, "required flag %q should exist", flagName) + } + + // Optional flags + flag := cmd.Flag("role") + assert.NotNil(t, flag, "optional flag 'role' should exist") +} + +func TestChangeVolumeType_RoleDisplay(t *testing.T) { + assert.Equal(t, "control-plane + infra", roleDisplay("")) + assert.Equal(t, "control-plane", roleDisplay("control-plane")) + assert.Equal(t, "infra", roleDisplay("infra")) +} + +func TestChangeVolumeType_CountReadyNodes(t *testing.T) { + // Empty list + nodes := &corev1.NodeList{} + assert.Equal(t, 0, countReadyNodes(nodes)) +} diff --git a/cmd/cluster/cmd.go b/cmd/cluster/cmd.go index d32765222..b72e56552 100644 --- a/cmd/cluster/cmd.go +++ b/cmd/cluster/cmd.go @@ -44,6 +44,7 @@ func NewCmdCluster(streams genericclioptions.IOStreams, client *k8s.LazyClient, clusterCmd.AddCommand(NewCmdHypershiftInfo(streams)) clusterCmd.AddCommand(newCmdOrgId()) clusterCmd.AddCommand(newCmdDetachStuckVolume()) + clusterCmd.AddCommand(newCmdChangeVolumeType()) clusterCmd.AddCommand(NewCmdVerifyDNS(streams)) clusterCmd.AddCommand(ssh.NewCmdSSH()) clusterCmd.AddCommand(sre_operators.NewCmdSREOperators(streams, client)) diff --git a/cmd/cluster/resize/infra_node.go b/cmd/cluster/resize/infra_node.go index d8187d665..17a656764 100644 --- a/cmd/cluster/resize/infra_node.go +++ b/cmd/cluster/resize/infra_node.go @@ -8,7 +8,6 @@ import ( "log" "slices" "strings" - "time" awssdk "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" @@ -18,27 +17,19 @@ import ( machinev1beta1 "github.com/openshift/api/machine/v1beta1" hivev1 "github.com/openshift/hive/apis/hive/v1" "github.com/openshift/osdctl/cmd/servicelog" + infraPkg "github.com/openshift/osdctl/pkg/infra" "github.com/openshift/osdctl/pkg/k8s" "github.com/openshift/osdctl/pkg/osdCloud" "github.com/openshift/osdctl/pkg/utils" "github.com/spf13/cobra" corev1 "k8s.io/api/core/v1" - apierrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/selection" - "k8s.io/apimachinery/pkg/util/wait" "sigs.k8s.io/controller-runtime/pkg/client" ) const ( - twentyMinuteTimeout = 20 * time.Minute - twentySecondIncrement = 20 * time.Second resizedInfraNodeServiceLogTemplate = "https://raw.githubusercontent.com/openshift/managed-notifications/master/osd/infranode_resized.json" resizedInfraNodeServiceLogTemplateGCP = "https://raw.githubusercontent.com/openshift/managed-notifications/master/osd/gcp/GCP_infranode_resized_auto.json" - infraNodeLabel = "node-role.kubernetes.io/infra" - temporaryInfraNodeLabel = "osdctl.openshift.io/infra-resize-temporary-machinepool" ) type Infra struct { @@ -234,7 +225,7 @@ func (r *Infra) RunInfra(ctx context.Context) error { } log.Printf("resizing infra nodes for %s - %s", r.cluster.Name(), r.clusterId) - originalMp, err := r.getInfraMachinePool(ctx) + originalMp, err := infraPkg.GetInfraMachinePool(ctx, r.hive, r.clusterId) if err != nil { return err } @@ -247,270 +238,28 @@ func (r *Infra) RunInfra(ctx context.Context) error { if err != nil { return err } - tempMp := newMp.DeepCopy() - tempMp.Name = fmt.Sprintf("%s2", tempMp.Name) - tempMp.Spec.Name = fmt.Sprintf("%s2", tempMp.Spec.Name) - tempMp.Spec.Labels[temporaryInfraNodeLabel] = "" - - instanceType, err := getInstanceType(tempMp) + instanceType, err := getInstanceType(newMp) if err != nil { return fmt.Errorf("failed to parse instance type from machinepool: %v", err) } - // Create the temporary machinepool log.Printf("planning to resize to instance type from %s to %s", originalInstanceType, instanceType) if !utils.ConfirmPrompt() { log.Printf("exiting") return nil } - log.Printf("creating temporary machinepool %s, with instance type %s", tempMp.Name, instanceType) - if err := r.hiveAdmin.Create(ctx, tempMp); err != nil { - return err - } - - // This selector will match all infra nodes - selector, err := labels.Parse(infraNodeLabel) - if err != nil { - return err - } - - if err := wait.PollImmediate(twentySecondIncrement, twentyMinuteTimeout, func() (bool, error) { - nodes := &corev1.NodeList{} - - if err := r.client.List(ctx, nodes, &client.ListOptions{LabelSelector: selector}); err != nil { - log.Printf("error retrieving nodes list, continuing to wait: %s", err) - return false, nil - } - - readyNodes := 0 - log.Printf("waiting for %d infra nodes to be reporting Ready", int(*originalMp.Spec.Replicas)*2) - for _, node := range nodes.Items { - for _, cond := range node.Status.Conditions { - if cond.Type == corev1.NodeReady { - if cond.Status == corev1.ConditionTrue { - readyNodes++ - log.Printf("found node %s reporting Ready", node.Name) - } - } - } - } - - switch { - case readyNodes >= int(*originalMp.Spec.Replicas)*2: - return true, nil - default: - log.Printf("found %d infra nodes reporting Ready, continuing to wait", readyNodes) - return false, nil - } - }); err != nil { - return err - } - - // Identify the original nodes and temp nodes - // requireInfra matches all infra nodes using the selector from above - requireInfra, err := labels.NewRequirement(infraNodeLabel, selection.Exists, nil) - if err != nil { - return err - } - - // requireNotTempNode matches all nodes that do not have the temporaryInfraNodeLabel, created with the new (temporary) machine pool - requireNotTempNode, err := labels.NewRequirement(temporaryInfraNodeLabel, selection.DoesNotExist, nil) - if err != nil { - return err - } - - // requireTempNode matches the opposite of above, all nodes that *do* have the temporaryInfraNodeLabel - requireTempNode, err := labels.NewRequirement(temporaryInfraNodeLabel, selection.Exists, nil) - if err != nil { - return err - } - - // infraNode + notTempNode = original nodes - originalNodeSelector := selector.Add(*requireInfra, *requireNotTempNode) - - // infraNode + tempNode = temp nodes - tempNodeSelector := selector.Add(*requireInfra, *requireTempNode) - - originalNodes := &corev1.NodeList{} - if err := r.client.List(ctx, originalNodes, &client.ListOptions{LabelSelector: originalNodeSelector}); err != nil { - return err - } - - // Delete original machinepool - log.Printf("deleting original machinepool %s, with instance type %s", originalMp.Name, originalInstanceType) - if err := r.hiveAdmin.Delete(ctx, originalMp); err != nil { - return err - } - - // Wait for original machinepool to delete - if err := wait.PollImmediate(twentySecondIncrement, twentyMinuteTimeout, func() (bool, error) { - mp := &hivev1.MachinePool{} - err := r.hive.Get(ctx, client.ObjectKey{Namespace: originalMp.Namespace, Name: originalMp.Name}, mp) - if err != nil { - if apierrors.IsNotFound(err) { - return true, nil - } - log.Printf("error retrieving machines list, continuing to wait: %s", err) - return false, nil - } - - log.Printf("original machinepool %s/%s still exists, continuing to wait", originalMp.Namespace, originalMp.Name) - return false, nil - }); err != nil { - return err - } - - // Wait for original nodes to delete - if err := wait.PollImmediate(twentySecondIncrement, twentyMinuteTimeout, func() (bool, error) { - // Re-check for originalNodes to see if they have been deleted - return skipError(wrapResult(r.nodesMatchExpectedCount(ctx, originalNodeSelector, 0)), "error matching expected count") - }); err != nil { - switch { - case errors.Is(err, wait.ErrWaitTimeout): - log.Printf("Warning: timed out waiting for nodes to drain: %v. Terminating backing cloud instances.", err.Error()) - - // Terminate the backing cloud instances if they are not removed by the 20 minute timeout - err := r.terminateCloudInstances(ctx, originalNodes) - if err != nil { - return err - } - - if err := wait.PollImmediate(twentySecondIncrement, twentyMinuteTimeout, func() (bool, error) { - log.Printf("waiting for nodes to terminate") - return skipError(wrapResult(r.nodesMatchExpectedCount(ctx, originalNodeSelector, 0)), "error matching expected count") - }); err != nil { - if errors.Is(err, wait.ErrWaitTimeout) { - log.Printf("timed out waiting for nodes to terminate: %v.", err.Error()) - } - return err - } - default: - return err - } - } - - // Create new permanent machinepool - log.Printf("creating new machinepool %s, with instance type %s", newMp.Name, instanceType) - if err := r.hiveAdmin.Create(ctx, newMp); err != nil { - return err - } - - // Wait for new permanent machines to become nodes - if err := wait.PollImmediate(twentySecondIncrement, twentyMinuteTimeout, func() (bool, error) { - nodes := &corev1.NodeList{} - selector, err := labels.Parse("node-role.kubernetes.io/infra=") - if err != nil { - // This should never happen, so we do not have to skip this error - return false, err - } - - if err := r.client.List(ctx, nodes, &client.ListOptions{LabelSelector: selector}); err != nil { - log.Printf("error retrieving nodes list, continuing to wait: %s", err) - return false, nil - } - - readyNodes := 0 - log.Printf("waiting for %d infra nodes to be reporting Ready", int(*originalMp.Spec.Replicas)*2) - for _, node := range nodes.Items { - for _, cond := range node.Status.Conditions { - if cond.Type == corev1.NodeReady { - if cond.Status == corev1.ConditionTrue { - readyNodes++ - log.Printf("found node %s reporting Ready", node.Name) - } - } - } - } - - switch { - case readyNodes >= int(*originalMp.Spec.Replicas)*2: - return true, nil - default: - log.Printf("found %d infra nodes reporting Ready, continuing to wait", readyNodes) - return false, nil - } - }); err != nil { - return err - } - - tempNodes := &corev1.NodeList{} - if err := r.client.List(ctx, tempNodes, &client.ListOptions{LabelSelector: tempNodeSelector}); err != nil { - return err - } - - // Delete temp machinepool - log.Printf("deleting temporary machinepool %s, with instance type %s", tempMp.Name, instanceType) - if err := r.hiveAdmin.Delete(ctx, tempMp); err != nil { - return err + clients := infraPkg.DanceClients{ + ClusterClient: r.client, + HiveClient: r.hive, + HiveAdmin: r.hiveAdmin, } - // Wait for temporary machinepool to delete - if err := wait.PollImmediate(twentySecondIncrement, twentyMinuteTimeout, func() (bool, error) { - mp := &hivev1.MachinePool{} - err := r.hive.Get(ctx, client.ObjectKey{Namespace: tempMp.Namespace, Name: tempMp.Name}, mp) - if err != nil { - if apierrors.IsNotFound(err) { - return true, nil - } - log.Printf("error retrieving old machine details, continuing to wait: %s", err) - return false, nil - } - - log.Printf("temporary machinepool %s/%s still exists, continuing to wait", tempMp.Namespace, tempMp.Name) - return false, nil - }); err != nil { + if err := infraPkg.RunMachinePoolDance(ctx, clients, originalMp, newMp, r.terminateCloudInstances); err != nil { return err } - // Wait for infra node count to return to normal - log.Printf("waiting for infra node count to return to: %d", int(*originalMp.Spec.Replicas)) - if err := wait.PollImmediate(twentySecondIncrement, twentyMinuteTimeout, func() (bool, error) { - nodes := &corev1.NodeList{} - selector, err := labels.Parse("node-role.kubernetes.io/infra=") - if err != nil { - // This should never happen, so we do not have to skip this errorreturn false, err - return false, err - } - - if err := r.client.List(ctx, nodes, &client.ListOptions{LabelSelector: selector}); err != nil { - log.Printf("error retrieving nodes list, continuing to wait: %s", err) - return false, nil - } - - switch len(nodes.Items) { - case int(*originalMp.Spec.Replicas): - log.Printf("found %d infra nodes, infra resize complete", len(nodes.Items)) - return true, nil - default: - log.Printf("found %d infra nodes, continuing to wait", len(nodes.Items)) - return false, nil - } - }); err != nil { - switch { - case errors.Is(err, wait.ErrWaitTimeout): - log.Printf("Warning: timed out waiting for nodes to drain: %v. Terminating backing cloud instances.", err.Error()) - - err := r.terminateCloudInstances(ctx, tempNodes) - if err != nil { - return err - } - - if err := wait.PollImmediate(twentySecondIncrement, twentyMinuteTimeout, func() (bool, error) { - log.Printf("waiting for nodes to terminate") - return skipError(wrapResult(r.nodesMatchExpectedCount(ctx, tempNodeSelector, 0)), "error matching expected count") - }); err != nil { - if errors.Is(err, wait.ErrWaitTimeout) { - log.Printf("timed out waiting for nodes to terminate: %v.", err.Error()) - } - return err - } - default: - return err - } - } - - postCmd := generateServiceLog(tempMp, r.instanceType, r.justification, r.clusterId, r.ohss) + postCmd := generateServiceLog(newMp, r.instanceType, r.justification, r.clusterId, r.ohss) if err := postCmd.Run(); err != nil { fmt.Println("Failed to generate service log. Please manually send a service log to the customer for the blocked egresses with:") fmt.Printf("osdctl servicelog post %v -t %v -p %v\n", @@ -520,37 +269,6 @@ func (r *Infra) RunInfra(ctx context.Context) error { return nil } -func (r *Infra) getInfraMachinePool(ctx context.Context) (*hivev1.MachinePool, error) { - ns := &corev1.NamespaceList{} - selector, err := labels.Parse(fmt.Sprintf("api.openshift.com/id=%s", r.clusterId)) - if err != nil { - return nil, err - } - - if err := r.hive.List(ctx, ns, &client.ListOptions{LabelSelector: selector, Limit: 1}); err != nil { - return nil, err - } - if len(ns.Items) != 1 { - return nil, fmt.Errorf("expected 1 namespace, found %d namespaces with tag: api.openshift.com/id=%s", len(ns.Items), r.clusterId) - } - - log.Printf("found namespace: %s", ns.Items[0].Name) - - mpList := &hivev1.MachinePoolList{} - if err := r.hive.List(ctx, mpList, &client.ListOptions{Namespace: ns.Items[0].Name}); err != nil { - return nil, err - } - - for _, mp := range mpList.Items { - if mp.Spec.Name == "infra" { - log.Printf("found machinepool %s", mp.Name) - return &mp, nil - } - } - - return nil, fmt.Errorf("did not find the infra machinepool in namespace: %s", ns.Items[0].Name) -} - func (r *Infra) embiggenMachinePool(mp *hivev1.MachinePool) (*hivev1.MachinePool, error) { embiggen := map[string]string{ "m5.xlarge": "r5.xlarge", @@ -565,18 +283,6 @@ func (r *Infra) embiggenMachinePool(mp *hivev1.MachinePool) (*hivev1.MachinePool "n2-highmem-8": "n2-highmem-16", } - newMp := &hivev1.MachinePool{} - mp.DeepCopyInto(newMp) - - // Unset fields we want to be regenerated - newMp.CreationTimestamp = metav1.Time{} - newMp.Finalizers = []string{} - newMp.ResourceVersion = "" - newMp.Generation = 0 - newMp.SelfLink = "" - newMp.UID = "" - newMp.Status = hivev1.MachinePoolStatus{} - // Update instance type sizing if r.instanceType != "" { log.Printf("using override instance type: %s", r.instanceType) @@ -592,16 +298,20 @@ func (r *Infra) embiggenMachinePool(mp *hivev1.MachinePool) (*hivev1.MachinePool r.instanceType = embiggen[instanceType] } - switch r.cluster.CloudProvider().ID() { - case "aws": - newMp.Spec.Platform.AWS.InstanceType = r.instanceType - case "gcp": - newMp.Spec.Platform.GCP.InstanceType = r.instanceType - default: - return nil, fmt.Errorf("cloud provider not supported: %s, only AWS and GCP are supported", r.cluster.CloudProvider().ID()) - } + cloudProvider := r.cluster.CloudProvider().ID() + newInstanceType := r.instanceType - return newMp, nil + return infraPkg.CloneMachinePool(mp, func(newMp *hivev1.MachinePool) error { + switch cloudProvider { + case "aws": + newMp.Spec.Platform.AWS.InstanceType = newInstanceType + case "gcp": + newMp.Spec.Platform.GCP.InstanceType = newInstanceType + default: + return fmt.Errorf("cloud provider not supported: %s, only AWS and GCP are supported", cloudProvider) + } + return nil + }) } func getInstanceType(mp *hivev1.MachinePool) (string, error) { @@ -728,22 +438,6 @@ func convertProviderIDtoInstanceID(providerID string) string { return providerIDSplit[len(providerIDSplit)-1] } -// nodesMatchExpectedCount accepts a context, labelselector and count of expected nodes, and ] -// returns true if the nodelist matching the labelselector is equal to the expected count -func (r *Infra) nodesMatchExpectedCount(ctx context.Context, labelSelector labels.Selector, count int) (bool, error) { - nodeList := &corev1.NodeList{} - - if err := r.client.List(ctx, nodeList, &client.ListOptions{LabelSelector: labelSelector}); err != nil { - return false, err - } - - if len(nodeList.Items) == count { - return true, nil - } - - return false, nil -} - // validateInstanceSize accepts a string for the requested new instance type and returns an error // if the instance type is invalid func validateInstanceSize(newInstanceSize string, nodeType string) error { @@ -752,19 +446,3 @@ func validateInstanceSize(newInstanceSize string, nodeType string) error { } return nil } - -type result struct { - condition bool - err error -} - -func wrapResult(condition bool, err error) result { - return result{condition, err} -} - -func skipError(res result, msg string) (bool, error) { - if res.err != nil { - log.Printf("%s, continuing to wait: %s", msg, res.err) - } - return res.condition, nil -} diff --git a/cmd/cluster/resize/infra_node_test.go b/cmd/cluster/resize/infra_node_test.go index 0162873c4..0ab0c6daa 100644 --- a/cmd/cluster/resize/infra_node_test.go +++ b/cmd/cluster/resize/infra_node_test.go @@ -1,8 +1,6 @@ package resize import ( - "context" - "fmt" "strings" "testing" @@ -11,11 +9,6 @@ import ( hivev1aws "github.com/openshift/hive/apis/hive/v1/aws" hivev1gcp "github.com/openshift/hive/apis/hive/v1/gcp" "github.com/openshift/osdctl/pkg/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/labels" ) // newTestCluster assembles a *cmv1.Cluster while handling the error to help out with inline test-case generation @@ -223,287 +216,6 @@ func TestConvertProviderIDtoInstanceID(t *testing.T) { } } -func TestSkipError(t *testing.T) { - tests := []struct { - name string - result result - msg string - expected bool - }{ - { - name: "no error", - result: result{ - condition: true, - err: nil, - }, - msg: "test message", - expected: true, - }, - { - name: "with error", - result: result{ - condition: false, - err: fmt.Errorf("test error"), - }, - msg: "test message", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - actual, err := skipError(test.result, test.msg) - if err != nil { - t.Errorf("expected nil error, got %v", err) - } - if actual != test.expected { - t.Errorf("expected condition %v, got %v", test.expected, actual) - } - }) - } -} - -func TestNodesMatchExpectedCount(t *testing.T) { - tests := []struct { - name string - labelSelector labels.Selector - expectedCount int - mockNodeList *corev1.NodeList - mockListError error - expectedMatch bool - expectedError error - }{ - { - name: "matching count", - labelSelector: labels.NewSelector(), - expectedCount: 2, - mockNodeList: &corev1.NodeList{ - Items: []corev1.Node{ - {ObjectMeta: metav1.ObjectMeta{Name: "node1"}}, - {ObjectMeta: metav1.ObjectMeta{Name: "node2"}}, - }, - }, - expectedMatch: true, - }, - { - name: "non-matching count", - labelSelector: labels.NewSelector(), - expectedCount: 2, - mockNodeList: &corev1.NodeList{ - Items: []corev1.Node{ - {ObjectMeta: metav1.ObjectMeta{Name: "node1"}}, - }, - }, - expectedMatch: false, - }, - { - name: "list error", - labelSelector: labels.NewSelector(), - expectedCount: 2, - mockListError: fmt.Errorf("list error"), - expectedError: fmt.Errorf("list error"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Create mock client - mockClient := &MockClient{} - mockClient.On("List", mock.Anything, mock.Anything, mock.Anything). - Return(test.mockListError). - Run(func(args mock.Arguments) { - if test.mockNodeList != nil { - arg := args.Get(1).(*corev1.NodeList) - *arg = *test.mockNodeList - } - }) - - // Create Infra instance with mock client - r := &Infra{ - client: mockClient, - } - - // Call the function - match, err := r.nodesMatchExpectedCount(context.Background(), test.labelSelector, test.expectedCount) - - // Verify results - if test.expectedError != nil { - if err == nil || err.Error() != test.expectedError.Error() { - t.Errorf("expected error %v, got %v", test.expectedError, err) - } - } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if match != test.expectedMatch { - t.Errorf("expected match %v, got %v", test.expectedMatch, match) - } - } - - // Verify mock was called correctly - mockClient.AssertExpectations(t) - }) - } -} - -func TestGetInfraMachinePool(t *testing.T) { - testNamespace := &corev1.Namespace{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-namespace", - Labels: map[string]string{ - "api.openshift.com/id": "test-cluster", - }, - }, - } - - testMachinePool := &hivev1.MachinePool{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-cluster-infra", - Namespace: "test-namespace", - }, - Spec: hivev1.MachinePoolSpec{ - Name: "infra", - Platform: hivev1.MachinePoolPlatform{ - AWS: &hivev1aws.MachinePoolPlatform{ - InstanceType: "r5.xlarge", - }, - }, - }, - } - - mockHive := &MockClient{} - firstCall := mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { - _, ok := obj.(*corev1.NamespaceList) - return ok - }), mock.Anything) - firstCall.Return(nil).Run(func(args mock.Arguments) { - nsList := args.Get(1).(*corev1.NamespaceList) - nsList.Items = []corev1.Namespace{*testNamespace} - }) - - secondCall := mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { - _, ok := obj.(*hivev1.MachinePoolList) - return ok - }), mock.Anything) - secondCall.Return(nil).Run(func(args mock.Arguments) { - mpList := args.Get(1).(*hivev1.MachinePoolList) - mpList.Items = []hivev1.MachinePool{*testMachinePool} - }) - - infra := &Infra{ - clusterId: "test-cluster", - hive: mockHive, - } - - mp, err := infra.getInfraMachinePool(context.Background()) - - assert.NoError(t, err) - assert.NotNil(t, mp) - assert.Equal(t, "infra", mp.Spec.Name) - assert.Equal(t, "r5.xlarge", mp.Spec.Platform.AWS.InstanceType) - mockHive.AssertExpectations(t) -} - -func TestGetInfraMachinePoolNoNamespace(t *testing.T) { - // Create mock client - mockHive := &MockClient{} - - // Set up mock expectations for namespace list - empty list - firstCall := mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { - _, ok := obj.(*corev1.NamespaceList) - return ok - }), mock.Anything) - firstCall.Return(nil).Run(func(args mock.Arguments) { - nsList := args.Get(1).(*corev1.NamespaceList) - nsList.Items = []corev1.Namespace{} // Empty list - }) - - // Create Infra instance - infra := &Infra{ - clusterId: "test-cluster", - hive: mockHive, - } - - // Call the function - mp, err := infra.getInfraMachinePool(context.Background()) - - // Verify results - assert.Error(t, err) - assert.Contains(t, err.Error(), "expected 1 namespace, found 0 namespaces with tag: api.openshift.com/id=test-cluster") - assert.Nil(t, mp) - - // Verify mock was called correctly - mockHive.AssertExpectations(t) -} - -func TestGetInfraMachinePoolNoInfraPool(t *testing.T) { - // Create a test namespace - testNamespace := &corev1.Namespace{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-namespace", - Labels: map[string]string{ - "api.openshift.com/id": "test-cluster", - }, - }, - } - - // Create a test machine pool (worker, not infra) - testMachinePool := &hivev1.MachinePool{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-cluster-worker", - Namespace: "test-namespace", - }, - Spec: hivev1.MachinePoolSpec{ - Name: "worker", // Not "infra" - Platform: hivev1.MachinePoolPlatform{ - AWS: &hivev1aws.MachinePoolPlatform{ - InstanceType: "r5.xlarge", - }, - }, - }, - } - - // Create mock client - mockHive := &MockClient{} - - // Set up mock expectations for namespace list - first call - firstCall := mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { - _, ok := obj.(*corev1.NamespaceList) - return ok - }), mock.Anything) - firstCall.Return(nil).Run(func(args mock.Arguments) { - nsList := args.Get(1).(*corev1.NamespaceList) - nsList.Items = []corev1.Namespace{*testNamespace} - }) - - // Set up mock expectations for machine pool list - second call - secondCall := mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { - _, ok := obj.(*hivev1.MachinePoolList) - return ok - }), mock.Anything) - secondCall.Return(nil).Run(func(args mock.Arguments) { - mpList := args.Get(1).(*hivev1.MachinePoolList) - mpList.Items = []hivev1.MachinePool{*testMachinePool} - }) - - // Create Infra instance - infra := &Infra{ - clusterId: "test-cluster", - hive: mockHive, - } - - // Call the function - mp, err := infra.getInfraMachinePool(context.Background()) - - // Verify results - assert.Error(t, err) - assert.Contains(t, err.Error(), "did not find the infra machinepool in namespace: test-namespace") - assert.Nil(t, mp) - - // Verify mock was called correctly - mockHive.AssertExpectations(t) -} - // TestHiveOcmUrlValidation tests the early validation of --hive-ocm-url flag in the infra resize command func TestHiveOcmUrlValidation(t *testing.T) { tests := []struct { diff --git a/docs/README.md b/docs/README.md index 280f0f170..be8f512bb 100644 --- a/docs/README.md +++ b/docs/README.md @@ -44,6 +44,7 @@ - `cleanup --cluster-id ` - Drop emergency access to a cluster - `cad` - Provides commands to run CAD tasks - `run` - Run a manual investigation on the CAD cluster + - `change-ebs-volume-type` - Change EBS volume type for control plane and/or infra nodes by replacing machines - `check-banned-user --cluster-id ` - Checks if the cluster owner is a banned user. - `context --cluster-id ` - Shows the context of a specified cluster - `cpd` - Runs diagnostic for a Cluster Provisioning Delay (CPD) @@ -1355,6 +1356,41 @@ osdctl cluster cad run [flags] -S, --skip-version-check skip checking to see if this is the most recent release ``` +### osdctl cluster change-ebs-volume-type + +Change the EBS volume type for control plane and/or infra nodes on a ROSA/OSD cluster. + +This command replaces machines to change volume types (not in-place modification). +For control plane nodes, it patches the ControlPlaneMachineSet (CPMS) which automatically +rolls nodes one at a time. For infra nodes, it uses the Hive MachinePool dance to safely +replace all infra nodes with new ones using the target volume type. + +Pre-flight checks are performed automatically before making changes. + +``` +osdctl cluster change-ebs-volume-type [flags] +``` + +#### Flags + +``` + --as string Username to impersonate for the operation. User could be a regular user or a service account in a namespace. + --cluster string The name of the kubeconfig cluster to use + -C, --cluster-id string The internal/external ID of the cluster + --context string The name of the kubeconfig context to use + -h, --help help for change-ebs-volume-type + --insecure-skip-tls-verify If true, the server's certificate will not be checked for validity. This will make your HTTPS connections insecure + --kubeconfig string Path to the kubeconfig file to use for CLI requests. + -o, --output string Valid formats are ['', 'json', 'yaml', 'env'] + --reason string Reason for elevation (OHSS/PD/JIRA ticket) + --request-timeout string The length of time to wait before giving up on a single server request. Non-zero values should contain a corresponding time unit (e.g. 1s, 2m, 3h). A value of zero means don't timeout requests. (default "0") + --role string Node role to change: control-plane, infra (default: both) + -s, --server string The address and port of the Kubernetes API server + --skip-aws-proxy-check aws_proxy Don't use the configured aws_proxy value + -S, --skip-version-check skip checking to see if this is the most recent release + --type string Target EBS volume type (gp3) +``` + ### osdctl cluster check-banned-user Checks if the cluster owner is a banned user. diff --git a/docs/osdctl_cluster.md b/docs/osdctl_cluster.md index d4e8e315f..d83147931 100644 --- a/docs/osdctl_cluster.md +++ b/docs/osdctl_cluster.md @@ -28,6 +28,7 @@ Provides information for a specified cluster * [osdctl](osdctl.md) - OSD CLI * [osdctl cluster break-glass](osdctl_cluster_break-glass.md) - Emergency access to a cluster * [osdctl cluster cad](osdctl_cluster_cad.md) - Provides commands to run CAD tasks +* [osdctl cluster change-ebs-volume-type](osdctl_cluster_change-ebs-volume-type.md) - Change EBS volume type for control plane and/or infra nodes by replacing machines * [osdctl cluster check-banned-user](osdctl_cluster_check-banned-user.md) - Checks if the cluster owner is a banned user. * [osdctl cluster context](osdctl_cluster_context.md) - Shows the context of a specified cluster * [osdctl cluster cpd](osdctl_cluster_cpd.md) - Runs diagnostic for a Cluster Provisioning Delay (CPD) diff --git a/docs/osdctl_cluster_change-ebs-volume-type.md b/docs/osdctl_cluster_change-ebs-volume-type.md new file mode 100644 index 000000000..1b64d3f9d --- /dev/null +++ b/docs/osdctl_cluster_change-ebs-volume-type.md @@ -0,0 +1,61 @@ +## osdctl cluster change-ebs-volume-type + +Change EBS volume type for control plane and/or infra nodes by replacing machines + +### Synopsis + +Change the EBS volume type for control plane and/or infra nodes on a ROSA/OSD cluster. + +This command replaces machines to change volume types (not in-place modification). +For control plane nodes, it patches the ControlPlaneMachineSet (CPMS) which automatically +rolls nodes one at a time. For infra nodes, it uses the Hive MachinePool dance to safely +replace all infra nodes with new ones using the target volume type. + +Pre-flight checks are performed automatically before making changes. + +``` +osdctl cluster change-ebs-volume-type [flags] +``` + +### Examples + +``` + # Change both control plane and infra volumes to gp3 + osdctl cluster change-ebs-volume-type -C ${CLUSTER_ID} --type gp3 --reason "SREP-3811" + + # Change only control plane volumes to gp3 + osdctl cluster change-ebs-volume-type -C ${CLUSTER_ID} --type gp3 --role control-plane --reason "SREP-3811" + + # Change only infra volumes to gp3 + osdctl cluster change-ebs-volume-type -C ${CLUSTER_ID} --type gp3 --role infra --reason "SREP-3811" +``` + +### Options + +``` + -C, --cluster-id string The internal/external ID of the cluster + -h, --help help for change-ebs-volume-type + --reason string Reason for elevation (OHSS/PD/JIRA ticket) + --role string Node role to change: control-plane, infra (default: both) + --type string Target EBS volume type (gp3) +``` + +### Options inherited from parent commands + +``` + --as string Username to impersonate for the operation. User could be a regular user or a service account in a namespace. + --cluster string The name of the kubeconfig cluster to use + --context string The name of the kubeconfig context to use + --insecure-skip-tls-verify If true, the server's certificate will not be checked for validity. This will make your HTTPS connections insecure + --kubeconfig string Path to the kubeconfig file to use for CLI requests. + -o, --output string Valid formats are ['', 'json', 'yaml', 'env'] + --request-timeout string The length of time to wait before giving up on a single server request. Non-zero values should contain a corresponding time unit (e.g. 1s, 2m, 3h). A value of zero means don't timeout requests. (default "0") + -s, --server string The address and port of the Kubernetes API server + --skip-aws-proxy-check aws_proxy Don't use the configured aws_proxy value + -S, --skip-version-check skip checking to see if this is the most recent release +``` + +### SEE ALSO + +* [osdctl cluster](osdctl_cluster.md) - Provides information for a specified cluster + diff --git a/pkg/infra/machinepool.go b/pkg/infra/machinepool.go new file mode 100644 index 000000000..753b66adc --- /dev/null +++ b/pkg/infra/machinepool.go @@ -0,0 +1,323 @@ +package infra + +import ( + "context" + "fmt" + "log" + "time" + + hivev1 "github.com/openshift/hive/apis/hive/v1" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/selection" + "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + pollTimeout = 20 * time.Minute + pollInterval = 20 * time.Second + InfraNodeLabel = "node-role.kubernetes.io/infra" + TemporaryInfraNodeLabel = "osdctl.openshift.io/infra-resize-temporary-machinepool" +) + +// GetInfraMachinePool finds the infra MachinePool in Hive for the given cluster ID. +func GetInfraMachinePool(ctx context.Context, hiveClient client.Client, clusterID string) (*hivev1.MachinePool, error) { + ns := &corev1.NamespaceList{} + selector, err := labels.Parse(fmt.Sprintf("api.openshift.com/id=%s", clusterID)) + if err != nil { + return nil, err + } + + if err := hiveClient.List(ctx, ns, &client.ListOptions{LabelSelector: selector, Limit: 1}); err != nil { + return nil, err + } + if len(ns.Items) != 1 { + return nil, fmt.Errorf("expected 1 namespace, found %d namespaces with tag: api.openshift.com/id=%s", len(ns.Items), clusterID) + } + + log.Printf("found namespace: %s", ns.Items[0].Name) + + mpList := &hivev1.MachinePoolList{} + if err := hiveClient.List(ctx, mpList, &client.ListOptions{Namespace: ns.Items[0].Name}); err != nil { + return nil, err + } + + for _, mp := range mpList.Items { + if mp.Spec.Name == "infra" { + log.Printf("found machinepool %s", mp.Name) + return &mp, nil + } + } + + return nil, fmt.Errorf("did not find the infra machinepool in namespace: %s", ns.Items[0].Name) +} + +// CloneMachinePool deep copies a MachinePool, resets metadata fields so it can +// be created as a new resource, and applies the modifier function. +func CloneMachinePool(mp *hivev1.MachinePool, modifyFn func(*hivev1.MachinePool) error) (*hivev1.MachinePool, error) { + newMp := &hivev1.MachinePool{} + mp.DeepCopyInto(newMp) + + newMp.CreationTimestamp = metav1.Time{} + newMp.Finalizers = []string{} + newMp.ResourceVersion = "" + newMp.Generation = 0 + newMp.UID = "" + newMp.Status = hivev1.MachinePoolStatus{} + + if modifyFn != nil { + if err := modifyFn(newMp); err != nil { + return nil, err + } + } + + return newMp, nil +} + +// DanceClients holds the k8s clients needed for the machinepool dance. +type DanceClients struct { + ClusterClient client.Client + HiveClient client.Client + HiveAdmin client.Client +} + +// RunMachinePoolDance performs the machinepool dance to replace infra nodes. +// It takes the original MachinePool and an already-modified new MachinePool. +// The dance creates a temporary pool, waits for nodes, deletes the original, +// creates a permanent replacement, then removes the temporary pool. +// +// The onTimeout callback is called when nodes fail to drain within the timeout. +// It receives the list of stuck nodes and should terminate the backing instances. +// If onTimeout is nil, the dance will return an error on timeout. +func RunMachinePoolDance(ctx context.Context, clients DanceClients, originalMp, newMp *hivev1.MachinePool, onTimeout func(ctx context.Context, nodes *corev1.NodeList) error) error { + tempMp := newMp.DeepCopy() + tempMp.Name = fmt.Sprintf("%s2", tempMp.Name) + tempMp.Spec.Name = fmt.Sprintf("%s2", tempMp.Spec.Name) + tempMp.Spec.Labels[TemporaryInfraNodeLabel] = "" + + // Create the temporary machinepool + log.Printf("creating temporary machinepool %s", tempMp.Name) + if err := clients.HiveAdmin.Create(ctx, tempMp); err != nil { + return err + } + + // Wait for 2x infra nodes to be Ready + selector, err := labels.Parse(InfraNodeLabel) + if err != nil { + return err + } + + pollCtx, cancel := context.WithTimeout(ctx, pollTimeout) + defer cancel() + if err := wait.PollUntilContextTimeout(pollCtx, pollInterval, pollTimeout, true, func(ctx context.Context) (bool, error) { + nodes := &corev1.NodeList{} + if err := clients.ClusterClient.List(ctx, nodes, &client.ListOptions{LabelSelector: selector}); err != nil { + log.Printf("error retrieving nodes list, continuing to wait: %s", err) + return false, nil + } + + readyNodes := countReadyNodes(nodes) + expected := int(*originalMp.Spec.Replicas) * 2 + log.Printf("waiting for %d infra nodes to be reporting Ready, found %d", expected, readyNodes) + + return readyNodes >= expected, nil + }); err != nil { + return err + } + + // Build selectors for original vs temp nodes + requireInfra, err := labels.NewRequirement(InfraNodeLabel, selection.Exists, nil) + if err != nil { + return err + } + requireNotTempNode, err := labels.NewRequirement(TemporaryInfraNodeLabel, selection.DoesNotExist, nil) + if err != nil { + return err + } + requireTempNode, err := labels.NewRequirement(TemporaryInfraNodeLabel, selection.Exists, nil) + if err != nil { + return err + } + + originalNodeSelector := selector.Add(*requireInfra, *requireNotTempNode) + tempNodeSelector := selector.Add(*requireInfra, *requireTempNode) + + originalNodes := &corev1.NodeList{} + if err := clients.ClusterClient.List(ctx, originalNodes, &client.ListOptions{LabelSelector: originalNodeSelector}); err != nil { + return err + } + + // Delete original machinepool + log.Printf("deleting original machinepool %s", originalMp.Name) + if err := clients.HiveAdmin.Delete(ctx, originalMp); err != nil { + return err + } + + // Wait for original machinepool to delete + if err := waitForMachinePoolDeletion(ctx, clients.HiveClient, originalMp); err != nil { + return err + } + + // Wait for original nodes to delete + if err := waitForNodesDeletion(ctx, clients.ClusterClient, originalNodeSelector, onTimeout, originalNodes); err != nil { + return err + } + + // Create new permanent machinepool + log.Printf("creating new permanent machinepool %s", newMp.Name) + if err := clients.HiveAdmin.Create(ctx, newMp); err != nil { + return err + } + + // Wait for new permanent machines to become nodes + pollCtx2, cancel2 := context.WithTimeout(ctx, pollTimeout) + defer cancel2() + if err := wait.PollUntilContextTimeout(pollCtx2, pollInterval, pollTimeout, true, func(ctx context.Context) (bool, error) { + nodes := &corev1.NodeList{} + infraSelector, err := labels.Parse("node-role.kubernetes.io/infra=") + if err != nil { + return false, err + } + if err := clients.ClusterClient.List(ctx, nodes, &client.ListOptions{LabelSelector: infraSelector}); err != nil { + log.Printf("error retrieving nodes list, continuing to wait: %s", err) + return false, nil + } + + readyNodes := countReadyNodes(nodes) + expected := int(*originalMp.Spec.Replicas) * 2 + log.Printf("waiting for %d infra nodes to be reporting Ready, found %d", expected, readyNodes) + + return readyNodes >= expected, nil + }); err != nil { + return err + } + + tempNodes := &corev1.NodeList{} + if err := clients.ClusterClient.List(ctx, tempNodes, &client.ListOptions{LabelSelector: tempNodeSelector}); err != nil { + return err + } + + // Delete temp machinepool + log.Printf("deleting temporary machinepool %s", tempMp.Name) + if err := clients.HiveAdmin.Delete(ctx, tempMp); err != nil { + return err + } + + // Wait for temporary machinepool to delete + if err := waitForMachinePoolDeletion(ctx, clients.HiveClient, tempMp); err != nil { + return err + } + + // Wait for infra node count to return to normal + log.Printf("waiting for infra node count to return to: %d", int(*originalMp.Spec.Replicas)) + pollCtx3, cancel3 := context.WithTimeout(ctx, pollTimeout) + defer cancel3() + if err := wait.PollUntilContextTimeout(pollCtx3, pollInterval, pollTimeout, true, func(ctx context.Context) (bool, error) { + nodes := &corev1.NodeList{} + infraSelector, err := labels.Parse("node-role.kubernetes.io/infra=") + if err != nil { + return false, err + } + if err := clients.ClusterClient.List(ctx, nodes, &client.ListOptions{LabelSelector: infraSelector}); err != nil { + log.Printf("error retrieving nodes list, continuing to wait: %s", err) + return false, nil + } + + switch len(nodes.Items) { + case int(*originalMp.Spec.Replicas): + log.Printf("found %d infra nodes, replacement complete", len(nodes.Items)) + return true, nil + default: + log.Printf("found %d infra nodes, continuing to wait", len(nodes.Items)) + return false, nil + } + }); err != nil { + if wait.Interrupted(err) && onTimeout != nil { + log.Printf("Warning: timed out waiting for nodes to drain: %v. Terminating backing cloud instances.", err.Error()) + if err := onTimeout(ctx, tempNodes); err != nil { + return err + } + if err := waitForNodesGone(ctx, clients.ClusterClient, tempNodeSelector); err != nil { + return err + } + } else { + return err + } + } + + return nil +} + +func waitForMachinePoolDeletion(ctx context.Context, hiveClient client.Client, mp *hivev1.MachinePool) error { + pollCtx, cancel := context.WithTimeout(ctx, pollTimeout) + defer cancel() + return wait.PollUntilContextTimeout(pollCtx, pollInterval, pollTimeout, true, func(ctx context.Context) (bool, error) { + existing := &hivev1.MachinePool{} + err := hiveClient.Get(ctx, client.ObjectKey{Namespace: mp.Namespace, Name: mp.Name}, existing) + if err != nil { + if apierrors.IsNotFound(err) { + return true, nil + } + log.Printf("error checking machinepool %s/%s, continuing to wait: %s", mp.Namespace, mp.Name, err) + return false, nil + } + log.Printf("machinepool %s/%s still exists, continuing to wait", mp.Namespace, mp.Name) + return false, nil + }) +} + +func waitForNodesDeletion(ctx context.Context, clusterClient client.Client, selector labels.Selector, onTimeout func(ctx context.Context, nodes *corev1.NodeList) error, originalNodes *corev1.NodeList) error { + pollCtx, cancel := context.WithTimeout(ctx, pollTimeout) + defer cancel() + if err := wait.PollUntilContextTimeout(pollCtx, pollInterval, pollTimeout, true, func(ctx context.Context) (bool, error) { + return nodesMatchExpectedCount(ctx, clusterClient, selector, 0) + }); err != nil { + if wait.Interrupted(err) && onTimeout != nil { + log.Printf("Warning: timed out waiting for nodes to drain: %v. Terminating backing cloud instances.", err.Error()) + if err := onTimeout(ctx, originalNodes); err != nil { + return err + } + return waitForNodesGone(ctx, clusterClient, selector) + } + return err + } + return nil +} + +func waitForNodesGone(ctx context.Context, clusterClient client.Client, selector labels.Selector) error { + pollCtx, cancel := context.WithTimeout(ctx, pollTimeout) + defer cancel() + return wait.PollUntilContextTimeout(pollCtx, pollInterval, pollTimeout, true, func(ctx context.Context) (bool, error) { + log.Printf("waiting for nodes to terminate") + match, err := nodesMatchExpectedCount(ctx, clusterClient, selector, 0) + if err != nil { + log.Printf("error matching expected count, continuing to wait: %s", err) + return false, nil + } + return match, nil + }) +} + +func nodesMatchExpectedCount(ctx context.Context, clusterClient client.Client, labelSelector labels.Selector, count int) (bool, error) { + nodeList := &corev1.NodeList{} + if err := clusterClient.List(ctx, nodeList, &client.ListOptions{LabelSelector: labelSelector}); err != nil { + return false, err + } + return len(nodeList.Items) == count, nil +} + +func countReadyNodes(nodes *corev1.NodeList) int { + ready := 0 + for _, node := range nodes.Items { + for _, cond := range node.Status.Conditions { + if cond.Type == corev1.NodeReady && cond.Status == corev1.ConditionTrue { + ready++ + log.Printf("found node %s reporting Ready", node.Name) + } + } + } + return ready +} diff --git a/pkg/infra/machinepool_test.go b/pkg/infra/machinepool_test.go new file mode 100644 index 000000000..92eecf287 --- /dev/null +++ b/pkg/infra/machinepool_test.go @@ -0,0 +1,350 @@ +package infra + +import ( + "context" + "fmt" + "testing" + + hivev1 "github.com/openshift/hive/apis/hive/v1" + hivev1aws "github.com/openshift/hive/apis/hive/v1/aws" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// MockClient is a mock implementation of the client.Client interface +type MockClient struct { + mock.Mock +} + +func (m *MockClient) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + args := m.Called(ctx, list, opts) + return args.Error(0) +} + +func (m *MockClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + args := m.Called(ctx, key, obj, opts) + return args.Error(0) +} + +func (m *MockClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *MockClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *MockClient) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *MockClient) Patch(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption) error { + args := m.Called(ctx, obj, patch, opts) + return args.Error(0) +} + +func (m *MockClient) DeleteAllOf(ctx context.Context, obj client.Object, opts ...client.DeleteAllOfOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *MockClient) GroupVersionKindFor(obj runtime.Object) (schema.GroupVersionKind, error) { + args := m.Called(obj) + return args.Get(0).(schema.GroupVersionKind), args.Error(1) +} + +func (m *MockClient) IsObjectNamespaced(obj runtime.Object) (bool, error) { + args := m.Called(obj) + return args.Bool(0), args.Error(1) +} + +func (m *MockClient) RESTMapper() meta.RESTMapper { + args := m.Called() + return args.Get(0).(meta.RESTMapper) +} + +func (m *MockClient) Scheme() *runtime.Scheme { + args := m.Called() + return args.Get(0).(*runtime.Scheme) +} + +func (m *MockClient) Status() client.StatusWriter { + args := m.Called() + return args.Get(0).(client.StatusWriter) +} + +func (m *MockClient) SubResource(subResource string) client.SubResourceClient { + args := m.Called(subResource) + return args.Get(0).(client.SubResourceClient) +} + +func TestGetInfraMachinePool(t *testing.T) { + testNamespace := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Labels: map[string]string{ + "api.openshift.com/id": "test-cluster", + }, + }, + } + + testMachinePool := &hivev1.MachinePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-infra", + Namespace: "test-namespace", + }, + Spec: hivev1.MachinePoolSpec{ + Name: "infra", + Platform: hivev1.MachinePoolPlatform{ + AWS: &hivev1aws.MachinePoolPlatform{ + InstanceType: "r5.xlarge", + }, + }, + }, + } + + mockHive := &MockClient{} + mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { + _, ok := obj.(*corev1.NamespaceList) + return ok + }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { + nsList := args.Get(1).(*corev1.NamespaceList) + nsList.Items = []corev1.Namespace{*testNamespace} + }) + + mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { + _, ok := obj.(*hivev1.MachinePoolList) + return ok + }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { + mpList := args.Get(1).(*hivev1.MachinePoolList) + mpList.Items = []hivev1.MachinePool{*testMachinePool} + }) + + mp, err := GetInfraMachinePool(context.Background(), mockHive, "test-cluster") + + assert.NoError(t, err) + assert.NotNil(t, mp) + assert.Equal(t, "infra", mp.Spec.Name) + assert.Equal(t, "r5.xlarge", mp.Spec.Platform.AWS.InstanceType) + mockHive.AssertExpectations(t) +} + +func TestGetInfraMachinePoolNoNamespace(t *testing.T) { + mockHive := &MockClient{} + mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { + _, ok := obj.(*corev1.NamespaceList) + return ok + }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { + nsList := args.Get(1).(*corev1.NamespaceList) + nsList.Items = []corev1.Namespace{} + }) + + mp, err := GetInfraMachinePool(context.Background(), mockHive, "test-cluster") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "expected 1 namespace, found 0") + assert.Nil(t, mp) + mockHive.AssertExpectations(t) +} + +func TestGetInfraMachinePoolNoInfraPool(t *testing.T) { + testNamespace := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-namespace", + Labels: map[string]string{"api.openshift.com/id": "test-cluster"}, + }, + } + + mockHive := &MockClient{} + mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { + _, ok := obj.(*corev1.NamespaceList) + return ok + }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { + nsList := args.Get(1).(*corev1.NamespaceList) + nsList.Items = []corev1.Namespace{*testNamespace} + }) + + mockHive.On("List", mock.Anything, mock.MatchedBy(func(obj interface{}) bool { + _, ok := obj.(*hivev1.MachinePoolList) + return ok + }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { + mpList := args.Get(1).(*hivev1.MachinePoolList) + mpList.Items = []hivev1.MachinePool{ + { + ObjectMeta: metav1.ObjectMeta{Name: "test-cluster-worker", Namespace: "test-namespace"}, + Spec: hivev1.MachinePoolSpec{Name: "worker"}, + }, + } + }) + + mp, err := GetInfraMachinePool(context.Background(), mockHive, "test-cluster") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "did not find the infra machinepool") + assert.Nil(t, mp) + mockHive.AssertExpectations(t) +} + +func TestCloneMachinePool(t *testing.T) { + originalMp := &hivev1.MachinePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-infra", + Namespace: "test-namespace", + ResourceVersion: "12345", + Generation: 3, + UID: "abc-123", + Finalizers: []string{"hive.openshift.io/machinepool"}, + }, + Spec: hivev1.MachinePoolSpec{ + Name: "infra", + Replicas: int64Ptr(2), + Labels: map[string]string{ + "node-role.kubernetes.io/infra": "", + }, + Platform: hivev1.MachinePoolPlatform{ + AWS: &hivev1aws.MachinePoolPlatform{ + InstanceType: "r5.xlarge", + EC2RootVolume: hivev1aws.EC2RootVolume{ + IOPS: 3000, + Size: 300, + Type: "io1", + }, + }, + }, + }, + } + + t.Run("success - changes volume type", func(t *testing.T) { + result, err := CloneMachinePool(originalMp, func(mp *hivev1.MachinePool) error { + mp.Spec.Platform.AWS.Type = "gp3" + mp.Spec.Platform.AWS.IOPS = 0 + return nil + }) + assert.NoError(t, err) + assert.NotNil(t, result) + + // Verify modifier was applied + assert.Equal(t, "gp3", result.Spec.Platform.AWS.Type) + assert.Equal(t, 0, result.Spec.Platform.AWS.IOPS) + + // Verify metadata was reset + assert.Empty(t, result.ResourceVersion) + assert.Equal(t, int64(0), result.Generation) + assert.Empty(t, string(result.UID)) + assert.Empty(t, result.Finalizers) + assert.Equal(t, metav1.Time{}, result.CreationTimestamp) + + // Verify other fields preserved + assert.Equal(t, "test-cluster-infra", result.Name) + assert.Equal(t, "test-namespace", result.Namespace) + assert.Equal(t, "infra", result.Spec.Name) + assert.Equal(t, int64(2), *result.Spec.Replicas) + assert.Equal(t, "r5.xlarge", result.Spec.Platform.AWS.InstanceType) + assert.Equal(t, 300, result.Spec.Platform.AWS.Size) + + // Verify original is unchanged + assert.Equal(t, "io1", originalMp.Spec.Platform.AWS.Type) + assert.Equal(t, 3000, originalMp.Spec.Platform.AWS.IOPS) + }) + + t.Run("nil modifier just clones", func(t *testing.T) { + result, err := CloneMachinePool(originalMp, nil) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Empty(t, result.ResourceVersion) + assert.Equal(t, "io1", result.Spec.Platform.AWS.Type) + }) + + t.Run("modifier error is propagated", func(t *testing.T) { + result, err := CloneMachinePool(originalMp, func(mp *hivev1.MachinePool) error { + return fmt.Errorf("infra volumes are already gp3") + }) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "infra volumes are already gp3") + }) +} + +func TestNodesMatchExpectedCount(t *testing.T) { + tests := []struct { + name string + expectedCount int + mockNodeList *corev1.NodeList + mockListError error + expectedMatch bool + expectedError error + }{ + { + name: "matching count", + expectedCount: 2, + mockNodeList: &corev1.NodeList{ + Items: []corev1.Node{ + {ObjectMeta: metav1.ObjectMeta{Name: "node1"}}, + {ObjectMeta: metav1.ObjectMeta{Name: "node2"}}, + }, + }, + expectedMatch: true, + }, + { + name: "non-matching count", + expectedCount: 2, + mockNodeList: &corev1.NodeList{ + Items: []corev1.Node{ + {ObjectMeta: metav1.ObjectMeta{Name: "node1"}}, + }, + }, + expectedMatch: false, + }, + { + name: "list error", + expectedCount: 2, + mockListError: fmt.Errorf("list error"), + expectedError: fmt.Errorf("list error"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mockClient := &MockClient{} + mockClient.On("List", mock.Anything, mock.Anything, mock.Anything). + Return(test.mockListError). + Run(func(args mock.Arguments) { + if test.mockNodeList != nil { + arg := args.Get(1).(*corev1.NodeList) + *arg = *test.mockNodeList + } + }) + + match, err := nodesMatchExpectedCount(context.Background(), mockClient, labels.NewSelector(), test.expectedCount) + + if test.expectedError != nil { + if err == nil || err.Error() != test.expectedError.Error() { + t.Errorf("expected error %v, got %v", test.expectedError, err) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if match != test.expectedMatch { + t.Errorf("expected match %v, got %v", test.expectedMatch, match) + } + } + + mockClient.AssertExpectations(t) + }) + } +} + +func int64Ptr(i int64) *int64 { + return &i +}