diff --git a/internal/ballot/ballot_test.go b/internal/ballot/ballot_test.go index 7a5017a..2605f7e 100644 --- a/internal/ballot/ballot_test.go +++ b/internal/ballot/ballot_test.go @@ -2356,3 +2356,560 @@ func TestHandleServiceCriticalState_WarningState(t *testing.T) { err := b.handleServiceCriticalState() assert.NoError(t, err) } + +func TestConsulClientWrappers(t *testing.T) { + // Test consulClient wrapper methods + // These tests verify that the wrapper methods correctly delegate to the underlying api.Client + // Note: These tests use a nil api.Client because we're testing the wrapper structure, + // not the actual Consul API calls + + t.Run("consulClient Agent returns AgentWrapper", func(t *testing.T) { + // Create a consul client with default config (will work even without running Consul) + apiClient, err := api.NewClient(api.DefaultConfig()) + assert.NoError(t, err) + + cc := &consulClient{client: apiClient} + agent := cc.Agent() + assert.NotNil(t, agent) + + // Verify it returns an AgentWrapper + _, ok := agent.(*AgentWrapper) + assert.True(t, ok, "Agent() should return *AgentWrapper") + }) + + t.Run("consulClient Catalog returns CatalogWrapper", func(t *testing.T) { + apiClient, err := api.NewClient(api.DefaultConfig()) + assert.NoError(t, err) + + cc := &consulClient{client: apiClient} + catalog := cc.Catalog() + assert.NotNil(t, catalog) + + _, ok := catalog.(*CatalogWrapper) + assert.True(t, ok, "Catalog() should return *CatalogWrapper") + }) + + t.Run("consulClient Health returns HealthWrapper", func(t *testing.T) { + apiClient, err := api.NewClient(api.DefaultConfig()) + assert.NoError(t, err) + + cc := &consulClient{client: apiClient} + health := cc.Health() + assert.NotNil(t, health) + + _, ok := health.(*HealthWrapper) + assert.True(t, ok, "Health() should return *HealthWrapper") + }) + + t.Run("consulClient Session returns SessionWrapper", func(t *testing.T) { + apiClient, err := api.NewClient(api.DefaultConfig()) + assert.NoError(t, err) + + cc := &consulClient{client: apiClient} + session := cc.Session() + assert.NotNil(t, session) + + _, ok := session.(*SessionWrapper) + assert.True(t, ok, "Session() should return *SessionWrapper") + }) + + t.Run("consulClient KV returns KVWrapper", func(t *testing.T) { + apiClient, err := api.NewClient(api.DefaultConfig()) + assert.NoError(t, err) + + cc := &consulClient{client: apiClient} + kv := cc.KV() + assert.NotNil(t, kv) + + _, ok := kv.(*KVWrapper) + assert.True(t, ok, "KV() should return *KVWrapper") + }) +} + +func TestCommandExecutor(t *testing.T) { + t.Run("CommandContext creates exec.Cmd", func(t *testing.T) { + executor := &commandExecutor{} + ctx := context.Background() + + cmd := executor.CommandContext(ctx, "echo", "hello") + assert.NotNil(t, cmd) + // Path is resolved to full path by exec.LookPath, check it contains "echo" + assert.Contains(t, cmd.Path, "echo") + assert.Contains(t, cmd.Args, "echo") + assert.Contains(t, cmd.Args, "hello") + }) + + t.Run("CommandContext with no args", func(t *testing.T) { + executor := &commandExecutor{} + ctx := context.Background() + + cmd := executor.CommandContext(ctx, "true") + assert.NotNil(t, cmd) + }) +} + +func TestHealthWrapper_Checks(t *testing.T) { + // Test HealthWrapper directly with a mock Health + apiClient, err := api.NewClient(api.DefaultConfig()) + assert.NoError(t, err) + + hw := &HealthWrapper{health: apiClient.Health()} + assert.NotNil(t, hw) + + // We can't actually call Checks without a running Consul, + // but we verify the wrapper is correctly structured +} + +func TestUpdateServiceTags_WithCommandExecution(t *testing.T) { + t.Run("Executes ExecOnPromote when becoming leader", func(t *testing.T) { + serviceID := "test_service_id" + serviceName := "test_service" + primaryTag := "primary" + sessionID := "session_id" + + mockAgent := new(MockAgent) + mockCatalog := new(MockCatalog) + mockKV := new(MockKV) + mockClient := &MockConsulClient{} + + // Service without primary tag + baseService := &api.AgentService{ + ID: serviceID, + Service: serviceName, + Tags: []string{"tag1"}, + Port: 8080, + Address: "127.0.0.1", + } + + mockAgent.On("Service", serviceID, mock.Anything).Return(baseService, nil, nil) + mockCatalog.On("Service", serviceName, primaryTag, mock.Anything).Return([]*api.CatalogService{}, nil, nil) + mockAgent.On("ServiceRegister", mock.Anything).Return(nil) + + // Mock KV for session data retrieval + payload := &ElectionPayload{ + Address: "127.0.0.1", + Port: 8080, + SessionID: sessionID, + } + data, _ := json.Marshal(payload) + mockKV.On("Get", "election/test/leader", mock.Anything).Return(&api.KVPair{ + Key: "election/test/leader", + Value: data, + }, nil, nil) + + mockClient.On("Agent").Return(mockAgent) + mockClient.On("Catalog").Return(mockCatalog) + mockClient.On("KV").Return(mockKV) + + // Use a mock executor that tracks calls + mockExecutor := new(MockCommandExecutor) + mockCmd := exec.Command("echo", "promoted") + mockExecutor.On("CommandContext", mock.Anything, "echo", []string{"promoted"}).Return(mockCmd) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + b := &Ballot{ + client: mockClient, + ID: serviceID, + Name: serviceName, + PrimaryTag: primaryTag, + Key: "election/test/leader", + ctx: ctx, + ExecOnPromote: "echo promoted", + executor: mockExecutor, + TTL: 10 * time.Second, + LockDelay: 3 * time.Second, + } + b.sessionID.Store(&sessionID) + + err := b.updateServiceTags(true) + assert.NoError(t, err) + + // Give goroutine time to execute + time.Sleep(1500 * time.Millisecond) + + mockAgent.AssertCalled(t, "ServiceRegister", mock.Anything) + }) + + t.Run("Executes ExecOnDemote when losing leadership", func(t *testing.T) { + serviceID := "test_service_id" + serviceName := "test_service" + primaryTag := "primary" + sessionID := "session_id" + + mockAgent := new(MockAgent) + mockCatalog := new(MockCatalog) + mockKV := new(MockKV) + mockClient := &MockConsulClient{} + + // Service with primary tag + serviceWithTag := &api.AgentService{ + ID: serviceID, + Service: serviceName, + Tags: []string{"tag1", primaryTag}, + Port: 8080, + Address: "127.0.0.1", + } + + mockAgent.On("Service", serviceID, mock.Anything).Return(serviceWithTag, nil, nil) + mockCatalog.On("Service", serviceName, primaryTag, mock.Anything).Return([]*api.CatalogService{}, nil, nil) + mockAgent.On("ServiceRegister", mock.Anything).Return(nil) + + // Mock KV for session data retrieval + payload := &ElectionPayload{ + Address: "127.0.0.1", + Port: 8080, + SessionID: sessionID, + } + data, _ := json.Marshal(payload) + mockKV.On("Get", "election/test/leader", mock.Anything).Return(&api.KVPair{ + Key: "election/test/leader", + Value: data, + }, nil, nil) + + mockClient.On("Agent").Return(mockAgent) + mockClient.On("Catalog").Return(mockCatalog) + mockClient.On("KV").Return(mockKV) + + // Use a mock executor + mockExecutor := new(MockCommandExecutor) + mockCmd := exec.Command("echo", "demoted") + mockExecutor.On("CommandContext", mock.Anything, "echo", []string{"demoted"}).Return(mockCmd) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + b := &Ballot{ + client: mockClient, + ID: serviceID, + Name: serviceName, + PrimaryTag: primaryTag, + Key: "election/test/leader", + ctx: ctx, + ExecOnDemote: "echo demoted", + executor: mockExecutor, + TTL: 10 * time.Second, + LockDelay: 3 * time.Second, + } + b.sessionID.Store(&sessionID) + + err := b.updateServiceTags(false) + assert.NoError(t, err) + + // Give goroutine time to execute + time.Sleep(1500 * time.Millisecond) + + mockAgent.AssertCalled(t, "ServiceRegister", mock.Anything) + }) + + t.Run("Handles command execution error gracefully", func(t *testing.T) { + serviceID := "test_service_id" + serviceName := "test_service" + primaryTag := "primary" + sessionID := "session_id" + + mockAgent := new(MockAgent) + mockCatalog := new(MockCatalog) + mockKV := new(MockKV) + mockClient := &MockConsulClient{} + + baseService := &api.AgentService{ + ID: serviceID, + Service: serviceName, + Tags: []string{"tag1"}, + Port: 8080, + Address: "127.0.0.1", + } + + mockAgent.On("Service", serviceID, mock.Anything).Return(baseService, nil, nil) + mockCatalog.On("Service", serviceName, primaryTag, mock.Anything).Return([]*api.CatalogService{}, nil, nil) + mockAgent.On("ServiceRegister", mock.Anything).Return(nil) + + // Mock KV for session data retrieval + payload := &ElectionPayload{ + Address: "127.0.0.1", + Port: 8080, + SessionID: sessionID, + } + data, _ := json.Marshal(payload) + mockKV.On("Get", "election/test/leader", mock.Anything).Return(&api.KVPair{ + Key: "election/test/leader", + Value: data, + }, nil, nil) + + mockClient.On("Agent").Return(mockAgent) + mockClient.On("Catalog").Return(mockCatalog) + mockClient.On("KV").Return(mockKV) + + // Use a mock executor that returns a failing command + mockExecutor := new(MockCommandExecutor) + mockCmd := exec.Command("false") // 'false' command exits with code 1 + mockExecutor.On("CommandContext", mock.Anything, "false", []string{}).Return(mockCmd) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + b := &Ballot{ + client: mockClient, + ID: serviceID, + Name: serviceName, + PrimaryTag: primaryTag, + Key: "election/test/leader", + ctx: ctx, + ExecOnPromote: "false", + executor: mockExecutor, + TTL: 10 * time.Second, + LockDelay: 3 * time.Second, + } + b.sessionID.Store(&sessionID) + + // Should not return error even if command fails + err := b.updateServiceTags(true) + assert.NoError(t, err) + + // Give goroutine time to execute + time.Sleep(1500 * time.Millisecond) + }) + + t.Run("Handles session data retrieval error in command goroutine", func(t *testing.T) { + serviceID := "test_service_id" + serviceName := "test_service" + primaryTag := "primary" + sessionID := "session_id" + + mockAgent := new(MockAgent) + mockCatalog := new(MockCatalog) + mockKV := new(MockKV) + mockClient := &MockConsulClient{} + + baseService := &api.AgentService{ + ID: serviceID, + Service: serviceName, + Tags: []string{"tag1"}, + Port: 8080, + Address: "127.0.0.1", + } + + mockAgent.On("Service", serviceID, mock.Anything).Return(baseService, nil, nil) + mockCatalog.On("Service", serviceName, primaryTag, mock.Anything).Return([]*api.CatalogService{}, nil, nil) + mockAgent.On("ServiceRegister", mock.Anything).Return(nil) + + // Mock KV to return error + mockKV.On("Get", "election/test/leader", mock.Anything).Return(nil, nil, errors.New("kv error")) + + mockClient.On("Agent").Return(mockAgent) + mockClient.On("Catalog").Return(mockCatalog) + mockClient.On("KV").Return(mockKV) + + mockExecutor := new(MockCommandExecutor) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + b := &Ballot{ + client: mockClient, + ID: serviceID, + Name: serviceName, + PrimaryTag: primaryTag, + Key: "election/test/leader", + ctx: ctx, + ExecOnPromote: "echo test", + executor: mockExecutor, + TTL: 10 * time.Millisecond, + LockDelay: 3 * time.Millisecond, + } + b.sessionID.Store(&sessionID) + + err := b.updateServiceTags(true) + assert.NoError(t, err) + + // Wait for goroutine to handle the error + time.Sleep(200 * time.Millisecond) + }) +} + +func TestRun_SmallTTL(t *testing.T) { + t.Run("Run with TTL less than 2 seconds uses 1 second interval", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + sessionID := "session_id" + b := &Ballot{ + ID: "test_service_id", + Name: "test_service", + Key: "election/test_service/leader", + PrimaryTag: "primary", + TTL: 500 * time.Millisecond, // TTL/2 = 250ms < 1s, so interval becomes 1s + ctx: ctx, + } + b.sessionID.Store(&sessionID) + + mockHealth := new(MockHealth) + mockHealth.On("Checks", b.Name, mock.Anything).Return([]*api.HealthCheck{ + {Status: "passing"}, + }, nil, nil) + + mockSession := new(MockSession) + mockSession.On("Create", mock.Anything, mock.Anything).Return(sessionID, nil, nil) + mockSession.On("RenewPeriodic", mock.Anything, sessionID, mock.Anything, mock.Anything).Return(nil) + mockSession.On("Info", sessionID, mock.Anything).Return(&api.SessionEntry{ID: sessionID}, &api.QueryMeta{}, nil) + + payload := &ElectionPayload{ + Address: "127.0.0.1", + Port: 8080, + SessionID: sessionID, + } + data, _ := json.Marshal(payload) + mockKV := new(MockKV) + mockKV.On("Acquire", mock.Anything, mock.Anything).Return(true, nil, nil) + mockKV.On("Get", b.Key, mock.Anything).Return(&api.KVPair{ + Key: b.Key, + Value: data, + Session: sessionID, + }, nil, nil) + + service := &api.AgentService{ + ID: b.ID, + Service: b.Name, + Address: "127.0.0.1", + Port: 8080, + Tags: []string{}, + } + mockAgent := new(MockAgent) + mockAgent.On("Service", b.ID, mock.Anything).Return(service, nil, nil) + mockAgent.On("ServiceRegister", mock.Anything).Return(nil) + + mockCatalog := new(MockCatalog) + mockCatalog.On("Service", b.Name, b.PrimaryTag, mock.Anything).Return([]*api.CatalogService{}, nil, nil) + mockCatalog.On("Service", b.Name, "", mock.Anything).Return([]*api.CatalogService{}, nil, nil) + mockCatalog.On("Register", mock.Anything, mock.Anything).Return(nil, nil) + + mockClient := &MockConsulClient{} + mockClient.On("Health").Return(mockHealth) + mockClient.On("Session").Return(mockSession) + mockClient.On("KV").Return(mockKV) + mockClient.On("Agent").Return(mockAgent) + mockClient.On("Catalog").Return(mockCatalog) + + b.client = mockClient + + done := make(chan error, 1) + go func() { + done <- b.Run() + }() + + // Let it run briefly then cancel + time.Sleep(100 * time.Millisecond) + cancel() + + err := <-done + assert.NoError(t, err) + }) +} + +func TestRun_ElectionErrorInLoop(t *testing.T) { + t.Run("Run handles election errors in ticker loop", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + sessionID := "session_id" + b := &Ballot{ + ID: "test_service_id", + Name: "test_service", + Key: "election/test_service/leader", + PrimaryTag: "primary", + TTL: 100 * time.Millisecond, + ctx: ctx, + } + b.sessionID.Store(&sessionID) + + // Mock health to return error (triggers election error) + mockHealth := new(MockHealth) + electionErr := errors.New("health check failed") + mockHealth.On("Checks", b.Name, mock.Anything).Return(nil, nil, electionErr) + + mockClient := &MockConsulClient{} + mockClient.On("Health").Return(mockHealth) + + b.client = mockClient + + done := make(chan error, 1) + go func() { + done <- b.Run() + }() + + // Let it run through at least one ticker cycle with error + time.Sleep(200 * time.Millisecond) + cancel() + + err := <-done + assert.NoError(t, err) // Run returns nil on context cancellation, errors are logged + }) +} + +func TestUpdateLeadershipStatus_Error(t *testing.T) { + t.Run("updateLeadershipStatus returns error from updateServiceTags", func(t *testing.T) { + mockAgent := new(MockAgent) + expectedErr := errors.New("service tags update failed") + mockAgent.On("Service", "test_id", mock.Anything).Return(nil, nil, expectedErr) + + mockClient := &MockConsulClient{} + mockClient.On("Agent").Return(mockAgent) + + b := &Ballot{ + ID: "test_id", + client: mockClient, + } + + err := b.updateLeadershipStatus(true) + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + }) +} + +func TestNew_ViperUnmarshalError(t *testing.T) { + t.Run("New returns error when viper unmarshal fails", func(t *testing.T) { + viper.Reset() + // Set an invalid type that will cause unmarshal to fail + // TTL expects a duration string but we give it an invalid map + viper.Set("election.services.badconfig.ttl", map[string]string{"invalid": "type"}) + viper.Set("election.services.badconfig.id", "test_id") + viper.Set("election.services.badconfig.key", "test/key") + + defer viper.Reset() + + b, err := New(context.Background(), "badconfig") + assert.Error(t, err) + assert.Nil(t, b) + }) +} + +func TestNew_DefaultValues(t *testing.T) { + t.Run("New sets default LockDelay and TTL when not specified", func(t *testing.T) { + viper.Reset() + viper.Set("election.services.defaults.id", "test_service_id") + viper.Set("election.services.defaults.key", "election/test/leader") + // Don't set TTL or LockDelay + + defer viper.Reset() + + b, err := New(context.Background(), "defaults") + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, 3*time.Second, b.LockDelay) + assert.Equal(t, 10*time.Second, b.TTL) + }) + + t.Run("New sets Name from parameter when not in config", func(t *testing.T) { + viper.Reset() + viper.Set("election.services.myservice.id", "test_service_id") + viper.Set("election.services.myservice.key", "election/test/leader") + // Don't set Name + + defer viper.Reset() + + b, err := New(context.Background(), "myservice") + assert.NoError(t, err) + assert.NotNil(t, b) + assert.Equal(t, "myservice", b.Name) + }) +}