From 589e0800f59eedcc8e04e8bc889c30510241dfd4 Mon Sep 17 00:00:00 2001 From: Chris Henzie Date: Mon, 23 Mar 2026 16:36:06 -0700 Subject: [PATCH] Implement zero-downtime NRI plugin upgrades Track duplicate NRI plugin connections during updates to prevent dropping container lifecycle events. Route requests only to newest connection, determined by socket connection time. Shadow older plugins and skip them in event routing. Remove older plugins when they disconnect (e.g., during DaemonSet rollout). Tested by running make test. Assisted-by: Antigravity Signed-off-by: Chris Henzie --- pkg/adaptation/adaptation.go | 66 +++++++++++++++++++++-- pkg/adaptation/adaptation_suite_test.go | 71 +++++++++++++++++++++++++ pkg/adaptation/plugin.go | 45 +++++++++------- 3 files changed, 158 insertions(+), 24 deletions(-) diff --git a/pkg/adaptation/adaptation.go b/pkg/adaptation/adaptation.go index db8e59f6..72c3770a 100644 --- a/pkg/adaptation/adaptation.go +++ b/pkg/adaptation/adaptation.go @@ -26,6 +26,7 @@ import ( "path/filepath" "sort" "sync" + "time" "github.com/containerd/nri/pkg/adaptation/builtin" "github.com/containerd/nri/pkg/api" @@ -229,6 +230,9 @@ func (r *Adaptation) UpdatePodSandbox(ctx context.Context, req *UpdatePodSandbox defer r.removeClosedPlugins() for _, plugin := range r.plugins { + if plugin.shadowed { + continue + } _, err := plugin.updatePodSandbox(ctx, req) if err != nil { return nil, err @@ -275,6 +279,9 @@ func (r *Adaptation) CreateContainer(ctx context.Context, req *CreateContainerRe } for _, plugin := range r.plugins { + if plugin.shadowed { + continue + } if validate != nil { validate.AddPlugin(plugin.base, plugin.idx) } @@ -317,6 +324,9 @@ func (r *Adaptation) UpdateContainer(ctx context.Context, req *UpdateContainerRe result := collectUpdateContainerResult(req) for _, plugin := range r.plugins { + if plugin.shadowed { + continue + } rpl, err := plugin.updateContainer(ctx, req) if err != nil { return nil, err @@ -344,6 +354,9 @@ func (r *Adaptation) StopContainer(ctx context.Context, req *StopContainerReques result := collectStopContainerResult() for _, plugin := range r.plugins { + if plugin.shadowed { + continue + } rpl, err := plugin.stopContainer(ctx, req) if err != nil { return nil, err @@ -374,6 +387,9 @@ func (r *Adaptation) StateChange(ctx context.Context, evt *StateChangeEvent) err defer r.removeClosedPlugins() for _, plugin := range r.plugins { + if plugin.shadowed { + continue + } err := plugin.StateChange(ctx, evt) if err != nil { return err @@ -406,6 +422,9 @@ func (r *Adaptation) validateContainerAdjustment(ctx context.Context, req *Valid wg := sync.WaitGroup{} for _, p := range r.validators { + if p.shadowed { + continue + } wg.Add(1) go func(p *plugin) { defer wg.Done() @@ -537,6 +556,7 @@ func (r *Adaptation) removeClosedPlugins() { r.plugins = active r.validators = validators + r.markShadowedPlugins() } func (r *Adaptation) startListener() error { @@ -666,24 +686,62 @@ func (r *Adaptation) discoverPlugins() ([]string, []string, []string, error) { return indices, plugins, configs, nil } +func comparePlugins(p1, p2 *plugin) bool { + if p1.idx != p2.idx { + return p1.idx < p2.idx + } + if p1.base != p2.base { + return p1.base < p2.base + } + return p1.connectedAt.Before(p2.connectedAt) +} + +func (r *Adaptation) markShadowedPlugins() { + for _, p := range r.plugins { + p.shadowed = false + } + + // Duplicate plugins are guaranteed to be adjacent based on sort order. + for i := 1; i < len(r.plugins); i++ { + prev := r.plugins[i-1] + curr := r.plugins[i] + if curr.idx == prev.idx && curr.base == prev.base { + prev.shadowed = true + pTime := prev.connectedAt.Format(time.RFC3339) + cTime := curr.connectedAt.Format(time.RFC3339) + log.Warnf(noCtx, "plugin %q (connected %s) is shadowed by %q (connected %s)", prev.name(), pTime, curr.name(), cTime) + } + } +} + func (r *Adaptation) sortPlugins() { r.removeClosedPlugins() sort.Slice(r.plugins, func(i, j int) bool { - return r.plugins[i].idx < r.plugins[j].idx + return comparePlugins(r.plugins[i], r.plugins[j]) }) sort.Slice(r.validators, func(i, j int) bool { - return r.validators[i].idx < r.validators[j].idx + return comparePlugins(r.validators[i], r.validators[j]) }) + r.markShadowedPlugins() + if len(r.plugins) > 0 { log.Infof(noCtx, "plugin invocation order") for i, p := range r.plugins { - log.Infof(noCtx, " #%d: %q (%s)", i+1, p.name(), p.qualifiedName()) + status := "" + if p.shadowed { + status = " (shadowed)" + } + log.Infof(noCtx, " #%d: %q (%s)%s", i+1, p.name(), p.qualifiedName(), status) } } if len(r.validators) > 0 { log.Infof(noCtx, "validator plugins") for _, p := range r.validators { - log.Infof(noCtx, " %q (%s)", p.name(), p.qualifiedName()) + status := "" + if p.shadowed { + status = " (shadowed)" + } + log.Infof(noCtx, " %q (%s)%s", p.name(), p.qualifiedName(), status) } } } diff --git a/pkg/adaptation/adaptation_suite_test.go b/pkg/adaptation/adaptation_suite_test.go index 98c7d586..ef1a0a99 100644 --- a/pkg/adaptation/adaptation_suite_test.go +++ b/pkg/adaptation/adaptation_suite_test.go @@ -282,6 +282,77 @@ var _ = Describe("Plugin connection", func() { }) }) +var _ = Describe("Plugin shadowing", func() { + var ( + s = &Suite{} + ) + + BeforeEach(func() { + s.Prepare( + &mockRuntime{}, + &mockPlugin{idx: "10", name: "test"}, + ) + }) + + AfterEach(func() { + s.Cleanup() + }) + + It("should route events only to the newest plugin of the same name/index", func() { + var ( + runtime = s.runtime + plugin1 = s.plugins[0] + plugin2 = &mockPlugin{idx: plugin1.idx, name: plugin1.name} + ctx = context.Background() + pod = &api.PodSandbox{Id: "pod0", Name: "pod0", Namespace: "default"} + ) + + s.Startup() + + Expect(plugin1.Events()).Should(ContainElement(PluginSynchronized)) + + s.StartPlugins(plugin2) + s.WaitForPluginsToSync(plugin2) + + Expect(plugin2.Events()).Should(ContainElement(PluginSynchronized)) + + plugin1.EventQ().Reset() + plugin2.EventQ().Reset() + + Expect(runtime.RunPodSandbox(ctx, &api.StateChangeEvent{Pod: pod})).To(Succeed()) + + Expect(plugin2.EventQ().Has(&Event{Type: RunPodSandbox})).To(BeTrue()) + Expect(plugin1.EventQ().Has(&Event{Type: RunPodSandbox})).To(BeFalse()) + }) + + It("should fallback to older plugin if the newer one disconnects", func() { + var ( + runtime = s.runtime + plugin1 = s.plugins[0] + plugin2 = &mockPlugin{idx: plugin1.idx, name: plugin1.name} + ctx = context.Background() + pod = &api.PodSandbox{Id: "pod0", Name: "pod0", Namespace: "default"} + ) + + s.Startup() + s.StartPlugins(plugin2) + s.WaitForPluginsToSync(plugin2) + + plugin1.EventQ().Reset() + plugin2.EventQ().Reset() + + plugin2.Stop() + Expect(plugin2.Wait(PluginDisconnected, time.After(startupTimeout))).To(Succeed()) + + Eventually(func() bool { + if err := runtime.RunPodSandbox(ctx, &api.StateChangeEvent{Pod: pod}); err != nil { + return false + } + return plugin1.EventQ().Has(&Event{Type: RunPodSandbox}) + }, 5*time.Second).Should(BeTrue()) + }) +}) + var _ = Describe("Pod and container requests and events", func() { var ( s = &Suite{} diff --git a/pkg/adaptation/plugin.go b/pkg/adaptation/plugin.go index 0441a0c3..01b58436 100644 --- a/pkg/adaptation/plugin.go +++ b/pkg/adaptation/plugin.go @@ -53,21 +53,23 @@ var ( type plugin struct { sync.Mutex - idx string - base string - cfg string - pid int - cmd *exec.Cmd - mux multiplex.Mux - rpcc *ttrpc.Client - rpcl stdnet.Listener - rpcs *ttrpc.Server - events EventMask - closed bool - regC chan error - closeC chan struct{} - r *Adaptation - impl *pluginType + idx string + base string + cfg string + pid int + cmd *exec.Cmd + mux multiplex.Mux + rpcc *ttrpc.Client + rpcl stdnet.Listener + rpcs *ttrpc.Server + events EventMask + closed bool + regC chan error + closeC chan struct{} + r *Adaptation + impl *pluginType + shadowed bool + connectedAt time.Time } // SetPluginRegistrationTimeout sets the timeout for plugin registration. @@ -176,11 +178,12 @@ func (r *Adaptation) newBuiltinPlugin(b *builtin.BuiltinPlugin) (*plugin, error) } return &plugin{ - idx: b.Index, - base: b.Base, - closeC: make(chan struct{}), - r: r, - impl: &pluginType{builtinImpl: b}, + idx: b.Index, + base: b.Base, + closeC: make(chan struct{}), + r: r, + impl: &pluginType{builtinImpl: b}, + connectedAt: time.Now(), }, nil } @@ -310,6 +313,8 @@ func (p *plugin) connect(conn stdnet.Conn) (retErr error) { api.RegisterRuntimeService(p.rpcs, p) + p.connectedAt = time.Now() + return nil }