From a309f6cf268ce009abab6311d4196901e3c79535 Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Tue, 2 Jun 2026 15:18:18 -0700 Subject: [PATCH 01/13] feat(discovery): fan out nmap discovery to fleet nodes and pair them inline Extend the existing "Find miners / Scan your network" flow to also discover and pair miners behind remote fleet nodes, with no client/UI changes. The existing client keeps calling Discover (nmap) and Pair; the server now does the fan-out and routing transparently. - New domain/fleetnode/discovery package owns the per-node "send command + drain batches until ack" loop (RunOnNode) and node targeting (ConfirmedConnectedNodeIDs). The admin DiscoverOnFleetNode handler now delegates to it; adds control.Registry.ConnectedFleetNodeIDs(). - PairingService.Discover fans an nmap scan out to every CONFIRMED + connected fleet node in parallel, merging all sources into one response stream (mutex-guarded send + cross-source dedup). Zero connected nodes == unchanged behavior; one node failing never fails the scan. - Fleet nodes scan their own private IPv4 subnet(s) via the reserved nmaptarget.LocalSubnetTarget ("fleetnode-local-subnet") target sentinel, since the cloud can't know each node's network. Detection narrows masks wider than /22, dedupes, caps at 8 subnets, and is IPv4-only. - PairingService.Pair routes fleet-node-discovered devices to the operator-confirmed ownership assignment (FleetNodeAssigner) instead of refusing them, so mixed selections pair in one action. Metadata only; the owning node dials and credentials the miner per RFC 0001. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/cmd/fleetd/main.go | 9 +- server/cmd/fleetnode/localsubnet.go | 115 ++++++ server/cmd/fleetnode/localsubnet_test.go | 145 ++++++++ server/cmd/fleetnode/nmap.go | 32 +- server/cmd/fleetnode/nmap_test.go | 42 +++ server/cmd/fleetnode/run.go | 1 + .../domain/fleetnode/control/registry.go | 13 + .../domain/fleetnode/control/registry_test.go | 18 + .../fleetnode/discovery/ackfailure_test.go} | 2 +- .../fleetnode/discovery/iprange_test.go} | 2 +- .../fleetnode/discovery}/reportscope.go | 17 +- .../fleetnode/discovery}/reportscope_test.go | 36 +- .../domain/fleetnode/discovery/service.go | 344 ++++++++++++++++++ .../fleetnode/discovery/service_test.go | 136 +++++++ .../internal/domain/nmaptarget/nmaptarget.go | 7 + .../domain/pairing/mocks/mock_service.go | 38 ++ server/internal/domain/pairing/service.go | 62 +++- .../pairing/service_pairrouting_test.go | 128 +++++++ .../handlers/fleetnode/admin/handler.go | 283 +------------- .../fleetnode/admin/handler_discover_test.go | 8 +- .../fleetnode/admin/handler_pairing_test.go | 4 +- server/internal/handlers/pairing/handler.go | 137 +++++-- .../testutil/infrastructure_provider.go | 3 +- 23 files changed, 1262 insertions(+), 320 deletions(-) create mode 100644 server/cmd/fleetnode/localsubnet.go create mode 100644 server/cmd/fleetnode/localsubnet_test.go rename server/internal/{handlers/fleetnode/admin/handler_ackfailure_test.go => domain/fleetnode/discovery/ackfailure_test.go} (98%) rename server/internal/{handlers/fleetnode/admin/handler_iprange_test.go => domain/fleetnode/discovery/iprange_test.go} (99%) rename server/internal/{handlers/fleetnode/admin => domain/fleetnode/discovery}/reportscope.go (91%) rename server/internal/{handlers/fleetnode/admin => domain/fleetnode/discovery}/reportscope_test.go (82%) create mode 100644 server/internal/domain/fleetnode/discovery/service.go create mode 100644 server/internal/domain/fleetnode/discovery/service_test.go create mode 100644 server/internal/domain/pairing/service_pairrouting_test.go diff --git a/server/cmd/fleetd/main.go b/server/cmd/fleetd/main.go index e9a456588..4515f7229 100644 --- a/server/cmd/fleetd/main.go +++ b/server/cmd/fleetd/main.go @@ -71,6 +71,7 @@ import ( fleetmanagementDomain "github.com/block/proto-fleet/server/internal/domain/fleetmanagement" fleetnodeauth "github.com/block/proto-fleet/server/internal/domain/fleetnode/auth" "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + fleetnodediscovery "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" fleetnodepairing "github.com/block/proto-fleet/server/internal/domain/fleetnode/pairing" "github.com/block/proto-fleet/server/internal/domain/fleetoptions" @@ -220,6 +221,7 @@ func start(config *Config) error { fleetNodePairingStore := sqlstores.NewSQLFleetNodePairingStore(conn) fleetNodePairingSvc := fleetnodepairing.NewService(fleetNodePairingStore, fleetNodeEnrollmentStore, transactor) fleetNodeControlRegistry := control.NewRegistry() + fleetNodeDiscoverySvc := fleetnodediscovery.NewService(fleetNodeControlRegistry, fleetNodeEnrollmentSvc) fleetNodeAuthStore := sqlstores.NewSQLFleetNodeAuthStore(conn) fleetNodeAuthSvc := fleetnodeauth.NewService(fleetNodeAuthStore, fleetNodeEnrollmentStore, apiKeySvc) @@ -369,6 +371,9 @@ func start(config *Config) error { ) pairingSvc.WithMinerInvalidator(minerService.InvalidateMiner) pairingSvc.WithOptionsCache(fleetOptionsCache) + // Route fleet-node-discovered devices through ownership assignment so the + // cloud Pair flow can pair them alongside directly-discovered miners. + pairingSvc.WithFleetNodeAssigner(fleetNodePairingSvc) // Initialize IP scanner service ipScannerService := ipscanner.NewIPScannerService( @@ -522,7 +527,7 @@ func start(config *Config) error { mux.Handle(authv1connect.NewAuthServiceHandler(auth.NewHandler(authSvc), li)) mux.Handle(onboardingv1connect.NewOnboardingServiceHandler(onboarding.NewHandler(authSvc, onboardingSvc), li)) - mux.Handle(pairingv1connect.NewPairingServiceHandler(pairing.NewHandler(pairingSvc), li)) + mux.Handle(pairingv1connect.NewPairingServiceHandler(pairing.NewHandler(pairingSvc, fleetNodeDiscoverySvc), li)) mux.Handle(networkinfov1connect.NewNetworkInfoServiceHandler(networkinfo.NewHandler(pairingSvc), li)) mux.Handle(fleetmanagementv1connect.NewFleetManagementServiceHandler(fleetmanagement.NewHandler(fleetMgmtSvc), li)) mux.Handle(minercommandv1connect.NewMinerCommandServiceHandler(command.NewHandler(commandSvc), li)) @@ -532,7 +537,7 @@ func start(config *Config) error { mux.Handle(sitesv1connect.NewSiteServiceHandler(sitesHandler.NewHandler(sitesSvc), li)) mux.Handle(buildingsv1connect.NewBuildingServiceHandler(buildingsHandler.NewHandler(buildingsSvc), li)) mux.Handle(fleetnodegatewayv1connect.NewFleetNodeGatewayServiceHandler(gateway.NewHandler(fleetNodeEnrollmentSvc, fleetNodeAuthSvc, fleetNodePairingSvc, fleetNodeControlRegistry), li)) - mux.Handle(fleetnodeadminv1connect.NewFleetNodeAdminServiceHandler(admin.NewHandler(fleetNodeEnrollmentSvc, fleetNodePairingSvc, fleetNodeControlRegistry), li)) + mux.Handle(fleetnodeadminv1connect.NewFleetNodeAdminServiceHandler(admin.NewHandler(fleetNodeEnrollmentSvc, fleetNodePairingSvc, fleetNodeDiscoverySvc), li)) mux.Handle(collectionv1connect.NewDeviceCollectionServiceHandler(collectionHandler.NewHandler(collectionSvc), li)) mux.Handle(device_setv1connect.NewDeviceSetServiceHandler(devicesetHandler.NewHandler(collectionSvc), li)) mux.Handle(telemetryv1connect.NewTelemetryServiceHandler(telemetryHandler.NewHandler(telemetryService), li)) diff --git a/server/cmd/fleetnode/localsubnet.go b/server/cmd/fleetnode/localsubnet.go new file mode 100644 index 000000000..274fb1bbd --- /dev/null +++ b/server/cmd/fleetnode/localsubnet.go @@ -0,0 +1,115 @@ +package main + +import ( + "errors" + "fmt" + "net" + "net/netip" + "strings" +) + +// autoSubnetMinPrefixBits caps an auto-detected subnet at /22 (<=1024 hosts) — +// the same ceiling the manual nmap path enforces (nmaptarget.MinIPv4PrefixBits). +// A NIC configured with a wider mask (e.g. /16) is narrowed around its own host +// address so an auto scan stays bounded and finishes inside the command timeout. +const autoSubnetMinPrefixBits = 22 + +// maxAutoSubnets caps how many distinct subnets one auto command scans, so a +// multi-homed host with many interfaces can't fan one command into a huge sweep. +const maxAutoSubnets = 8 + +// errNoLocalPrivateSubnet means no connected, non-virtual interface had a private +// IPv4 address — the agent has nothing to auto-scan. Surfaces as AGENT_INCAPABLE +// so a fan-out skips this node and tries the others. +var errNoLocalPrivateSubnet = errors.New("no connected private IPv4 subnet found") + +// virtualIfacePrefixes are name prefixes for container/VPN/virtual adapters whose +// subnets aren't the miner LAN. Best-effort: a miss only means a virtual private +// subnet might be scanned (still port-probed, still private), never a public scan. +var virtualIfacePrefixes = []string{ + "docker", "br-", "veth", "virbr", "vmnet", "vboxnet", + "tun", "tap", "utun", "cni", "cali", "flannel", "kube", + "zt", "tailscale", "ts", "wg", +} + +// detectLocalSubnets returns the private IPv4 subnet(s) the agent should scan for +// an auto_local_subnet nmap command. The localSubnets seam lets tests inject +// canned CIDRs; production enumerates the host's interfaces. +func (r *RunCmd) detectLocalSubnets() ([]string, error) { + if r.localSubnets != nil { + return r.localSubnets() + } + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("list network interfaces: %w", err) + } + return selectLocalPrivateSubnets(ifaces, (*net.Interface).Addrs) +} + +// selectLocalPrivateSubnets returns the canonical CIDR(s) of the connected, +// non-virtual, private IPv4 subnet(s) of the given interfaces. addrsOf is +// injected for testing (net.Interface.Addrs in production). Subnets wider than +// /22 are narrowed around the host address, results are deduped and capped at +// maxAutoSubnets, and IPv6 is ignored (the manual nmap path rejects IPv6 CIDR +// too). Returns errNoLocalPrivateSubnet when none qualify. +func selectLocalPrivateSubnets(ifaces []net.Interface, addrsOf func(*net.Interface) ([]net.Addr, error)) ([]string, error) { + seen := make(map[string]struct{}) + out := make([]string, 0, maxAutoSubnets) + for i := range ifaces { + iface := ifaces[i] + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagRunning == 0 { + continue + } + if iface.Flags&net.FlagLoopback != 0 || isVirtualIface(iface.Name) { + continue + } + addrs, err := addrsOf(&iface) + if err != nil { + continue + } + for _, a := range addrs { + ipNet, ok := a.(*net.IPNet) + if !ok { + continue + } + addr, ok := netip.AddrFromSlice(ipNet.IP) + if !ok { + continue + } + addr = addr.Unmap() + if !addr.Is4() || !addr.IsPrivate() { + continue + } + ones, _ := ipNet.Mask.Size() + if ones <= 0 || ones > addr.BitLen() { + continue // non-canonical mask + } + if ones < autoSubnetMinPrefixBits { + ones = autoSubnetMinPrefixBits + } + cidr := netip.PrefixFrom(addr, ones).Masked().String() + if _, dup := seen[cidr]; dup { + continue + } + seen[cidr] = struct{}{} + out = append(out, cidr) + if len(out) >= maxAutoSubnets { + return out, nil + } + } + } + if len(out) == 0 { + return nil, errNoLocalPrivateSubnet + } + return out, nil +} + +func isVirtualIface(name string) bool { + lower := strings.ToLower(name) + for _, p := range virtualIfacePrefixes { + if strings.HasPrefix(lower, p) { + return true + } + } + return false +} diff --git a/server/cmd/fleetnode/localsubnet_test.go b/server/cmd/fleetnode/localsubnet_test.go new file mode 100644 index 000000000..37f0945ae --- /dev/null +++ b/server/cmd/fleetnode/localsubnet_test.go @@ -0,0 +1,145 @@ +package main + +import ( + "fmt" + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// hostIPNet builds the *net.IPNet shape an interface reports for a host address +// (host IP + the subnet mask), e.g. "192.168.1.50/24". +func hostIPNet(cidr string) *net.IPNet { + ip, n, err := net.ParseCIDR(cidr) + if err != nil { + panic(err) + } + return &net.IPNet{IP: ip, Mask: n.Mask} +} + +func stubAddrs(byName map[string][]net.Addr) func(*net.Interface) ([]net.Addr, error) { + return func(i *net.Interface) ([]net.Addr, error) { return byName[i.Name], nil } +} + +func TestSelectLocalPrivateSubnets_Typical24(t *testing.T) { + // Arrange + ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} + addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("192.168.1.50/24")}}) + + // Act + got, err := selectLocalPrivateSubnets(ifaces, addrs) + + // Assert + require.NoError(t, err) + assert.Equal(t, []string{"192.168.1.0/24"}, got) +} + +func TestSelectLocalPrivateSubnets_OversizedMaskNarrowedTo22(t *testing.T) { + // Arrange: a /8-masked NIC must narrow to /22 around its own host address so + // the auto scan stays within the manual-path host ceiling. + ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} + addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("10.1.2.3/8")}}) + + // Act + got, err := selectLocalPrivateSubnets(ifaces, addrs) + + // Assert + require.NoError(t, err) + require.Len(t, got, 1) + prefix, perr := netip.ParsePrefix(got[0]) + require.NoError(t, perr) + assert.Equal(t, 22, prefix.Bits(), "oversized mask must narrow to /22") + assert.True(t, prefix.Contains(netip.MustParseAddr("10.1.2.3")), "narrowed subnet must contain the host: %s", got[0]) +} + +func TestSelectLocalPrivateSubnets_FiltersLoopbackDownAndVirtual(t *testing.T) { + // Arrange: loopback, a not-running NIC, and a docker bridge all excluded. + ifaces := []net.Interface{ + {Name: "lo", Flags: net.FlagUp | net.FlagRunning | net.FlagLoopback}, + {Name: "eth1", Flags: net.FlagUp}, // up but not running + {Name: "docker0", Flags: net.FlagUp | net.FlagRunning}, + } + addrs := stubAddrs(map[string][]net.Addr{ + "lo": {hostIPNet("127.0.0.1/8")}, + "eth1": {hostIPNet("192.168.5.5/24")}, + "docker0": {hostIPNet("172.17.0.1/16")}, + }) + + // Act + _, err := selectLocalPrivateSubnets(ifaces, addrs) + + // Assert + require.ErrorIs(t, err, errNoLocalPrivateSubnet) +} + +func TestSelectLocalPrivateSubnets_SkipsPublicAddress(t *testing.T) { + // Arrange + ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} + addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("8.8.8.8/24")}}) + + // Act + _, err := selectLocalPrivateSubnets(ifaces, addrs) + + // Assert + require.ErrorIs(t, err, errNoLocalPrivateSubnet) +} + +func TestSelectLocalPrivateSubnets_DedupesSameSubnetAcrossNICs(t *testing.T) { + // Arrange + ifaces := []net.Interface{ + {Name: "eth0", Flags: net.FlagUp | net.FlagRunning}, + {Name: "eth1", Flags: net.FlagUp | net.FlagRunning}, + } + addrs := stubAddrs(map[string][]net.Addr{ + "eth0": {hostIPNet("192.168.1.10/24")}, + "eth1": {hostIPNet("192.168.1.20/24")}, + }) + + // Act + got, err := selectLocalPrivateSubnets(ifaces, addrs) + + // Assert + require.NoError(t, err) + assert.Equal(t, []string{"192.168.1.0/24"}, got) +} + +func TestSelectLocalPrivateSubnets_CapsResultCount(t *testing.T) { + // Arrange: more distinct private subnets than the cap allows. + ifaces := make([]net.Interface, 0, maxAutoSubnets+2) + byName := make(map[string][]net.Addr) + for i := range maxAutoSubnets + 2 { + name := fmt.Sprintf("eth%d", i) + ifaces = append(ifaces, net.Interface{Name: name, Flags: net.FlagUp | net.FlagRunning}) + byName[name] = []net.Addr{hostIPNet(fmt.Sprintf("192.168.%d.5/24", i))} + } + + // Act + got, err := selectLocalPrivateSubnets(ifaces, stubAddrs(byName)) + + // Assert + require.NoError(t, err) + assert.Len(t, got, maxAutoSubnets) +} + +func TestSelectLocalPrivateSubnets_IgnoresIPv6ULA(t *testing.T) { + // Arrange + ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} + addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("fd00::1/64")}}) + + // Act + _, err := selectLocalPrivateSubnets(ifaces, addrs) + + // Assert + require.ErrorIs(t, err, errNoLocalPrivateSubnet) +} + +func TestSelectLocalPrivateSubnets_NoInterfaces(t *testing.T) { + // Act + _, err := selectLocalPrivateSubnets(nil, stubAddrs(nil)) + + // Assert + require.ErrorIs(t, err, errNoLocalPrivateSubnet) +} diff --git a/server/cmd/fleetnode/nmap.go b/server/cmd/fleetnode/nmap.go index 734f10144..c04d4fb77 100644 --- a/server/cmd/fleetnode/nmap.go +++ b/server/cmd/fleetnode/nmap.go @@ -107,6 +107,19 @@ func validateNmapBinary(path string) (string, error) { func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRequest, ports []string) ([]nmap.Option, error) { target := strings.TrimSpace(req.GetTarget()) + + // The LocalSubnetTarget sentinel means the server couldn't know the node's + // network, so the agent enumerates its own private IPv4 subnet(s) and scans + // those (IPv4 only, same as the manual path's IPv6-CIDR rejection). Matched + // exactly, before any hostname resolution. + if target == nmaptarget.LocalSubnetTarget { + subnets, err := r.detectLocalSubnets() + if err != nil { + return nil, cmdErr(pb.AckCode_ACK_CODE_AGENT_INCAPABLE, "no connected private IPv4 subnet for local-subnet scan: %s", err) + } + return append(baseNmapOptions(r.nmapPath, ports), nmap.WithTargets(subnets...)), nil + } + if err := validateNmapTarget(target); err != nil { return nil, cmdErr(pb.AckCode_ACK_CODE_BAD_REQUEST, "%s", err) } @@ -114,9 +127,18 @@ func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRe if err != nil { return nil, cmdErr(pb.AckCode_ACK_CODE_BAD_REQUEST, "%s", err) } - opts := []nmap.Option{ - nmap.WithBinaryPath(r.nmapPath), - nmap.WithTargets(resolved), + opts := append(baseNmapOptions(r.nmapPath, ports), nmap.WithTargets(resolved)) + if useIPv6 { + opts = append(opts, nmap.WithIPv6Scanning()) + } + return opts, nil +} + +// baseNmapOptions are the timing/safety options shared by targeted and +// auto-local-subnet scans; callers append the target(s) (and -6 if needed). +func baseNmapOptions(binaryPath string, ports []string) []nmap.Option { + return []nmap.Option{ + nmap.WithBinaryPath(binaryPath), nmap.WithPorts(strings.Join(ports, ",")), nmap.WithUnique(), nmap.WithDisabledDNSResolution(), @@ -125,10 +147,6 @@ func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRe nmap.WithHostTimeout(nmapHostTimeout), nmap.WithMinRTTTimeout(nmapMinRTT), } - if useIPv6 { - opts = append(opts, nmap.WithIPv6Scanning()) - } - return opts, nil } // Mirrors pairing-service validateNmapTargets so agent and server feed diff --git a/server/cmd/fleetnode/nmap_test.go b/server/cmd/fleetnode/nmap_test.go index e8dd6dd58..c625bb25f 100644 --- a/server/cmd/fleetnode/nmap_test.go +++ b/server/cmd/fleetnode/nmap_test.go @@ -16,7 +16,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + pb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" ) func testLogger() *slog.Logger { return slog.New(slog.DiscardHandler) } @@ -246,6 +248,46 @@ func TestBuildNmapOptions_AddsIPv6Scanning(t *testing.T) { assert.True(t, slices.Contains(v6Scanner.Args(), "-6"), "IPv6 target must carry -6") } +func TestBuildNmapOptions_LocalSubnetTarget_UsesDetectedCIDRs(t *testing.T) { + // Arrange: inject the detected subnets so the test doesn't depend on the host. + r := &RunCmd{ + nmapPath: "/usr/bin/nmap", + discoverer: &stubDiscoverer{}, + localSubnets: func() ([]string, error) { return []string{"192.168.1.0/24"}, nil }, + } + req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} + + // Act + opts, err := r.buildNmapOptions(context.Background(), req, req.Ports) + require.NoError(t, err) + scanner, err := nmap.NewScanner(context.Background(), opts...) + require.NoError(t, err) + + // Assert: the detected subnet reaches nmap and the scan stays IPv4-only. + args := scanner.Args() + assert.True(t, slices.Contains(args, "192.168.1.0/24"), "expected detected subnet in argv: %v", args) + assert.False(t, slices.Contains(args, "-6"), "local-subnet scan must be IPv4-only: %v", args) + assert.False(t, slices.Contains(args, nmaptarget.LocalSubnetTarget), "sentinel must not reach nmap as a literal target: %v", args) +} + +func TestBuildNmapOptions_LocalSubnetTarget_NoSubnetIsAgentIncapable(t *testing.T) { + // Arrange: detection finds no private subnet. + r := &RunCmd{ + nmapPath: "/usr/bin/nmap", + discoverer: &stubDiscoverer{}, + localSubnets: func() ([]string, error) { return nil, errNoLocalPrivateSubnet }, + } + req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} + + // Act + _, err := r.buildNmapOptions(context.Background(), req, req.Ports) + + // Assert: a fan-out should skip this node, so the ack maps to FailedPrecondition. + var ce *commandError + require.ErrorAs(t, err, &ce) + assert.Equal(t, pb.AckCode_ACK_CODE_AGENT_INCAPABLE, ce.code) +} + func TestDiscoverForCommand_NmapPathEmptyFailsClosed(t *testing.T) { // Arrange r := &RunCmd{nmapPath: "", discoverer: &stubDiscoverer{}} diff --git a/server/cmd/fleetnode/run.go b/server/cmd/fleetnode/run.go index 295b8b1f1..119d06d4f 100644 --- a/server/cmd/fleetnode/run.go +++ b/server/cmd/fleetnode/run.go @@ -34,6 +34,7 @@ type RunCmd struct { discoverer discoverer `kong:"-"` nmapPath string `kong:"-"` resolver ipResolver `kong:"-"` + localSubnets func() ([]string, error) `kong:"-"` // test seam for auto_local_subnet detection stateMu sync.Mutex `kong:"-"` // guards st.SessionToken across refreshAndSave + tokenSource. } diff --git a/server/internal/domain/fleetnode/control/registry.go b/server/internal/domain/fleetnode/control/registry.go index c96ee2eb7..283b81ff6 100644 --- a/server/internal/domain/fleetnode/control/registry.go +++ b/server/internal/domain/fleetnode/control/registry.go @@ -93,6 +93,19 @@ func NewRegistry() *Registry { return &Registry{conns: make(map[int64]*connection)} } +// ConnectedFleetNodeIDs returns the fleet_node IDs with an active ControlStream +// right now. Used by fan-out discovery to target only nodes the server can reach; +// callers intersect this with the org's CONFIRMED nodes. Order is unspecified. +func (r *Registry) ConnectedFleetNodeIDs() []int64 { + r.mu.Lock() + defer r.mu.Unlock() + ids := make([]int64, 0, len(r.conns)) + for id := range r.conns { + ids = append(ids, id) + } + return ids +} + // teardown closes connection.done and cmd.done (if any). Caller holds Registry.mu // and must then remove/replace the conn so teardown can't run twice. func teardown(conn *connection) { diff --git a/server/internal/domain/fleetnode/control/registry_test.go b/server/internal/domain/fleetnode/control/registry_test.go index bae4bc704..a3a66bc25 100644 --- a/server/internal/domain/fleetnode/control/registry_test.go +++ b/server/internal/domain/fleetnode/control/registry_test.go @@ -327,3 +327,21 @@ func receive(t *testing.T, ch <-chan CommandEvent) CommandEvent { return CommandEvent{} } } + +func TestConnectedFleetNodeIDs_ReflectsRegisterAndUnregister(t *testing.T) { + // Arrange + r := NewRegistry() + s1 := r.Register(1) + s2 := r.Register(2) + + // Act + Assert: both connected. + assert.ElementsMatch(t, []int64{1, 2}, r.ConnectedFleetNodeIDs()) + + // Act + Assert: unregistering one drops it. + s1.Unregister() + assert.ElementsMatch(t, []int64{2}, r.ConnectedFleetNodeIDs()) + + // Act + Assert: empty once all are gone. + s2.Unregister() + assert.Empty(t, r.ConnectedFleetNodeIDs()) +} diff --git a/server/internal/handlers/fleetnode/admin/handler_ackfailure_test.go b/server/internal/domain/fleetnode/discovery/ackfailure_test.go similarity index 98% rename from server/internal/handlers/fleetnode/admin/handler_ackfailure_test.go rename to server/internal/domain/fleetnode/discovery/ackfailure_test.go index 4fea4bb42..f8caaeb37 100644 --- a/server/internal/handlers/fleetnode/admin/handler_ackfailure_test.go +++ b/server/internal/domain/fleetnode/discovery/ackfailure_test.go @@ -1,4 +1,4 @@ -package admin +package discovery import ( "testing" diff --git a/server/internal/handlers/fleetnode/admin/handler_iprange_test.go b/server/internal/domain/fleetnode/discovery/iprange_test.go similarity index 99% rename from server/internal/handlers/fleetnode/admin/handler_iprange_test.go rename to server/internal/domain/fleetnode/discovery/iprange_test.go index 62f607d9b..f46a3957b 100644 --- a/server/internal/handlers/fleetnode/admin/handler_iprange_test.go +++ b/server/internal/domain/fleetnode/discovery/iprange_test.go @@ -1,4 +1,4 @@ -package admin +package discovery import ( "testing" diff --git a/server/internal/handlers/fleetnode/admin/reportscope.go b/server/internal/domain/fleetnode/discovery/reportscope.go similarity index 91% rename from server/internal/handlers/fleetnode/admin/reportscope.go rename to server/internal/domain/fleetnode/discovery/reportscope.go index 9c4a05d24..01255a67c 100644 --- a/server/internal/handlers/fleetnode/admin/reportscope.go +++ b/server/internal/domain/fleetnode/discovery/reportscope.go @@ -1,4 +1,4 @@ -package admin +package discovery import ( "net/netip" @@ -25,8 +25,21 @@ func buildReportScope(req *pairingpb.DiscoverRequest) control.ReportScope { return inPort(port) && inIP(ip) } case *pairingpb.DiscoverRequest_Nmap: - inTarget := nmapTargetMatcher(m.Nmap.GetTarget()) inPort := portMatcher(m.Nmap.GetPorts()) + // The LocalSubnetTarget sentinel lets the agent pick its own subnet, so + // the server can't predict the IPs. Degrade the IP scope to the + // private-only invariant (RFC1918/RFC4193) that validateReport + // independently enforces; port scoping is still applied. + if m.Nmap.GetTarget() == nmaptarget.LocalSubnetTarget { + return func(ip, port string) bool { + if !inPort(port) { + return false + } + a, ok := parseScopeAddr(ip) + return ok && a.IsPrivate() + } + } + inTarget := nmapTargetMatcher(m.Nmap.GetTarget()) return func(ip, port string) bool { return inPort(port) && inTarget(ip) } diff --git a/server/internal/handlers/fleetnode/admin/reportscope_test.go b/server/internal/domain/fleetnode/discovery/reportscope_test.go similarity index 82% rename from server/internal/handlers/fleetnode/admin/reportscope_test.go rename to server/internal/domain/fleetnode/discovery/reportscope_test.go index a98eb6536..29e7c30d9 100644 --- a/server/internal/handlers/fleetnode/admin/reportscope_test.go +++ b/server/internal/domain/fleetnode/discovery/reportscope_test.go @@ -1,4 +1,4 @@ -package admin +package discovery import ( "testing" @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" ) func ipListReq(ips, ports []string) *pairingpb.DiscoverRequest { @@ -21,6 +22,12 @@ func nmapReq(target string, ports []string) *pairingpb.DiscoverRequest { }} } +func autoNmapReq(ports []string) *pairingpb.DiscoverRequest { + return &pairingpb.DiscoverRequest{Mode: &pairingpb.DiscoverRequest_Nmap{ + Nmap: &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: ports}, + }} +} + func TestBuildReportScope(t *testing.T) { tests := []struct { name string @@ -49,6 +56,10 @@ func TestBuildReportScope(t *testing.T) { {"nmap literal mismatch", nmapReq("192.168.1.10", []string{"80"}), "192.168.1.11", "80", false}, {"nmap hostname leaves ip unconstrained", nmapReq("miner.lan", []string{"80"}), "10.1.2.3", "80", true}, {"nmap hostname still enforces ports", nmapReq("miner.lan", []string{"80"}), "10.1.2.3", "22", false}, + {"auto accepts private ip on in-scope port", autoNmapReq([]string{"80"}), "192.168.5.9", "80", true}, + {"auto rejects public ip", autoNmapReq([]string{"80"}), "8.8.8.8", "80", false}, + {"auto rejects out-of-scope port", autoNmapReq([]string{"80"}), "192.168.5.9", "22", false}, + {"auto empty ports allows any private ip and port", autoNmapReq(nil), "10.2.3.4", "31337", true}, } for _, tc := range tests { @@ -147,6 +158,29 @@ func TestNormalizeDiscoverRequest_RejectsPublicNmapTarget(t *testing.T) { } } +func TestNormalizeDiscoverRequest_AutoLocalSubnet_Accepts(t *testing.T) { + // Arrange: auto mode with no target and valid ports. + req := autoNmapReq([]string{"80", "4028"}) + + // Act + _, err := normalizeDiscoverRequest(req) + + // Assert + require.NoError(t, err) +} + +func TestNormalizeDiscoverRequest_AutoLocalSubnet_RejectsInvalidPort(t *testing.T) { + // Arrange + req := autoNmapReq([]string{"80/tcp"}) + + // Act + _, err := normalizeDiscoverRequest(req) + + // Assert + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid port") +} + func TestNmapTargetIsPrivate(t *testing.T) { tests := []struct { name string diff --git a/server/internal/domain/fleetnode/discovery/service.go b/server/internal/domain/fleetnode/discovery/service.go new file mode 100644 index 000000000..31e9e6eec --- /dev/null +++ b/server/internal/domain/fleetnode/discovery/service.go @@ -0,0 +1,344 @@ +// Package discovery dispatches server-initiated miner discovery to fleet nodes +// over the ControlStream and streams the results back. It owns the per-node +// run loop (normalize -> send command -> drain batches until ack) shared by the +// operator-facing single-node RPC (handlers/fleetnode/admin) and the cloud +// "Find miners" fan-out (handlers/pairing), plus the helpers that decide which +// nodes a fan-out should target. +package discovery + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/netip" + "strconv" + "time" + + "connectrpc.com/connect" + "google.golang.org/protobuf/proto" + + gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" + pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/discoverylimits" + "github.com/block/proto-fleet/server/internal/domain/fleeterror" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" + "github.com/block/proto-fleet/server/internal/domain/netutil" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" + "github.com/block/proto-fleet/server/internal/infrastructure/id" +) + +// DiscoverCommandTimeout bounds how long RunOnNode waits for the agent's batches +// and ack, so a silent node can't pin operator streams and registry slots. Must +// exceed the agent's scan budget (commandTimeout, 10m) plus report/ack slack: too +// short frees the slot mid-scan, the agent's ack is rejected as stale, and a new +// command dispatches while the node is still busy. Var for tests. +var DiscoverCommandTimeout = 12 * time.Minute + +// nodeLister is the subset of enrollment.Service that fan-out targeting needs. +type nodeLister interface { + ListFleetNodes(ctx context.Context, orgID int64) ([]enrollment.FleetNodeListing, error) +} + +// Service runs discovery commands against connected fleet nodes. +type Service struct { + registry *control.Registry + enrollment nodeLister +} + +func NewService(registry *control.Registry, enrollmentSvc nodeLister) *Service { + return &Service{registry: registry, enrollment: enrollmentSvc} +} + +// ConfirmedConnectedNodeIDs returns the IDs of fleet nodes in orgID that are both +// CONFIRMED and currently connected (active ControlStream) — the set a fan-out +// can dispatch to. A node with a live stream but a non-CONFIRMED enrollment +// status is excluded. +func (s *Service) ConfirmedConnectedNodeIDs(ctx context.Context, orgID int64) ([]int64, error) { + nodes, err := s.enrollment.ListFleetNodes(ctx, orgID) + if err != nil { + return nil, err + } + confirmed := make(map[int64]struct{}, len(nodes)) + for _, n := range nodes { + if n.EnrollmentStatus == enrollment.FleetNodeStatusConfirmed { + confirmed[n.ID] = struct{}{} + } + } + connected := s.registry.ConnectedFleetNodeIDs() + out := make([]int64, 0, len(connected)) + for _, nodeID := range connected { + if _, ok := confirmed[nodeID]; ok { + out = append(out, nodeID) + } + } + return out, nil +} + +// RunOnNode normalizes req, builds the report scope, dispatches the command over +// the node's ControlStream, and invokes onBatch for each discovered-device batch +// until the node acks (or the command times out / the stream drops). It returns +// nil on an OK or PARTIAL ack, and an error otherwise — including any non-nil +// error returned by onBatch, which is treated as terminal (the caller's stream +// is gone, so there is nothing left to forward). +func (s *Service) RunOnNode(ctx context.Context, fleetNodeID int64, req *pairingpb.DiscoverRequest, onBatch func(*pairingpb.DiscoverResponse) error) error { + normalized, err := normalizeDiscoverRequest(req) + if err != nil { + return err + } + + commandID := id.GenerateID() + payload, err := proto.Marshal(normalized) + if err != nil { + return fleeterror.NewInternalErrorf("marshal discover payload: %v", err) + } + + ctx, cancel := context.WithTimeout(ctx, DiscoverCommandTimeout) + defer cancel() + + session, err := s.registry.Send(ctx, fleetNodeID, &gatewaypb.ControlCommand{ + CommandId: commandID, + Payload: payload, + }, buildReportScope(normalized)) + if err != nil { + if errors.Is(err, control.ErrNoActiveStream) { + return fleeterror.NewFailedPreconditionError("fleet node has no active control stream") + } + return err + } + defer session.Close() + + // handleEvent forwards a batch or resolves the command on an ack. terminal + // reports whether the command is finished; err is set only on failure. + handleEvent := func(ev control.CommandEvent) (terminal bool, err error) { + switch { + case ev.Batch != nil: + if sendErr := onBatch(ev.Batch); sendErr != nil { + return true, sendErr + } + return false, nil + case ev.Ack != nil: + // PARTIAL carries succeeded=false but its reports already streamed; + // treat it as a usable (incomplete) result, not a failure. + if ev.Ack.GetCode() == gatewaypb.AckCode_ACK_CODE_PARTIAL { + slog.Warn("fleet node discovery completed partially", + "fleet_node_id", fleetNodeID, "detail", ev.Ack.GetErrorMessage()) + return true, nil + } + // Require the structured OK code, not just the boolean, so an + // inconsistent ack (succeeded=true with a non-OK/unset code) can't + // pass a failed scan off as success. + if ev.Ack.GetCode() != gatewaypb.AckCode_ACK_CODE_OK || !ev.Ack.GetSucceeded() { + return true, discoverAckFailure(ev.Ack) + } + return true, nil + default: + return false, nil + } + } + + events := session.Events() + for { + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return connect.NewError(connect.CodeDeadlineExceeded, fmt.Errorf("discovery command timed out after %s", DiscoverCommandTimeout)) + } + return fleeterror.NewInternalErrorf("operator stream cancelled: %v", ctx.Err()) + case ev := <-events: + if terminal, err := handleEvent(ev); terminal { + return err + } + case <-session.Done(): + // Stream died before an ack. Drain buffered events first (a final + // ack or last batch) so select randomness doesn't drop them. + for { + select { + case ev := <-events: + if terminal, err := handleEvent(ev); terminal { + return err + } + default: + return fleeterror.NewFailedPreconditionError("fleet node control stream closed before command completed") + } + } + } + } +} + +// discoverAckFailure maps a non-OK ack to an operator-facing error, even when +// error_message is empty. The structured AckCode drives the gRPC code so the +// operator can tell a retryable condition (BUSY) and a capability gap +// (AGENT_INCAPABLE) apart from a malformed request (BAD_REQUEST); anything else +// is an opaque Internal failure. +func discoverAckFailure(ack *gatewaypb.ControlAck) error { + reason := ack.GetErrorMessage() + if reason == "" { + reason = "code " + ack.GetCode().String() + } + // if/else (not switch) so the exhaustive linter doesn't demand a case per + // AckCode; everything outside these three is an opaque Internal failure. + code := ack.GetCode() + if code == gatewaypb.AckCode_ACK_CODE_BAD_REQUEST { + return fleeterror.NewInvalidArgumentErrorf("fleet node rejected discovery command: %s", reason) + } + if code == gatewaypb.AckCode_ACK_CODE_BUSY { + return fleeterror.NewPlainError( + fmt.Sprintf("fleet node is busy with another command; retry shortly: %s", reason), + connect.CodeResourceExhausted, + ) + } + if code == gatewaypb.AckCode_ACK_CODE_AGENT_INCAPABLE { + return fleeterror.NewFailedPreconditionErrorf("fleet node cannot service this discovery request; try another node: %s", reason) + } + return fleeterror.NewInternalErrorf("fleet node reported discovery failure: %s", reason) +} + +func normalizeDiscoverRequest(in *pairingpb.DiscoverRequest) (*pairingpb.DiscoverRequest, error) { + switch m := in.GetMode().(type) { + case *pairingpb.DiscoverRequest_IpList: + if m.IpList == nil || len(m.IpList.GetIpAddresses()) == 0 { + return nil, fleeterror.NewInvalidArgumentError("ip_list.ip_addresses must not be empty") + } + if err := checkScanLimits(m.IpList.GetIpAddresses(), m.IpList.GetPorts()); err != nil { + return nil, err + } + // Every entry must be a valid IP or hostname, and IP literals must be + // private. A malformed token (e.g. "bad/entry") is unresolvable for the + // agent yet trips the scope matcher's hostname fallback, widening the + // command to port-only scope. A public literal scans fine but every report + // is rejected by validateReport (private-only), surfacing as a late + // REPORT_FAILED. Hostnames resolve agent-side to an IP the server can't + // check here, so they pass through. + for _, e := range m.IpList.GetIpAddresses() { + addr, perr := netip.ParseAddr(e) + if perr != nil { + if !nmaptarget.IsHostname(e) { + return nil, fleeterror.NewInvalidArgumentErrorf("ip_list entry %q is not a valid IP address or hostname", e) + } + continue + } + if !addr.Unmap().IsPrivate() { + return nil, fleeterror.NewInvalidArgumentErrorf("ip_list entry %q is not a private (RFC1918/RFC4193) address", e) + } + } + return in, nil + case *pairingpb.DiscoverRequest_IpRange: + ips, err := expandIPv4Range(m.IpRange.GetStartIp(), m.IpRange.GetEndIp()) + if err != nil { + return nil, err + } + if err := checkScanLimits(ips, m.IpRange.GetPorts()); err != nil { + return nil, err + } + return &pairingpb.DiscoverRequest{ + Mode: &pairingpb.DiscoverRequest_IpList{ + IpList: &pairingpb.IPListModeRequest{ + IpAddresses: ips, + Ports: m.IpRange.GetPorts(), + }, + }, + }, nil + case *pairingpb.DiscoverRequest_Nmap: + target := m.Nmap.GetTarget() + // The LocalSubnetTarget sentinel defers the target to the agent (it scans + // its own private subnet(s)), so there is nothing to validate here; the + // report scope (buildReportScope) and validateReport still confine reports + // to private addresses. + if target == nmaptarget.LocalSubnetTarget { + if err := checkScanLimits(nil, m.Nmap.GetPorts()); err != nil { + return nil, err + } + return in, nil + } + // Validate against the shared grammar (incl. the /22 CIDR cap), then + // reject IPv6 CIDR — both rejections the agent makes — so an unsupported + // target fails fast here instead of as a late agent BAD_REQUEST ack. + if err := nmaptarget.Validate(target); err != nil { + return nil, fleeterror.NewInvalidArgumentError(err.Error()) + } + if prefix, perr := netip.ParsePrefix(target); perr == nil && prefix.Addr().Is6() { + return nil, fleeterror.NewInvalidArgumentError("nmap IPv6 CIDR is not supported; use ip_list for IPv6 devices") + } + // A public target scans fine but every report comes back non-private and + // is rejected by validateReport, so fail fast. Hostnames resolve agent-side + // and pass through (the report validator still guards what they return). + if !nmapTargetIsPrivate(target) { + return nil, fleeterror.NewInvalidArgumentError("nmap target must be within a private (RFC1918/RFC4193) range") + } + if err := checkScanLimits(nil, m.Nmap.GetPorts()); err != nil { + return nil, err + } + return in, nil + case *pairingpb.DiscoverRequest_Mdns: + return nil, fleeterror.NewInvalidArgumentError("mdns discovery is not supported on fleet nodes") + default: + return nil, fleeterror.NewInvalidArgumentError("discover request mode is required") + } +} + +// checkScanLimits enforces the agent's per-command caps (via discoverylimits) +// and rejects malformed ports before dispatch, so an over-cap or invalid request +// fails fast with a validation error instead of a late agent BAD_REQUEST ack. +// The proto caps are the wire ceiling; these are the real limits. +func checkScanLimits(ipAddresses, ports []string) error { + if len(ipAddresses) > discoverylimits.MaxScanTargets { + return fleeterror.NewInvalidArgumentErrorf("too many targets: %d exceeds the limit of %d", len(ipAddresses), discoverylimits.MaxScanTargets) + } + if len(ports) > discoverylimits.MaxPortsPerIP { + return fleeterror.NewInvalidArgumentErrorf("too many ports: %d exceeds the limit of %d", len(ports), discoverylimits.MaxPortsPerIP) + } + // Each port must be a bare decimal in 1-65535, matching the agent's + // resolveAndValidatePorts; otherwise a token like "80/tcp" or "70000" + // dispatches and returns as a late agent BAD_REQUEST ack. + for _, p := range ports { + if n, err := strconv.Atoi(p); err != nil || n < 1 || n > 65535 { + return fleeterror.NewInvalidArgumentErrorf("invalid port %q: must be a decimal in 1-65535", p) + } + } + return nil +} + +func expandIPv4Range(startStr, endStr string) ([]string, error) { + startAddr, err := netutil.ParseIPv4(startStr) + if err != nil { + return nil, fleeterror.NewInvalidArgumentErrorf("invalid start_ip: %v", err) + } + endAddr, err := netutil.ParseIPv4(endStr) + if err != nil { + return nil, fleeterror.NewInvalidArgumentErrorf("invalid end_ip: %v", err) + } + // Both ends must be private. The MaxScanTargets cap below keeps the range far + // smaller than the gap between RFC1918 blocks, so private endpoints imply a + // fully private range. A public range scans fine but every report is rejected + // by validateReport, surfacing as a late REPORT_FAILED. + if !startAddr.IsPrivate() || !endAddr.IsPrivate() { + return nil, fleeterror.NewInvalidArgumentError("ip range must be within a private (RFC1918) range") + } + start, end := netutil.IPv4ToUint32(startAddr), netutil.IPv4ToUint32(endAddr) + if end < start { + return nil, fleeterror.NewInvalidArgumentError("end_ip must be >= start_ip") + } + // Skip the network (.0) and gateway (.1) start addresses, matching the agent + // and cloud pairing. Otherwise expanding to an IP list would scan .0/.1 as + // literal targets — gateways answer on many ports and look like miners. + start = netutil.AdjustIPv4RangeStart(start) + if end < start { + return nil, fleeterror.NewInvalidArgumentError("ip range covers only network/gateway addresses") + } + // uint64 math so a range ending at 255.255.255.255 can't wrap (in uint32, + // end-start+1 would overflow to 0, bypassing the cap and never terminating). + size := uint64(end) - uint64(start) + 1 + if size > discoverylimits.MaxScanTargets { + return nil, fleeterror.NewInvalidArgumentErrorf("ip range exceeds %d addresses", discoverylimits.MaxScanTargets) + } + out := make([]string, 0, size) + for v := start; ; v++ { + out = append(out, netutil.Uint32ToIPv4(v)) + if v == end { + break + } + } + return out, nil +} diff --git a/server/internal/domain/fleetnode/discovery/service_test.go b/server/internal/domain/fleetnode/discovery/service_test.go new file mode 100644 index 000000000..338ac535f --- /dev/null +++ b/server/internal/domain/fleetnode/discovery/service_test.go @@ -0,0 +1,136 @@ +package discovery + +import ( + "context" + "testing" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" + pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/fleeterror" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" +) + +type stubLister struct{ nodes []enrollment.FleetNodeListing } + +func (s stubLister) ListFleetNodes(context.Context, int64) ([]enrollment.FleetNodeListing, error) { + return s.nodes, nil +} + +func collectBatches(dst *[]*pairingpb.Device) func(*pairingpb.DiscoverResponse) error { + return func(b *pairingpb.DiscoverResponse) error { + *dst = append(*dst, b.GetDevices()...) + return nil + } +} + +func TestRunOnNode_ForwardsBatchesUntilAck(t *testing.T) { + // Arrange + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(7) + stream := reg.Register(nodeID) + defer stream.Unregister() + go func() { + cmd := <-stream.Outgoing + reg.PublishBatch(nodeID, cmd.GetCommandId(), &pairingpb.DiscoverResponse{ + Devices: []*pairingpb.Device{{DeviceIdentifier: "auto:1", IpAddress: "10.0.0.5", Port: "4028"}}, + }) + stream.PublishAck(&gatewaypb.ControlAck{CommandId: cmd.GetCommandId(), Succeeded: true, Code: gatewaypb.AckCode_ACK_CODE_OK}) + }() + var got []*pairingpb.Device + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.5"}, []string{"4028"}), collectBatches(&got)) + + // Assert + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, "auto:1", got[0].GetDeviceIdentifier()) +} + +func TestRunOnNode_PartialAckIsNotFailure(t *testing.T) { + // Arrange + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(8) + stream := reg.Register(nodeID) + defer stream.Unregister() + go func() { + cmd := <-stream.Outgoing + reg.PublishBatch(nodeID, cmd.GetCommandId(), &pairingpb.DiscoverResponse{ + Devices: []*pairingpb.Device{{DeviceIdentifier: "auto:2", IpAddress: "10.0.0.6", Port: "4028"}}, + }) + stream.PublishAck(&gatewaypb.ControlAck{CommandId: cmd.GetCommandId(), Code: gatewaypb.AckCode_ACK_CODE_PARTIAL}) + }() + var got []*pairingpb.Device + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.6"}, []string{"4028"}), collectBatches(&got)) + + // Assert: PARTIAL is a usable (incomplete) result, and its batch still streamed. + require.NoError(t, err) + require.Len(t, got, 1) +} + +func TestRunOnNode_DisconnectBeforeAckReturnsError(t *testing.T) { + // Arrange: agent takes the command then drops without acking. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(9) + stream := reg.Register(nodeID) + go func() { + <-stream.Outgoing + stream.Unregister() + }() + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.7"}, nil), func(*pairingpb.DiscoverResponse) error { return nil }) + + // Assert + require.Error(t, err) + var fe fleeterror.FleetError + require.ErrorAs(t, err, &fe) + assert.Equal(t, connect.CodeFailedPrecondition, fe.ConnectError().Code()) +} + +func TestRunOnNode_NoActiveStreamReturnsFailedPrecondition(t *testing.T) { + // Arrange: no stream registered for the target node. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + + // Act + err := svc.RunOnNode(context.Background(), 404, ipListReq([]string{"10.0.0.8"}, nil), func(*pairingpb.DiscoverResponse) error { return nil }) + + // Assert + require.Error(t, err) + var fe fleeterror.FleetError + require.ErrorAs(t, err, &fe) + assert.Equal(t, connect.CodeFailedPrecondition, fe.ConnectError().Code()) +} + +func TestConfirmedConnectedNodeIDs_IntersectsStatusAndConnection(t *testing.T) { + // Arrange: 1 = confirmed+connected, 2 = confirmed+disconnected, 3 = pending+connected. + reg := control.NewRegistry() + lister := stubLister{nodes: []enrollment.FleetNodeListing{ + {FleetNode: enrollment.FleetNode{ID: 1, EnrollmentStatus: enrollment.FleetNodeStatusConfirmed}}, + {FleetNode: enrollment.FleetNode{ID: 2, EnrollmentStatus: enrollment.FleetNodeStatusConfirmed}}, + {FleetNode: enrollment.FleetNode{ID: 3, EnrollmentStatus: enrollment.FleetNodeStatusPending}}, + }} + svc := NewService(reg, lister) + s1 := reg.Register(1) + defer s1.Unregister() + s3 := reg.Register(3) + defer s3.Unregister() + + // Act + got, err := svc.ConfirmedConnectedNodeIDs(context.Background(), 1) + + // Assert: only the confirmed AND connected node. + require.NoError(t, err) + assert.Equal(t, []int64{1}, got) +} diff --git a/server/internal/domain/nmaptarget/nmaptarget.go b/server/internal/domain/nmaptarget/nmaptarget.go index cbd3aab7b..7c426a7c5 100644 --- a/server/internal/domain/nmaptarget/nmaptarget.go +++ b/server/internal/domain/nmaptarget/nmaptarget.go @@ -29,6 +29,13 @@ var ( // multi-hour scan at the public internet. Operators split larger scopes. const MinIPv4PrefixBits = 22 +// LocalSubnetTarget is a reserved nmap target value meaning "the fleet node +// should scan the private (RFC1918/RFC4193) IPv4 subnet(s) of its own +// interfaces." The cloud fan-out sets it because it cannot know each node's +// local network. It is matched exactly (before any hostname resolution); the +// dashed, namespaced name keeps it from colliding with a real hostname. +const LocalSubnetTarget = "fleetnode-local-subnet" + // IsIPv4Range reports whether s is an "A.B.C.D-N" nmap range. func IsIPv4Range(s string) bool { return ipv4RangeRE.MatchString(s) } diff --git a/server/internal/domain/pairing/mocks/mock_service.go b/server/internal/domain/pairing/mocks/mock_service.go index 3d9cda0e3..daa411f6a 100644 --- a/server/internal/domain/pairing/mocks/mock_service.go +++ b/server/internal/domain/pairing/mocks/mock_service.go @@ -166,3 +166,41 @@ func (mr *MockCapabilitiesProviderMockRecorder) GetMinerCapabilitiesForDevice(ct mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMinerCapabilitiesForDevice", reflect.TypeOf((*MockCapabilitiesProvider)(nil).GetMinerCapabilitiesForDevice), ctx, device) } + +// MockFleetNodeAssigner is a mock of FleetNodeAssigner interface. +type MockFleetNodeAssigner struct { + ctrl *gomock.Controller + recorder *MockFleetNodeAssignerMockRecorder + isgomock struct{} +} + +// MockFleetNodeAssignerMockRecorder is the mock recorder for MockFleetNodeAssigner. +type MockFleetNodeAssignerMockRecorder struct { + mock *MockFleetNodeAssigner +} + +// NewMockFleetNodeAssigner creates a new mock instance. +func NewMockFleetNodeAssigner(ctrl *gomock.Controller) *MockFleetNodeAssigner { + mock := &MockFleetNodeAssigner{ctrl: ctrl} + mock.recorder = &MockFleetNodeAssignerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFleetNodeAssigner) EXPECT() *MockFleetNodeAssignerMockRecorder { + return m.recorder +} + +// PairDevice mocks base method. +func (m *MockFleetNodeAssigner) PairDevice(ctx context.Context, fleetNodeID, deviceID, orgID int64, assignedBy *int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PairDevice", ctx, fleetNodeID, deviceID, orgID, assignedBy) + ret0, _ := ret[0].(error) + return ret0 +} + +// PairDevice indicates an expected call of PairDevice. +func (mr *MockFleetNodeAssignerMockRecorder) PairDevice(ctx, fleetNodeID, deviceID, orgID, assignedBy any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PairDevice", reflect.TypeOf((*MockFleetNodeAssigner)(nil).PairDevice), ctx, fleetNodeID, deviceID, orgID, assignedBy) +} diff --git a/server/internal/domain/pairing/service.go b/server/internal/domain/pairing/service.go index ac7830f69..b23c1a9ab 100644 --- a/server/internal/domain/pairing/service.go +++ b/server/internal/domain/pairing/service.go @@ -176,6 +176,17 @@ type Service struct { probeSemaphore chan struct{} invalidateMiner func(models.DeviceIdentifier) optionsCache *fleetoptions.Cache + fleetNodeAssigner FleetNodeAssigner +} + +// FleetNodeAssigner records operator-confirmed ownership of a fleet-node-discovered +// device. The cloud Pair flow routes remote-origin devices (discovered_device rows +// with discovered_by_fleet_node_id set) here instead of dialing them: per RFC 0001 +// the owning node dials and credentials the miner, so this is a metadata-only +// assignment. *fleetnode/pairing.Service satisfies this. Optional; when nil, cloud +// pairing refuses remote-origin devices (the pre-fan-out behavior). +type FleetNodeAssigner interface { + PairDevice(ctx context.Context, fleetNodeID, deviceID, orgID int64, assignedBy *int64) error } func NewService( @@ -214,6 +225,13 @@ func (s *Service) WithOptionsCache(cache *fleetoptions.Cache) { s.optionsCache = cache } +// WithFleetNodeAssigner wires the collaborator that records ownership of +// fleet-node-discovered devices so Pair can include them. Pass nil to keep the +// default behavior of refusing remote-origin devices. +func (s *Service) WithFleetNodeAssigner(a FleetNodeAssigner) { + s.fleetNodeAssigner = a +} + type NetworkInfo struct { networking.NetworkInfo } @@ -1181,6 +1199,7 @@ func (s *Service) PairDevices(ctx context.Context, r *pb.PairRequest) (*pb.PairR failedIDs := make([]string, 0, len(deviceIdentifiers)) credentials := r.Credentials + assignedBy := info.UserID // Deduplicate to prevent concurrent pairDevice calls against the same physical device. // We check both exact identifier strings and IP+port because different identifiers can @@ -1201,14 +1220,43 @@ func (s *Service) PairDevices(ctx context.Context, r *pb.PairRequest) (*pb.PairR OrgID: info.OrganizationID, }) if ddErr == nil { - // Cloud pairing dials the IP via plugin RPC; remote-origin - // rows must route through PairDeviceToFleetNode instead. + // Cloud pairing dials the IP via plugin RPC, so a remote-origin row + // can't take that path. With a fleet-node assigner wired, route it to + // the operator-confirmed ownership assignment (metadata only — the + // owning node dials and credentials the miner per RFC 0001), so the + // operator can pair fleet-node-discovered miners alongside direct ones + // in a single request. Without an assigner, refuse as before. if dd.DiscoveredByFleetNodeID != nil { - slog.Warn("refusing to pair remote-fleet-node-reported device via cloud pairing; use PairDeviceToFleetNode", - "device_identifier", id, - "fleet_node_id", *dd.DiscoveredByFleetNodeID, - ) - failedIDs = append(failedIDs, id) + if s.fleetNodeAssigner == nil { + slog.Warn("refusing to pair remote-fleet-node-reported device via cloud pairing; use PairDeviceToFleetNode", + "device_identifier", id, + "fleet_node_id", *dd.DiscoveredByFleetNodeID, + ) + failedIDs = append(failedIDs, id) + continue + } + dbID, idErr := s.discoveredDeviceStore.GetDatabaseID(ctx, discoverymodels.DeviceOrgIdentifier{ + DeviceIdentifier: id, + OrgID: info.OrganizationID, + }) + if idErr != nil { + slog.Error("failed to resolve discovered device id for fleet-node assignment", + "device_identifier", id, "error", idErr) + failedIDs = append(failedIDs, id) + continue + } + if assignErr := s.fleetNodeAssigner.PairDevice(ctx, *dd.DiscoveredByFleetNodeID, dbID, info.OrganizationID, &assignedBy); assignErr != nil { + slog.Warn("failed to assign fleet-node-discovered device to its node", + "device_identifier", id, + "fleet_node_id", *dd.DiscoveredByFleetNodeID, + "error", assignErr, + ) + failedIDs = append(failedIDs, id) + continue + } + // Assignment is metadata only: no plugin dial, credentials, + // handle invalidation, or telemetry scheduling here. + successfulIDs = append(successfulIDs, models.DeviceIdentifier(id)) continue } endpoint := dd.IpAddress + ":" + dd.Port diff --git a/server/internal/domain/pairing/service_pairrouting_test.go b/server/internal/domain/pairing/service_pairrouting_test.go new file mode 100644 index 000000000..d973e9fc9 --- /dev/null +++ b/server/internal/domain/pairing/service_pairrouting_test.go @@ -0,0 +1,128 @@ +package pairing + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + commonv1 "github.com/block/proto-fleet/server/generated/grpc/common/v1" + commandpb "github.com/block/proto-fleet/server/generated/grpc/minercommand/v1" + pb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/fleeterror" + discoverymodels "github.com/block/proto-fleet/server/internal/domain/minerdiscovery/models" + "github.com/block/proto-fleet/server/internal/domain/stores/interfaces/mocks" +) + +type assignCall struct { + fleetNodeID int64 + deviceID int64 + orgID int64 + assignedBy *int64 +} + +type fakeAssigner struct { + calls []assignCall + failDeviceIDs map[int64]bool +} + +func (f *fakeAssigner) PairDevice(_ context.Context, fleetNodeID, deviceID, orgID int64, assignedBy *int64) error { + f.calls = append(f.calls, assignCall{fleetNodeID, deviceID, orgID, assignedBy}) + if f.failDeviceIDs[deviceID] { + return fleeterror.NewFailedPreconditionError("device already paired; unpair first") + } + return nil +} + +func includeReq(ids ...string) *pb.PairRequest { + return &pb.PairRequest{DeviceSelector: &commandpb.DeviceSelector{ + SelectionType: &commandpb.DeviceSelector_IncludeDevices{ + IncludeDevices: &commonv1.DeviceIdentifierList{DeviceIdentifiers: ids}, + }, + }} +} + +func fleetNodeDiscoveredDevice(identifier string, orgID, nodeID int64) *discoverymodels.DiscoveredDevice { + return &discoverymodels.DiscoveredDevice{ + Device: pb.Device{DeviceIdentifier: identifier, IpAddress: "10.0.0.5", Port: "80"}, + OrgID: orgID, + DiscoveredByFleetNodeID: &nodeID, + } +} + +func TestPairDevices_FleetNodeDeviceRoutesToAssigner(t *testing.T) { + // Arrange + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const ( + orgID = int64(7) + userID = int64(3) + nodeID = int64(55) + dbID = int64(900) + ) + doi := discoverymodels.DeviceOrgIdentifier{DeviceIdentifier: "dev-1", OrgID: orgID} + mockDD := mocks.NewMockDiscoveredDeviceStore(ctrl) + mockDD.EXPECT().GetDevice(gomock.Any(), doi).Return(fleetNodeDiscoveredDevice("dev-1", orgID, nodeID), nil) + mockDD.EXPECT().GetDatabaseID(gomock.Any(), doi).Return(dbID, nil) + assigner := &fakeAssigner{} + svc := &Service{discoveredDeviceStore: mockDD, fleetNodeAssigner: assigner} + ctx := mockSessionContext(t.Context(), userID, orgID) + + // Act + resp, err := svc.PairDevices(ctx, includeReq("dev-1")) + + // Assert: routed to the assigner with the resolved DB id + caller, not dialed. + require.NoError(t, err) + assert.Empty(t, resp.GetFailedDeviceIds()) + require.Len(t, assigner.calls, 1) + assert.Equal(t, nodeID, assigner.calls[0].fleetNodeID) + assert.Equal(t, dbID, assigner.calls[0].deviceID) + assert.Equal(t, orgID, assigner.calls[0].orgID) + require.NotNil(t, assigner.calls[0].assignedBy) + assert.Equal(t, userID, *assigner.calls[0].assignedBy) +} + +func TestPairDevices_FleetNodeDeviceRefusedWithoutAssigner(t *testing.T) { + // Arrange: no assigner wired keeps the pre-fan-out refusal behavior. + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const orgID = int64(7) + doi := discoverymodels.DeviceOrgIdentifier{DeviceIdentifier: "dev-1", OrgID: orgID} + mockDD := mocks.NewMockDiscoveredDeviceStore(ctrl) + mockDD.EXPECT().GetDevice(gomock.Any(), doi).Return(fleetNodeDiscoveredDevice("dev-1", orgID, 55), nil) + svc := &Service{discoveredDeviceStore: mockDD} // fleetNodeAssigner nil + ctx := mockSessionContext(t.Context(), 3, orgID) + + // Act + _, err := svc.PairDevices(ctx, includeReq("dev-1")) + + // Assert: the only device was refused, so nothing paired. + require.Error(t, err) + assert.Contains(t, err.Error(), "Failed to pair any devices") +} + +func TestPairDevices_FleetNodeAssignPartialSuccess(t *testing.T) { + // Arrange: two fleet-node devices; the assigner fails the second only. + ctrl := gomock.NewController(t) + defer ctrl.Finish() + const orgID = int64(7) + doi1 := discoverymodels.DeviceOrgIdentifier{DeviceIdentifier: "dev-1", OrgID: orgID} + doi2 := discoverymodels.DeviceOrgIdentifier{DeviceIdentifier: "dev-2", OrgID: orgID} + mockDD := mocks.NewMockDiscoveredDeviceStore(ctrl) + mockDD.EXPECT().GetDevice(gomock.Any(), doi1).Return(fleetNodeDiscoveredDevice("dev-1", orgID, 55), nil) + mockDD.EXPECT().GetDatabaseID(gomock.Any(), doi1).Return(int64(900), nil) + mockDD.EXPECT().GetDevice(gomock.Any(), doi2).Return(fleetNodeDiscoveredDevice("dev-2", orgID, 55), nil) + mockDD.EXPECT().GetDatabaseID(gomock.Any(), doi2).Return(int64(901), nil) + assigner := &fakeAssigner{failDeviceIDs: map[int64]bool{901: true}} + svc := &Service{discoveredDeviceStore: mockDD, fleetNodeAssigner: assigner} + ctx := mockSessionContext(t.Context(), 3, orgID) + + // Act + resp, err := svc.PairDevices(ctx, includeReq("dev-1", "dev-2")) + + // Assert: dev-1 paired, dev-2 failed; partial success returns no top-level error. + require.NoError(t, err) + assert.Equal(t, []string{"dev-2"}, resp.GetFailedDeviceIds()) +} diff --git a/server/internal/handlers/fleetnode/admin/handler.go b/server/internal/handlers/fleetnode/admin/handler.go index 2f2898fa8..c9a2cd6b7 100644 --- a/server/internal/handlers/fleetnode/admin/handler.go +++ b/server/internal/handlers/fleetnode/admin/handler.go @@ -2,31 +2,19 @@ package admin import ( "context" - "errors" - "fmt" - "log/slog" - "net/netip" - "strconv" - "time" "connectrpc.com/connect" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" pb "github.com/block/proto-fleet/server/generated/grpc/fleetnodeadmin/v1" "github.com/block/proto-fleet/server/generated/grpc/fleetnodeadmin/v1/fleetnodeadminv1connect" - gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" "github.com/block/proto-fleet/server/internal/domain/authz" - "github.com/block/proto-fleet/server/internal/domain/discoverylimits" "github.com/block/proto-fleet/server/internal/domain/fleeterror" - "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" "github.com/block/proto-fleet/server/internal/domain/fleetnode/pairing" - "github.com/block/proto-fleet/server/internal/domain/netutil" - "github.com/block/proto-fleet/server/internal/domain/nmaptarget" "github.com/block/proto-fleet/server/internal/handlers/middleware" - "github.com/block/proto-fleet/server/internal/infrastructure/id" ) type Handler struct { @@ -34,13 +22,13 @@ type Handler struct { enrollment *enrollment.Service pairing *pairing.Service - registry *control.Registry + discovery *discovery.Service } var _ fleetnodeadminv1connect.FleetNodeAdminServiceHandler = &Handler{} -func NewHandler(enrollment *enrollment.Service, pairing *pairing.Service, registry *control.Registry) *Handler { - return &Handler{enrollment: enrollment, pairing: pairing, registry: registry} +func NewHandler(enrollment *enrollment.Service, pairing *pairing.Service, discoverySvc *discovery.Service) *Handler { + return &Handler{enrollment: enrollment, pairing: pairing, discovery: discoverySvc} } func (h *Handler) CreateEnrollmentCode(ctx context.Context, _ *connect.Request[pb.CreateEnrollmentCodeRequest]) (*connect.Response[pb.CreateEnrollmentCodeResponse], error) { @@ -165,14 +153,9 @@ func (h *Handler) ListFleetNodeDevices(ctx context.Context, req *connect.Request return connect.NewResponse(resp), nil } -// DiscoverCommandTimeout bounds how long DiscoverOnFleetNode waits for the -// agent's batches and ack, so a silent node can't pin operator streams and -// registry slots. Must exceed the agent's scan budget (commandTimeout, 10m) -// plus report/ack slack: too short frees the slot mid-scan, the agent's ack is -// rejected as stale, and a new command dispatches while the node is still busy. -// Var for tests. -var DiscoverCommandTimeout = 12 * time.Minute - +// DiscoverOnFleetNode runs discovery on a single CONFIRMED node and streams the +// node's device batches back to the operator. The dispatch/drain loop lives in +// the discovery service so the cloud "Find miners" fan-out can reuse it. func (h *Handler) DiscoverOnFleetNode(ctx context.Context, req *connect.Request[pb.DiscoverOnFleetNodeRequest], stream *connect.ServerStream[pb.DiscoverOnFleetNodeResponse]) error { info, err := middleware.RequirePermission(ctx, authz.PermFleetnodeManage, authz.ResourceContext{}) if err != nil { @@ -195,254 +178,12 @@ func (h *Handler) DiscoverOnFleetNode(ctx context.Context, req *connect.Request[ return fleeterror.NewFailedPreconditionError("fleet node is not CONFIRMED") } - normalized, err := normalizeDiscoverRequest(discoverReq) - if err != nil { - return err - } - - commandID := id.GenerateID() - payload, err := proto.Marshal(normalized) - if err != nil { - return fleeterror.NewInternalErrorf("marshal discover payload: %v", err) - } - - ctx, cancel := context.WithTimeout(ctx, DiscoverCommandTimeout) - defer cancel() - - session, err := h.registry.Send(ctx, fleetNodeID, &gatewaypb.ControlCommand{ - CommandId: commandID, - Payload: payload, - }, buildReportScope(normalized)) - if err != nil { - if errors.Is(err, control.ErrNoActiveStream) { - return fleeterror.NewFailedPreconditionError("fleet node has no active control stream") - } - return err - } - defer session.Close() - - // handleEvent forwards a batch or resolves the command on an ack. terminal - // reports whether the command is finished; err is set only on failure. - handleEvent := func(ev control.CommandEvent) (terminal bool, err error) { - switch { - case ev.Batch != nil: - if sendErr := stream.Send(&pb.DiscoverOnFleetNodeResponse{Response: ev.Batch}); sendErr != nil { - return true, fleeterror.NewInternalErrorf("send batch to operator: %v", sendErr) - } - return false, nil - case ev.Ack != nil: - // PARTIAL carries succeeded=false but its reports already streamed; - // treat it as a usable (incomplete) result, not a failure. - if ev.Ack.GetCode() == gatewaypb.AckCode_ACK_CODE_PARTIAL { - slog.Warn("fleet node discovery completed partially", - "fleet_node_id", fleetNodeID, "detail", ev.Ack.GetErrorMessage()) - return true, nil - } - // Require the structured OK code, not just the boolean, so an - // inconsistent ack (succeeded=true with a non-OK/unset code) can't - // pass a failed scan off as success. - if ev.Ack.GetCode() != gatewaypb.AckCode_ACK_CODE_OK || !ev.Ack.GetSucceeded() { - return true, discoverAckFailure(ev.Ack) - } - return true, nil - default: - return false, nil - } - } - - events := session.Events() - for { - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return connect.NewError(connect.CodeDeadlineExceeded, fmt.Errorf("discovery command timed out after %s", DiscoverCommandTimeout)) - } - return fleeterror.NewInternalErrorf("operator stream cancelled: %v", ctx.Err()) - case ev := <-events: - if terminal, err := handleEvent(ev); terminal { - return err - } - case <-session.Done(): - // Stream died before an ack. Drain buffered events first (a final - // ack or last batch) so select randomness doesn't drop them. - for { - select { - case ev := <-events: - if terminal, err := handleEvent(ev); terminal { - return err - } - default: - return fleeterror.NewFailedPreconditionError("fleet node control stream closed before command completed") - } - } - } - } -} - -// discoverAckFailure maps a non-OK ack to an operator-facing error, even when -// error_message is empty. The structured AckCode drives the gRPC code so the -// operator can tell a retryable condition (BUSY) and a capability gap -// (AGENT_INCAPABLE) apart from a malformed request (BAD_REQUEST); anything else -// is an opaque Internal failure. -func discoverAckFailure(ack *gatewaypb.ControlAck) error { - reason := ack.GetErrorMessage() - if reason == "" { - reason = "code " + ack.GetCode().String() - } - // if/else (not switch) so the exhaustive linter doesn't demand a case per - // AckCode; everything outside these three is an opaque Internal failure. - code := ack.GetCode() - if code == gatewaypb.AckCode_ACK_CODE_BAD_REQUEST { - return fleeterror.NewInvalidArgumentErrorf("fleet node rejected discovery command: %s", reason) - } - if code == gatewaypb.AckCode_ACK_CODE_BUSY { - return fleeterror.NewPlainError( - fmt.Sprintf("fleet node is busy with another command; retry shortly: %s", reason), - connect.CodeResourceExhausted, - ) - } - if code == gatewaypb.AckCode_ACK_CODE_AGENT_INCAPABLE { - return fleeterror.NewFailedPreconditionErrorf("fleet node cannot service this discovery request; try another node: %s", reason) - } - return fleeterror.NewInternalErrorf("fleet node reported discovery failure: %s", reason) -} - -func normalizeDiscoverRequest(in *pairingpb.DiscoverRequest) (*pairingpb.DiscoverRequest, error) { - switch m := in.GetMode().(type) { - case *pairingpb.DiscoverRequest_IpList: - if m.IpList == nil || len(m.IpList.GetIpAddresses()) == 0 { - return nil, fleeterror.NewInvalidArgumentError("ip_list.ip_addresses must not be empty") - } - if err := checkScanLimits(m.IpList.GetIpAddresses(), m.IpList.GetPorts()); err != nil { - return nil, err - } - // Every entry must be a valid IP or hostname, and IP literals must be - // private. A malformed token (e.g. "bad/entry") is unresolvable for the - // agent yet trips the scope matcher's hostname fallback, widening the - // command to port-only scope. A public literal scans fine but every report - // is rejected by validateReport (private-only), surfacing as a late - // REPORT_FAILED. Hostnames resolve agent-side to an IP the server can't - // check here, so they pass through. - for _, e := range m.IpList.GetIpAddresses() { - addr, perr := netip.ParseAddr(e) - if perr != nil { - if !nmaptarget.IsHostname(e) { - return nil, fleeterror.NewInvalidArgumentErrorf("ip_list entry %q is not a valid IP address or hostname", e) - } - continue - } - if !addr.Unmap().IsPrivate() { - return nil, fleeterror.NewInvalidArgumentErrorf("ip_list entry %q is not a private (RFC1918/RFC4193) address", e) - } - } - return in, nil - case *pairingpb.DiscoverRequest_IpRange: - ips, err := expandIPv4Range(m.IpRange.GetStartIp(), m.IpRange.GetEndIp()) - if err != nil { - return nil, err - } - if err := checkScanLimits(ips, m.IpRange.GetPorts()); err != nil { - return nil, err - } - return &pairingpb.DiscoverRequest{ - Mode: &pairingpb.DiscoverRequest_IpList{ - IpList: &pairingpb.IPListModeRequest{ - IpAddresses: ips, - Ports: m.IpRange.GetPorts(), - }, - }, - }, nil - case *pairingpb.DiscoverRequest_Nmap: - target := m.Nmap.GetTarget() - // Validate against the shared grammar (incl. the /22 CIDR cap), then - // reject IPv6 CIDR — both rejections the agent makes — so an unsupported - // target fails fast here instead of as a late agent BAD_REQUEST ack. - if err := nmaptarget.Validate(target); err != nil { - return nil, fleeterror.NewInvalidArgumentError(err.Error()) - } - if prefix, perr := netip.ParsePrefix(target); perr == nil && prefix.Addr().Is6() { - return nil, fleeterror.NewInvalidArgumentError("nmap IPv6 CIDR is not supported; use ip_list for IPv6 devices") - } - // A public target scans fine but every report comes back non-private and - // is rejected by validateReport, so fail fast. Hostnames resolve agent-side - // and pass through (the report validator still guards what they return). - if !nmapTargetIsPrivate(target) { - return nil, fleeterror.NewInvalidArgumentError("nmap target must be within a private (RFC1918/RFC4193) range") - } - if err := checkScanLimits(nil, m.Nmap.GetPorts()); err != nil { - return nil, err - } - return in, nil - case *pairingpb.DiscoverRequest_Mdns: - return nil, fleeterror.NewInvalidArgumentError("mdns discovery is not supported on fleet nodes") - default: - return nil, fleeterror.NewInvalidArgumentError("discover request mode is required") - } -} - -// checkScanLimits enforces the agent's per-command caps (via discoverylimits) -// and rejects malformed ports before dispatch, so an over-cap or invalid request -// fails fast with a validation error instead of a late agent BAD_REQUEST ack. -// The proto caps are the wire ceiling; these are the real limits. -func checkScanLimits(ipAddresses, ports []string) error { - if len(ipAddresses) > discoverylimits.MaxScanTargets { - return fleeterror.NewInvalidArgumentErrorf("too many targets: %d exceeds the limit of %d", len(ipAddresses), discoverylimits.MaxScanTargets) - } - if len(ports) > discoverylimits.MaxPortsPerIP { - return fleeterror.NewInvalidArgumentErrorf("too many ports: %d exceeds the limit of %d", len(ports), discoverylimits.MaxPortsPerIP) - } - // Each port must be a bare decimal in 1-65535, matching the agent's - // resolveAndValidatePorts; otherwise a token like "80/tcp" or "70000" - // dispatches and returns as a late agent BAD_REQUEST ack. - for _, p := range ports { - if n, err := strconv.Atoi(p); err != nil || n < 1 || n > 65535 { - return fleeterror.NewInvalidArgumentErrorf("invalid port %q: must be a decimal in 1-65535", p) + return h.discovery.RunOnNode(ctx, fleetNodeID, discoverReq, func(batch *pairingpb.DiscoverResponse) error { + if sendErr := stream.Send(&pb.DiscoverOnFleetNodeResponse{Response: batch}); sendErr != nil { + return fleeterror.NewInternalErrorf("send batch to operator: %v", sendErr) } - } - return nil -} - -func expandIPv4Range(startStr, endStr string) ([]string, error) { - startAddr, err := netutil.ParseIPv4(startStr) - if err != nil { - return nil, fleeterror.NewInvalidArgumentErrorf("invalid start_ip: %v", err) - } - endAddr, err := netutil.ParseIPv4(endStr) - if err != nil { - return nil, fleeterror.NewInvalidArgumentErrorf("invalid end_ip: %v", err) - } - // Both ends must be private. The MaxScanTargets cap below keeps the range far - // smaller than the gap between RFC1918 blocks, so private endpoints imply a - // fully private range. A public range scans fine but every report is rejected - // by validateReport, surfacing as a late REPORT_FAILED. - if !startAddr.IsPrivate() || !endAddr.IsPrivate() { - return nil, fleeterror.NewInvalidArgumentError("ip range must be within a private (RFC1918) range") - } - start, end := netutil.IPv4ToUint32(startAddr), netutil.IPv4ToUint32(endAddr) - if end < start { - return nil, fleeterror.NewInvalidArgumentError("end_ip must be >= start_ip") - } - // Skip the network (.0) and gateway (.1) start addresses, matching the agent - // and cloud pairing. Otherwise expanding to an IP list would scan .0/.1 as - // literal targets — gateways answer on many ports and look like miners. - start = netutil.AdjustIPv4RangeStart(start) - if end < start { - return nil, fleeterror.NewInvalidArgumentError("ip range covers only network/gateway addresses") - } - // uint64 math so a range ending at 255.255.255.255 can't wrap (in uint32, - // end-start+1 would overflow to 0, bypassing the cap and never terminating). - size := uint64(end) - uint64(start) + 1 - if size > discoverylimits.MaxScanTargets { - return nil, fleeterror.NewInvalidArgumentErrorf("ip range exceeds %d addresses", discoverylimits.MaxScanTargets) - } - out := make([]string, 0, size) - for v := start; ; v++ { - out = append(out, netutil.Uint32ToIPv4(v)) - if v == end { - break - } - } - return out, nil + return nil + }) } // AWAITING_CONFIRMATION lives only on pending_enrollment, so a PENDING fleet diff --git a/server/internal/handlers/fleetnode/admin/handler_discover_test.go b/server/internal/handlers/fleetnode/admin/handler_discover_test.go index 825365035..98bcf6323 100644 --- a/server/internal/handlers/fleetnode/admin/handler_discover_test.go +++ b/server/internal/handlers/fleetnode/admin/handler_discover_test.go @@ -19,8 +19,8 @@ import ( gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" "github.com/block/proto-fleet/server/internal/domain/authz" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" "github.com/block/proto-fleet/server/internal/domain/session" - "github.com/block/proto-fleet/server/internal/handlers/fleetnode/admin" "github.com/block/proto-fleet/server/internal/handlers/interceptors" "github.com/block/proto-fleet/server/internal/handlers/middleware" ) @@ -557,9 +557,9 @@ func TestDiscoverOnFleetNode_TimesOutWhenAgentNeverResponds(t *testing.T) { // Arrange: register an agent stream but never publish batch or ack. // Override DiscoverCommandTimeout to a short window so the test // terminates quickly. - prev := admin.DiscoverCommandTimeout - admin.DiscoverCommandTimeout = 200 * time.Millisecond - t.Cleanup(func() { admin.DiscoverCommandTimeout = prev }) + prev := discovery.DiscoverCommandTimeout + discovery.DiscoverCommandTimeout = 200 * time.Millisecond + t.Cleanup(func() { discovery.DiscoverCommandTimeout = prev }) h := newPairingHarness(t) fleetNodeID := h.createFleetNode(t, "admin-discover-timeout") diff --git a/server/internal/handlers/fleetnode/admin/handler_pairing_test.go b/server/internal/handlers/fleetnode/admin/handler_pairing_test.go index 4a3462b58..c6d4602b5 100644 --- a/server/internal/handlers/fleetnode/admin/handler_pairing_test.go +++ b/server/internal/handlers/fleetnode/admin/handler_pairing_test.go @@ -19,6 +19,7 @@ import ( "github.com/block/proto-fleet/server/internal/domain/authz" "github.com/block/proto-fleet/server/internal/domain/fleeterror" "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" "github.com/block/proto-fleet/server/internal/domain/fleetnode/pairing" "github.com/block/proto-fleet/server/internal/domain/session" @@ -57,8 +58,9 @@ func newPairingHarness(t *testing.T) *pairingHarness { pairingSvc := pairing.NewService(pairingStore, enrollmentStore, transactor) registry := control.NewRegistry() + discoverySvc := discovery.NewService(registry, enrollmentSvc) return &pairingHarness{ - handler: admin.NewHandler(enrollmentSvc, pairingSvc, registry), + handler: admin.NewHandler(enrollmentSvc, pairingSvc, discoverySvc), db: db, orgID: 1, enrollment: enrollmentSvc, diff --git a/server/internal/handlers/pairing/handler.go b/server/internal/handlers/pairing/handler.go index b0c0024d6..9be1f1d8a 100644 --- a/server/internal/handlers/pairing/handler.go +++ b/server/internal/handlers/pairing/handler.go @@ -3,9 +3,12 @@ package pairing import ( "context" "log/slog" + "sync" "github.com/block/proto-fleet/server/internal/domain/authz" "github.com/block/proto-fleet/server/internal/domain/fleeterror" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" "github.com/block/proto-fleet/server/internal/handlers/middleware" "connectrpc.com/connect" @@ -17,59 +20,149 @@ import ( // Handler handles the Connect-RPC endpoints type Handler struct { pairingSvc *pairing.Service + // discovery fans the "Scan your network" nmap action out to connected fleet + // nodes so their LAN-local miners surface alongside the cloud's own scan. + // Optional; nil disables fan-out (cloud-only discovery). + discovery *discovery.Service } var _ pairingv1connect.PairingServiceHandler = &Handler{} // NewHandler creates a new instance of Handler -func NewHandler(pairingSvc *pairing.Service) *Handler { +func NewHandler(pairingSvc *pairing.Service, discoverySvc *discovery.Service) *Handler { return &Handler{ pairingSvc: pairingSvc, + discovery: discoverySvc, } } -// Discover implements pairingv1connect.DeviceDiscoveryServiceHandler. +// Discover implements pairingv1connect.PairingServiceHandler. +// +// Beyond the cloud's own network scan, an nmap ("Scan your network") request +// also fans out to every CONFIRMED + connected fleet node, which scan their own +// local subnets and report back. All sources merge into this single response +// stream so the operator pairs LAN-local and cloud-local miners together with no +// client change. Manual modes (ipList/ipRange/mdns) target the cloud's own +// network only. func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRequest], s *connect.ServerStream[pb.DiscoverResponse]) error { - if _, err := middleware.RequirePermission(ctx, authz.PermMinerPair, authz.ResourceContext{}); err != nil { + info, err := middleware.RequirePermission(ctx, authz.PermMinerPair, authz.ResourceContext{}) + if err != nil { return err } slog.Debug("Discover: handling discover request", "payload", r.Msg) + + // A send failure (operator disconnected) cancels every source. + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Connect server streams are not safe for concurrent Send, and the cloud + // scan + each node write to this one stream. Serialize through send, which + // also dedupes devices across sources by identifier (each source dedupes + // internally, but not against the others). + var ( + sendMu sync.Mutex + seen = make(map[string]struct{}) + sendErr error + ) + send := func(resp *pb.DiscoverResponse) error { + sendMu.Lock() + defer sendMu.Unlock() + if sendErr != nil { + return sendErr + } + out := resp + if len(resp.GetDevices()) > 0 { + deduped := make([]*pb.Device, 0, len(resp.GetDevices())) + for _, d := range resp.GetDevices() { + key := d.GetDeviceIdentifier() + if key != "" { + if _, dup := seen[key]; dup { + continue + } + seen[key] = struct{}{} + } + deduped = append(deduped, d) + } + if len(deduped) == 0 && resp.GetError() == "" { + return nil // whole batch was duplicates; nothing to forward + } + out = &pb.DiscoverResponse{Devices: deduped, Error: resp.GetError()} + } + if sErr := s.Send(out); sErr != nil { + sendErr = sErr + cancel() + return sErr //nolint:wrapcheck // a connect stream Send error is already a connect error + } + return nil + } + var resultChan <-chan *pb.DiscoverResponse - var err error switch r.Msg.Mode.(type) { case *pb.DiscoverRequest_IpList: - resultChan, err = h.pairingSvc.DiscoverWithIPList(ctx, r.Msg.GetIpList()) + resultChan, err = h.pairingSvc.DiscoverWithIPList(streamCtx, r.Msg.GetIpList()) case *pb.DiscoverRequest_IpRange: - resultChan, err = h.pairingSvc.DiscoverWithIPRange(ctx, r.Msg.GetIpRange()) + resultChan, err = h.pairingSvc.DiscoverWithIPRange(streamCtx, r.Msg.GetIpRange()) case *pb.DiscoverRequest_Nmap: - resultChan, err = h.pairingSvc.DiscoverWithNmap(ctx, r.Msg.GetNmap()) + resultChan, err = h.pairingSvc.DiscoverWithNmap(streamCtx, r.Msg.GetNmap()) case *pb.DiscoverRequest_Mdns: - resultChan, err = h.pairingSvc.DiscoverWithMDNS(ctx, r.Msg.GetMdns()) + resultChan, err = h.pairingSvc.DiscoverWithMDNS(streamCtx, r.Msg.GetMdns()) default: return fleeterror.NewInternalError("unsupported mode") } - if err != nil { return err } - for { - select { - case result, ok := <-resultChan: - if !ok { - return nil - } - res := &pb.DiscoverResponse{ - Devices: result.Devices, + var wg sync.WaitGroup + + // Cloud discovery source. + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case result, ok := <-resultChan: + if !ok { + return + } + if err := send(result); err != nil { + return + } + case <-streamCtx.Done(): + return } - if err := s.Send(res); err != nil { - // nolint:wrapcheck - return err + } + }() + + // Fleet node fan-out (nmap only). + if _, ok := r.Msg.Mode.(*pb.DiscoverRequest_Nmap); ok && h.discovery != nil { + nodeIDs, listErr := h.discovery.ConfirmedConnectedNodeIDs(streamCtx, info.OrganizationID) + if listErr != nil { + // Fan-out is best-effort; a lookup failure must never break the + // cloud scan. With zero connected nodes this is the same path. + slog.Warn("skipping fleet node discovery fan-out", "error", listErr) + } else { + autoReq := &pb.DiscoverRequest{Mode: &pb.DiscoverRequest_Nmap{Nmap: &pb.NmapModeRequest{ + Target: nmaptarget.LocalSubnetTarget, + Ports: r.Msg.GetNmap().GetPorts(), + }}} + for _, nodeID := range nodeIDs { + wg.Add(1) + go func(nodeID int64) { + defer wg.Done() + runErr := h.discovery.RunOnNode(streamCtx, nodeID, autoReq, send) + if runErr != nil { + // One node failing must not fail the whole scan. + slog.Warn("fleet node discovery failed during cloud fan-out", + "fleet_node_id", nodeID, "error", runErr) + } + }(nodeID) } - case <-ctx.Done(): - return fleeterror.NewCanceledError() } } + + wg.Wait() + return sendErr } // Pair implements pairingv1connect.PairingServiceHandler. diff --git a/server/internal/testutil/infrastructure_provider.go b/server/internal/testutil/infrastructure_provider.go index 2587af397..7269a9741 100644 --- a/server/internal/testutil/infrastructure_provider.go +++ b/server/internal/testutil/infrastructure_provider.go @@ -54,7 +54,8 @@ func NewInfrastructureProvider(t *testing.T, serviceProvider *ServiceProvider, a authHandler := auth.NewHandler(serviceProvider.AuthService) mux.Handle(authv1connect.NewAuthServiceHandler(authHandler, interceptorsOption)) - pairingHandler := pairing.NewHandler(serviceProvider.PairingService) + // nil discovery service: no fleet node fan-out in this test harness. + pairingHandler := pairing.NewHandler(serviceProvider.PairingService, nil) mux.Handle(pairingv1connect.NewPairingServiceHandler(pairingHandler, interceptorsOption)) onboardingHandler := onboarding.NewHandler(serviceProvider.AuthService, serviceProvider.OnboardingService) From faab077fd4fed05d0e0ab4c8fa3b12843a465295 Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Tue, 2 Jun 2026 15:25:47 -0700 Subject: [PATCH 02/13] docs(discovery): fix stale auto_local_subnet comment references The auto_local_subnet bool was replaced by the LocalSubnetTarget sentinel; update the comments and two test names that still referenced the old field. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/cmd/fleetnode/localsubnet.go | 21 ++++++++++--------- server/cmd/fleetnode/localsubnet_test.go | 2 +- server/cmd/fleetnode/nmap.go | 2 +- server/cmd/fleetnode/run.go | 2 +- .../fleetnode/discovery/reportscope_test.go | 6 +++--- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/server/cmd/fleetnode/localsubnet.go b/server/cmd/fleetnode/localsubnet.go index 274fb1bbd..f2ee3459a 100644 --- a/server/cmd/fleetnode/localsubnet.go +++ b/server/cmd/fleetnode/localsubnet.go @@ -8,19 +8,19 @@ import ( "strings" ) -// autoSubnetMinPrefixBits caps an auto-detected subnet at /22 (<=1024 hosts) — -// the same ceiling the manual nmap path enforces (nmaptarget.MinIPv4PrefixBits). -// A NIC configured with a wider mask (e.g. /16) is narrowed around its own host -// address so an auto scan stays bounded and finishes inside the command timeout. +// autoSubnetMinPrefixBits caps a detected subnet at /22 (<=1024 hosts) — the +// same ceiling the manual nmap path enforces (nmaptarget.MinIPv4PrefixBits). A +// NIC configured with a wider mask (e.g. /16) is narrowed around its own host +// address so the scan stays bounded and finishes inside the command timeout. const autoSubnetMinPrefixBits = 22 -// maxAutoSubnets caps how many distinct subnets one auto command scans, so a +// maxAutoSubnets caps how many distinct subnets one command scans, so a // multi-homed host with many interfaces can't fan one command into a huge sweep. const maxAutoSubnets = 8 // errNoLocalPrivateSubnet means no connected, non-virtual interface had a private -// IPv4 address — the agent has nothing to auto-scan. Surfaces as AGENT_INCAPABLE -// so a fan-out skips this node and tries the others. +// IPv4 address — the agent has nothing to scan. Surfaces as AGENT_INCAPABLE so a +// fan-out skips this node and tries the others. var errNoLocalPrivateSubnet = errors.New("no connected private IPv4 subnet found") // virtualIfacePrefixes are name prefixes for container/VPN/virtual adapters whose @@ -33,8 +33,9 @@ var virtualIfacePrefixes = []string{ } // detectLocalSubnets returns the private IPv4 subnet(s) the agent should scan for -// an auto_local_subnet nmap command. The localSubnets seam lets tests inject -// canned CIDRs; production enumerates the host's interfaces. +// a local-subnet nmap command (the nmaptarget.LocalSubnetTarget sentinel). The +// localSubnets seam lets tests inject canned CIDRs; production enumerates the +// host's interfaces. func (r *RunCmd) detectLocalSubnets() ([]string, error) { if r.localSubnets != nil { return r.localSubnets() @@ -48,7 +49,7 @@ func (r *RunCmd) detectLocalSubnets() ([]string, error) { // selectLocalPrivateSubnets returns the canonical CIDR(s) of the connected, // non-virtual, private IPv4 subnet(s) of the given interfaces. addrsOf is -// injected for testing (net.Interface.Addrs in production). Subnets wider than +// injected for testing ((*net.Interface).Addrs in production). Subnets wider than // /22 are narrowed around the host address, results are deduped and capped at // maxAutoSubnets, and IPv6 is ignored (the manual nmap path rejects IPv6 CIDR // too). Returns errNoLocalPrivateSubnet when none qualify. diff --git a/server/cmd/fleetnode/localsubnet_test.go b/server/cmd/fleetnode/localsubnet_test.go index 37f0945ae..d1800ebdd 100644 --- a/server/cmd/fleetnode/localsubnet_test.go +++ b/server/cmd/fleetnode/localsubnet_test.go @@ -39,7 +39,7 @@ func TestSelectLocalPrivateSubnets_Typical24(t *testing.T) { func TestSelectLocalPrivateSubnets_OversizedMaskNarrowedTo22(t *testing.T) { // Arrange: a /8-masked NIC must narrow to /22 around its own host address so - // the auto scan stays within the manual-path host ceiling. + // the local-subnet scan stays within the manual-path host ceiling. ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("10.1.2.3/8")}}) diff --git a/server/cmd/fleetnode/nmap.go b/server/cmd/fleetnode/nmap.go index c04d4fb77..27a8ad62a 100644 --- a/server/cmd/fleetnode/nmap.go +++ b/server/cmd/fleetnode/nmap.go @@ -135,7 +135,7 @@ func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRe } // baseNmapOptions are the timing/safety options shared by targeted and -// auto-local-subnet scans; callers append the target(s) (and -6 if needed). +// local-subnet scans; callers append the target(s) (and -6 if needed). func baseNmapOptions(binaryPath string, ports []string) []nmap.Option { return []nmap.Option{ nmap.WithBinaryPath(binaryPath), diff --git a/server/cmd/fleetnode/run.go b/server/cmd/fleetnode/run.go index 119d06d4f..decf0178d 100644 --- a/server/cmd/fleetnode/run.go +++ b/server/cmd/fleetnode/run.go @@ -34,7 +34,7 @@ type RunCmd struct { discoverer discoverer `kong:"-"` nmapPath string `kong:"-"` resolver ipResolver `kong:"-"` - localSubnets func() ([]string, error) `kong:"-"` // test seam for auto_local_subnet detection + localSubnets func() ([]string, error) `kong:"-"` // test seam for local-subnet detection stateMu sync.Mutex `kong:"-"` // guards st.SessionToken across refreshAndSave + tokenSource. } diff --git a/server/internal/domain/fleetnode/discovery/reportscope_test.go b/server/internal/domain/fleetnode/discovery/reportscope_test.go index 29e7c30d9..8d5d83699 100644 --- a/server/internal/domain/fleetnode/discovery/reportscope_test.go +++ b/server/internal/domain/fleetnode/discovery/reportscope_test.go @@ -158,8 +158,8 @@ func TestNormalizeDiscoverRequest_RejectsPublicNmapTarget(t *testing.T) { } } -func TestNormalizeDiscoverRequest_AutoLocalSubnet_Accepts(t *testing.T) { - // Arrange: auto mode with no target and valid ports. +func TestNormalizeDiscoverRequest_LocalSubnetTarget_Accepts(t *testing.T) { + // Arrange: the local-subnet sentinel target with valid ports. req := autoNmapReq([]string{"80", "4028"}) // Act @@ -169,7 +169,7 @@ func TestNormalizeDiscoverRequest_AutoLocalSubnet_Accepts(t *testing.T) { require.NoError(t, err) } -func TestNormalizeDiscoverRequest_AutoLocalSubnet_RejectsInvalidPort(t *testing.T) { +func TestNormalizeDiscoverRequest_LocalSubnetTarget_RejectsInvalidPort(t *testing.T) { // Arrange req := autoNmapReq([]string{"80/tcp"}) From 3efca5cb8791bb726db82cf15be3d6cdbdcddc33 Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Tue, 2 Jun 2026 15:42:15 -0700 Subject: [PATCH 03/13] fix(discovery): address review feedback; revert unworkable fleet-node Pair routing - Revert Part C: the cloud Pair routing passed discovered_device.id where fleetnode/pairing.PairDevice expects a device.id, and fleet-node discoveries have no device row to assign (the cloud can't dial them to create one, and reports carry no MAC). Restore the safe refusal; inline pairing of fleet-node miners needs the node-side device-creation flow (follow-up). - Gate the fleet-node fan-out to the automatic "Scan your network" action (nmap target == the cloud's own local subnet) so a manual/explicit nmap scan no longer also sweeps every connected node's LAN. Adds Service.IsLocalSubnetScan. - Cap the agent's local-subnet auto-scan to a total host budget (discoverylimits.MaxScanTargets) across all detected subnets, not just /22 per subnet, so a multi-homed node can't 8x-oversize one command. - Map caller cancellation to CodeCanceled in RunOnNode and in Discover (was Internal / silent nil), and suppress the fan-out WARN when the stream is already cancelled (expected operator disconnect). - Fix the LocalSubnetTarget comment: IPv4 private space is RFC1918, not RFC4193. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/cmd/fleetd/main.go | 3 - server/cmd/fleetnode/localsubnet.go | 18 ++- server/cmd/fleetnode/localsubnet_test.go | 34 ++++- .../domain/fleetnode/discovery/service.go | 4 +- .../internal/domain/nmaptarget/nmaptarget.go | 8 +- .../domain/pairing/mocks/mock_service.go | 38 ------ server/internal/domain/pairing/service.go | 76 +++-------- .../domain/pairing/service_internal_test.go | 29 ++++ .../pairing/service_pairrouting_test.go | 128 ------------------ server/internal/handlers/pairing/handler.go | 27 +++- 10 files changed, 125 insertions(+), 240 deletions(-) delete mode 100644 server/internal/domain/pairing/service_pairrouting_test.go diff --git a/server/cmd/fleetd/main.go b/server/cmd/fleetd/main.go index 4515f7229..dda28e02e 100644 --- a/server/cmd/fleetd/main.go +++ b/server/cmd/fleetd/main.go @@ -371,9 +371,6 @@ func start(config *Config) error { ) pairingSvc.WithMinerInvalidator(minerService.InvalidateMiner) pairingSvc.WithOptionsCache(fleetOptionsCache) - // Route fleet-node-discovered devices through ownership assignment so the - // cloud Pair flow can pair them alongside directly-discovered miners. - pairingSvc.WithFleetNodeAssigner(fleetNodePairingSvc) // Initialize IP scanner service ipScannerService := ipscanner.NewIPScannerService( diff --git a/server/cmd/fleetnode/localsubnet.go b/server/cmd/fleetnode/localsubnet.go index f2ee3459a..bc5de4ca9 100644 --- a/server/cmd/fleetnode/localsubnet.go +++ b/server/cmd/fleetnode/localsubnet.go @@ -6,6 +6,8 @@ import ( "net" "net/netip" "strings" + + "github.com/block/proto-fleet/server/internal/domain/discoverylimits" ) // autoSubnetMinPrefixBits caps a detected subnet at /22 (<=1024 hosts) — the @@ -18,6 +20,12 @@ const autoSubnetMinPrefixBits = 22 // multi-homed host with many interfaces can't fan one command into a huge sweep. const maxAutoSubnets = 8 +// maxAutoScanHosts bounds the TOTAL addresses across all detected subnets to the +// same per-command ceiling the manual path enforces (discoverylimits.MaxScanTargets), +// so a multi-homed node can't turn one command into an 8x-oversized sweep. A +// single /22 already consumes the whole budget. +const maxAutoScanHosts = discoverylimits.MaxScanTargets + // errNoLocalPrivateSubnet means no connected, non-virtual interface had a private // IPv4 address — the agent has nothing to scan. Surfaces as AGENT_INCAPABLE so a // fan-out skips this node and tries the others. @@ -56,6 +64,7 @@ func (r *RunCmd) detectLocalSubnets() ([]string, error) { func selectLocalPrivateSubnets(ifaces []net.Interface, addrsOf func(*net.Interface) ([]net.Addr, error)) ([]string, error) { seen := make(map[string]struct{}) out := make([]string, 0, maxAutoSubnets) + hostBudget := maxAutoScanHosts for i := range ifaces { iface := ifaces[i] if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagRunning == 0 { @@ -92,9 +101,16 @@ func selectLocalPrivateSubnets(ifaces []net.Interface, addrsOf func(*net.Interfa if _, dup := seen[cidr]; dup { continue } + // Skip a subnet that wouldn't fit the remaining host budget so the + // total swept address space never exceeds maxAutoScanHosts. + hosts := 1 << (addr.BitLen() - ones) + if hosts > hostBudget { + continue + } seen[cidr] = struct{}{} out = append(out, cidr) - if len(out) >= maxAutoSubnets { + hostBudget -= hosts + if len(out) >= maxAutoSubnets || hostBudget <= 0 { return out, nil } } diff --git a/server/cmd/fleetnode/localsubnet_test.go b/server/cmd/fleetnode/localsubnet_test.go index d1800ebdd..974a1ad1f 100644 --- a/server/cmd/fleetnode/localsubnet_test.go +++ b/server/cmd/fleetnode/localsubnet_test.go @@ -106,11 +106,11 @@ func TestSelectLocalPrivateSubnets_DedupesSameSubnetAcrossNICs(t *testing.T) { assert.Equal(t, []string{"192.168.1.0/24"}, got) } -func TestSelectLocalPrivateSubnets_CapsResultCount(t *testing.T) { - // Arrange: more distinct private subnets than the cap allows. - ifaces := make([]net.Interface, 0, maxAutoSubnets+2) +func TestSelectLocalPrivateSubnets_CapsTotalHostBudget(t *testing.T) { + // Arrange: ten /24s (256 hosts each); the host budget admits only floor(1024/256). + ifaces := make([]net.Interface, 0, 10) byName := make(map[string][]net.Addr) - for i := range maxAutoSubnets + 2 { + for i := range 10 { name := fmt.Sprintf("eth%d", i) ifaces = append(ifaces, net.Interface{Name: name, Flags: net.FlagUp | net.FlagRunning}) byName[name] = []net.Addr{hostIPNet(fmt.Sprintf("192.168.%d.5/24", i))} @@ -119,9 +119,33 @@ func TestSelectLocalPrivateSubnets_CapsResultCount(t *testing.T) { // Act got, err := selectLocalPrivateSubnets(ifaces, stubAddrs(byName)) + // Assert: total swept addresses stay within the per-command budget. + require.NoError(t, err) + total := 0 + for _, c := range got { + total += 1 << (32 - netip.MustParsePrefix(c).Bits()) + } + assert.LessOrEqual(t, total, maxAutoScanHosts) + assert.Len(t, got, maxAutoScanHosts/256) +} + +func TestSelectLocalPrivateSubnets_SingleWideSubnetConsumesBudget(t *testing.T) { + // Arrange: a /22 (1024 hosts) exhausts the budget, so a second subnet is dropped. + ifaces := []net.Interface{ + {Name: "eth0", Flags: net.FlagUp | net.FlagRunning}, + {Name: "eth1", Flags: net.FlagUp | net.FlagRunning}, + } + addrs := stubAddrs(map[string][]net.Addr{ + "eth0": {hostIPNet("10.1.0.5/22")}, + "eth1": {hostIPNet("192.168.9.5/24")}, + }) + + // Act + got, err := selectLocalPrivateSubnets(ifaces, addrs) + // Assert require.NoError(t, err) - assert.Len(t, got, maxAutoSubnets) + assert.Equal(t, []string{"10.1.0.0/22"}, got) } func TestSelectLocalPrivateSubnets_IgnoresIPv6ULA(t *testing.T) { diff --git a/server/internal/domain/fleetnode/discovery/service.go b/server/internal/domain/fleetnode/discovery/service.go index 31e9e6eec..a3701ef1d 100644 --- a/server/internal/domain/fleetnode/discovery/service.go +++ b/server/internal/domain/fleetnode/discovery/service.go @@ -145,7 +145,9 @@ func (s *Service) RunOnNode(ctx context.Context, fleetNodeID int64, req *pairing if errors.Is(ctx.Err(), context.DeadlineExceeded) { return connect.NewError(connect.CodeDeadlineExceeded, fmt.Errorf("discovery command timed out after %s", DiscoverCommandTimeout)) } - return fleeterror.NewInternalErrorf("operator stream cancelled: %v", ctx.Err()) + // Caller (operator or fan-out) cancelled; report it as such rather + // than a server-side Internal failure. + return fleeterror.NewCanceledError() case ev := <-events: if terminal, err := handleEvent(ev); terminal { return err diff --git a/server/internal/domain/nmaptarget/nmaptarget.go b/server/internal/domain/nmaptarget/nmaptarget.go index 7c426a7c5..7a61781c5 100644 --- a/server/internal/domain/nmaptarget/nmaptarget.go +++ b/server/internal/domain/nmaptarget/nmaptarget.go @@ -30,10 +30,10 @@ var ( const MinIPv4PrefixBits = 22 // LocalSubnetTarget is a reserved nmap target value meaning "the fleet node -// should scan the private (RFC1918/RFC4193) IPv4 subnet(s) of its own -// interfaces." The cloud fan-out sets it because it cannot know each node's -// local network. It is matched exactly (before any hostname resolution); the -// dashed, namespaced name keeps it from colliding with a real hostname. +// should scan the private (RFC1918) IPv4 subnet(s) of its own interfaces." The +// cloud fan-out sets it because it cannot know each node's local network. It is +// matched exactly (before any hostname resolution); the dashed, namespaced name +// keeps it from colliding with a real hostname. const LocalSubnetTarget = "fleetnode-local-subnet" // IsIPv4Range reports whether s is an "A.B.C.D-N" nmap range. diff --git a/server/internal/domain/pairing/mocks/mock_service.go b/server/internal/domain/pairing/mocks/mock_service.go index daa411f6a..3d9cda0e3 100644 --- a/server/internal/domain/pairing/mocks/mock_service.go +++ b/server/internal/domain/pairing/mocks/mock_service.go @@ -166,41 +166,3 @@ func (mr *MockCapabilitiesProviderMockRecorder) GetMinerCapabilitiesForDevice(ct mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMinerCapabilitiesForDevice", reflect.TypeOf((*MockCapabilitiesProvider)(nil).GetMinerCapabilitiesForDevice), ctx, device) } - -// MockFleetNodeAssigner is a mock of FleetNodeAssigner interface. -type MockFleetNodeAssigner struct { - ctrl *gomock.Controller - recorder *MockFleetNodeAssignerMockRecorder - isgomock struct{} -} - -// MockFleetNodeAssignerMockRecorder is the mock recorder for MockFleetNodeAssigner. -type MockFleetNodeAssignerMockRecorder struct { - mock *MockFleetNodeAssigner -} - -// NewMockFleetNodeAssigner creates a new mock instance. -func NewMockFleetNodeAssigner(ctrl *gomock.Controller) *MockFleetNodeAssigner { - mock := &MockFleetNodeAssigner{ctrl: ctrl} - mock.recorder = &MockFleetNodeAssignerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockFleetNodeAssigner) EXPECT() *MockFleetNodeAssignerMockRecorder { - return m.recorder -} - -// PairDevice mocks base method. -func (m *MockFleetNodeAssigner) PairDevice(ctx context.Context, fleetNodeID, deviceID, orgID int64, assignedBy *int64) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PairDevice", ctx, fleetNodeID, deviceID, orgID, assignedBy) - ret0, _ := ret[0].(error) - return ret0 -} - -// PairDevice indicates an expected call of PairDevice. -func (mr *MockFleetNodeAssignerMockRecorder) PairDevice(ctx, fleetNodeID, deviceID, orgID, assignedBy any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PairDevice", reflect.TypeOf((*MockFleetNodeAssigner)(nil).PairDevice), ctx, fleetNodeID, deviceID, orgID, assignedBy) -} diff --git a/server/internal/domain/pairing/service.go b/server/internal/domain/pairing/service.go index b23c1a9ab..4e87757af 100644 --- a/server/internal/domain/pairing/service.go +++ b/server/internal/domain/pairing/service.go @@ -176,17 +176,6 @@ type Service struct { probeSemaphore chan struct{} invalidateMiner func(models.DeviceIdentifier) optionsCache *fleetoptions.Cache - fleetNodeAssigner FleetNodeAssigner -} - -// FleetNodeAssigner records operator-confirmed ownership of a fleet-node-discovered -// device. The cloud Pair flow routes remote-origin devices (discovered_device rows -// with discovered_by_fleet_node_id set) here instead of dialing them: per RFC 0001 -// the owning node dials and credentials the miner, so this is a metadata-only -// assignment. *fleetnode/pairing.Service satisfies this. Optional; when nil, cloud -// pairing refuses remote-origin devices (the pre-fan-out behavior). -type FleetNodeAssigner interface { - PairDevice(ctx context.Context, fleetNodeID, deviceID, orgID int64, assignedBy *int64) error } func NewService( @@ -225,13 +214,6 @@ func (s *Service) WithOptionsCache(cache *fleetoptions.Cache) { s.optionsCache = cache } -// WithFleetNodeAssigner wires the collaborator that records ownership of -// fleet-node-discovered devices so Pair can include them. Pass nil to keep the -// default behavior of refusing remote-origin devices. -func (s *Service) WithFleetNodeAssigner(a FleetNodeAssigner) { - s.fleetNodeAssigner = a -} - type NetworkInfo struct { networking.NetworkInfo } @@ -251,6 +233,20 @@ func (s *Service) GetLocalNetworkInfo(ctx context.Context) (*NetworkInfo, error) return defaultLocalNetworkInfo(ctx) } +// IsLocalSubnetScan reports whether target is the cloud host's own local subnet, +// i.e. the automatic "Scan your network" action rather than an operator-typed +// target. Fleet-node fan-out is gated on this (the same signal that triggers +// known-subnet auto-expansion) so a manual/explicit cloud scan doesn't also +// sweep every connected node's LAN. False when the host has no local subnet. +func (s *Service) IsLocalSubnetScan(ctx context.Context, target string) bool { + info, err := s.GetLocalNetworkInfo(ctx) + if err != nil { + return false + } + _, ok := maskBitsForLocalSubnetTarget(target, info.Subnet) + return ok +} + func canonicalCIDR(cidr string) (canonical string, maskBits int, isIPv4 bool, ok bool) { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { @@ -1199,7 +1195,6 @@ func (s *Service) PairDevices(ctx context.Context, r *pb.PairRequest) (*pb.PairR failedIDs := make([]string, 0, len(deviceIdentifiers)) credentials := r.Credentials - assignedBy := info.UserID // Deduplicate to prevent concurrent pairDevice calls against the same physical device. // We check both exact identifier strings and IP+port because different identifiers can @@ -1220,43 +1215,14 @@ func (s *Service) PairDevices(ctx context.Context, r *pb.PairRequest) (*pb.PairR OrgID: info.OrganizationID, }) if ddErr == nil { - // Cloud pairing dials the IP via plugin RPC, so a remote-origin row - // can't take that path. With a fleet-node assigner wired, route it to - // the operator-confirmed ownership assignment (metadata only — the - // owning node dials and credentials the miner per RFC 0001), so the - // operator can pair fleet-node-discovered miners alongside direct ones - // in a single request. Without an assigner, refuse as before. + // Cloud pairing dials the IP via plugin RPC; remote-origin + // rows must route through PairDeviceToFleetNode instead. if dd.DiscoveredByFleetNodeID != nil { - if s.fleetNodeAssigner == nil { - slog.Warn("refusing to pair remote-fleet-node-reported device via cloud pairing; use PairDeviceToFleetNode", - "device_identifier", id, - "fleet_node_id", *dd.DiscoveredByFleetNodeID, - ) - failedIDs = append(failedIDs, id) - continue - } - dbID, idErr := s.discoveredDeviceStore.GetDatabaseID(ctx, discoverymodels.DeviceOrgIdentifier{ - DeviceIdentifier: id, - OrgID: info.OrganizationID, - }) - if idErr != nil { - slog.Error("failed to resolve discovered device id for fleet-node assignment", - "device_identifier", id, "error", idErr) - failedIDs = append(failedIDs, id) - continue - } - if assignErr := s.fleetNodeAssigner.PairDevice(ctx, *dd.DiscoveredByFleetNodeID, dbID, info.OrganizationID, &assignedBy); assignErr != nil { - slog.Warn("failed to assign fleet-node-discovered device to its node", - "device_identifier", id, - "fleet_node_id", *dd.DiscoveredByFleetNodeID, - "error", assignErr, - ) - failedIDs = append(failedIDs, id) - continue - } - // Assignment is metadata only: no plugin dial, credentials, - // handle invalidation, or telemetry scheduling here. - successfulIDs = append(successfulIDs, models.DeviceIdentifier(id)) + slog.Warn("refusing to pair remote-fleet-node-reported device via cloud pairing; use PairDeviceToFleetNode", + "device_identifier", id, + "fleet_node_id", *dd.DiscoveredByFleetNodeID, + ) + failedIDs = append(failedIDs, id) continue } endpoint := dd.IpAddress + ":" + dd.Port diff --git a/server/internal/domain/pairing/service_internal_test.go b/server/internal/domain/pairing/service_internal_test.go index 8fe6e5c17..fb2ea73e2 100644 --- a/server/internal/domain/pairing/service_internal_test.go +++ b/server/internal/domain/pairing/service_internal_test.go @@ -307,6 +307,35 @@ func TestResolveNmapTargets_DoesNotExpandIPv6Targets(t *testing.T) { require.Equal(t, []string{"fd00::/64"}, targets) } +func TestIsLocalSubnetScan(t *testing.T) { + // Arrange + svc := &Service{ + localNetworkInfo: func(context.Context) (*NetworkInfo, error) { + return &NetworkInfo{NetworkInfo: networking.NetworkInfo{Subnet: "192.168.1.0/24"}}, nil + }, + } + ctx := t.Context() + + // Act + Assert: the host's own subnet (canonical or with host bits) is the auto scan. + require.True(t, svc.IsLocalSubnetScan(ctx, "192.168.1.0/24")) + require.True(t, svc.IsLocalSubnetScan(ctx, "192.168.1.50/24")) + // A different/explicit target (manual scan) or the fan-out sentinel is not. + require.False(t, svc.IsLocalSubnetScan(ctx, "10.0.0.0/24")) + require.False(t, svc.IsLocalSubnetScan(ctx, "fleetnode-local-subnet")) +} + +func TestIsLocalSubnetScan_NoLocalNetworkIsFalse(t *testing.T) { + // Arrange: cloud-mode host with no local subnet. + svc := &Service{ + localNetworkInfo: func(context.Context) (*NetworkInfo, error) { + return nil, errors.New("no local network") + }, + } + + // Act + Assert + require.False(t, svc.IsLocalSubnetScan(t.Context(), "192.168.1.0/24")) +} + func TestValidateNmapTargets(t *testing.T) { noopLookup := func(context.Context, string) ([]net.IPAddr, error) { return nil, errors.New("no DNS") diff --git a/server/internal/domain/pairing/service_pairrouting_test.go b/server/internal/domain/pairing/service_pairrouting_test.go deleted file mode 100644 index d973e9fc9..000000000 --- a/server/internal/domain/pairing/service_pairrouting_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package pairing - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - commonv1 "github.com/block/proto-fleet/server/generated/grpc/common/v1" - commandpb "github.com/block/proto-fleet/server/generated/grpc/minercommand/v1" - pb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" - "github.com/block/proto-fleet/server/internal/domain/fleeterror" - discoverymodels "github.com/block/proto-fleet/server/internal/domain/minerdiscovery/models" - "github.com/block/proto-fleet/server/internal/domain/stores/interfaces/mocks" -) - -type assignCall struct { - fleetNodeID int64 - deviceID int64 - orgID int64 - assignedBy *int64 -} - -type fakeAssigner struct { - calls []assignCall - failDeviceIDs map[int64]bool -} - -func (f *fakeAssigner) PairDevice(_ context.Context, fleetNodeID, deviceID, orgID int64, assignedBy *int64) error { - f.calls = append(f.calls, assignCall{fleetNodeID, deviceID, orgID, assignedBy}) - if f.failDeviceIDs[deviceID] { - return fleeterror.NewFailedPreconditionError("device already paired; unpair first") - } - return nil -} - -func includeReq(ids ...string) *pb.PairRequest { - return &pb.PairRequest{DeviceSelector: &commandpb.DeviceSelector{ - SelectionType: &commandpb.DeviceSelector_IncludeDevices{ - IncludeDevices: &commonv1.DeviceIdentifierList{DeviceIdentifiers: ids}, - }, - }} -} - -func fleetNodeDiscoveredDevice(identifier string, orgID, nodeID int64) *discoverymodels.DiscoveredDevice { - return &discoverymodels.DiscoveredDevice{ - Device: pb.Device{DeviceIdentifier: identifier, IpAddress: "10.0.0.5", Port: "80"}, - OrgID: orgID, - DiscoveredByFleetNodeID: &nodeID, - } -} - -func TestPairDevices_FleetNodeDeviceRoutesToAssigner(t *testing.T) { - // Arrange - ctrl := gomock.NewController(t) - defer ctrl.Finish() - const ( - orgID = int64(7) - userID = int64(3) - nodeID = int64(55) - dbID = int64(900) - ) - doi := discoverymodels.DeviceOrgIdentifier{DeviceIdentifier: "dev-1", OrgID: orgID} - mockDD := mocks.NewMockDiscoveredDeviceStore(ctrl) - mockDD.EXPECT().GetDevice(gomock.Any(), doi).Return(fleetNodeDiscoveredDevice("dev-1", orgID, nodeID), nil) - mockDD.EXPECT().GetDatabaseID(gomock.Any(), doi).Return(dbID, nil) - assigner := &fakeAssigner{} - svc := &Service{discoveredDeviceStore: mockDD, fleetNodeAssigner: assigner} - ctx := mockSessionContext(t.Context(), userID, orgID) - - // Act - resp, err := svc.PairDevices(ctx, includeReq("dev-1")) - - // Assert: routed to the assigner with the resolved DB id + caller, not dialed. - require.NoError(t, err) - assert.Empty(t, resp.GetFailedDeviceIds()) - require.Len(t, assigner.calls, 1) - assert.Equal(t, nodeID, assigner.calls[0].fleetNodeID) - assert.Equal(t, dbID, assigner.calls[0].deviceID) - assert.Equal(t, orgID, assigner.calls[0].orgID) - require.NotNil(t, assigner.calls[0].assignedBy) - assert.Equal(t, userID, *assigner.calls[0].assignedBy) -} - -func TestPairDevices_FleetNodeDeviceRefusedWithoutAssigner(t *testing.T) { - // Arrange: no assigner wired keeps the pre-fan-out refusal behavior. - ctrl := gomock.NewController(t) - defer ctrl.Finish() - const orgID = int64(7) - doi := discoverymodels.DeviceOrgIdentifier{DeviceIdentifier: "dev-1", OrgID: orgID} - mockDD := mocks.NewMockDiscoveredDeviceStore(ctrl) - mockDD.EXPECT().GetDevice(gomock.Any(), doi).Return(fleetNodeDiscoveredDevice("dev-1", orgID, 55), nil) - svc := &Service{discoveredDeviceStore: mockDD} // fleetNodeAssigner nil - ctx := mockSessionContext(t.Context(), 3, orgID) - - // Act - _, err := svc.PairDevices(ctx, includeReq("dev-1")) - - // Assert: the only device was refused, so nothing paired. - require.Error(t, err) - assert.Contains(t, err.Error(), "Failed to pair any devices") -} - -func TestPairDevices_FleetNodeAssignPartialSuccess(t *testing.T) { - // Arrange: two fleet-node devices; the assigner fails the second only. - ctrl := gomock.NewController(t) - defer ctrl.Finish() - const orgID = int64(7) - doi1 := discoverymodels.DeviceOrgIdentifier{DeviceIdentifier: "dev-1", OrgID: orgID} - doi2 := discoverymodels.DeviceOrgIdentifier{DeviceIdentifier: "dev-2", OrgID: orgID} - mockDD := mocks.NewMockDiscoveredDeviceStore(ctrl) - mockDD.EXPECT().GetDevice(gomock.Any(), doi1).Return(fleetNodeDiscoveredDevice("dev-1", orgID, 55), nil) - mockDD.EXPECT().GetDatabaseID(gomock.Any(), doi1).Return(int64(900), nil) - mockDD.EXPECT().GetDevice(gomock.Any(), doi2).Return(fleetNodeDiscoveredDevice("dev-2", orgID, 55), nil) - mockDD.EXPECT().GetDatabaseID(gomock.Any(), doi2).Return(int64(901), nil) - assigner := &fakeAssigner{failDeviceIDs: map[int64]bool{901: true}} - svc := &Service{discoveredDeviceStore: mockDD, fleetNodeAssigner: assigner} - ctx := mockSessionContext(t.Context(), 3, orgID) - - // Act - resp, err := svc.PairDevices(ctx, includeReq("dev-1", "dev-2")) - - // Assert: dev-1 paired, dev-2 failed; partial success returns no top-level error. - require.NoError(t, err) - assert.Equal(t, []string{"dev-2"}, resp.GetFailedDeviceIds()) -} diff --git a/server/internal/handlers/pairing/handler.go b/server/internal/handlers/pairing/handler.go index 9be1f1d8a..7913745d9 100644 --- a/server/internal/handlers/pairing/handler.go +++ b/server/internal/handlers/pairing/handler.go @@ -2,6 +2,7 @@ package pairing import ( "context" + "errors" "log/slog" "sync" @@ -134,8 +135,11 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe } }() - // Fleet node fan-out (nmap only). - if _, ok := r.Msg.Mode.(*pb.DiscoverRequest_Nmap); ok && h.discovery != nil { + // Fleet node fan-out, gated to the automatic "Scan your network" action: an + // nmap request whose target is the cloud's own local subnet. A manual/explicit + // nmap target (a user-typed subnet/IP) must NOT also sweep every node's LAN. + if _, ok := r.Msg.Mode.(*pb.DiscoverRequest_Nmap); ok && h.discovery != nil && + h.pairingSvc.IsLocalSubnetScan(streamCtx, r.Msg.GetNmap().GetTarget()) { nodeIDs, listErr := h.discovery.ConfirmedConnectedNodeIDs(streamCtx, info.OrganizationID) if listErr != nil { // Fan-out is best-effort; a lookup failure must never break the @@ -151,8 +155,10 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe go func(nodeID int64) { defer wg.Done() runErr := h.discovery.RunOnNode(streamCtx, nodeID, autoReq, send) - if runErr != nil { - // One node failing must not fail the whole scan. + // One node failing must not fail the whole scan. Stay quiet + // when streamCtx is already cancelled (operator disconnected) — + // that's expected, not a node fault. + if runErr != nil && streamCtx.Err() == nil { slog.Warn("fleet node discovery failed during cloud fan-out", "fleet_node_id", nodeID, "error", runErr) } @@ -162,7 +168,18 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe } wg.Wait() - return sendErr + if sendErr != nil { + return sendErr + } + // A client cancel/deadline drains the sources without a Send error; surface + // it as canceled/deadline rather than a successful completion. + if ctxErr := ctx.Err(); ctxErr != nil { + if errors.Is(ctxErr, context.DeadlineExceeded) { + return connect.NewError(connect.CodeDeadlineExceeded, ctxErr) + } + return fleeterror.NewCanceledError() + } + return nil } // Pair implements pairingv1connect.PairingServiceHandler. From 08dba3823725cd9ad84b1a5c685e31c400e0224f Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Tue, 2 Jun 2026 15:54:47 -0700 Subject: [PATCH 04/13] refactor(fleetnode): use combined-mode local-subnet detection for parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the agent's bespoke per-NIC private-subnet enumeration with networking.GetLocalNetworkInfo() — the same primary-interface detection the cloud/combined Discover path already uses, so the node scans the same way combined mode does. Intentionally less robust (picks one interface, no virtual-NIC filtering, no /22 narrowing or total-host cap); parity is enough for now and hardening is a follow-up. Removes selectLocalPrivateSubnets and its tests. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/cmd/fleetnode/localsubnet.go | 136 +++--------------- server/cmd/fleetnode/localsubnet_test.go | 169 ----------------------- server/cmd/fleetnode/nmap_test.go | 2 +- 3 files changed, 21 insertions(+), 286 deletions(-) delete mode 100644 server/cmd/fleetnode/localsubnet_test.go diff --git a/server/cmd/fleetnode/localsubnet.go b/server/cmd/fleetnode/localsubnet.go index bc5de4ca9..8826ad995 100644 --- a/server/cmd/fleetnode/localsubnet.go +++ b/server/cmd/fleetnode/localsubnet.go @@ -3,130 +3,34 @@ package main import ( "errors" "fmt" - "net" - "net/netip" - "strings" - "github.com/block/proto-fleet/server/internal/domain/discoverylimits" + "github.com/block/proto-fleet/server/internal/infrastructure/networking" ) -// autoSubnetMinPrefixBits caps a detected subnet at /22 (<=1024 hosts) — the -// same ceiling the manual nmap path enforces (nmaptarget.MinIPv4PrefixBits). A -// NIC configured with a wider mask (e.g. /16) is narrowed around its own host -// address so the scan stays bounded and finishes inside the command timeout. -const autoSubnetMinPrefixBits = 22 - -// maxAutoSubnets caps how many distinct subnets one command scans, so a -// multi-homed host with many interfaces can't fan one command into a huge sweep. -const maxAutoSubnets = 8 - -// maxAutoScanHosts bounds the TOTAL addresses across all detected subnets to the -// same per-command ceiling the manual path enforces (discoverylimits.MaxScanTargets), -// so a multi-homed node can't turn one command into an 8x-oversized sweep. A -// single /22 already consumes the whole budget. -const maxAutoScanHosts = discoverylimits.MaxScanTargets - -// errNoLocalPrivateSubnet means no connected, non-virtual interface had a private -// IPv4 address — the agent has nothing to scan. Surfaces as AGENT_INCAPABLE so a -// fan-out skips this node and tries the others. -var errNoLocalPrivateSubnet = errors.New("no connected private IPv4 subnet found") - -// virtualIfacePrefixes are name prefixes for container/VPN/virtual adapters whose -// subnets aren't the miner LAN. Best-effort: a miss only means a virtual private -// subnet might be scanned (still port-probed, still private), never a public scan. -var virtualIfacePrefixes = []string{ - "docker", "br-", "veth", "virbr", "vmnet", "vboxnet", - "tun", "tap", "utun", "cni", "cali", "flannel", "kube", - "zt", "tailscale", "ts", "wg", -} - -// detectLocalSubnets returns the private IPv4 subnet(s) the agent should scan for -// a local-subnet nmap command (the nmaptarget.LocalSubnetTarget sentinel). The -// localSubnets seam lets tests inject canned CIDRs; production enumerates the -// host's interfaces. +// errNoLocalSubnet means the host has no usable IPv4 subnet to scan for a +// LocalSubnetTarget command. Surfaces as AGENT_INCAPABLE so a fan-out skips this +// node and tries the others. +var errNoLocalSubnet = errors.New("no local IPv4 subnet found") + +// detectLocalSubnets returns the subnet(s) the agent scans for a local-subnet +// nmap command (the nmaptarget.LocalSubnetTarget sentinel). +// +// For parity with combined mode it reuses the server's own primary-interface +// detection (networking.GetLocalNetworkInfo) — the same logic the cloud Discover +// path scans. This is intentionally less robust than per-NIC private filtering +// (it picks one interface, doesn't skip virtual/container NICs, and doesn't +// narrow or cap the mask); hardening is a follow-up. The localSubnets seam lets +// tests inject canned CIDRs. func (r *RunCmd) detectLocalSubnets() ([]string, error) { if r.localSubnets != nil { return r.localSubnets() } - ifaces, err := net.Interfaces() + info, err := networking.GetLocalNetworkInfo() if err != nil { - return nil, fmt.Errorf("list network interfaces: %w", err) + return nil, fmt.Errorf("get local network info: %w", err) } - return selectLocalPrivateSubnets(ifaces, (*net.Interface).Addrs) -} - -// selectLocalPrivateSubnets returns the canonical CIDR(s) of the connected, -// non-virtual, private IPv4 subnet(s) of the given interfaces. addrsOf is -// injected for testing ((*net.Interface).Addrs in production). Subnets wider than -// /22 are narrowed around the host address, results are deduped and capped at -// maxAutoSubnets, and IPv6 is ignored (the manual nmap path rejects IPv6 CIDR -// too). Returns errNoLocalPrivateSubnet when none qualify. -func selectLocalPrivateSubnets(ifaces []net.Interface, addrsOf func(*net.Interface) ([]net.Addr, error)) ([]string, error) { - seen := make(map[string]struct{}) - out := make([]string, 0, maxAutoSubnets) - hostBudget := maxAutoScanHosts - for i := range ifaces { - iface := ifaces[i] - if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagRunning == 0 { - continue - } - if iface.Flags&net.FlagLoopback != 0 || isVirtualIface(iface.Name) { - continue - } - addrs, err := addrsOf(&iface) - if err != nil { - continue - } - for _, a := range addrs { - ipNet, ok := a.(*net.IPNet) - if !ok { - continue - } - addr, ok := netip.AddrFromSlice(ipNet.IP) - if !ok { - continue - } - addr = addr.Unmap() - if !addr.Is4() || !addr.IsPrivate() { - continue - } - ones, _ := ipNet.Mask.Size() - if ones <= 0 || ones > addr.BitLen() { - continue // non-canonical mask - } - if ones < autoSubnetMinPrefixBits { - ones = autoSubnetMinPrefixBits - } - cidr := netip.PrefixFrom(addr, ones).Masked().String() - if _, dup := seen[cidr]; dup { - continue - } - // Skip a subnet that wouldn't fit the remaining host budget so the - // total swept address space never exceeds maxAutoScanHosts. - hosts := 1 << (addr.BitLen() - ones) - if hosts > hostBudget { - continue - } - seen[cidr] = struct{}{} - out = append(out, cidr) - hostBudget -= hosts - if len(out) >= maxAutoSubnets || hostBudget <= 0 { - return out, nil - } - } - } - if len(out) == 0 { - return nil, errNoLocalPrivateSubnet - } - return out, nil -} - -func isVirtualIface(name string) bool { - lower := strings.ToLower(name) - for _, p := range virtualIfacePrefixes { - if strings.HasPrefix(lower, p) { - return true - } + if info.Subnet == "" { + return nil, errNoLocalSubnet } - return false + return []string{info.Subnet}, nil } diff --git a/server/cmd/fleetnode/localsubnet_test.go b/server/cmd/fleetnode/localsubnet_test.go deleted file mode 100644 index 974a1ad1f..000000000 --- a/server/cmd/fleetnode/localsubnet_test.go +++ /dev/null @@ -1,169 +0,0 @@ -package main - -import ( - "fmt" - "net" - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// hostIPNet builds the *net.IPNet shape an interface reports for a host address -// (host IP + the subnet mask), e.g. "192.168.1.50/24". -func hostIPNet(cidr string) *net.IPNet { - ip, n, err := net.ParseCIDR(cidr) - if err != nil { - panic(err) - } - return &net.IPNet{IP: ip, Mask: n.Mask} -} - -func stubAddrs(byName map[string][]net.Addr) func(*net.Interface) ([]net.Addr, error) { - return func(i *net.Interface) ([]net.Addr, error) { return byName[i.Name], nil } -} - -func TestSelectLocalPrivateSubnets_Typical24(t *testing.T) { - // Arrange - ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} - addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("192.168.1.50/24")}}) - - // Act - got, err := selectLocalPrivateSubnets(ifaces, addrs) - - // Assert - require.NoError(t, err) - assert.Equal(t, []string{"192.168.1.0/24"}, got) -} - -func TestSelectLocalPrivateSubnets_OversizedMaskNarrowedTo22(t *testing.T) { - // Arrange: a /8-masked NIC must narrow to /22 around its own host address so - // the local-subnet scan stays within the manual-path host ceiling. - ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} - addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("10.1.2.3/8")}}) - - // Act - got, err := selectLocalPrivateSubnets(ifaces, addrs) - - // Assert - require.NoError(t, err) - require.Len(t, got, 1) - prefix, perr := netip.ParsePrefix(got[0]) - require.NoError(t, perr) - assert.Equal(t, 22, prefix.Bits(), "oversized mask must narrow to /22") - assert.True(t, prefix.Contains(netip.MustParseAddr("10.1.2.3")), "narrowed subnet must contain the host: %s", got[0]) -} - -func TestSelectLocalPrivateSubnets_FiltersLoopbackDownAndVirtual(t *testing.T) { - // Arrange: loopback, a not-running NIC, and a docker bridge all excluded. - ifaces := []net.Interface{ - {Name: "lo", Flags: net.FlagUp | net.FlagRunning | net.FlagLoopback}, - {Name: "eth1", Flags: net.FlagUp}, // up but not running - {Name: "docker0", Flags: net.FlagUp | net.FlagRunning}, - } - addrs := stubAddrs(map[string][]net.Addr{ - "lo": {hostIPNet("127.0.0.1/8")}, - "eth1": {hostIPNet("192.168.5.5/24")}, - "docker0": {hostIPNet("172.17.0.1/16")}, - }) - - // Act - _, err := selectLocalPrivateSubnets(ifaces, addrs) - - // Assert - require.ErrorIs(t, err, errNoLocalPrivateSubnet) -} - -func TestSelectLocalPrivateSubnets_SkipsPublicAddress(t *testing.T) { - // Arrange - ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} - addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("8.8.8.8/24")}}) - - // Act - _, err := selectLocalPrivateSubnets(ifaces, addrs) - - // Assert - require.ErrorIs(t, err, errNoLocalPrivateSubnet) -} - -func TestSelectLocalPrivateSubnets_DedupesSameSubnetAcrossNICs(t *testing.T) { - // Arrange - ifaces := []net.Interface{ - {Name: "eth0", Flags: net.FlagUp | net.FlagRunning}, - {Name: "eth1", Flags: net.FlagUp | net.FlagRunning}, - } - addrs := stubAddrs(map[string][]net.Addr{ - "eth0": {hostIPNet("192.168.1.10/24")}, - "eth1": {hostIPNet("192.168.1.20/24")}, - }) - - // Act - got, err := selectLocalPrivateSubnets(ifaces, addrs) - - // Assert - require.NoError(t, err) - assert.Equal(t, []string{"192.168.1.0/24"}, got) -} - -func TestSelectLocalPrivateSubnets_CapsTotalHostBudget(t *testing.T) { - // Arrange: ten /24s (256 hosts each); the host budget admits only floor(1024/256). - ifaces := make([]net.Interface, 0, 10) - byName := make(map[string][]net.Addr) - for i := range 10 { - name := fmt.Sprintf("eth%d", i) - ifaces = append(ifaces, net.Interface{Name: name, Flags: net.FlagUp | net.FlagRunning}) - byName[name] = []net.Addr{hostIPNet(fmt.Sprintf("192.168.%d.5/24", i))} - } - - // Act - got, err := selectLocalPrivateSubnets(ifaces, stubAddrs(byName)) - - // Assert: total swept addresses stay within the per-command budget. - require.NoError(t, err) - total := 0 - for _, c := range got { - total += 1 << (32 - netip.MustParsePrefix(c).Bits()) - } - assert.LessOrEqual(t, total, maxAutoScanHosts) - assert.Len(t, got, maxAutoScanHosts/256) -} - -func TestSelectLocalPrivateSubnets_SingleWideSubnetConsumesBudget(t *testing.T) { - // Arrange: a /22 (1024 hosts) exhausts the budget, so a second subnet is dropped. - ifaces := []net.Interface{ - {Name: "eth0", Flags: net.FlagUp | net.FlagRunning}, - {Name: "eth1", Flags: net.FlagUp | net.FlagRunning}, - } - addrs := stubAddrs(map[string][]net.Addr{ - "eth0": {hostIPNet("10.1.0.5/22")}, - "eth1": {hostIPNet("192.168.9.5/24")}, - }) - - // Act - got, err := selectLocalPrivateSubnets(ifaces, addrs) - - // Assert - require.NoError(t, err) - assert.Equal(t, []string{"10.1.0.0/22"}, got) -} - -func TestSelectLocalPrivateSubnets_IgnoresIPv6ULA(t *testing.T) { - // Arrange - ifaces := []net.Interface{{Name: "eth0", Flags: net.FlagUp | net.FlagRunning}} - addrs := stubAddrs(map[string][]net.Addr{"eth0": {hostIPNet("fd00::1/64")}}) - - // Act - _, err := selectLocalPrivateSubnets(ifaces, addrs) - - // Assert - require.ErrorIs(t, err, errNoLocalPrivateSubnet) -} - -func TestSelectLocalPrivateSubnets_NoInterfaces(t *testing.T) { - // Act - _, err := selectLocalPrivateSubnets(nil, stubAddrs(nil)) - - // Assert - require.ErrorIs(t, err, errNoLocalPrivateSubnet) -} diff --git a/server/cmd/fleetnode/nmap_test.go b/server/cmd/fleetnode/nmap_test.go index c625bb25f..9a9244dab 100644 --- a/server/cmd/fleetnode/nmap_test.go +++ b/server/cmd/fleetnode/nmap_test.go @@ -275,7 +275,7 @@ func TestBuildNmapOptions_LocalSubnetTarget_NoSubnetIsAgentIncapable(t *testing. r := &RunCmd{ nmapPath: "/usr/bin/nmap", discoverer: &stubDiscoverer{}, - localSubnets: func() ([]string, error) { return nil, errNoLocalPrivateSubnet }, + localSubnets: func() ([]string, error) { return nil, errNoLocalSubnet }, } req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} From 81ea038861dabbf40b631ec54fefa3587d2e10f3 Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Tue, 2 Jun 2026 15:58:59 -0700 Subject: [PATCH 05/13] refactor(fleetnode): fold detectLocalSubnets into nmap.go After the combined-mode parity change, localsubnet.go held only errNoLocalSubnet and a small detectLocalSubnets wrapper used solely by buildNmapOptions. Move both next to their caller in nmap.go and drop the now-empty file. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/cmd/fleetnode/localsubnet.go | 36 ----------------------------- server/cmd/fleetnode/nmap.go | 29 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 36 deletions(-) delete mode 100644 server/cmd/fleetnode/localsubnet.go diff --git a/server/cmd/fleetnode/localsubnet.go b/server/cmd/fleetnode/localsubnet.go deleted file mode 100644 index 8826ad995..000000000 --- a/server/cmd/fleetnode/localsubnet.go +++ /dev/null @@ -1,36 +0,0 @@ -package main - -import ( - "errors" - "fmt" - - "github.com/block/proto-fleet/server/internal/infrastructure/networking" -) - -// errNoLocalSubnet means the host has no usable IPv4 subnet to scan for a -// LocalSubnetTarget command. Surfaces as AGENT_INCAPABLE so a fan-out skips this -// node and tries the others. -var errNoLocalSubnet = errors.New("no local IPv4 subnet found") - -// detectLocalSubnets returns the subnet(s) the agent scans for a local-subnet -// nmap command (the nmaptarget.LocalSubnetTarget sentinel). -// -// For parity with combined mode it reuses the server's own primary-interface -// detection (networking.GetLocalNetworkInfo) — the same logic the cloud Discover -// path scans. This is intentionally less robust than per-NIC private filtering -// (it picks one interface, doesn't skip virtual/container NICs, and doesn't -// narrow or cap the mask); hardening is a follow-up. The localSubnets seam lets -// tests inject canned CIDRs. -func (r *RunCmd) detectLocalSubnets() ([]string, error) { - if r.localSubnets != nil { - return r.localSubnets() - } - info, err := networking.GetLocalNetworkInfo() - if err != nil { - return nil, fmt.Errorf("get local network info: %w", err) - } - if info.Subnet == "" { - return nil, errNoLocalSubnet - } - return []string{info.Subnet}, nil -} diff --git a/server/cmd/fleetnode/nmap.go b/server/cmd/fleetnode/nmap.go index 27a8ad62a..00036b26a 100644 --- a/server/cmd/fleetnode/nmap.go +++ b/server/cmd/fleetnode/nmap.go @@ -20,8 +20,37 @@ import ( pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" "github.com/block/proto-fleet/server/internal/domain/netutil" "github.com/block/proto-fleet/server/internal/domain/nmaptarget" + "github.com/block/proto-fleet/server/internal/infrastructure/networking" ) +// errNoLocalSubnet means the host has no usable IPv4 subnet to scan for a +// LocalSubnetTarget command. Surfaces as AGENT_INCAPABLE so a fan-out skips this +// node and tries the others. +var errNoLocalSubnet = errors.New("no local IPv4 subnet found") + +// detectLocalSubnets returns the subnet(s) the agent scans for a local-subnet +// nmap command (the nmaptarget.LocalSubnetTarget sentinel). +// +// For parity with combined mode it reuses the server's own primary-interface +// detection (networking.GetLocalNetworkInfo) — the same logic the cloud Discover +// path scans. This is intentionally less robust than per-NIC private filtering +// (it picks one interface, doesn't skip virtual/container NICs, and doesn't +// narrow or cap the mask); hardening is a follow-up. The localSubnets seam lets +// tests inject canned CIDRs. +func (r *RunCmd) detectLocalSubnets() ([]string, error) { + if r.localSubnets != nil { + return r.localSubnets() + } + info, err := networking.GetLocalNetworkInfo() + if err != nil { + return nil, fmt.Errorf("get local network info: %w", err) + } + if info.Subnet == "" { + return nil, errNoLocalSubnet + } + return []string{info.Subnet}, nil +} + // validateNmapTarget enforces the shared nmap target grammar (see nmaptarget). func validateNmapTarget(s string) error { return nmaptarget.Validate(s) From caadb11d9a9a674521be2b4727500d1de5a62292 Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Wed, 3 Jun 2026 09:23:00 -0700 Subject: [PATCH 06/13] refactor(discovery): apply code-review fixes (concurrency cap, shared dedup key, tests) From the multi-agent code review (no correctness defects found): - Cap fan-out concurrency: a maxConcurrentFleetNodeScans (32) semaphore bounds in-flight per-node ControlStream commands so a large fleet can't spawn an unbounded number of slots; goroutines acquire via a streamCtx-aware select so they exit on operator disconnect. - Extract pairing.DeviceDedupKey and use it in both dedupeDiscoverResponses and the Discover handler's cross-source send closure, so the dedup identity stays in lockstep instead of being re-derived. - Add service-level tests: RunOnNode treats an onBatch error as terminal, RunOnNode returns CodeDeadlineExceeded when the agent never acks, and ConfirmedConnectedNodeIDs propagates a ListFleetNodes error. - Minor comment polish (detectLocalSubnets doc, DiscoverOnFleetNode godoc). Co-Authored-By: Claude Opus 4.8 (1M context) --- server/cmd/fleetnode/nmap.go | 11 ++- .../fleetnode/discovery/service_test.go | 67 ++++++++++++++++++- server/internal/domain/pairing/service.go | 16 +++-- .../handlers/fleetnode/admin/handler.go | 4 +- server/internal/handlers/pairing/handler.go | 33 ++++++--- 5 files changed, 109 insertions(+), 22 deletions(-) diff --git a/server/cmd/fleetnode/nmap.go b/server/cmd/fleetnode/nmap.go index 00036b26a..bcf335063 100644 --- a/server/cmd/fleetnode/nmap.go +++ b/server/cmd/fleetnode/nmap.go @@ -31,12 +31,11 @@ var errNoLocalSubnet = errors.New("no local IPv4 subnet found") // detectLocalSubnets returns the subnet(s) the agent scans for a local-subnet // nmap command (the nmaptarget.LocalSubnetTarget sentinel). // -// For parity with combined mode it reuses the server's own primary-interface -// detection (networking.GetLocalNetworkInfo) — the same logic the cloud Discover -// path scans. This is intentionally less robust than per-NIC private filtering -// (it picks one interface, doesn't skip virtual/container NICs, and doesn't -// narrow or cap the mask); hardening is a follow-up. The localSubnets seam lets -// tests inject canned CIDRs. +// It reuses the same primary-interface detection the cloud Discover path uses +// (networking.GetLocalNetworkInfo). That is intentionally less robust than +// per-NIC private filtering — it picks one interface, doesn't skip +// virtual/container NICs, and doesn't narrow or cap the mask; hardening is a +// follow-up. The localSubnets seam lets tests inject canned CIDRs. func (r *RunCmd) detectLocalSubnets() ([]string, error) { if r.localSubnets != nil { return r.localSubnets() diff --git a/server/internal/domain/fleetnode/discovery/service_test.go b/server/internal/domain/fleetnode/discovery/service_test.go index 338ac535f..08ebbb39e 100644 --- a/server/internal/domain/fleetnode/discovery/service_test.go +++ b/server/internal/domain/fleetnode/discovery/service_test.go @@ -2,7 +2,9 @@ package discovery import ( "context" + "errors" "testing" + "time" "connectrpc.com/connect" "github.com/stretchr/testify/assert" @@ -15,10 +17,13 @@ import ( "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" ) -type stubLister struct{ nodes []enrollment.FleetNodeListing } +type stubLister struct { + nodes []enrollment.FleetNodeListing + err error +} func (s stubLister) ListFleetNodes(context.Context, int64) ([]enrollment.FleetNodeListing, error) { - return s.nodes, nil + return s.nodes, s.err } func collectBatches(dst *[]*pairingpb.Device) func(*pairingpb.DiscoverResponse) error { @@ -134,3 +139,61 @@ func TestConfirmedConnectedNodeIDs_IntersectsStatusAndConnection(t *testing.T) { require.NoError(t, err) assert.Equal(t, []int64{1}, got) } + +func TestRunOnNode_OnBatchErrorIsTerminal(t *testing.T) { + // Arrange: the agent emits a batch; the caller's onBatch reports its stream gone. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(11) + stream := reg.Register(nodeID) + defer stream.Unregister() + go func() { + cmd := <-stream.Outgoing + reg.PublishBatch(nodeID, cmd.GetCommandId(), &pairingpb.DiscoverResponse{ + Devices: []*pairingpb.Device{{DeviceIdentifier: "auto:x", IpAddress: "10.0.0.5", Port: "4028"}}, + }) + }() + sentinel := errors.New("operator stream gone") + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.5"}, []string{"4028"}), func(*pairingpb.DiscoverResponse) error { + return sentinel + }) + + // Assert: an onBatch error terminates RunOnNode with that error. + require.ErrorIs(t, err, sentinel) +} + +func TestRunOnNode_TimesOutWhenAgentNeverAcks(t *testing.T) { + // Arrange: shrink the command timeout, drain the command, but never ack. + prev := DiscoverCommandTimeout + DiscoverCommandTimeout = 100 * time.Millisecond + t.Cleanup(func() { DiscoverCommandTimeout = prev }) + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(12) + stream := reg.Register(nodeID) + defer stream.Unregister() + go func() { <-stream.Outgoing }() + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.5"}, nil), func(*pairingpb.DiscoverResponse) error { return nil }) + + // Assert + require.Error(t, err) + var ce *connect.Error + require.True(t, errors.As(err, &ce)) + assert.Equal(t, connect.CodeDeadlineExceeded, ce.Code()) +} + +func TestConfirmedConnectedNodeIDs_PropagatesListError(t *testing.T) { + // Arrange: the enrollment lookup fails. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{err: errors.New("db unavailable")}) + + // Act + _, err := svc.ConfirmedConnectedNodeIDs(context.Background(), 1) + + // Assert: the error propagates (fan-out treats it as best-effort upstream). + require.Error(t, err) +} diff --git a/server/internal/domain/pairing/service.go b/server/internal/domain/pairing/service.go index 4e87757af..21aa9b5d3 100644 --- a/server/internal/domain/pairing/service.go +++ b/server/internal/domain/pairing/service.go @@ -98,6 +98,17 @@ func shouldSkipNetworkOrGatewayAddress(ip net.IP) bool { return lastOctet == networkAddressLastOctet || lastOctet == gatewayAddressLastOctet } +// DeviceDedupKey is the identity used to dedupe discovered devices: the +// device_identifier, or "ip:port" when the plugin hasn't resolved one yet. +// Exported so the Discover handler's cross-source dedup stays in lockstep with +// dedupeDiscoverResponses instead of re-deriving the key. +func DeviceDedupKey(d *pb.Device) string { + if id := d.GetDeviceIdentifier(); id != "" { + return id + } + return d.GetIpAddress() + ":" + d.GetPort() +} + func dedupeDiscoverResponses(source <-chan *pb.DiscoverResponse) <-chan *pb.DiscoverResponse { resultChan := make(chan *pb.DiscoverResponse) @@ -117,10 +128,7 @@ func dedupeDiscoverResponses(source <-chan *pb.DiscoverResponse) <-chan *pb.Disc continue } - identity := device.DeviceIdentifier - if identity == "" { - identity = fmt.Sprintf("%s:%s", device.IpAddress, device.Port) - } + identity := DeviceDedupKey(device) if _, alreadySeen := seenDevices[identity]; alreadySeen { continue diff --git a/server/internal/handlers/fleetnode/admin/handler.go b/server/internal/handlers/fleetnode/admin/handler.go index c9a2cd6b7..715a9ef06 100644 --- a/server/internal/handlers/fleetnode/admin/handler.go +++ b/server/internal/handlers/fleetnode/admin/handler.go @@ -154,8 +154,8 @@ func (h *Handler) ListFleetNodeDevices(ctx context.Context, req *connect.Request } // DiscoverOnFleetNode runs discovery on a single CONFIRMED node and streams the -// node's device batches back to the operator. The dispatch/drain loop lives in -// the discovery service so the cloud "Find miners" fan-out can reuse it. +// node's device batches back to the operator. See discovery.RunOnNode for the +// dispatch/drain loop. func (h *Handler) DiscoverOnFleetNode(ctx context.Context, req *connect.Request[pb.DiscoverOnFleetNodeRequest], stream *connect.ServerStream[pb.DiscoverOnFleetNodeResponse]) error { info, err := middleware.RequirePermission(ctx, authz.PermFleetnodeManage, authz.ResourceContext{}) if err != nil { diff --git a/server/internal/handlers/pairing/handler.go b/server/internal/handlers/pairing/handler.go index 7913745d9..a4098e054 100644 --- a/server/internal/handlers/pairing/handler.go +++ b/server/internal/handlers/pairing/handler.go @@ -27,6 +27,12 @@ type Handler struct { discovery *discovery.Service } +// maxConcurrentFleetNodeScans bounds how many fleet nodes one fan-out scans at +// once, so a large fleet can't spawn an unbounded number of in-flight +// ControlStream commands (each held until its ack or the per-node timeout). It +// sits comfortably above typical fleet sizes; only the pathological case binds. +const maxConcurrentFleetNodeScans = 32 + var _ pairingv1connect.PairingServiceHandler = &Handler{} // NewHandler creates a new instance of Handler @@ -75,19 +81,19 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe if len(resp.GetDevices()) > 0 { deduped := make([]*pb.Device, 0, len(resp.GetDevices())) for _, d := range resp.GetDevices() { - key := d.GetDeviceIdentifier() - if key != "" { - if _, dup := seen[key]; dup { - continue - } - seen[key] = struct{}{} + key := pairing.DeviceDedupKey(d) + if _, dup := seen[key]; dup { + continue } + seen[key] = struct{}{} deduped = append(deduped, d) } if len(deduped) == 0 && resp.GetError() == "" { return nil // whole batch was duplicates; nothing to forward } - out = &pb.DiscoverResponse{Devices: deduped, Error: resp.GetError()} + if len(deduped) < len(resp.GetDevices()) { + out = &pb.DiscoverResponse{Devices: deduped, Error: resp.GetError()} + } } if sErr := s.Send(out); sErr != nil { sendErr = sErr @@ -98,12 +104,14 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe } var resultChan <-chan *pb.DiscoverResponse + var isNmap bool switch r.Msg.Mode.(type) { case *pb.DiscoverRequest_IpList: resultChan, err = h.pairingSvc.DiscoverWithIPList(streamCtx, r.Msg.GetIpList()) case *pb.DiscoverRequest_IpRange: resultChan, err = h.pairingSvc.DiscoverWithIPRange(streamCtx, r.Msg.GetIpRange()) case *pb.DiscoverRequest_Nmap: + isNmap = true resultChan, err = h.pairingSvc.DiscoverWithNmap(streamCtx, r.Msg.GetNmap()) case *pb.DiscoverRequest_Mdns: resultChan, err = h.pairingSvc.DiscoverWithMDNS(streamCtx, r.Msg.GetMdns()) @@ -138,7 +146,7 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe // Fleet node fan-out, gated to the automatic "Scan your network" action: an // nmap request whose target is the cloud's own local subnet. A manual/explicit // nmap target (a user-typed subnet/IP) must NOT also sweep every node's LAN. - if _, ok := r.Msg.Mode.(*pb.DiscoverRequest_Nmap); ok && h.discovery != nil && + if isNmap && h.discovery != nil && h.pairingSvc.IsLocalSubnetScan(streamCtx, r.Msg.GetNmap().GetTarget()) { nodeIDs, listErr := h.discovery.ConfirmedConnectedNodeIDs(streamCtx, info.OrganizationID) if listErr != nil { @@ -150,10 +158,19 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe Target: nmaptarget.LocalSubnetTarget, Ports: r.Msg.GetNmap().GetPorts(), }}} + sem := make(chan struct{}, maxConcurrentFleetNodeScans) for _, nodeID := range nodeIDs { wg.Add(1) go func(nodeID int64) { defer wg.Done() + // Cap concurrent in-flight node commands; exit early if the + // operator disconnected while we were queued behind the cap. + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-streamCtx.Done(): + return + } runErr := h.discovery.RunOnNode(streamCtx, nodeID, autoReq, send) // One node failing must not fail the whole scan. Stay quiet // when streamCtx is already cancelled (operator disconnected) — From 25a12f983dd1d9f663c2dbed9af786eae797cb30 Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Wed, 3 Jun 2026 09:41:13 -0700 Subject: [PATCH 07/13] refactor(discovery): resolve remaining code-review items Addresses the residual findings from the multi-agent review: - #4/#7: DiscoverWithNmap now reports whether the target is the cloud's own local subnet (resolveNmapTargets already computed it); the Discover handler gates fan-out on that return value instead of a second GetLocalNetworkInfo call. Removes IsLocalSubnetScan and the duplicate ip-route fork. - #1: extract the mutex-guarded dedup/serialize send into dedupForwarder and unit-test it (cross-source dedup, ip:port fallback, all-duplicate drop, send-error cancel, concurrency under -race). - #3: bound the fan-out with fleetNodeFanOutTimeout (5m) so one wedged node can't extend the operator's wait to the full per-node 12m budget. - #8: discovery.Service depends on a narrow nodeRegistry interface instead of *control.Registry (mirrors the nodeLister idiom; eases testing). - #13: control.Registry.mu is now sync.RWMutex; ConnectedFleetNodeIDs and ReportScopeFor take read locks. - #12: test that the LocalSubnetTarget sentinel dispatches through RunOnNode (the shared single-node + fan-out path). Plus service-level tests for the new DiscoverWithNmap/resolveNmapTargets bool. No behavior change for cloud-only discovery or existing single-node discovery. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../domain/fleetnode/control/registry.go | 6 +- .../domain/fleetnode/control/stream.go | 4 +- .../domain/fleetnode/discovery/service.go | 12 +- .../fleetnode/discovery/service_test.go | 30 +++++ server/internal/domain/pairing/service.go | 50 ++++----- .../domain/pairing/service_internal_test.go | 41 ++----- server/internal/handlers/pairing/forwarder.go | 69 ++++++++++++ .../handlers/pairing/forwarder_test.go | 104 ++++++++++++++++++ server/internal/handlers/pairing/handler.go | 93 ++++++---------- 9 files changed, 281 insertions(+), 128 deletions(-) create mode 100644 server/internal/handlers/pairing/forwarder.go create mode 100644 server/internal/handlers/pairing/forwarder_test.go diff --git a/server/internal/domain/fleetnode/control/registry.go b/server/internal/domain/fleetnode/control/registry.go index 283b81ff6..eee02f4fa 100644 --- a/server/internal/domain/fleetnode/control/registry.go +++ b/server/internal/domain/fleetnode/control/registry.go @@ -85,7 +85,7 @@ type inflightCommand struct { } type Registry struct { - mu sync.Mutex + mu sync.RWMutex conns map[int64]*connection } @@ -97,8 +97,8 @@ func NewRegistry() *Registry { // right now. Used by fan-out discovery to target only nodes the server can reach; // callers intersect this with the org's CONFIRMED nodes. Order is unspecified. func (r *Registry) ConnectedFleetNodeIDs() []int64 { - r.mu.Lock() - defer r.mu.Unlock() + r.mu.RLock() + defer r.mu.RUnlock() ids := make([]int64, 0, len(r.conns)) for id := range r.conns { ids = append(ids, id) diff --git a/server/internal/domain/fleetnode/control/stream.go b/server/internal/domain/fleetnode/control/stream.go index e577a121a..4e042248c 100644 --- a/server/internal/domain/fleetnode/control/stream.go +++ b/server/internal/domain/fleetnode/control/stream.go @@ -81,8 +81,8 @@ func (r *Registry) AdmitReport(fleetNodeID int64, commandID string, deviceCount // the command is in flight but unconstrained. Callers filter reported devices // through the matcher so a node can't report outside the requested scope. func (r *Registry) ReportScopeFor(fleetNodeID int64, commandID string) (ReportScope, bool) { - r.mu.Lock() - defer r.mu.Unlock() + r.mu.RLock() + defer r.mu.RUnlock() cmd := r.inflightFor(fleetNodeID, commandID) if cmd == nil { return nil, false diff --git a/server/internal/domain/fleetnode/discovery/service.go b/server/internal/domain/fleetnode/discovery/service.go index a3701ef1d..d861fc3c3 100644 --- a/server/internal/domain/fleetnode/discovery/service.go +++ b/server/internal/domain/fleetnode/discovery/service.go @@ -41,13 +41,21 @@ type nodeLister interface { ListFleetNodes(ctx context.Context, orgID int64) ([]enrollment.FleetNodeListing, error) } +// nodeRegistry is the slice of control.Registry this service needs: enumerate +// connected nodes and dispatch a command to one. Narrowing it (like nodeLister) +// makes the coupling explicit and lets tests inject a fake without a Registry. +type nodeRegistry interface { + ConnectedFleetNodeIDs() []int64 + Send(ctx context.Context, fleetNodeID int64, cmd *gatewaypb.ControlCommand, scope control.ReportScope) (*control.Session, error) +} + // Service runs discovery commands against connected fleet nodes. type Service struct { - registry *control.Registry + registry nodeRegistry enrollment nodeLister } -func NewService(registry *control.Registry, enrollmentSvc nodeLister) *Service { +func NewService(registry nodeRegistry, enrollmentSvc nodeLister) *Service { return &Service{registry: registry, enrollment: enrollmentSvc} } diff --git a/server/internal/domain/fleetnode/discovery/service_test.go b/server/internal/domain/fleetnode/discovery/service_test.go index 08ebbb39e..0f1f66f56 100644 --- a/server/internal/domain/fleetnode/discovery/service_test.go +++ b/server/internal/domain/fleetnode/discovery/service_test.go @@ -9,12 +9,14 @@ import ( "connectrpc.com/connect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" "github.com/block/proto-fleet/server/internal/domain/fleeterror" "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" ) type stubLister struct { @@ -197,3 +199,31 @@ func TestConfirmedConnectedNodeIDs_PropagatesListError(t *testing.T) { // Assert: the error propagates (fan-out treats it as best-effort upstream). require.Error(t, err) } + +func TestRunOnNode_DispatchesLocalSubnetTargetSentinel(t *testing.T) { + // Arrange: the LocalSubnetTarget sentinel (operator single-node scan or fan-out) + // must reach the agent unchanged so the node scans its own subnet. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(21) + stream := reg.Register(nodeID) + defer stream.Unregister() + gotTarget := make(chan string, 1) + go func() { + cmd := <-stream.Outgoing + var req pairingpb.DiscoverRequest + _ = proto.Unmarshal(cmd.GetPayload(), &req) + gotTarget <- req.GetNmap().GetTarget() + stream.PublishAck(&gatewaypb.ControlAck{CommandId: cmd.GetCommandId(), Succeeded: true, Code: gatewaypb.AckCode_ACK_CODE_OK}) + }() + req := &pairingpb.DiscoverRequest{Mode: &pairingpb.DiscoverRequest_Nmap{ + Nmap: &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget}, + }} + + // Act + err := svc.RunOnNode(context.Background(), nodeID, req, func(*pairingpb.DiscoverResponse) error { return nil }) + + // Assert + require.NoError(t, err) + assert.Equal(t, nmaptarget.LocalSubnetTarget, <-gotTarget) +} diff --git a/server/internal/domain/pairing/service.go b/server/internal/domain/pairing/service.go index 21aa9b5d3..1b7766fcf 100644 --- a/server/internal/domain/pairing/service.go +++ b/server/internal/domain/pairing/service.go @@ -241,20 +241,6 @@ func (s *Service) GetLocalNetworkInfo(ctx context.Context) (*NetworkInfo, error) return defaultLocalNetworkInfo(ctx) } -// IsLocalSubnetScan reports whether target is the cloud host's own local subnet, -// i.e. the automatic "Scan your network" action rather than an operator-typed -// target. Fleet-node fan-out is gated on this (the same signal that triggers -// known-subnet auto-expansion) so a manual/explicit cloud scan doesn't also -// sweep every connected node's LAN. False when the host has no local subnet. -func (s *Service) IsLocalSubnetScan(ctx context.Context, target string) bool { - info, err := s.GetLocalNetworkInfo(ctx) - if err != nil { - return false - } - _, ok := maskBitsForLocalSubnetTarget(target, info.Subnet) - return ok -} - func canonicalCIDR(cidr string) (canonical string, maskBits int, isIPv4 bool, ok bool) { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { @@ -311,15 +297,19 @@ func mergeAutoDiscoveryTargets(baseTarget string, knownSubnets []string) []strin return targets } -func (s *Service) resolveNmapTargets(ctx context.Context, target string) ([]string, error) { - targets := []string{target} +// resolveNmapTargets returns the scan targets and whether `target` is the cloud +// host's own local subnet (isLocalSubnet) — the same condition that drives +// known-subnet expansion. Callers reuse isLocalSubnet to decide fleet-node +// fan-out without recomputing the local network. +func (s *Service) resolveNmapTargets(ctx context.Context, target string) (targets []string, isLocalSubnet bool, err error) { + targets = []string{target} localNetworkInfo, err := s.GetLocalNetworkInfo(ctx) if err != nil { slog.Debug("Skipping known-subnet expansion for nmap discovery because local network info is unavailable", "target", target, "error", err) - return targets, nil + return targets, false, nil } maskBits, shouldExpand := maskBitsForLocalSubnetTarget(target, localNetworkInfo.Subnet) @@ -327,14 +317,14 @@ func (s *Service) resolveNmapTargets(ctx context.Context, target string) ([]stri slog.Debug("Skipping known-subnet expansion because target does not match local subnet", "target", target, "local_subnet", localNetworkInfo.Subnet) - return targets, nil + return targets, false, nil } // Subnet expansion only runs for IPv4 targets matching the local subnet // (the guard above ensures this). Pass isIPv4=true directly. info, err := session.GetInfo(ctx) if err != nil { - return nil, err + return nil, false, err } knownSubnets, err := s.deviceStore.GetKnownSubnets(ctx, info.OrganizationID, maskBits, true) @@ -342,7 +332,7 @@ func (s *Service) resolveNmapTargets(ctx context.Context, target string) ([]stri slog.Debug("Skipping known-subnet expansion because subnet query failed", "target", target, "error", err) - return targets, nil + return targets, true, nil } expandedTargets := mergeAutoDiscoveryTargets(target, knownSubnets) @@ -353,7 +343,7 @@ func (s *Service) resolveNmapTargets(ctx context.Context, target string) ([]stri "organization_id", info.OrganizationID) } - return expandedTargets, nil + return expandedTargets, true, nil } // validateNmapTargets validates targets and resolves hostnames to IP literals @@ -493,18 +483,20 @@ func (s *Service) DiscoverWithMDNS(ctx context.Context, r *pb.MDNSModeRequest) ( return resultChan, nil } -// DiscoverWithNmap discovers devices using Nmap -func (s *Service) DiscoverWithNmap(ctx context.Context, r *pb.NmapModeRequest) (<-chan *pb.DiscoverResponse, error) { +// DiscoverWithNmap discovers devices using Nmap. isLocalSubnet reports whether +// the target is the cloud host's own local subnet (the "Scan your network" +// action), which the Discover handler uses to gate fleet-node fan-out. +func (s *Service) DiscoverWithNmap(ctx context.Context, r *pb.NmapModeRequest) (results <-chan *pb.DiscoverResponse, isLocalSubnet bool, err error) { if r.Target == "" { - return nil, fleeterror.NewInvalidArgumentError("nmap discovery target is required") + return nil, false, fleeterror.NewInvalidArgumentError("nmap discovery target is required") } ports, err := s.resolveDiscoveryPorts(ctx, r.Ports) if err != nil { - return nil, err + return nil, false, err } - targets, err := s.resolveNmapTargets(ctx, r.Target) + targets, isLocalSubnet, err := s.resolveNmapTargets(ctx, r.Target) if err != nil { - return nil, err + return nil, false, err } // Apply server-controlled timeout before any DNS work so hostname @@ -514,7 +506,7 @@ func (s *Service) DiscoverWithNmap(ctx context.Context, r *pb.NmapModeRequest) ( targets, useIPv6Scanning, err := validateNmapTargets(timeoutCtx, targets, net.DefaultResolver.LookupIPAddr) if err != nil { cancel() - return nil, err + return nil, false, err } // Create channels after validation to avoid leaking the dedupe goroutine on early returns. @@ -680,7 +672,7 @@ func (s *Service) DiscoverWithNmap(ctx context.Context, r *pb.NmapModeRequest) ( wg.Wait() }() - return resultChan, nil + return resultChan, isLocalSubnet, nil } // DiscoverWithIPRange discovers devices using an IPv4 IP range. diff --git a/server/internal/domain/pairing/service_internal_test.go b/server/internal/domain/pairing/service_internal_test.go index fb2ea73e2..2c2d29389 100644 --- a/server/internal/domain/pairing/service_internal_test.go +++ b/server/internal/domain/pairing/service_internal_test.go @@ -240,9 +240,10 @@ func TestResolveNmapTargets_ExpandsLocalSubnetWithKnownSubnets(t *testing.T) { GetKnownSubnets(gomock.Any(), int64(42), 24, true). Return([]string{"192.168.25.0/24", "192.168.1.0/24", "not-a-cidr"}, nil) - targets, err := service.resolveNmapTargets(ctx, "192.168.1.0/24") + targets, isLocalSubnet, err := service.resolveNmapTargets(ctx, "192.168.1.0/24") require.NoError(t, err) require.Equal(t, []string{"192.168.1.0/24", "192.168.25.0/24"}, targets) + require.True(t, isLocalSubnet) } func TestResolveNmapTargets_SkipsExpansionForNonLocalTargets(t *testing.T) { @@ -259,9 +260,10 @@ func TestResolveNmapTargets_SkipsExpansionForNonLocalTargets(t *testing.T) { ctx := mockSessionContext(t.Context(), 1, 42) - targets, err := service.resolveNmapTargets(ctx, "192.168.25.0/24") + targets, isLocalSubnet, err := service.resolveNmapTargets(ctx, "192.168.25.0/24") require.NoError(t, err) require.Equal(t, []string{"192.168.25.0/24"}, targets) + require.False(t, isLocalSubnet) } func TestResolveNmapTargets_FallsBackWhenLocalNetworkInfoFails(t *testing.T) { @@ -278,9 +280,10 @@ func TestResolveNmapTargets_FallsBackWhenLocalNetworkInfoFails(t *testing.T) { ctx := mockSessionContext(t.Context(), 1, 42) - targets, err := service.resolveNmapTargets(ctx, "192.168.1.0/24") + targets, isLocalSubnet, err := service.resolveNmapTargets(ctx, "192.168.1.0/24") require.NoError(t, err) require.Equal(t, []string{"192.168.1.0/24"}, targets) + require.False(t, isLocalSubnet, "no local network info means we can't confirm a local-subnet scan") } func TestResolveNmapTargets_DoesNotExpandIPv6Targets(t *testing.T) { @@ -302,38 +305,10 @@ func TestResolveNmapTargets_DoesNotExpandIPv6Targets(t *testing.T) { // IPv6 targets should not be auto-expanded because IPv6 subnets are // too large for nmap sweeps. - targets, err := service.resolveNmapTargets(ctx, "fd00::/64") + targets, isLocalSubnet, err := service.resolveNmapTargets(ctx, "fd00::/64") require.NoError(t, err) require.Equal(t, []string{"fd00::/64"}, targets) -} - -func TestIsLocalSubnetScan(t *testing.T) { - // Arrange - svc := &Service{ - localNetworkInfo: func(context.Context) (*NetworkInfo, error) { - return &NetworkInfo{NetworkInfo: networking.NetworkInfo{Subnet: "192.168.1.0/24"}}, nil - }, - } - ctx := t.Context() - - // Act + Assert: the host's own subnet (canonical or with host bits) is the auto scan. - require.True(t, svc.IsLocalSubnetScan(ctx, "192.168.1.0/24")) - require.True(t, svc.IsLocalSubnetScan(ctx, "192.168.1.50/24")) - // A different/explicit target (manual scan) or the fan-out sentinel is not. - require.False(t, svc.IsLocalSubnetScan(ctx, "10.0.0.0/24")) - require.False(t, svc.IsLocalSubnetScan(ctx, "fleetnode-local-subnet")) -} - -func TestIsLocalSubnetScan_NoLocalNetworkIsFalse(t *testing.T) { - // Arrange: cloud-mode host with no local subnet. - svc := &Service{ - localNetworkInfo: func(context.Context) (*NetworkInfo, error) { - return nil, errors.New("no local network") - }, - } - - // Act + Assert - require.False(t, svc.IsLocalSubnetScan(t.Context(), "192.168.1.0/24")) + require.False(t, isLocalSubnet, "IPv6 target is never treated as the local-subnet scan") } func TestValidateNmapTargets(t *testing.T) { diff --git a/server/internal/handlers/pairing/forwarder.go b/server/internal/handlers/pairing/forwarder.go new file mode 100644 index 000000000..706063fda --- /dev/null +++ b/server/internal/handlers/pairing/forwarder.go @@ -0,0 +1,69 @@ +package pairing + +import ( + "sync" + + pb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/pairing" +) + +// dedupForwarder serializes concurrent Discover sources (the cloud scan and each +// fleet node) onto one server stream — Connect streams are not safe for +// concurrent Send — and dedupes devices across sources by pairing.DeviceDedupKey. +// The first Send error is recorded and onErr is invoked once so the caller can +// cancel the remaining sources. +type dedupForwarder struct { + mu sync.Mutex + seen map[string]struct{} + send func(*pb.DiscoverResponse) error + onErr func() + err error +} + +func newDedupForwarder(send func(*pb.DiscoverResponse) error, onErr func()) *dedupForwarder { + return &dedupForwarder{seen: make(map[string]struct{}), send: send, onErr: onErr} +} + +// forward dedupes resp's devices across sources and forwards it. A batch reduced +// entirely to duplicates (with no error payload) is dropped. Once any Send has +// failed, forward returns that error without sending again. +func (f *dedupForwarder) forward(resp *pb.DiscoverResponse) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.err != nil { + return f.err + } + out := resp + if len(resp.GetDevices()) > 0 { + deduped := make([]*pb.Device, 0, len(resp.GetDevices())) + for _, d := range resp.GetDevices() { + key := pairing.DeviceDedupKey(d) + if _, dup := f.seen[key]; dup { + continue + } + f.seen[key] = struct{}{} + deduped = append(deduped, d) + } + if len(deduped) == 0 && resp.GetError() == "" { + return nil // whole batch was duplicates; nothing to forward + } + if len(deduped) < len(resp.GetDevices()) { + out = &pb.DiscoverResponse{Devices: deduped, Error: resp.GetError()} + } + } + if err := f.send(out); err != nil { + f.err = err + if f.onErr != nil { + f.onErr() + } + return err + } + return nil +} + +// failure returns the first Send error, if any. +func (f *dedupForwarder) failure() error { + f.mu.Lock() + defer f.mu.Unlock() + return f.err +} diff --git a/server/internal/handlers/pairing/forwarder_test.go b/server/internal/handlers/pairing/forwarder_test.go new file mode 100644 index 000000000..fd836ab2b --- /dev/null +++ b/server/internal/handlers/pairing/forwarder_test.go @@ -0,0 +1,104 @@ +package pairing + +import ( + "errors" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + pb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" +) + +func dev(id, ip, port string) *pb.Device { + return &pb.Device{DeviceIdentifier: id, IpAddress: ip, Port: port} +} + +func TestDedupForwarder_DedupesAcrossSources(t *testing.T) { + // Arrange + var sent []*pb.DiscoverResponse + fwd := newDedupForwarder(func(r *pb.DiscoverResponse) error { sent = append(sent, r); return nil }, nil) + + // Act: two sources report a shared device (mac:a) plus distinct ones. + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80"), dev("mac:b", "10.0.0.2", "80")}})) + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80"), dev("mac:c", "10.0.0.3", "80")}})) + + // Assert: mac:a forwarded once; the second batch is reduced to just mac:c. + require.Len(t, sent, 2) + assert.Len(t, sent[0].GetDevices(), 2) + require.Len(t, sent[1].GetDevices(), 1) + assert.Equal(t, "mac:c", sent[1].GetDevices()[0].GetDeviceIdentifier()) +} + +func TestDedupForwarder_IPPortFallbackKey(t *testing.T) { + // Arrange: devices without an identifier dedupe by ip:port. + var sent []*pb.DiscoverResponse + fwd := newDedupForwarder(func(r *pb.DiscoverResponse) error { sent = append(sent, r); return nil }, nil) + + // Act + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("", "10.0.0.5", "4028")}})) + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("", "10.0.0.5", "4028")}})) + + // Assert: the second (all-duplicate) batch is dropped. + require.Len(t, sent, 1) +} + +func TestDedupForwarder_DropsAllDuplicateBatchButKeepsErrorResponse(t *testing.T) { + // Arrange + var sent []*pb.DiscoverResponse + fwd := newDedupForwarder(func(r *pb.DiscoverResponse) error { sent = append(sent, r); return nil }, nil) + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80")}})) + + // Act: a fully-duplicate batch is dropped; an error-only response is forwarded. + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80")}})) + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Error: "scan failed"})) + + // Assert + require.Len(t, sent, 2) + assert.Equal(t, "scan failed", sent[1].GetError()) +} + +func TestDedupForwarder_SendErrorRecordedAndCancels(t *testing.T) { + // Arrange + sendErr := errors.New("stream gone") + var cancelled bool + fwd := newDedupForwarder(func(*pb.DiscoverResponse) error { return sendErr }, func() { cancelled = true }) + + // Act + err1 := fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80")}}) + err2 := fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:b", "10.0.0.2", "80")}}) + + // Assert: the first failure records the error and cancels; the second short-circuits. + require.ErrorIs(t, err1, sendErr) + require.ErrorIs(t, err2, sendErr) + assert.True(t, cancelled) + require.ErrorIs(t, fwd.failure(), sendErr) +} + +func TestDedupForwarder_ConcurrentForwardIsSerialized(t *testing.T) { + // Arrange: many goroutines forward distinct devices; -race verifies safety. + var mu sync.Mutex + count := 0 + fwd := newDedupForwarder(func(r *pb.DiscoverResponse) error { + mu.Lock() + count += len(r.GetDevices()) + mu.Unlock() + return nil + }, nil) + var wg sync.WaitGroup + + // Act + for i := range 50 { + wg.Add(1) + go func(i int) { + defer wg.Done() + _ = fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev(fmt.Sprintf("mac:%d", i), "10.0.0.1", "80")}}) + }(i) + } + wg.Wait() + + // Assert: all 50 distinct devices forwarded exactly once. + assert.Equal(t, 50, count) +} diff --git a/server/internal/handlers/pairing/handler.go b/server/internal/handlers/pairing/handler.go index a4098e054..03cf056a4 100644 --- a/server/internal/handlers/pairing/handler.go +++ b/server/internal/handlers/pairing/handler.go @@ -5,6 +5,7 @@ import ( "errors" "log/slog" "sync" + "time" "github.com/block/proto-fleet/server/internal/domain/authz" "github.com/block/proto-fleet/server/internal/domain/fleeterror" @@ -33,6 +34,13 @@ type Handler struct { // sits comfortably above typical fleet sizes; only the pathological case binds. const maxConcurrentFleetNodeScans = 32 +// fleetNodeFanOutTimeout caps how long the fleet-node fan-out can extend the +// Discover stream. RunOnNode's 12m timeout is the dedicated single-node budget; +// for the opportunistic fan-out a tighter ceiling keeps a wedged node from making +// the operator wait minutes past the cloud scan. A LAN subnet scan finishes well +// within this. +const fleetNodeFanOutTimeout = 5 * time.Minute + var _ pairingv1connect.PairingServiceHandler = &Handler{} // NewHandler creates a new instance of Handler @@ -62,57 +70,19 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe streamCtx, cancel := context.WithCancel(ctx) defer cancel() - // Connect server streams are not safe for concurrent Send, and the cloud - // scan + each node write to this one stream. Serialize through send, which - // also dedupes devices across sources by identifier (each source dedupes - // internally, but not against the others). - var ( - sendMu sync.Mutex - seen = make(map[string]struct{}) - sendErr error - ) - send := func(resp *pb.DiscoverResponse) error { - sendMu.Lock() - defer sendMu.Unlock() - if sendErr != nil { - return sendErr - } - out := resp - if len(resp.GetDevices()) > 0 { - deduped := make([]*pb.Device, 0, len(resp.GetDevices())) - for _, d := range resp.GetDevices() { - key := pairing.DeviceDedupKey(d) - if _, dup := seen[key]; dup { - continue - } - seen[key] = struct{}{} - deduped = append(deduped, d) - } - if len(deduped) == 0 && resp.GetError() == "" { - return nil // whole batch was duplicates; nothing to forward - } - if len(deduped) < len(resp.GetDevices()) { - out = &pb.DiscoverResponse{Devices: deduped, Error: resp.GetError()} - } - } - if sErr := s.Send(out); sErr != nil { - sendErr = sErr - cancel() - return sErr //nolint:wrapcheck // a connect stream Send error is already a connect error - } - return nil - } + // Serialize the concurrent sources (cloud scan + each node) onto the one + // stream and dedupe devices across them; a Send failure cancels the rest. + fwd := newDedupForwarder(s.Send, cancel) var resultChan <-chan *pb.DiscoverResponse - var isNmap bool + var isLocalSubnetNmap bool switch r.Msg.Mode.(type) { case *pb.DiscoverRequest_IpList: resultChan, err = h.pairingSvc.DiscoverWithIPList(streamCtx, r.Msg.GetIpList()) case *pb.DiscoverRequest_IpRange: resultChan, err = h.pairingSvc.DiscoverWithIPRange(streamCtx, r.Msg.GetIpRange()) case *pb.DiscoverRequest_Nmap: - isNmap = true - resultChan, err = h.pairingSvc.DiscoverWithNmap(streamCtx, r.Msg.GetNmap()) + resultChan, isLocalSubnetNmap, err = h.pairingSvc.DiscoverWithNmap(streamCtx, r.Msg.GetNmap()) case *pb.DiscoverRequest_Mdns: resultChan, err = h.pairingSvc.DiscoverWithMDNS(streamCtx, r.Msg.GetMdns()) default: @@ -134,7 +104,7 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe if !ok { return } - if err := send(result); err != nil { + if err := fwd.forward(result); err != nil { return } case <-streamCtx.Done(): @@ -143,17 +113,21 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe } }() - // Fleet node fan-out, gated to the automatic "Scan your network" action: an - // nmap request whose target is the cloud's own local subnet. A manual/explicit - // nmap target (a user-typed subnet/IP) must NOT also sweep every node's LAN. - if isNmap && h.discovery != nil && - h.pairingSvc.IsLocalSubnetScan(streamCtx, r.Msg.GetNmap().GetTarget()) { + // Fleet node fan-out, only for the automatic "Scan your network" action — an + // nmap target equal to the cloud's own local subnet (reported by + // DiscoverWithNmap). A manual/explicit nmap target must NOT sweep every node's + // LAN. + if isLocalSubnetNmap && h.discovery != nil { nodeIDs, listErr := h.discovery.ConfirmedConnectedNodeIDs(streamCtx, info.OrganizationID) if listErr != nil { // Fan-out is best-effort; a lookup failure must never break the // cloud scan. With zero connected nodes this is the same path. slog.Warn("skipping fleet node discovery fan-out", "error", listErr) } else { + // Bound the fan-out's contribution to the stream so one wedged node + // can't extend the operator's wait to the full per-node timeout. + fanOutCtx, fanOutCancel := context.WithTimeout(streamCtx, fleetNodeFanOutTimeout) + defer fanOutCancel() autoReq := &pb.DiscoverRequest{Mode: &pb.DiscoverRequest_Nmap{Nmap: &pb.NmapModeRequest{ Target: nmaptarget.LocalSubnetTarget, Ports: r.Msg.GetNmap().GetPorts(), @@ -164,18 +138,18 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe go func(nodeID int64) { defer wg.Done() // Cap concurrent in-flight node commands; exit early if the - // operator disconnected while we were queued behind the cap. + // stream closed or the fan-out budget expired while queued. select { case sem <- struct{}{}: defer func() { <-sem }() - case <-streamCtx.Done(): + case <-fanOutCtx.Done(): return } - runErr := h.discovery.RunOnNode(streamCtx, nodeID, autoReq, send) - // One node failing must not fail the whole scan. Stay quiet - // when streamCtx is already cancelled (operator disconnected) — - // that's expected, not a node fault. - if runErr != nil && streamCtx.Err() == nil { + runErr := h.discovery.RunOnNode(fanOutCtx, nodeID, autoReq, fwd.forward) + // One node failing (or hitting the fan-out budget) must not + // fail the scan; it's expected on disconnect/budget, so stay + // quiet once fanOutCtx is done. + if runErr != nil && fanOutCtx.Err() == nil { slog.Warn("fleet node discovery failed during cloud fan-out", "fleet_node_id", nodeID, "error", runErr) } @@ -185,11 +159,12 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe } wg.Wait() - if sendErr != nil { - return sendErr + if err := fwd.failure(); err != nil { + return err } // A client cancel/deadline drains the sources without a Send error; surface - // it as canceled/deadline rather than a successful completion. + // it as canceled/deadline rather than a successful completion. (The fan-out + // budget firing is not a client error — it returns whatever streamed.) if ctxErr := ctx.Err(); ctxErr != nil { if errors.Is(ctxErr, context.DeadlineExceeded) { return connect.NewError(connect.CodeDeadlineExceeded, ctxErr) From e5dc1d905e14499890204978dcf019b4d30522ca Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Wed, 3 Jun 2026 09:44:15 -0700 Subject: [PATCH 08/13] docs(discovery): trim verbose comments in the Discover handler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Compress the fan-out comments and const docs to the load-bearing "why" — no behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/internal/handlers/pairing/handler.go | 44 +++++++-------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/server/internal/handlers/pairing/handler.go b/server/internal/handlers/pairing/handler.go index 03cf056a4..f4ccc4796 100644 --- a/server/internal/handlers/pairing/handler.go +++ b/server/internal/handlers/pairing/handler.go @@ -23,22 +23,17 @@ import ( type Handler struct { pairingSvc *pairing.Service // discovery fans the "Scan your network" nmap action out to connected fleet - // nodes so their LAN-local miners surface alongside the cloud's own scan. - // Optional; nil disables fan-out (cloud-only discovery). + // nodes; nil disables fan-out (cloud-only discovery). discovery *discovery.Service } -// maxConcurrentFleetNodeScans bounds how many fleet nodes one fan-out scans at -// once, so a large fleet can't spawn an unbounded number of in-flight -// ControlStream commands (each held until its ack or the per-node timeout). It -// sits comfortably above typical fleet sizes; only the pathological case binds. +// maxConcurrentFleetNodeScans bounds in-flight per-node commands so a large fleet +// can't spawn an unbounded number of ControlStream slots. Above typical fleet sizes. const maxConcurrentFleetNodeScans = 32 -// fleetNodeFanOutTimeout caps how long the fleet-node fan-out can extend the -// Discover stream. RunOnNode's 12m timeout is the dedicated single-node budget; -// for the opportunistic fan-out a tighter ceiling keeps a wedged node from making -// the operator wait minutes past the cloud scan. A LAN subnet scan finishes well -// within this. +// fleetNodeFanOutTimeout caps how long the opportunistic fan-out can extend the +// Discover stream — tighter than RunOnNode's 12m budget so one wedged node can't +// make the operator wait minutes past the cloud scan. const fleetNodeFanOutTimeout = 5 * time.Minute var _ pairingv1connect.PairingServiceHandler = &Handler{} @@ -51,14 +46,9 @@ func NewHandler(pairingSvc *pairing.Service, discoverySvc *discovery.Service) *H } } -// Discover implements pairingv1connect.PairingServiceHandler. -// -// Beyond the cloud's own network scan, an nmap ("Scan your network") request -// also fans out to every CONFIRMED + connected fleet node, which scan their own -// local subnets and report back. All sources merge into this single response -// stream so the operator pairs LAN-local and cloud-local miners together with no -// client change. Manual modes (ipList/ipRange/mdns) target the cloud's own -// network only. +// Discover implements pairingv1connect.PairingServiceHandler. An nmap "Scan your +// network" request also fans out to every CONFIRMED + connected fleet node and +// merges their LAN-local results into this stream; other modes are cloud-only. func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRequest], s *connect.ServerStream[pb.DiscoverResponse]) error { info, err := middleware.RequirePermission(ctx, authz.PermMinerPair, authz.ResourceContext{}) if err != nil { @@ -113,10 +103,8 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe } }() - // Fleet node fan-out, only for the automatic "Scan your network" action — an - // nmap target equal to the cloud's own local subnet (reported by - // DiscoverWithNmap). A manual/explicit nmap target must NOT sweep every node's - // LAN. + // Fan out only for the automatic "Scan your network" action (nmap target == + // the cloud's own local subnet), never a manual/explicit target. if isLocalSubnetNmap && h.discovery != nil { nodeIDs, listErr := h.discovery.ConfirmedConnectedNodeIDs(streamCtx, info.OrganizationID) if listErr != nil { @@ -146,9 +134,8 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe return } runErr := h.discovery.RunOnNode(fanOutCtx, nodeID, autoReq, fwd.forward) - // One node failing (or hitting the fan-out budget) must not - // fail the scan; it's expected on disconnect/budget, so stay - // quiet once fanOutCtx is done. + // One node failing must not fail the scan, and is expected + // once fanOutCtx is done (disconnect/budget) — stay quiet then. if runErr != nil && fanOutCtx.Err() == nil { slog.Warn("fleet node discovery failed during cloud fan-out", "fleet_node_id", nodeID, "error", runErr) @@ -162,9 +149,8 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe if err := fwd.failure(); err != nil { return err } - // A client cancel/deadline drains the sources without a Send error; surface - // it as canceled/deadline rather than a successful completion. (The fan-out - // budget firing is not a client error — it returns whatever streamed.) + // A client cancel/deadline drains the sources without a Send error; report it + // rather than success. (A fan-out-budget expiry is not a client error.) if ctxErr := ctx.Err(); ctxErr != nil { if errors.Is(ctxErr, context.DeadlineExceeded) { return connect.NewError(connect.CodeDeadlineExceeded, ctxErr) From 15d93f18febb0d7ce9d28dd401a21942068d0ade Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Wed, 3 Jun 2026 09:49:44 -0700 Subject: [PATCH 09/13] refactor(discovery): drop over-built fan-out guards Peel back machinery added reactively to low-priority review findings, none of which earns its complexity at the actual fleet scale: - Remove fleetNodeFanOutTimeout: it was a third timeout layer (atop the agent's 10m and RunOnNode's 12m) with an arbitrary value that second-guessed the agent's own budget. Discover streams incrementally, so a wedged node only delays the stream close, not results. RunOnNode's per-node timeout is the single, sufficient bound. - Remove the maxConcurrentFleetNodeScans semaphore: inert at typical fleet sizes (tens of nodes); the per-node timeout already bounds each node. - Revert control.Registry to sync.Mutex: the RWMutex was a P3 micro-opt for contention that doesn't exist at this scale and added lock-discipline risk. No behavior change for the realistic path; fan-out goroutines run on streamCtx and remain bounded by RunOnNode's per-node timeout. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../domain/fleetnode/control/registry.go | 6 ++-- .../domain/fleetnode/control/stream.go | 4 +-- server/internal/handlers/pairing/handler.go | 32 +++---------------- 3 files changed, 10 insertions(+), 32 deletions(-) diff --git a/server/internal/domain/fleetnode/control/registry.go b/server/internal/domain/fleetnode/control/registry.go index eee02f4fa..283b81ff6 100644 --- a/server/internal/domain/fleetnode/control/registry.go +++ b/server/internal/domain/fleetnode/control/registry.go @@ -85,7 +85,7 @@ type inflightCommand struct { } type Registry struct { - mu sync.RWMutex + mu sync.Mutex conns map[int64]*connection } @@ -97,8 +97,8 @@ func NewRegistry() *Registry { // right now. Used by fan-out discovery to target only nodes the server can reach; // callers intersect this with the org's CONFIRMED nodes. Order is unspecified. func (r *Registry) ConnectedFleetNodeIDs() []int64 { - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() + defer r.mu.Unlock() ids := make([]int64, 0, len(r.conns)) for id := range r.conns { ids = append(ids, id) diff --git a/server/internal/domain/fleetnode/control/stream.go b/server/internal/domain/fleetnode/control/stream.go index 4e042248c..e577a121a 100644 --- a/server/internal/domain/fleetnode/control/stream.go +++ b/server/internal/domain/fleetnode/control/stream.go @@ -81,8 +81,8 @@ func (r *Registry) AdmitReport(fleetNodeID int64, commandID string, deviceCount // the command is in flight but unconstrained. Callers filter reported devices // through the matcher so a node can't report outside the requested scope. func (r *Registry) ReportScopeFor(fleetNodeID int64, commandID string) (ReportScope, bool) { - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() + defer r.mu.Unlock() cmd := r.inflightFor(fleetNodeID, commandID) if cmd == nil { return nil, false diff --git a/server/internal/handlers/pairing/handler.go b/server/internal/handlers/pairing/handler.go index f4ccc4796..b4abe14d5 100644 --- a/server/internal/handlers/pairing/handler.go +++ b/server/internal/handlers/pairing/handler.go @@ -5,7 +5,6 @@ import ( "errors" "log/slog" "sync" - "time" "github.com/block/proto-fleet/server/internal/domain/authz" "github.com/block/proto-fleet/server/internal/domain/fleeterror" @@ -27,15 +26,6 @@ type Handler struct { discovery *discovery.Service } -// maxConcurrentFleetNodeScans bounds in-flight per-node commands so a large fleet -// can't spawn an unbounded number of ControlStream slots. Above typical fleet sizes. -const maxConcurrentFleetNodeScans = 32 - -// fleetNodeFanOutTimeout caps how long the opportunistic fan-out can extend the -// Discover stream — tighter than RunOnNode's 12m budget so one wedged node can't -// make the operator wait minutes past the cloud scan. -const fleetNodeFanOutTimeout = 5 * time.Minute - var _ pairingv1connect.PairingServiceHandler = &Handler{} // NewHandler creates a new instance of Handler @@ -112,31 +102,19 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe // cloud scan. With zero connected nodes this is the same path. slog.Warn("skipping fleet node discovery fan-out", "error", listErr) } else { - // Bound the fan-out's contribution to the stream so one wedged node - // can't extend the operator's wait to the full per-node timeout. - fanOutCtx, fanOutCancel := context.WithTimeout(streamCtx, fleetNodeFanOutTimeout) - defer fanOutCancel() autoReq := &pb.DiscoverRequest{Mode: &pb.DiscoverRequest_Nmap{Nmap: &pb.NmapModeRequest{ Target: nmaptarget.LocalSubnetTarget, Ports: r.Msg.GetNmap().GetPorts(), }}} - sem := make(chan struct{}, maxConcurrentFleetNodeScans) for _, nodeID := range nodeIDs { wg.Add(1) go func(nodeID int64) { defer wg.Done() - // Cap concurrent in-flight node commands; exit early if the - // stream closed or the fan-out budget expired while queued. - select { - case sem <- struct{}{}: - defer func() { <-sem }() - case <-fanOutCtx.Done(): - return - } - runErr := h.discovery.RunOnNode(fanOutCtx, nodeID, autoReq, fwd.forward) - // One node failing must not fail the scan, and is expected - // once fanOutCtx is done (disconnect/budget) — stay quiet then. - if runErr != nil && fanOutCtx.Err() == nil { + // Each node is bounded by RunOnNode's per-node timeout. + runErr := h.discovery.RunOnNode(streamCtx, nodeID, autoReq, fwd.forward) + // One node failing must not fail the scan, and is expected on + // operator disconnect — stay quiet once streamCtx is done. + if runErr != nil && streamCtx.Err() == nil { slog.Warn("fleet node discovery failed during cloud fan-out", "fleet_node_id", nodeID, "error", runErr) } From 14d28d8a64976817d8c7ab5f3e3761bd68fe8a0f Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Wed, 3 Jun 2026 09:58:50 -0700 Subject: [PATCH 10/13] test(discovery): assert unordered node IDs with ElementsMatch ConnectedFleetNodeIDs iterates a map (order unspecified), so assert the ConfirmedConnectedNodeIDs result with ElementsMatch rather than Equal to match the contract and avoid future flakiness. Also tighten the RunOnNode handleEvent comment to the non-obvious terminal-with-nil-err semantic. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/internal/domain/fleetnode/discovery/service.go | 4 ++-- server/internal/domain/fleetnode/discovery/service_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/internal/domain/fleetnode/discovery/service.go b/server/internal/domain/fleetnode/discovery/service.go index d861fc3c3..c18ad1e78 100644 --- a/server/internal/domain/fleetnode/discovery/service.go +++ b/server/internal/domain/fleetnode/discovery/service.go @@ -117,8 +117,8 @@ func (s *Service) RunOnNode(ctx context.Context, fleetNodeID int64, req *pairing } defer session.Close() - // handleEvent forwards a batch or resolves the command on an ack. terminal - // reports whether the command is finished; err is set only on failure. + // terminal=true stops the loop whether or not err is set — an OK/PARTIAL ack + // is terminal with a nil err. handleEvent := func(ev control.CommandEvent) (terminal bool, err error) { switch { case ev.Batch != nil: diff --git a/server/internal/domain/fleetnode/discovery/service_test.go b/server/internal/domain/fleetnode/discovery/service_test.go index 0f1f66f56..e6cfe7a75 100644 --- a/server/internal/domain/fleetnode/discovery/service_test.go +++ b/server/internal/domain/fleetnode/discovery/service_test.go @@ -137,9 +137,9 @@ func TestConfirmedConnectedNodeIDs_IntersectsStatusAndConnection(t *testing.T) { // Act got, err := svc.ConfirmedConnectedNodeIDs(context.Background(), 1) - // Assert: only the confirmed AND connected node. + // Assert: only the confirmed AND connected node (order is unspecified). require.NoError(t, err) - assert.Equal(t, []int64{1}, got) + assert.ElementsMatch(t, []int64{1}, got) } func TestRunOnNode_OnBatchErrorIsTerminal(t *testing.T) { From ada7392f4f6660d450c6652c1d43887690f9c223 Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Wed, 3 Jun 2026 11:35:37 -0700 Subject: [PATCH 11/13] fix: gate fleet-node discovery fan-out behind fleetnode:manage PairingService.Discover fan-out enumerated all confirmed connected fleet nodes and issued discovery commands over their control streams while gated only by miner:pair. The single-node DiscoverOnFleetNode path requires fleetnode:manage, so this was a weaker authorization path to the same fleet-node command surface. Fan-out now also requires fleetnode:manage via a non-failing check, so miner:pair-only callers still get cloud-only discovery but cannot drive discovery commands on fleet nodes. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/internal/handlers/pairing/handler.go | 16 +++- .../handlers/pairing/handler_internal_test.go | 74 +++++++++++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 server/internal/handlers/pairing/handler_internal_test.go diff --git a/server/internal/handlers/pairing/handler.go b/server/internal/handlers/pairing/handler.go index b4abe14d5..0d1698e96 100644 --- a/server/internal/handlers/pairing/handler.go +++ b/server/internal/handlers/pairing/handler.go @@ -94,8 +94,11 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe }() // Fan out only for the automatic "Scan your network" action (nmap target == - // the cloud's own local subnet), never a manual/explicit target. - if isLocalSubnetNmap && h.discovery != nil { + // the cloud's own local subnet), never a manual/explicit target, and only for + // callers who also hold fleetnode:manage — the same permission the single-node + // DiscoverOnFleetNode path requires. Without it, discovery stays cloud-only so + // the weaker miner:pair grant can't drive discovery commands on fleet nodes. + if isLocalSubnetNmap && h.discovery != nil && callerCanManageFleetNodes(ctx) { nodeIDs, listErr := h.discovery.ConfirmedConnectedNodeIDs(streamCtx, info.OrganizationID) if listErr != nil { // Fan-out is best-effort; a lookup failure must never break the @@ -138,6 +141,15 @@ func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRe return nil } +// callerCanManageFleetNodes reports whether the request holds fleetnode:manage. +// It reuses the canonical permission path (so the synthesized-actor and +// fail-closed semantics match) but treats absence as a soft signal to skip +// fan-out rather than an error to return. +func callerCanManageFleetNodes(ctx context.Context) bool { + _, err := middleware.RequirePermission(ctx, authz.PermFleetnodeManage, authz.ResourceContext{}) + return err == nil +} + // Pair implements pairingv1connect.PairingServiceHandler. func (h *Handler) Pair(ctx context.Context, r *connect.Request[pb.PairRequest]) (*connect.Response[pb.PairResponse], error) { if _, err := middleware.RequirePermission(ctx, authz.PermMinerPair, authz.ResourceContext{}); err != nil { diff --git a/server/internal/handlers/pairing/handler_internal_test.go b/server/internal/handlers/pairing/handler_internal_test.go new file mode 100644 index 000000000..71f1e131b --- /dev/null +++ b/server/internal/handlers/pairing/handler_internal_test.go @@ -0,0 +1,74 @@ +package pairing + +import ( + "context" + "testing" + + "connectrpc.com/authn" + "github.com/stretchr/testify/assert" + + "github.com/block/proto-fleet/server/internal/domain/authz" + "github.com/block/proto-fleet/server/internal/domain/session" + "github.com/block/proto-fleet/server/internal/handlers/middleware" +) + +// ctxWithPerms builds the request context the auth interceptor would produce: +// session info plus the caller's effective org-scoped permissions. +func ctxWithPerms(perms ...string) context.Context { + info := &session.Info{ + AuthMethod: session.AuthMethodSession, + SessionID: "sess-1", + UserID: 1, + OrganizationID: 1, + ExternalUserID: "user-1", + Username: "alice", + } + ctx := authn.SetInfo(context.Background(), info) + return middleware.WithEffectivePermissions(ctx, authz.NewEffectivePermissions( + []authz.Assignment{{AssignmentID: 1, ScopeType: authz.ScopeOrg, Permissions: perms}}, + )) +} + +func TestCallerCanManageFleetNodes(t *testing.T) { + tests := []struct { + name string + perms []string + want bool + }{ + { + // The fan-out regression: miner:pair alone (no fleetnode:manage) must + // NOT unlock fleet-node discovery commands. + name: "miner:pair only does not grant fleet-node management", + perms: []string{authz.PermMinerPair}, + want: false, + }, + { + name: "fleetnode:manage grants it", + perms: []string{authz.PermMinerPair, authz.PermFleetnodeManage}, + want: true, + }, + { + name: "fleetnode:read alone does not grant it", + perms: []string{authz.PermMinerPair, authz.PermFleetnodeRead}, + want: false, + }, + { + name: "no permissions", + perms: nil, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Arrange + ctx := ctxWithPerms(tc.perms...) + + // Act + got := callerCanManageFleetNodes(ctx) + + // Assert + assert.Equal(t, tc.want, got) + }) + } +} From 35dc9be7d60a385caa4bab768d6ae0a7634a9424 Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Wed, 3 Jun 2026 12:01:47 -0700 Subject: [PATCH 12/13] fix: reject non-private local subnets before fan-out nmap scan The local-subnet sentinel passed detectLocalSubnets() output straight to nmap. That detection reuses primary-interface logic that doesn't filter for RFC1918, so a node whose primary NIC is public could have that subnet scanned by an automatic "Scan your network". The server drops non-private reports, but only after the node has already sent the scan traffic. Validate each detected prefix is a private IPv4 subnet before scanning and fail AGENT_INCAPABLE otherwise, so a fan-out skips such a node. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/cmd/fleetnode/nmap.go | 35 +++++++++++++++++++++++++++++-- server/cmd/fleetnode/nmap_test.go | 19 +++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/server/cmd/fleetnode/nmap.go b/server/cmd/fleetnode/nmap.go index bcf335063..29817edf3 100644 --- a/server/cmd/fleetnode/nmap.go +++ b/server/cmd/fleetnode/nmap.go @@ -34,8 +34,10 @@ var errNoLocalSubnet = errors.New("no local IPv4 subnet found") // It reuses the same primary-interface detection the cloud Discover path uses // (networking.GetLocalNetworkInfo). That is intentionally less robust than // per-NIC private filtering — it picks one interface, doesn't skip -// virtual/container NICs, and doesn't narrow or cap the mask; hardening is a -// follow-up. The localSubnets seam lets tests inject canned CIDRs. +// virtual/container NICs, and doesn't narrow or cap the mask. The caller +// (buildNmapOptions) rejects non-private results before scanning; narrowing an +// over-broad private mask is a follow-up. The localSubnets seam lets tests +// inject canned CIDRs. func (r *RunCmd) detectLocalSubnets() ([]string, error) { if r.localSubnets != nil { return r.localSubnets() @@ -55,6 +57,24 @@ func validateNmapTarget(s string) error { return nmaptarget.Validate(s) } +// requirePrivateIPv4Subnet guards the local-subnet sentinel: a detected subnet +// must be a private (RFC1918) IPv4 prefix before it is scanned, since +// primary-interface detection itself doesn't filter for RFC1918. +func requirePrivateIPv4Subnet(s string) error { + prefix, err := netip.ParsePrefix(s) + if err != nil { + return fmt.Errorf("not a CIDR: %w", err) + } + addr := prefix.Addr() + if !addr.Is4() { + return errors.New("not IPv4") + } + if !addr.IsPrivate() { + return errors.New("not in an RFC1918 private range") + } + return nil +} + const ( nmapHostTimeout = 10 * time.Second nmapMinRTT = 100 * time.Millisecond @@ -145,6 +165,17 @@ func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRe if err != nil { return nil, cmdErr(pb.AckCode_ACK_CODE_AGENT_INCAPABLE, "no connected private IPv4 subnet for local-subnet scan: %s", err) } + // detectLocalSubnets reuses primary-interface detection that does not + // filter for RFC1918, so a node whose primary NIC is public would + // otherwise have that subnet scanned. Refuse non-private targets here so + // an automatic "Scan your network" can never probe a public network; the + // server's report scope drops public IPs, but only after the node has + // already sent the scan traffic. + for _, s := range subnets { + if err := requirePrivateIPv4Subnet(s); err != nil { + return nil, cmdErr(pb.AckCode_ACK_CODE_AGENT_INCAPABLE, "local-subnet scan target %q is not a private IPv4 subnet: %s", s, err) + } + } return append(baseNmapOptions(r.nmapPath, ports), nmap.WithTargets(subnets...)), nil } diff --git a/server/cmd/fleetnode/nmap_test.go b/server/cmd/fleetnode/nmap_test.go index 9a9244dab..a9d888f4b 100644 --- a/server/cmd/fleetnode/nmap_test.go +++ b/server/cmd/fleetnode/nmap_test.go @@ -270,6 +270,25 @@ func TestBuildNmapOptions_LocalSubnetTarget_UsesDetectedCIDRs(t *testing.T) { assert.False(t, slices.Contains(args, nmaptarget.LocalSubnetTarget), "sentinel must not reach nmap as a literal target: %v", args) } +func TestBuildNmapOptions_LocalSubnetTarget_RejectsNonPrivateSubnet(t *testing.T) { + // Arrange: detection returns a public subnet (e.g. a node whose primary NIC + // is public). An automatic scan must never probe it. + r := &RunCmd{ + nmapPath: "/usr/bin/nmap", + discoverer: &stubDiscoverer{}, + localSubnets: func() ([]string, error) { return []string{"8.8.8.0/24"}, nil }, + } + req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} + + // Act + _, err := r.buildNmapOptions(context.Background(), req, req.Ports) + + // Assert: refused before any scan, mapped so a fan-out skips this node. + var ce *commandError + require.ErrorAs(t, err, &ce) + assert.Equal(t, pb.AckCode_ACK_CODE_AGENT_INCAPABLE, ce.code) +} + func TestBuildNmapOptions_LocalSubnetTarget_NoSubnetIsAgentIncapable(t *testing.T) { // Arrange: detection finds no private subnet. r := &RunCmd{ From c87b7022587c44736eda511ea4b9704ce8ab506f Mon Sep 17 00:00:00 2001 From: Ankit Goswami Date: Wed, 3 Jun 2026 12:07:03 -0700 Subject: [PATCH 13/13] fix: cap local-subnet fan-out scan breadth at /22 The local-subnet sentinel validated detected subnets were private but not their breadth, so a node whose primary NIC has a prefix broader than the /22 nmap limit (e.g. 10.0.0.0/16, returned directly from the OS interface mask) could be made to sweep tens of thousands of hosts per fan-out, defeating the cap that applies to operator-supplied targets. validateLocalSubnetTarget now also rejects prefixes broader than nmaptarget.MinIPv4PrefixBits (AGENT_INCAPABLE), mirroring nmaptarget.Validate. Co-Authored-By: Claude Opus 4.8 (1M context) --- server/cmd/fleetnode/nmap.go | 38 ++++++++++++++++------------ server/cmd/fleetnode/nmap_test.go | 42 ++++++++++++++++++++----------- 2 files changed, 50 insertions(+), 30 deletions(-) diff --git a/server/cmd/fleetnode/nmap.go b/server/cmd/fleetnode/nmap.go index 29817edf3..96f0fe0df 100644 --- a/server/cmd/fleetnode/nmap.go +++ b/server/cmd/fleetnode/nmap.go @@ -34,10 +34,10 @@ var errNoLocalSubnet = errors.New("no local IPv4 subnet found") // It reuses the same primary-interface detection the cloud Discover path uses // (networking.GetLocalNetworkInfo). That is intentionally less robust than // per-NIC private filtering — it picks one interface, doesn't skip -// virtual/container NICs, and doesn't narrow or cap the mask. The caller -// (buildNmapOptions) rejects non-private results before scanning; narrowing an -// over-broad private mask is a follow-up. The localSubnets seam lets tests -// inject canned CIDRs. +// virtual/container NICs, and returns the raw OS mask. The caller +// (buildNmapOptions) rejects non-private or over-broad results before scanning +// (see validateLocalSubnetTarget). The localSubnets seam lets tests inject +// canned CIDRs. func (r *RunCmd) detectLocalSubnets() ([]string, error) { if r.localSubnets != nil { return r.localSubnets() @@ -57,10 +57,14 @@ func validateNmapTarget(s string) error { return nmaptarget.Validate(s) } -// requirePrivateIPv4Subnet guards the local-subnet sentinel: a detected subnet -// must be a private (RFC1918) IPv4 prefix before it is scanned, since -// primary-interface detection itself doesn't filter for RFC1918. -func requirePrivateIPv4Subnet(s string) error { +// validateLocalSubnetTarget guards the local-subnet sentinel: a detected subnet +// must be a private (RFC1918) IPv4 prefix no broader than the shared scan-size +// cap before it is scanned. Primary-interface detection doesn't filter for +// RFC1918 and returns the raw OS interface mask, so a public NIC or an +// over-broad prefix (e.g. 10.0.0.0/16) would otherwise reach nmap. The breadth +// limit mirrors nmaptarget.Validate so the fan-out can't sweep more hosts per +// node than an operator-supplied target is allowed to. +func validateLocalSubnetTarget(s string) error { prefix, err := netip.ParsePrefix(s) if err != nil { return fmt.Errorf("not a CIDR: %w", err) @@ -72,6 +76,9 @@ func requirePrivateIPv4Subnet(s string) error { if !addr.IsPrivate() { return errors.New("not in an RFC1918 private range") } + if prefix.Bits() < nmaptarget.MinIPv4PrefixBits { + return fmt.Errorf("prefix /%d is broader than the supported maximum /%d", prefix.Bits(), nmaptarget.MinIPv4PrefixBits) + } return nil } @@ -165,15 +172,14 @@ func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRe if err != nil { return nil, cmdErr(pb.AckCode_ACK_CODE_AGENT_INCAPABLE, "no connected private IPv4 subnet for local-subnet scan: %s", err) } - // detectLocalSubnets reuses primary-interface detection that does not - // filter for RFC1918, so a node whose primary NIC is public would - // otherwise have that subnet scanned. Refuse non-private targets here so - // an automatic "Scan your network" can never probe a public network; the - // server's report scope drops public IPs, but only after the node has - // already sent the scan traffic. + // detectLocalSubnets reuses primary-interface detection that doesn't + // filter for RFC1918 or cap breadth, so a public NIC or an over-broad + // prefix would otherwise be scanned. Refuse such targets here so an + // automatic "Scan your network" can never probe a public network or sweep + // more hosts than an operator-supplied target may. for _, s := range subnets { - if err := requirePrivateIPv4Subnet(s); err != nil { - return nil, cmdErr(pb.AckCode_ACK_CODE_AGENT_INCAPABLE, "local-subnet scan target %q is not a private IPv4 subnet: %s", s, err) + if err := validateLocalSubnetTarget(s); err != nil { + return nil, cmdErr(pb.AckCode_ACK_CODE_AGENT_INCAPABLE, "local-subnet scan target %q is not scannable: %s", s, err) } } return append(baseNmapOptions(r.nmapPath, ports), nmap.WithTargets(subnets...)), nil diff --git a/server/cmd/fleetnode/nmap_test.go b/server/cmd/fleetnode/nmap_test.go index a9d888f4b..b7b8b44cf 100644 --- a/server/cmd/fleetnode/nmap_test.go +++ b/server/cmd/fleetnode/nmap_test.go @@ -270,23 +270,37 @@ func TestBuildNmapOptions_LocalSubnetTarget_UsesDetectedCIDRs(t *testing.T) { assert.False(t, slices.Contains(args, nmaptarget.LocalSubnetTarget), "sentinel must not reach nmap as a literal target: %v", args) } -func TestBuildNmapOptions_LocalSubnetTarget_RejectsNonPrivateSubnet(t *testing.T) { - // Arrange: detection returns a public subnet (e.g. a node whose primary NIC - // is public). An automatic scan must never probe it. - r := &RunCmd{ - nmapPath: "/usr/bin/nmap", - discoverer: &stubDiscoverer{}, - localSubnets: func() ([]string, error) { return []string{"8.8.8.0/24"}, nil }, +func TestBuildNmapOptions_LocalSubnetTarget_RejectsUnscannableSubnet(t *testing.T) { + tests := []struct { + name string + subnet string + }{ + // A node whose primary NIC is public must never have it probed. + {name: "public subnet", subnet: "8.8.8.0/24"}, + // A private prefix broader than the /22 cap would sweep tens of + // thousands of hosts; reject it like an operator-supplied target. + {name: "over-broad private subnet", subnet: "10.0.0.0/16"}, } - req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} - // Act - _, err := r.buildNmapOptions(context.Background(), req, req.Ports) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Arrange + r := &RunCmd{ + nmapPath: "/usr/bin/nmap", + discoverer: &stubDiscoverer{}, + localSubnets: func() ([]string, error) { return []string{tc.subnet}, nil }, + } + req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} - // Assert: refused before any scan, mapped so a fan-out skips this node. - var ce *commandError - require.ErrorAs(t, err, &ce) - assert.Equal(t, pb.AckCode_ACK_CODE_AGENT_INCAPABLE, ce.code) + // Act + _, err := r.buildNmapOptions(context.Background(), req, req.Ports) + + // Assert: refused before any scan, mapped so a fan-out skips this node. + var ce *commandError + require.ErrorAs(t, err, &ce) + assert.Equal(t, pb.AckCode_ACK_CODE_AGENT_INCAPABLE, ce.code) + }) + } } func TestBuildNmapOptions_LocalSubnetTarget_NoSubnetIsAgentIncapable(t *testing.T) {