Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ package coreweave

import (
"fmt"
"sync"

apiv1 "k8s.io/api/core/v1"
"k8s.io/autoscaler/cluster-autoscaler/cloudprovider"
"k8s.io/autoscaler/cluster-autoscaler/config"
"k8s.io/autoscaler/cluster-autoscaler/simulator/framework"
"k8s.io/klog/v2"
"sync"
)

// CoreWeaveNodeGroup represents a node group in the CoreWeave cloud provider.
Expand Down Expand Up @@ -84,13 +84,6 @@ func (ng *CoreWeaveNodeGroup) DeleteNodes(nodes []*apiv1.Node) error {
if err != nil {
return fmt.Errorf("some nodes do not belong to node group %s: %v", ng.Name, err)
}
// If we reach here, it means we can delete the nodes
for _, node := range nodes {
// Mark the node for removal
if err := ng.nodepool.MarkNodeForRemoval(node); err != nil {
return fmt.Errorf("failed to mark node %s for removal: %v", node.Name, err)
}
}
//update target size
if err := ng.nodepool.SetSize(ng.nodepool.GetTargetSize() - len(nodes)); err != nil {
return fmt.Errorf("failed to update target size after marking nodes for removal: %v", err)
Expand All @@ -107,6 +100,9 @@ func (ng *CoreWeaveNodeGroup) ForceDeleteNodes(nodes []*apiv1.Node) error {
// DecreaseTargetSize decreases the target size of the node group by the specified delta.
func (ng *CoreWeaveNodeGroup) DecreaseTargetSize(delta int) error {
klog.V(4).Infof("Decreasing target size of node group %s by %d", ng.Name, delta)
if delta < 0 {
delta = -delta
}
return ng.nodepool.SetSize(ng.nodepool.GetTargetSize() - delta)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"context"
"testing"

"github.com/stretchr/testify/require"

apiv1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
Expand Down Expand Up @@ -109,18 +111,92 @@ func TestIncreaseSize(t *testing.T) {
}

func TestDeleteNodes(t *testing.T) {
ng := makeTestNodeGroup("ng-1", "uid-1", 0, 5, 3)
validNode := &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "node1",
Labels: map[string]string{coreWeaveNodePoolUID: "uid-1"},
initialTargetSize := int64(3)

testCases := map[string]struct {
nodesToDelete []*apiv1.Node
expectedTargetSize int
expectedError error
}{
"reduce-target-size-by-one-node": {
nodesToDelete: []*apiv1.Node{
{
ObjectMeta: metav1.ObjectMeta{
Name: "node1",
Labels: map[string]string{coreWeaveNodePoolUID: "uid-1"},
},
},
},
expectedTargetSize: 2,
},
"reduce-target-size-by-three-node": {
nodesToDelete: []*apiv1.Node{
{
ObjectMeta: metav1.ObjectMeta{
Name: "node1",
Labels: map[string]string{coreWeaveNodePoolUID: "uid-1"},
},
},
{
ObjectMeta: metav1.ObjectMeta{
Name: "node2",
Labels: map[string]string{coreWeaveNodePoolUID: "uid-1"},
},
},
{
ObjectMeta: metav1.ObjectMeta{
Name: "node3",
Labels: map[string]string{coreWeaveNodePoolUID: "uid-1"},
},
},
},
expectedTargetSize: 0,
},
}
nodes := []*apiv1.Node{
validNode,

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
ng := makeTestNodeGroup("ng-1", "uid-1", 0, 5, initialTargetSize)

err := ng.DeleteNodes(tc.nodesToDelete)
if tc.expectedError != nil {
require.Equal(t, tc.expectedError, err)
return
}
require.NoError(t, err)
require.Equal(t, ng.nodepool.GetTargetSize(), tc.expectedTargetSize)
})
}
err := ng.DeleteNodes(nodes)
if err != nil && err != cloudprovider.ErrNotImplemented {
t.Errorf("expected ErrNotImplemented or nil, got %v", err)
}

func TestDecreaseTargetSize(t *testing.T) {
testCases := map[string]struct {
delta int
expectedTargetSize int
expectedError error
}{
"positive-delta": {
delta: 2,
expectedTargetSize: 1,
},
"negative-delta": {
delta: -2,
expectedTargetSize: 1,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
ng := makeTestNodeGroup("ng-1", "uid-1", 1, 5, 3)

err := ng.DecreaseTargetSize(tc.delta)
if tc.expectedError != nil {
require.Error(t, err)
require.Equal(t, tc.expectedError, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.expectedTargetSize, ng.nodepool.GetTargetSize())
})
}
}
36 changes: 0 additions & 36 deletions cluster-autoscaler/cloudprovider/coreweave/coreweave_nodepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,42 +187,6 @@ func (np *CoreWeaveNodePool) SetSize(size int) error {
return nil
}

// MarkNodeForRemoval marks a node for removal from the node pool.
func (np *CoreWeaveNodePool) MarkNodeForRemoval(node *apiv1.Node) error {
ctx, cancel := GetCoreWeaveContext()
defer cancel()
if node == nil {
return fmt.Errorf("node cannot be nil")
}
if node.Name == "" {
return fmt.Errorf("node name cannot be empty")
}
// Log the node being marked for removal
klog.V(4).Infof("Marking node %s for removal from node pool %s", node.Name, np.GetName())
// Fetch the current node object
currentNode, err := np.client.CoreV1().Nodes().Get(ctx, node.Name, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("failed to get node %s: %v", node.Name, err)
}
// Check if the node belongs to this node pool
if currentNode.Labels == nil || currentNode.Labels[coreWeaveNodePoolUID] != np.GetUID() {
return fmt.Errorf("node %s does not belong to node pool %s", node.Name, np.GetName())
}
// Check if the node is already marked for removal
if currentNode.Labels != nil && currentNode.Labels[coreWeaveRemoveNode] == "true" {
klog.V(4).Infof("Node %s is already marked for removal", currentNode.Name)
return nil // Node is already marked for removal, no action needed
}
// Set the label to indicate the node should be removed
currentNode.Labels[coreWeaveRemoveNode] = "true"
// Update the node using the client
_, err = np.client.CoreV1().Nodes().Update(ctx, currentNode, metav1.UpdateOptions{})
if err != nil {
return fmt.Errorf("failed to mark node %s for removal: %v", node.Name, err)
}
return nil
}

// ValidateNodes checks if the provided nodes belong to the node pool.
func (np *CoreWeaveNodePool) ValidateNodes(nodes []*apiv1.Node) error {
if len(nodes) == 0 {
Expand Down