From b39f3760efbb38970ae45add43161df1d955833f Mon Sep 17 00:00:00 2001 From: Torben Schmitz Date: Mon, 16 Feb 2026 08:10:09 -0800 Subject: [PATCH] Fix bug where the client would always exit ungracefully after receiving a signal PiperOrigin-RevId: 870898716 --- cmd/fleetspeak_client/fleetspeak_client.go | 51 +++++++++++----------- fleetspeak/src/client/entry/entry.go | 19 ++++++++ 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/cmd/fleetspeak_client/fleetspeak_client.go b/cmd/fleetspeak_client/fleetspeak_client.go index 5ebf84d0..d09b03d2 100644 --- a/cmd/fleetspeak_client/fleetspeak_client.go +++ b/cmd/fleetspeak_client/fleetspeak_client.go @@ -7,7 +7,6 @@ import ( "os" "time" - log "github.com/golang/glog" "google.golang.org/protobuf/encoding/prototext" "github.com/google/fleetspeak/fleetspeak/src/client" @@ -20,6 +19,7 @@ import ( "github.com/google/fleetspeak/fleetspeak/src/client/socketservice" "github.com/google/fleetspeak/fleetspeak/src/client/stats" "github.com/google/fleetspeak/fleetspeak/src/client/stdinservice" + "github.com/google/fleetspeak/fleetspeak/src/common/fscontext" gpb "github.com/google/fleetspeak/fleetspeak/src/client/generic/proto/fleetspeak_client_generic" ) @@ -29,42 +29,34 @@ const stopTimeout = time.Minute var configFile = flag.String("config", "", "Client configuration file, required.") func innerMain(ctx context.Context, cfgReloadSignals <-chan os.Signal) error { - for { - cl, err := createClient() + for ctx.Err() == nil { + ctx, cancel := entry.ContextWithSignals(ctx, cfgReloadSignals) + err := runClient(ctx) + cancel() if err != nil { - return fmt.Errorf("error starting client: %v", err) - } - - select { - case <-cfgReloadSignals: - // We implement config reloading by tearing down the client and creating a - // new one. - log.Info("Config reload requested") - time.AfterFunc(stopTimeout, func() { - entry.ExitUngracefully(fmt.Errorf("client did not stop within %s", stopTimeout)) - }) - cl.Stop() - continue - case <-ctx.Done(): - // A timeout for process termination is handled higher up. - cl.Stop() - return nil + return err } } + return nil } -func createClient() (*client.Client, error) { +func runClient(ctx context.Context) error { + stop := fscontext.AfterDelayFunc(ctx, stopTimeout, func() { + entry.ExitUngracefully(fmt.Errorf("client did not stop within %s", stopTimeout)) + }) + defer stop() + b, err := os.ReadFile(*configFile) if err != nil { - return nil, fmt.Errorf("unable to read configuration file %q: %v", *configFile, err) + return fmt.Errorf("unable to read configuration file %q: %v", *configFile, err) } cfgPB := &gpb.Config{} if err := prototext.Unmarshal(b, cfgPB); err != nil { - return nil, fmt.Errorf("unable to parse configuration file %q: %v", *configFile, err) + return fmt.Errorf("unable to parse configuration file %q: %v", *configFile, err) } cfg, err := generic.MakeConfiguration(cfgPB) if err != nil { - return nil, fmt.Errorf("error in configuration file: %v", err) + return fmt.Errorf("error in configuration file: %v", err) } var com comms.Communicator @@ -74,7 +66,7 @@ func createClient() (*client.Client, error) { com = &https.Communicator{} } - return client.New( + cl, err := client.New( cfg, client.Components{ ServiceFactories: map[string]service.Factory{ @@ -87,6 +79,15 @@ func createClient() (*client.Client, error) { Stats: stats.NoopCollector{}, }, ) + if err != nil { + return fmt.Errorf("error creating client: %v", err) + } + + select { + case <-ctx.Done(): + cl.Stop() + } + return nil } func main() { diff --git a/fleetspeak/src/client/entry/entry.go b/fleetspeak/src/client/entry/entry.go index 19b4d471..99f3d53b 100644 --- a/fleetspeak/src/client/entry/entry.go +++ b/fleetspeak/src/client/entry/entry.go @@ -6,6 +6,8 @@ import ( "context" "os" "time" + + log "github.com/golang/glog" ) // Timeout for shutting down gracefully. @@ -22,3 +24,20 @@ const shutdownTimeout = 10 * time.Second // is requested. We use UNIX conventions here, the Windows layer can send a // [syscall.SIGHUP] when appropriate. type InnerMain func(ctx context.Context, cfgReloadSignals <-chan os.Signal) error + +// ContextWithSignals returns a context that is canceled when a signal is +// received. +func ContextWithSignals(ctx context.Context, signals <-chan os.Signal) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + go func() { + select { + case si, ok := <-signals: + if ok { + log.Infof("Signal received: %v", si) + cancel() + } + case <-ctx.Done(): + } + }() + return ctx, cancel +}