diff --git a/server/cmd/fleetd/main.go b/server/cmd/fleetd/main.go index e9a456588..dda28e02e 100644 --- a/server/cmd/fleetd/main.go +++ b/server/cmd/fleetd/main.go @@ -71,6 +71,7 @@ import ( fleetmanagementDomain "github.com/block/proto-fleet/server/internal/domain/fleetmanagement" fleetnodeauth "github.com/block/proto-fleet/server/internal/domain/fleetnode/auth" "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + fleetnodediscovery "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" fleetnodepairing "github.com/block/proto-fleet/server/internal/domain/fleetnode/pairing" "github.com/block/proto-fleet/server/internal/domain/fleetoptions" @@ -220,6 +221,7 @@ func start(config *Config) error { fleetNodePairingStore := sqlstores.NewSQLFleetNodePairingStore(conn) fleetNodePairingSvc := fleetnodepairing.NewService(fleetNodePairingStore, fleetNodeEnrollmentStore, transactor) fleetNodeControlRegistry := control.NewRegistry() + fleetNodeDiscoverySvc := fleetnodediscovery.NewService(fleetNodeControlRegistry, fleetNodeEnrollmentSvc) fleetNodeAuthStore := sqlstores.NewSQLFleetNodeAuthStore(conn) fleetNodeAuthSvc := fleetnodeauth.NewService(fleetNodeAuthStore, fleetNodeEnrollmentStore, apiKeySvc) @@ -522,7 +524,7 @@ func start(config *Config) error { mux.Handle(authv1connect.NewAuthServiceHandler(auth.NewHandler(authSvc), li)) mux.Handle(onboardingv1connect.NewOnboardingServiceHandler(onboarding.NewHandler(authSvc, onboardingSvc), li)) - mux.Handle(pairingv1connect.NewPairingServiceHandler(pairing.NewHandler(pairingSvc), li)) + mux.Handle(pairingv1connect.NewPairingServiceHandler(pairing.NewHandler(pairingSvc, fleetNodeDiscoverySvc), li)) mux.Handle(networkinfov1connect.NewNetworkInfoServiceHandler(networkinfo.NewHandler(pairingSvc), li)) mux.Handle(fleetmanagementv1connect.NewFleetManagementServiceHandler(fleetmanagement.NewHandler(fleetMgmtSvc), li)) mux.Handle(minercommandv1connect.NewMinerCommandServiceHandler(command.NewHandler(commandSvc), li)) @@ -532,7 +534,7 @@ func start(config *Config) error { mux.Handle(sitesv1connect.NewSiteServiceHandler(sitesHandler.NewHandler(sitesSvc), li)) mux.Handle(buildingsv1connect.NewBuildingServiceHandler(buildingsHandler.NewHandler(buildingsSvc), li)) mux.Handle(fleetnodegatewayv1connect.NewFleetNodeGatewayServiceHandler(gateway.NewHandler(fleetNodeEnrollmentSvc, fleetNodeAuthSvc, fleetNodePairingSvc, fleetNodeControlRegistry), li)) - mux.Handle(fleetnodeadminv1connect.NewFleetNodeAdminServiceHandler(admin.NewHandler(fleetNodeEnrollmentSvc, fleetNodePairingSvc, fleetNodeControlRegistry), li)) + mux.Handle(fleetnodeadminv1connect.NewFleetNodeAdminServiceHandler(admin.NewHandler(fleetNodeEnrollmentSvc, fleetNodePairingSvc, fleetNodeDiscoverySvc), li)) mux.Handle(collectionv1connect.NewDeviceCollectionServiceHandler(collectionHandler.NewHandler(collectionSvc), li)) mux.Handle(device_setv1connect.NewDeviceSetServiceHandler(devicesetHandler.NewHandler(collectionSvc), li)) mux.Handle(telemetryv1connect.NewTelemetryServiceHandler(telemetryHandler.NewHandler(telemetryService), li)) diff --git a/server/cmd/fleetnode/nmap.go b/server/cmd/fleetnode/nmap.go index 734f10144..96f0fe0df 100644 --- a/server/cmd/fleetnode/nmap.go +++ b/server/cmd/fleetnode/nmap.go @@ -20,13 +20,68 @@ import ( pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" "github.com/block/proto-fleet/server/internal/domain/netutil" "github.com/block/proto-fleet/server/internal/domain/nmaptarget" + "github.com/block/proto-fleet/server/internal/infrastructure/networking" ) +// errNoLocalSubnet means the host has no usable IPv4 subnet to scan for a +// LocalSubnetTarget command. Surfaces as AGENT_INCAPABLE so a fan-out skips this +// node and tries the others. +var errNoLocalSubnet = errors.New("no local IPv4 subnet found") + +// detectLocalSubnets returns the subnet(s) the agent scans for a local-subnet +// nmap command (the nmaptarget.LocalSubnetTarget sentinel). +// +// It reuses the same primary-interface detection the cloud Discover path uses +// (networking.GetLocalNetworkInfo). That is intentionally less robust than +// per-NIC private filtering — it picks one interface, doesn't skip +// virtual/container NICs, and returns the raw OS mask. The caller +// (buildNmapOptions) rejects non-private or over-broad results before scanning +// (see validateLocalSubnetTarget). The localSubnets seam lets tests inject +// canned CIDRs. +func (r *RunCmd) detectLocalSubnets() ([]string, error) { + if r.localSubnets != nil { + return r.localSubnets() + } + info, err := networking.GetLocalNetworkInfo() + if err != nil { + return nil, fmt.Errorf("get local network info: %w", err) + } + if info.Subnet == "" { + return nil, errNoLocalSubnet + } + return []string{info.Subnet}, nil +} + // validateNmapTarget enforces the shared nmap target grammar (see nmaptarget). func validateNmapTarget(s string) error { return nmaptarget.Validate(s) } +// validateLocalSubnetTarget guards the local-subnet sentinel: a detected subnet +// must be a private (RFC1918) IPv4 prefix no broader than the shared scan-size +// cap before it is scanned. Primary-interface detection doesn't filter for +// RFC1918 and returns the raw OS interface mask, so a public NIC or an +// over-broad prefix (e.g. 10.0.0.0/16) would otherwise reach nmap. The breadth +// limit mirrors nmaptarget.Validate so the fan-out can't sweep more hosts per +// node than an operator-supplied target is allowed to. +func validateLocalSubnetTarget(s string) error { + prefix, err := netip.ParsePrefix(s) + if err != nil { + return fmt.Errorf("not a CIDR: %w", err) + } + addr := prefix.Addr() + if !addr.Is4() { + return errors.New("not IPv4") + } + if !addr.IsPrivate() { + return errors.New("not in an RFC1918 private range") + } + if prefix.Bits() < nmaptarget.MinIPv4PrefixBits { + return fmt.Errorf("prefix /%d is broader than the supported maximum /%d", prefix.Bits(), nmaptarget.MinIPv4PrefixBits) + } + return nil +} + const ( nmapHostTimeout = 10 * time.Second nmapMinRTT = 100 * time.Millisecond @@ -107,6 +162,29 @@ func validateNmapBinary(path string) (string, error) { func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRequest, ports []string) ([]nmap.Option, error) { target := strings.TrimSpace(req.GetTarget()) + + // The LocalSubnetTarget sentinel means the server couldn't know the node's + // network, so the agent enumerates its own private IPv4 subnet(s) and scans + // those (IPv4 only, same as the manual path's IPv6-CIDR rejection). Matched + // exactly, before any hostname resolution. + if target == nmaptarget.LocalSubnetTarget { + subnets, err := r.detectLocalSubnets() + if err != nil { + return nil, cmdErr(pb.AckCode_ACK_CODE_AGENT_INCAPABLE, "no connected private IPv4 subnet for local-subnet scan: %s", err) + } + // detectLocalSubnets reuses primary-interface detection that doesn't + // filter for RFC1918 or cap breadth, so a public NIC or an over-broad + // prefix would otherwise be scanned. Refuse such targets here so an + // automatic "Scan your network" can never probe a public network or sweep + // more hosts than an operator-supplied target may. + for _, s := range subnets { + if err := validateLocalSubnetTarget(s); err != nil { + return nil, cmdErr(pb.AckCode_ACK_CODE_AGENT_INCAPABLE, "local-subnet scan target %q is not scannable: %s", s, err) + } + } + return append(baseNmapOptions(r.nmapPath, ports), nmap.WithTargets(subnets...)), nil + } + if err := validateNmapTarget(target); err != nil { return nil, cmdErr(pb.AckCode_ACK_CODE_BAD_REQUEST, "%s", err) } @@ -114,9 +192,18 @@ func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRe if err != nil { return nil, cmdErr(pb.AckCode_ACK_CODE_BAD_REQUEST, "%s", err) } - opts := []nmap.Option{ - nmap.WithBinaryPath(r.nmapPath), - nmap.WithTargets(resolved), + opts := append(baseNmapOptions(r.nmapPath, ports), nmap.WithTargets(resolved)) + if useIPv6 { + opts = append(opts, nmap.WithIPv6Scanning()) + } + return opts, nil +} + +// baseNmapOptions are the timing/safety options shared by targeted and +// local-subnet scans; callers append the target(s) (and -6 if needed). +func baseNmapOptions(binaryPath string, ports []string) []nmap.Option { + return []nmap.Option{ + nmap.WithBinaryPath(binaryPath), nmap.WithPorts(strings.Join(ports, ",")), nmap.WithUnique(), nmap.WithDisabledDNSResolution(), @@ -125,10 +212,6 @@ func (r *RunCmd) buildNmapOptions(ctx context.Context, req *pairingpb.NmapModeRe nmap.WithHostTimeout(nmapHostTimeout), nmap.WithMinRTTTimeout(nmapMinRTT), } - if useIPv6 { - opts = append(opts, nmap.WithIPv6Scanning()) - } - return opts, nil } // Mirrors pairing-service validateNmapTargets so agent and server feed diff --git a/server/cmd/fleetnode/nmap_test.go b/server/cmd/fleetnode/nmap_test.go index e8dd6dd58..b7b8b44cf 100644 --- a/server/cmd/fleetnode/nmap_test.go +++ b/server/cmd/fleetnode/nmap_test.go @@ -16,7 +16,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + pb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" ) func testLogger() *slog.Logger { return slog.New(slog.DiscardHandler) } @@ -246,6 +248,79 @@ func TestBuildNmapOptions_AddsIPv6Scanning(t *testing.T) { assert.True(t, slices.Contains(v6Scanner.Args(), "-6"), "IPv6 target must carry -6") } +func TestBuildNmapOptions_LocalSubnetTarget_UsesDetectedCIDRs(t *testing.T) { + // Arrange: inject the detected subnets so the test doesn't depend on the host. + r := &RunCmd{ + nmapPath: "/usr/bin/nmap", + discoverer: &stubDiscoverer{}, + localSubnets: func() ([]string, error) { return []string{"192.168.1.0/24"}, nil }, + } + req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} + + // Act + opts, err := r.buildNmapOptions(context.Background(), req, req.Ports) + require.NoError(t, err) + scanner, err := nmap.NewScanner(context.Background(), opts...) + require.NoError(t, err) + + // Assert: the detected subnet reaches nmap and the scan stays IPv4-only. + args := scanner.Args() + assert.True(t, slices.Contains(args, "192.168.1.0/24"), "expected detected subnet in argv: %v", args) + assert.False(t, slices.Contains(args, "-6"), "local-subnet scan must be IPv4-only: %v", args) + assert.False(t, slices.Contains(args, nmaptarget.LocalSubnetTarget), "sentinel must not reach nmap as a literal target: %v", args) +} + +func TestBuildNmapOptions_LocalSubnetTarget_RejectsUnscannableSubnet(t *testing.T) { + tests := []struct { + name string + subnet string + }{ + // A node whose primary NIC is public must never have it probed. + {name: "public subnet", subnet: "8.8.8.0/24"}, + // A private prefix broader than the /22 cap would sweep tens of + // thousands of hosts; reject it like an operator-supplied target. + {name: "over-broad private subnet", subnet: "10.0.0.0/16"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Arrange + r := &RunCmd{ + nmapPath: "/usr/bin/nmap", + discoverer: &stubDiscoverer{}, + localSubnets: func() ([]string, error) { return []string{tc.subnet}, nil }, + } + req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} + + // Act + _, err := r.buildNmapOptions(context.Background(), req, req.Ports) + + // Assert: refused before any scan, mapped so a fan-out skips this node. + var ce *commandError + require.ErrorAs(t, err, &ce) + assert.Equal(t, pb.AckCode_ACK_CODE_AGENT_INCAPABLE, ce.code) + }) + } +} + +func TestBuildNmapOptions_LocalSubnetTarget_NoSubnetIsAgentIncapable(t *testing.T) { + // Arrange: detection finds no private subnet. + r := &RunCmd{ + nmapPath: "/usr/bin/nmap", + discoverer: &stubDiscoverer{}, + localSubnets: func() ([]string, error) { return nil, errNoLocalSubnet }, + } + req := &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: []string{"4028"}} + + // Act + _, err := r.buildNmapOptions(context.Background(), req, req.Ports) + + // Assert: a fan-out should skip this node, so the ack maps to FailedPrecondition. + var ce *commandError + require.ErrorAs(t, err, &ce) + assert.Equal(t, pb.AckCode_ACK_CODE_AGENT_INCAPABLE, ce.code) +} + func TestDiscoverForCommand_NmapPathEmptyFailsClosed(t *testing.T) { // Arrange r := &RunCmd{nmapPath: "", discoverer: &stubDiscoverer{}} diff --git a/server/cmd/fleetnode/run.go b/server/cmd/fleetnode/run.go index 295b8b1f1..decf0178d 100644 --- a/server/cmd/fleetnode/run.go +++ b/server/cmd/fleetnode/run.go @@ -34,6 +34,7 @@ type RunCmd struct { discoverer discoverer `kong:"-"` nmapPath string `kong:"-"` resolver ipResolver `kong:"-"` + localSubnets func() ([]string, error) `kong:"-"` // test seam for local-subnet detection stateMu sync.Mutex `kong:"-"` // guards st.SessionToken across refreshAndSave + tokenSource. } diff --git a/server/internal/domain/fleetnode/control/registry.go b/server/internal/domain/fleetnode/control/registry.go index c96ee2eb7..283b81ff6 100644 --- a/server/internal/domain/fleetnode/control/registry.go +++ b/server/internal/domain/fleetnode/control/registry.go @@ -93,6 +93,19 @@ func NewRegistry() *Registry { return &Registry{conns: make(map[int64]*connection)} } +// ConnectedFleetNodeIDs returns the fleet_node IDs with an active ControlStream +// right now. Used by fan-out discovery to target only nodes the server can reach; +// callers intersect this with the org's CONFIRMED nodes. Order is unspecified. +func (r *Registry) ConnectedFleetNodeIDs() []int64 { + r.mu.Lock() + defer r.mu.Unlock() + ids := make([]int64, 0, len(r.conns)) + for id := range r.conns { + ids = append(ids, id) + } + return ids +} + // teardown closes connection.done and cmd.done (if any). Caller holds Registry.mu // and must then remove/replace the conn so teardown can't run twice. func teardown(conn *connection) { diff --git a/server/internal/domain/fleetnode/control/registry_test.go b/server/internal/domain/fleetnode/control/registry_test.go index bae4bc704..a3a66bc25 100644 --- a/server/internal/domain/fleetnode/control/registry_test.go +++ b/server/internal/domain/fleetnode/control/registry_test.go @@ -327,3 +327,21 @@ func receive(t *testing.T, ch <-chan CommandEvent) CommandEvent { return CommandEvent{} } } + +func TestConnectedFleetNodeIDs_ReflectsRegisterAndUnregister(t *testing.T) { + // Arrange + r := NewRegistry() + s1 := r.Register(1) + s2 := r.Register(2) + + // Act + Assert: both connected. + assert.ElementsMatch(t, []int64{1, 2}, r.ConnectedFleetNodeIDs()) + + // Act + Assert: unregistering one drops it. + s1.Unregister() + assert.ElementsMatch(t, []int64{2}, r.ConnectedFleetNodeIDs()) + + // Act + Assert: empty once all are gone. + s2.Unregister() + assert.Empty(t, r.ConnectedFleetNodeIDs()) +} diff --git a/server/internal/handlers/fleetnode/admin/handler_ackfailure_test.go b/server/internal/domain/fleetnode/discovery/ackfailure_test.go similarity index 98% rename from server/internal/handlers/fleetnode/admin/handler_ackfailure_test.go rename to server/internal/domain/fleetnode/discovery/ackfailure_test.go index 4fea4bb42..f8caaeb37 100644 --- a/server/internal/handlers/fleetnode/admin/handler_ackfailure_test.go +++ b/server/internal/domain/fleetnode/discovery/ackfailure_test.go @@ -1,4 +1,4 @@ -package admin +package discovery import ( "testing" diff --git a/server/internal/handlers/fleetnode/admin/handler_iprange_test.go b/server/internal/domain/fleetnode/discovery/iprange_test.go similarity index 99% rename from server/internal/handlers/fleetnode/admin/handler_iprange_test.go rename to server/internal/domain/fleetnode/discovery/iprange_test.go index 62f607d9b..f46a3957b 100644 --- a/server/internal/handlers/fleetnode/admin/handler_iprange_test.go +++ b/server/internal/domain/fleetnode/discovery/iprange_test.go @@ -1,4 +1,4 @@ -package admin +package discovery import ( "testing" diff --git a/server/internal/handlers/fleetnode/admin/reportscope.go b/server/internal/domain/fleetnode/discovery/reportscope.go similarity index 91% rename from server/internal/handlers/fleetnode/admin/reportscope.go rename to server/internal/domain/fleetnode/discovery/reportscope.go index 9c4a05d24..01255a67c 100644 --- a/server/internal/handlers/fleetnode/admin/reportscope.go +++ b/server/internal/domain/fleetnode/discovery/reportscope.go @@ -1,4 +1,4 @@ -package admin +package discovery import ( "net/netip" @@ -25,8 +25,21 @@ func buildReportScope(req *pairingpb.DiscoverRequest) control.ReportScope { return inPort(port) && inIP(ip) } case *pairingpb.DiscoverRequest_Nmap: - inTarget := nmapTargetMatcher(m.Nmap.GetTarget()) inPort := portMatcher(m.Nmap.GetPorts()) + // The LocalSubnetTarget sentinel lets the agent pick its own subnet, so + // the server can't predict the IPs. Degrade the IP scope to the + // private-only invariant (RFC1918/RFC4193) that validateReport + // independently enforces; port scoping is still applied. + if m.Nmap.GetTarget() == nmaptarget.LocalSubnetTarget { + return func(ip, port string) bool { + if !inPort(port) { + return false + } + a, ok := parseScopeAddr(ip) + return ok && a.IsPrivate() + } + } + inTarget := nmapTargetMatcher(m.Nmap.GetTarget()) return func(ip, port string) bool { return inPort(port) && inTarget(ip) } diff --git a/server/internal/handlers/fleetnode/admin/reportscope_test.go b/server/internal/domain/fleetnode/discovery/reportscope_test.go similarity index 82% rename from server/internal/handlers/fleetnode/admin/reportscope_test.go rename to server/internal/domain/fleetnode/discovery/reportscope_test.go index a98eb6536..8d5d83699 100644 --- a/server/internal/handlers/fleetnode/admin/reportscope_test.go +++ b/server/internal/domain/fleetnode/discovery/reportscope_test.go @@ -1,4 +1,4 @@ -package admin +package discovery import ( "testing" @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" ) func ipListReq(ips, ports []string) *pairingpb.DiscoverRequest { @@ -21,6 +22,12 @@ func nmapReq(target string, ports []string) *pairingpb.DiscoverRequest { }} } +func autoNmapReq(ports []string) *pairingpb.DiscoverRequest { + return &pairingpb.DiscoverRequest{Mode: &pairingpb.DiscoverRequest_Nmap{ + Nmap: &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget, Ports: ports}, + }} +} + func TestBuildReportScope(t *testing.T) { tests := []struct { name string @@ -49,6 +56,10 @@ func TestBuildReportScope(t *testing.T) { {"nmap literal mismatch", nmapReq("192.168.1.10", []string{"80"}), "192.168.1.11", "80", false}, {"nmap hostname leaves ip unconstrained", nmapReq("miner.lan", []string{"80"}), "10.1.2.3", "80", true}, {"nmap hostname still enforces ports", nmapReq("miner.lan", []string{"80"}), "10.1.2.3", "22", false}, + {"auto accepts private ip on in-scope port", autoNmapReq([]string{"80"}), "192.168.5.9", "80", true}, + {"auto rejects public ip", autoNmapReq([]string{"80"}), "8.8.8.8", "80", false}, + {"auto rejects out-of-scope port", autoNmapReq([]string{"80"}), "192.168.5.9", "22", false}, + {"auto empty ports allows any private ip and port", autoNmapReq(nil), "10.2.3.4", "31337", true}, } for _, tc := range tests { @@ -147,6 +158,29 @@ func TestNormalizeDiscoverRequest_RejectsPublicNmapTarget(t *testing.T) { } } +func TestNormalizeDiscoverRequest_LocalSubnetTarget_Accepts(t *testing.T) { + // Arrange: the local-subnet sentinel target with valid ports. + req := autoNmapReq([]string{"80", "4028"}) + + // Act + _, err := normalizeDiscoverRequest(req) + + // Assert + require.NoError(t, err) +} + +func TestNormalizeDiscoverRequest_LocalSubnetTarget_RejectsInvalidPort(t *testing.T) { + // Arrange + req := autoNmapReq([]string{"80/tcp"}) + + // Act + _, err := normalizeDiscoverRequest(req) + + // Assert + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid port") +} + func TestNmapTargetIsPrivate(t *testing.T) { tests := []struct { name string diff --git a/server/internal/domain/fleetnode/discovery/service.go b/server/internal/domain/fleetnode/discovery/service.go new file mode 100644 index 000000000..c18ad1e78 --- /dev/null +++ b/server/internal/domain/fleetnode/discovery/service.go @@ -0,0 +1,354 @@ +// Package discovery dispatches server-initiated miner discovery to fleet nodes +// over the ControlStream and streams the results back. It owns the per-node +// run loop (normalize -> send command -> drain batches until ack) shared by the +// operator-facing single-node RPC (handlers/fleetnode/admin) and the cloud +// "Find miners" fan-out (handlers/pairing), plus the helpers that decide which +// nodes a fan-out should target. +package discovery + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/netip" + "strconv" + "time" + + "connectrpc.com/connect" + "google.golang.org/protobuf/proto" + + gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" + pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/discoverylimits" + "github.com/block/proto-fleet/server/internal/domain/fleeterror" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" + "github.com/block/proto-fleet/server/internal/domain/netutil" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" + "github.com/block/proto-fleet/server/internal/infrastructure/id" +) + +// DiscoverCommandTimeout bounds how long RunOnNode waits for the agent's batches +// and ack, so a silent node can't pin operator streams and registry slots. Must +// exceed the agent's scan budget (commandTimeout, 10m) plus report/ack slack: too +// short frees the slot mid-scan, the agent's ack is rejected as stale, and a new +// command dispatches while the node is still busy. Var for tests. +var DiscoverCommandTimeout = 12 * time.Minute + +// nodeLister is the subset of enrollment.Service that fan-out targeting needs. +type nodeLister interface { + ListFleetNodes(ctx context.Context, orgID int64) ([]enrollment.FleetNodeListing, error) +} + +// nodeRegistry is the slice of control.Registry this service needs: enumerate +// connected nodes and dispatch a command to one. Narrowing it (like nodeLister) +// makes the coupling explicit and lets tests inject a fake without a Registry. +type nodeRegistry interface { + ConnectedFleetNodeIDs() []int64 + Send(ctx context.Context, fleetNodeID int64, cmd *gatewaypb.ControlCommand, scope control.ReportScope) (*control.Session, error) +} + +// Service runs discovery commands against connected fleet nodes. +type Service struct { + registry nodeRegistry + enrollment nodeLister +} + +func NewService(registry nodeRegistry, enrollmentSvc nodeLister) *Service { + return &Service{registry: registry, enrollment: enrollmentSvc} +} + +// ConfirmedConnectedNodeIDs returns the IDs of fleet nodes in orgID that are both +// CONFIRMED and currently connected (active ControlStream) — the set a fan-out +// can dispatch to. A node with a live stream but a non-CONFIRMED enrollment +// status is excluded. +func (s *Service) ConfirmedConnectedNodeIDs(ctx context.Context, orgID int64) ([]int64, error) { + nodes, err := s.enrollment.ListFleetNodes(ctx, orgID) + if err != nil { + return nil, err + } + confirmed := make(map[int64]struct{}, len(nodes)) + for _, n := range nodes { + if n.EnrollmentStatus == enrollment.FleetNodeStatusConfirmed { + confirmed[n.ID] = struct{}{} + } + } + connected := s.registry.ConnectedFleetNodeIDs() + out := make([]int64, 0, len(connected)) + for _, nodeID := range connected { + if _, ok := confirmed[nodeID]; ok { + out = append(out, nodeID) + } + } + return out, nil +} + +// RunOnNode normalizes req, builds the report scope, dispatches the command over +// the node's ControlStream, and invokes onBatch for each discovered-device batch +// until the node acks (or the command times out / the stream drops). It returns +// nil on an OK or PARTIAL ack, and an error otherwise — including any non-nil +// error returned by onBatch, which is treated as terminal (the caller's stream +// is gone, so there is nothing left to forward). +func (s *Service) RunOnNode(ctx context.Context, fleetNodeID int64, req *pairingpb.DiscoverRequest, onBatch func(*pairingpb.DiscoverResponse) error) error { + normalized, err := normalizeDiscoverRequest(req) + if err != nil { + return err + } + + commandID := id.GenerateID() + payload, err := proto.Marshal(normalized) + if err != nil { + return fleeterror.NewInternalErrorf("marshal discover payload: %v", err) + } + + ctx, cancel := context.WithTimeout(ctx, DiscoverCommandTimeout) + defer cancel() + + session, err := s.registry.Send(ctx, fleetNodeID, &gatewaypb.ControlCommand{ + CommandId: commandID, + Payload: payload, + }, buildReportScope(normalized)) + if err != nil { + if errors.Is(err, control.ErrNoActiveStream) { + return fleeterror.NewFailedPreconditionError("fleet node has no active control stream") + } + return err + } + defer session.Close() + + // terminal=true stops the loop whether or not err is set — an OK/PARTIAL ack + // is terminal with a nil err. + handleEvent := func(ev control.CommandEvent) (terminal bool, err error) { + switch { + case ev.Batch != nil: + if sendErr := onBatch(ev.Batch); sendErr != nil { + return true, sendErr + } + return false, nil + case ev.Ack != nil: + // PARTIAL carries succeeded=false but its reports already streamed; + // treat it as a usable (incomplete) result, not a failure. + if ev.Ack.GetCode() == gatewaypb.AckCode_ACK_CODE_PARTIAL { + slog.Warn("fleet node discovery completed partially", + "fleet_node_id", fleetNodeID, "detail", ev.Ack.GetErrorMessage()) + return true, nil + } + // Require the structured OK code, not just the boolean, so an + // inconsistent ack (succeeded=true with a non-OK/unset code) can't + // pass a failed scan off as success. + if ev.Ack.GetCode() != gatewaypb.AckCode_ACK_CODE_OK || !ev.Ack.GetSucceeded() { + return true, discoverAckFailure(ev.Ack) + } + return true, nil + default: + return false, nil + } + } + + events := session.Events() + for { + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return connect.NewError(connect.CodeDeadlineExceeded, fmt.Errorf("discovery command timed out after %s", DiscoverCommandTimeout)) + } + // Caller (operator or fan-out) cancelled; report it as such rather + // than a server-side Internal failure. + return fleeterror.NewCanceledError() + case ev := <-events: + if terminal, err := handleEvent(ev); terminal { + return err + } + case <-session.Done(): + // Stream died before an ack. Drain buffered events first (a final + // ack or last batch) so select randomness doesn't drop them. + for { + select { + case ev := <-events: + if terminal, err := handleEvent(ev); terminal { + return err + } + default: + return fleeterror.NewFailedPreconditionError("fleet node control stream closed before command completed") + } + } + } + } +} + +// discoverAckFailure maps a non-OK ack to an operator-facing error, even when +// error_message is empty. The structured AckCode drives the gRPC code so the +// operator can tell a retryable condition (BUSY) and a capability gap +// (AGENT_INCAPABLE) apart from a malformed request (BAD_REQUEST); anything else +// is an opaque Internal failure. +func discoverAckFailure(ack *gatewaypb.ControlAck) error { + reason := ack.GetErrorMessage() + if reason == "" { + reason = "code " + ack.GetCode().String() + } + // if/else (not switch) so the exhaustive linter doesn't demand a case per + // AckCode; everything outside these three is an opaque Internal failure. + code := ack.GetCode() + if code == gatewaypb.AckCode_ACK_CODE_BAD_REQUEST { + return fleeterror.NewInvalidArgumentErrorf("fleet node rejected discovery command: %s", reason) + } + if code == gatewaypb.AckCode_ACK_CODE_BUSY { + return fleeterror.NewPlainError( + fmt.Sprintf("fleet node is busy with another command; retry shortly: %s", reason), + connect.CodeResourceExhausted, + ) + } + if code == gatewaypb.AckCode_ACK_CODE_AGENT_INCAPABLE { + return fleeterror.NewFailedPreconditionErrorf("fleet node cannot service this discovery request; try another node: %s", reason) + } + return fleeterror.NewInternalErrorf("fleet node reported discovery failure: %s", reason) +} + +func normalizeDiscoverRequest(in *pairingpb.DiscoverRequest) (*pairingpb.DiscoverRequest, error) { + switch m := in.GetMode().(type) { + case *pairingpb.DiscoverRequest_IpList: + if m.IpList == nil || len(m.IpList.GetIpAddresses()) == 0 { + return nil, fleeterror.NewInvalidArgumentError("ip_list.ip_addresses must not be empty") + } + if err := checkScanLimits(m.IpList.GetIpAddresses(), m.IpList.GetPorts()); err != nil { + return nil, err + } + // Every entry must be a valid IP or hostname, and IP literals must be + // private. A malformed token (e.g. "bad/entry") is unresolvable for the + // agent yet trips the scope matcher's hostname fallback, widening the + // command to port-only scope. A public literal scans fine but every report + // is rejected by validateReport (private-only), surfacing as a late + // REPORT_FAILED. Hostnames resolve agent-side to an IP the server can't + // check here, so they pass through. + for _, e := range m.IpList.GetIpAddresses() { + addr, perr := netip.ParseAddr(e) + if perr != nil { + if !nmaptarget.IsHostname(e) { + return nil, fleeterror.NewInvalidArgumentErrorf("ip_list entry %q is not a valid IP address or hostname", e) + } + continue + } + if !addr.Unmap().IsPrivate() { + return nil, fleeterror.NewInvalidArgumentErrorf("ip_list entry %q is not a private (RFC1918/RFC4193) address", e) + } + } + return in, nil + case *pairingpb.DiscoverRequest_IpRange: + ips, err := expandIPv4Range(m.IpRange.GetStartIp(), m.IpRange.GetEndIp()) + if err != nil { + return nil, err + } + if err := checkScanLimits(ips, m.IpRange.GetPorts()); err != nil { + return nil, err + } + return &pairingpb.DiscoverRequest{ + Mode: &pairingpb.DiscoverRequest_IpList{ + IpList: &pairingpb.IPListModeRequest{ + IpAddresses: ips, + Ports: m.IpRange.GetPorts(), + }, + }, + }, nil + case *pairingpb.DiscoverRequest_Nmap: + target := m.Nmap.GetTarget() + // The LocalSubnetTarget sentinel defers the target to the agent (it scans + // its own private subnet(s)), so there is nothing to validate here; the + // report scope (buildReportScope) and validateReport still confine reports + // to private addresses. + if target == nmaptarget.LocalSubnetTarget { + if err := checkScanLimits(nil, m.Nmap.GetPorts()); err != nil { + return nil, err + } + return in, nil + } + // Validate against the shared grammar (incl. the /22 CIDR cap), then + // reject IPv6 CIDR — both rejections the agent makes — so an unsupported + // target fails fast here instead of as a late agent BAD_REQUEST ack. + if err := nmaptarget.Validate(target); err != nil { + return nil, fleeterror.NewInvalidArgumentError(err.Error()) + } + if prefix, perr := netip.ParsePrefix(target); perr == nil && prefix.Addr().Is6() { + return nil, fleeterror.NewInvalidArgumentError("nmap IPv6 CIDR is not supported; use ip_list for IPv6 devices") + } + // A public target scans fine but every report comes back non-private and + // is rejected by validateReport, so fail fast. Hostnames resolve agent-side + // and pass through (the report validator still guards what they return). + if !nmapTargetIsPrivate(target) { + return nil, fleeterror.NewInvalidArgumentError("nmap target must be within a private (RFC1918/RFC4193) range") + } + if err := checkScanLimits(nil, m.Nmap.GetPorts()); err != nil { + return nil, err + } + return in, nil + case *pairingpb.DiscoverRequest_Mdns: + return nil, fleeterror.NewInvalidArgumentError("mdns discovery is not supported on fleet nodes") + default: + return nil, fleeterror.NewInvalidArgumentError("discover request mode is required") + } +} + +// checkScanLimits enforces the agent's per-command caps (via discoverylimits) +// and rejects malformed ports before dispatch, so an over-cap or invalid request +// fails fast with a validation error instead of a late agent BAD_REQUEST ack. +// The proto caps are the wire ceiling; these are the real limits. +func checkScanLimits(ipAddresses, ports []string) error { + if len(ipAddresses) > discoverylimits.MaxScanTargets { + return fleeterror.NewInvalidArgumentErrorf("too many targets: %d exceeds the limit of %d", len(ipAddresses), discoverylimits.MaxScanTargets) + } + if len(ports) > discoverylimits.MaxPortsPerIP { + return fleeterror.NewInvalidArgumentErrorf("too many ports: %d exceeds the limit of %d", len(ports), discoverylimits.MaxPortsPerIP) + } + // Each port must be a bare decimal in 1-65535, matching the agent's + // resolveAndValidatePorts; otherwise a token like "80/tcp" or "70000" + // dispatches and returns as a late agent BAD_REQUEST ack. + for _, p := range ports { + if n, err := strconv.Atoi(p); err != nil || n < 1 || n > 65535 { + return fleeterror.NewInvalidArgumentErrorf("invalid port %q: must be a decimal in 1-65535", p) + } + } + return nil +} + +func expandIPv4Range(startStr, endStr string) ([]string, error) { + startAddr, err := netutil.ParseIPv4(startStr) + if err != nil { + return nil, fleeterror.NewInvalidArgumentErrorf("invalid start_ip: %v", err) + } + endAddr, err := netutil.ParseIPv4(endStr) + if err != nil { + return nil, fleeterror.NewInvalidArgumentErrorf("invalid end_ip: %v", err) + } + // Both ends must be private. The MaxScanTargets cap below keeps the range far + // smaller than the gap between RFC1918 blocks, so private endpoints imply a + // fully private range. A public range scans fine but every report is rejected + // by validateReport, surfacing as a late REPORT_FAILED. + if !startAddr.IsPrivate() || !endAddr.IsPrivate() { + return nil, fleeterror.NewInvalidArgumentError("ip range must be within a private (RFC1918) range") + } + start, end := netutil.IPv4ToUint32(startAddr), netutil.IPv4ToUint32(endAddr) + if end < start { + return nil, fleeterror.NewInvalidArgumentError("end_ip must be >= start_ip") + } + // Skip the network (.0) and gateway (.1) start addresses, matching the agent + // and cloud pairing. Otherwise expanding to an IP list would scan .0/.1 as + // literal targets — gateways answer on many ports and look like miners. + start = netutil.AdjustIPv4RangeStart(start) + if end < start { + return nil, fleeterror.NewInvalidArgumentError("ip range covers only network/gateway addresses") + } + // uint64 math so a range ending at 255.255.255.255 can't wrap (in uint32, + // end-start+1 would overflow to 0, bypassing the cap and never terminating). + size := uint64(end) - uint64(start) + 1 + if size > discoverylimits.MaxScanTargets { + return nil, fleeterror.NewInvalidArgumentErrorf("ip range exceeds %d addresses", discoverylimits.MaxScanTargets) + } + out := make([]string, 0, size) + for v := start; ; v++ { + out = append(out, netutil.Uint32ToIPv4(v)) + if v == end { + break + } + } + return out, nil +} diff --git a/server/internal/domain/fleetnode/discovery/service_test.go b/server/internal/domain/fleetnode/discovery/service_test.go new file mode 100644 index 000000000..e6cfe7a75 --- /dev/null +++ b/server/internal/domain/fleetnode/discovery/service_test.go @@ -0,0 +1,229 @@ +package discovery + +import ( + "context" + "errors" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" + pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/fleeterror" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" +) + +type stubLister struct { + nodes []enrollment.FleetNodeListing + err error +} + +func (s stubLister) ListFleetNodes(context.Context, int64) ([]enrollment.FleetNodeListing, error) { + return s.nodes, s.err +} + +func collectBatches(dst *[]*pairingpb.Device) func(*pairingpb.DiscoverResponse) error { + return func(b *pairingpb.DiscoverResponse) error { + *dst = append(*dst, b.GetDevices()...) + return nil + } +} + +func TestRunOnNode_ForwardsBatchesUntilAck(t *testing.T) { + // Arrange + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(7) + stream := reg.Register(nodeID) + defer stream.Unregister() + go func() { + cmd := <-stream.Outgoing + reg.PublishBatch(nodeID, cmd.GetCommandId(), &pairingpb.DiscoverResponse{ + Devices: []*pairingpb.Device{{DeviceIdentifier: "auto:1", IpAddress: "10.0.0.5", Port: "4028"}}, + }) + stream.PublishAck(&gatewaypb.ControlAck{CommandId: cmd.GetCommandId(), Succeeded: true, Code: gatewaypb.AckCode_ACK_CODE_OK}) + }() + var got []*pairingpb.Device + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.5"}, []string{"4028"}), collectBatches(&got)) + + // Assert + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, "auto:1", got[0].GetDeviceIdentifier()) +} + +func TestRunOnNode_PartialAckIsNotFailure(t *testing.T) { + // Arrange + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(8) + stream := reg.Register(nodeID) + defer stream.Unregister() + go func() { + cmd := <-stream.Outgoing + reg.PublishBatch(nodeID, cmd.GetCommandId(), &pairingpb.DiscoverResponse{ + Devices: []*pairingpb.Device{{DeviceIdentifier: "auto:2", IpAddress: "10.0.0.6", Port: "4028"}}, + }) + stream.PublishAck(&gatewaypb.ControlAck{CommandId: cmd.GetCommandId(), Code: gatewaypb.AckCode_ACK_CODE_PARTIAL}) + }() + var got []*pairingpb.Device + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.6"}, []string{"4028"}), collectBatches(&got)) + + // Assert: PARTIAL is a usable (incomplete) result, and its batch still streamed. + require.NoError(t, err) + require.Len(t, got, 1) +} + +func TestRunOnNode_DisconnectBeforeAckReturnsError(t *testing.T) { + // Arrange: agent takes the command then drops without acking. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(9) + stream := reg.Register(nodeID) + go func() { + <-stream.Outgoing + stream.Unregister() + }() + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.7"}, nil), func(*pairingpb.DiscoverResponse) error { return nil }) + + // Assert + require.Error(t, err) + var fe fleeterror.FleetError + require.ErrorAs(t, err, &fe) + assert.Equal(t, connect.CodeFailedPrecondition, fe.ConnectError().Code()) +} + +func TestRunOnNode_NoActiveStreamReturnsFailedPrecondition(t *testing.T) { + // Arrange: no stream registered for the target node. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + + // Act + err := svc.RunOnNode(context.Background(), 404, ipListReq([]string{"10.0.0.8"}, nil), func(*pairingpb.DiscoverResponse) error { return nil }) + + // Assert + require.Error(t, err) + var fe fleeterror.FleetError + require.ErrorAs(t, err, &fe) + assert.Equal(t, connect.CodeFailedPrecondition, fe.ConnectError().Code()) +} + +func TestConfirmedConnectedNodeIDs_IntersectsStatusAndConnection(t *testing.T) { + // Arrange: 1 = confirmed+connected, 2 = confirmed+disconnected, 3 = pending+connected. + reg := control.NewRegistry() + lister := stubLister{nodes: []enrollment.FleetNodeListing{ + {FleetNode: enrollment.FleetNode{ID: 1, EnrollmentStatus: enrollment.FleetNodeStatusConfirmed}}, + {FleetNode: enrollment.FleetNode{ID: 2, EnrollmentStatus: enrollment.FleetNodeStatusConfirmed}}, + {FleetNode: enrollment.FleetNode{ID: 3, EnrollmentStatus: enrollment.FleetNodeStatusPending}}, + }} + svc := NewService(reg, lister) + s1 := reg.Register(1) + defer s1.Unregister() + s3 := reg.Register(3) + defer s3.Unregister() + + // Act + got, err := svc.ConfirmedConnectedNodeIDs(context.Background(), 1) + + // Assert: only the confirmed AND connected node (order is unspecified). + require.NoError(t, err) + assert.ElementsMatch(t, []int64{1}, got) +} + +func TestRunOnNode_OnBatchErrorIsTerminal(t *testing.T) { + // Arrange: the agent emits a batch; the caller's onBatch reports its stream gone. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(11) + stream := reg.Register(nodeID) + defer stream.Unregister() + go func() { + cmd := <-stream.Outgoing + reg.PublishBatch(nodeID, cmd.GetCommandId(), &pairingpb.DiscoverResponse{ + Devices: []*pairingpb.Device{{DeviceIdentifier: "auto:x", IpAddress: "10.0.0.5", Port: "4028"}}, + }) + }() + sentinel := errors.New("operator stream gone") + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.5"}, []string{"4028"}), func(*pairingpb.DiscoverResponse) error { + return sentinel + }) + + // Assert: an onBatch error terminates RunOnNode with that error. + require.ErrorIs(t, err, sentinel) +} + +func TestRunOnNode_TimesOutWhenAgentNeverAcks(t *testing.T) { + // Arrange: shrink the command timeout, drain the command, but never ack. + prev := DiscoverCommandTimeout + DiscoverCommandTimeout = 100 * time.Millisecond + t.Cleanup(func() { DiscoverCommandTimeout = prev }) + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(12) + stream := reg.Register(nodeID) + defer stream.Unregister() + go func() { <-stream.Outgoing }() + + // Act + err := svc.RunOnNode(context.Background(), nodeID, ipListReq([]string{"10.0.0.5"}, nil), func(*pairingpb.DiscoverResponse) error { return nil }) + + // Assert + require.Error(t, err) + var ce *connect.Error + require.True(t, errors.As(err, &ce)) + assert.Equal(t, connect.CodeDeadlineExceeded, ce.Code()) +} + +func TestConfirmedConnectedNodeIDs_PropagatesListError(t *testing.T) { + // Arrange: the enrollment lookup fails. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{err: errors.New("db unavailable")}) + + // Act + _, err := svc.ConfirmedConnectedNodeIDs(context.Background(), 1) + + // Assert: the error propagates (fan-out treats it as best-effort upstream). + require.Error(t, err) +} + +func TestRunOnNode_DispatchesLocalSubnetTargetSentinel(t *testing.T) { + // Arrange: the LocalSubnetTarget sentinel (operator single-node scan or fan-out) + // must reach the agent unchanged so the node scans its own subnet. + reg := control.NewRegistry() + svc := NewService(reg, stubLister{}) + const nodeID = int64(21) + stream := reg.Register(nodeID) + defer stream.Unregister() + gotTarget := make(chan string, 1) + go func() { + cmd := <-stream.Outgoing + var req pairingpb.DiscoverRequest + _ = proto.Unmarshal(cmd.GetPayload(), &req) + gotTarget <- req.GetNmap().GetTarget() + stream.PublishAck(&gatewaypb.ControlAck{CommandId: cmd.GetCommandId(), Succeeded: true, Code: gatewaypb.AckCode_ACK_CODE_OK}) + }() + req := &pairingpb.DiscoverRequest{Mode: &pairingpb.DiscoverRequest_Nmap{ + Nmap: &pairingpb.NmapModeRequest{Target: nmaptarget.LocalSubnetTarget}, + }} + + // Act + err := svc.RunOnNode(context.Background(), nodeID, req, func(*pairingpb.DiscoverResponse) error { return nil }) + + // Assert + require.NoError(t, err) + assert.Equal(t, nmaptarget.LocalSubnetTarget, <-gotTarget) +} diff --git a/server/internal/domain/nmaptarget/nmaptarget.go b/server/internal/domain/nmaptarget/nmaptarget.go index cbd3aab7b..7a61781c5 100644 --- a/server/internal/domain/nmaptarget/nmaptarget.go +++ b/server/internal/domain/nmaptarget/nmaptarget.go @@ -29,6 +29,13 @@ var ( // multi-hour scan at the public internet. Operators split larger scopes. const MinIPv4PrefixBits = 22 +// LocalSubnetTarget is a reserved nmap target value meaning "the fleet node +// should scan the private (RFC1918) IPv4 subnet(s) of its own interfaces." The +// cloud fan-out sets it because it cannot know each node's local network. It is +// matched exactly (before any hostname resolution); the dashed, namespaced name +// keeps it from colliding with a real hostname. +const LocalSubnetTarget = "fleetnode-local-subnet" + // IsIPv4Range reports whether s is an "A.B.C.D-N" nmap range. func IsIPv4Range(s string) bool { return ipv4RangeRE.MatchString(s) } diff --git a/server/internal/domain/pairing/service.go b/server/internal/domain/pairing/service.go index ac7830f69..1b7766fcf 100644 --- a/server/internal/domain/pairing/service.go +++ b/server/internal/domain/pairing/service.go @@ -98,6 +98,17 @@ func shouldSkipNetworkOrGatewayAddress(ip net.IP) bool { return lastOctet == networkAddressLastOctet || lastOctet == gatewayAddressLastOctet } +// DeviceDedupKey is the identity used to dedupe discovered devices: the +// device_identifier, or "ip:port" when the plugin hasn't resolved one yet. +// Exported so the Discover handler's cross-source dedup stays in lockstep with +// dedupeDiscoverResponses instead of re-deriving the key. +func DeviceDedupKey(d *pb.Device) string { + if id := d.GetDeviceIdentifier(); id != "" { + return id + } + return d.GetIpAddress() + ":" + d.GetPort() +} + func dedupeDiscoverResponses(source <-chan *pb.DiscoverResponse) <-chan *pb.DiscoverResponse { resultChan := make(chan *pb.DiscoverResponse) @@ -117,10 +128,7 @@ func dedupeDiscoverResponses(source <-chan *pb.DiscoverResponse) <-chan *pb.Disc continue } - identity := device.DeviceIdentifier - if identity == "" { - identity = fmt.Sprintf("%s:%s", device.IpAddress, device.Port) - } + identity := DeviceDedupKey(device) if _, alreadySeen := seenDevices[identity]; alreadySeen { continue @@ -289,15 +297,19 @@ func mergeAutoDiscoveryTargets(baseTarget string, knownSubnets []string) []strin return targets } -func (s *Service) resolveNmapTargets(ctx context.Context, target string) ([]string, error) { - targets := []string{target} +// resolveNmapTargets returns the scan targets and whether `target` is the cloud +// host's own local subnet (isLocalSubnet) — the same condition that drives +// known-subnet expansion. Callers reuse isLocalSubnet to decide fleet-node +// fan-out without recomputing the local network. +func (s *Service) resolveNmapTargets(ctx context.Context, target string) (targets []string, isLocalSubnet bool, err error) { + targets = []string{target} localNetworkInfo, err := s.GetLocalNetworkInfo(ctx) if err != nil { slog.Debug("Skipping known-subnet expansion for nmap discovery because local network info is unavailable", "target", target, "error", err) - return targets, nil + return targets, false, nil } maskBits, shouldExpand := maskBitsForLocalSubnetTarget(target, localNetworkInfo.Subnet) @@ -305,14 +317,14 @@ func (s *Service) resolveNmapTargets(ctx context.Context, target string) ([]stri slog.Debug("Skipping known-subnet expansion because target does not match local subnet", "target", target, "local_subnet", localNetworkInfo.Subnet) - return targets, nil + return targets, false, nil } // Subnet expansion only runs for IPv4 targets matching the local subnet // (the guard above ensures this). Pass isIPv4=true directly. info, err := session.GetInfo(ctx) if err != nil { - return nil, err + return nil, false, err } knownSubnets, err := s.deviceStore.GetKnownSubnets(ctx, info.OrganizationID, maskBits, true) @@ -320,7 +332,7 @@ func (s *Service) resolveNmapTargets(ctx context.Context, target string) ([]stri slog.Debug("Skipping known-subnet expansion because subnet query failed", "target", target, "error", err) - return targets, nil + return targets, true, nil } expandedTargets := mergeAutoDiscoveryTargets(target, knownSubnets) @@ -331,7 +343,7 @@ func (s *Service) resolveNmapTargets(ctx context.Context, target string) ([]stri "organization_id", info.OrganizationID) } - return expandedTargets, nil + return expandedTargets, true, nil } // validateNmapTargets validates targets and resolves hostnames to IP literals @@ -471,18 +483,20 @@ func (s *Service) DiscoverWithMDNS(ctx context.Context, r *pb.MDNSModeRequest) ( return resultChan, nil } -// DiscoverWithNmap discovers devices using Nmap -func (s *Service) DiscoverWithNmap(ctx context.Context, r *pb.NmapModeRequest) (<-chan *pb.DiscoverResponse, error) { +// DiscoverWithNmap discovers devices using Nmap. isLocalSubnet reports whether +// the target is the cloud host's own local subnet (the "Scan your network" +// action), which the Discover handler uses to gate fleet-node fan-out. +func (s *Service) DiscoverWithNmap(ctx context.Context, r *pb.NmapModeRequest) (results <-chan *pb.DiscoverResponse, isLocalSubnet bool, err error) { if r.Target == "" { - return nil, fleeterror.NewInvalidArgumentError("nmap discovery target is required") + return nil, false, fleeterror.NewInvalidArgumentError("nmap discovery target is required") } ports, err := s.resolveDiscoveryPorts(ctx, r.Ports) if err != nil { - return nil, err + return nil, false, err } - targets, err := s.resolveNmapTargets(ctx, r.Target) + targets, isLocalSubnet, err := s.resolveNmapTargets(ctx, r.Target) if err != nil { - return nil, err + return nil, false, err } // Apply server-controlled timeout before any DNS work so hostname @@ -492,7 +506,7 @@ func (s *Service) DiscoverWithNmap(ctx context.Context, r *pb.NmapModeRequest) ( targets, useIPv6Scanning, err := validateNmapTargets(timeoutCtx, targets, net.DefaultResolver.LookupIPAddr) if err != nil { cancel() - return nil, err + return nil, false, err } // Create channels after validation to avoid leaking the dedupe goroutine on early returns. @@ -658,7 +672,7 @@ func (s *Service) DiscoverWithNmap(ctx context.Context, r *pb.NmapModeRequest) ( wg.Wait() }() - return resultChan, nil + return resultChan, isLocalSubnet, nil } // DiscoverWithIPRange discovers devices using an IPv4 IP range. diff --git a/server/internal/domain/pairing/service_internal_test.go b/server/internal/domain/pairing/service_internal_test.go index 8fe6e5c17..2c2d29389 100644 --- a/server/internal/domain/pairing/service_internal_test.go +++ b/server/internal/domain/pairing/service_internal_test.go @@ -240,9 +240,10 @@ func TestResolveNmapTargets_ExpandsLocalSubnetWithKnownSubnets(t *testing.T) { GetKnownSubnets(gomock.Any(), int64(42), 24, true). Return([]string{"192.168.25.0/24", "192.168.1.0/24", "not-a-cidr"}, nil) - targets, err := service.resolveNmapTargets(ctx, "192.168.1.0/24") + targets, isLocalSubnet, err := service.resolveNmapTargets(ctx, "192.168.1.0/24") require.NoError(t, err) require.Equal(t, []string{"192.168.1.0/24", "192.168.25.0/24"}, targets) + require.True(t, isLocalSubnet) } func TestResolveNmapTargets_SkipsExpansionForNonLocalTargets(t *testing.T) { @@ -259,9 +260,10 @@ func TestResolveNmapTargets_SkipsExpansionForNonLocalTargets(t *testing.T) { ctx := mockSessionContext(t.Context(), 1, 42) - targets, err := service.resolveNmapTargets(ctx, "192.168.25.0/24") + targets, isLocalSubnet, err := service.resolveNmapTargets(ctx, "192.168.25.0/24") require.NoError(t, err) require.Equal(t, []string{"192.168.25.0/24"}, targets) + require.False(t, isLocalSubnet) } func TestResolveNmapTargets_FallsBackWhenLocalNetworkInfoFails(t *testing.T) { @@ -278,9 +280,10 @@ func TestResolveNmapTargets_FallsBackWhenLocalNetworkInfoFails(t *testing.T) { ctx := mockSessionContext(t.Context(), 1, 42) - targets, err := service.resolveNmapTargets(ctx, "192.168.1.0/24") + targets, isLocalSubnet, err := service.resolveNmapTargets(ctx, "192.168.1.0/24") require.NoError(t, err) require.Equal(t, []string{"192.168.1.0/24"}, targets) + require.False(t, isLocalSubnet, "no local network info means we can't confirm a local-subnet scan") } func TestResolveNmapTargets_DoesNotExpandIPv6Targets(t *testing.T) { @@ -302,9 +305,10 @@ func TestResolveNmapTargets_DoesNotExpandIPv6Targets(t *testing.T) { // IPv6 targets should not be auto-expanded because IPv6 subnets are // too large for nmap sweeps. - targets, err := service.resolveNmapTargets(ctx, "fd00::/64") + targets, isLocalSubnet, err := service.resolveNmapTargets(ctx, "fd00::/64") require.NoError(t, err) require.Equal(t, []string{"fd00::/64"}, targets) + require.False(t, isLocalSubnet, "IPv6 target is never treated as the local-subnet scan") } func TestValidateNmapTargets(t *testing.T) { diff --git a/server/internal/handlers/fleetnode/admin/handler.go b/server/internal/handlers/fleetnode/admin/handler.go index 2f2898fa8..715a9ef06 100644 --- a/server/internal/handlers/fleetnode/admin/handler.go +++ b/server/internal/handlers/fleetnode/admin/handler.go @@ -2,31 +2,19 @@ package admin import ( "context" - "errors" - "fmt" - "log/slog" - "net/netip" - "strconv" - "time" "connectrpc.com/connect" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" pb "github.com/block/proto-fleet/server/generated/grpc/fleetnodeadmin/v1" "github.com/block/proto-fleet/server/generated/grpc/fleetnodeadmin/v1/fleetnodeadminv1connect" - gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" "github.com/block/proto-fleet/server/internal/domain/authz" - "github.com/block/proto-fleet/server/internal/domain/discoverylimits" "github.com/block/proto-fleet/server/internal/domain/fleeterror" - "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" "github.com/block/proto-fleet/server/internal/domain/fleetnode/pairing" - "github.com/block/proto-fleet/server/internal/domain/netutil" - "github.com/block/proto-fleet/server/internal/domain/nmaptarget" "github.com/block/proto-fleet/server/internal/handlers/middleware" - "github.com/block/proto-fleet/server/internal/infrastructure/id" ) type Handler struct { @@ -34,13 +22,13 @@ type Handler struct { enrollment *enrollment.Service pairing *pairing.Service - registry *control.Registry + discovery *discovery.Service } var _ fleetnodeadminv1connect.FleetNodeAdminServiceHandler = &Handler{} -func NewHandler(enrollment *enrollment.Service, pairing *pairing.Service, registry *control.Registry) *Handler { - return &Handler{enrollment: enrollment, pairing: pairing, registry: registry} +func NewHandler(enrollment *enrollment.Service, pairing *pairing.Service, discoverySvc *discovery.Service) *Handler { + return &Handler{enrollment: enrollment, pairing: pairing, discovery: discoverySvc} } func (h *Handler) CreateEnrollmentCode(ctx context.Context, _ *connect.Request[pb.CreateEnrollmentCodeRequest]) (*connect.Response[pb.CreateEnrollmentCodeResponse], error) { @@ -165,14 +153,9 @@ func (h *Handler) ListFleetNodeDevices(ctx context.Context, req *connect.Request return connect.NewResponse(resp), nil } -// DiscoverCommandTimeout bounds how long DiscoverOnFleetNode waits for the -// agent's batches and ack, so a silent node can't pin operator streams and -// registry slots. Must exceed the agent's scan budget (commandTimeout, 10m) -// plus report/ack slack: too short frees the slot mid-scan, the agent's ack is -// rejected as stale, and a new command dispatches while the node is still busy. -// Var for tests. -var DiscoverCommandTimeout = 12 * time.Minute - +// DiscoverOnFleetNode runs discovery on a single CONFIRMED node and streams the +// node's device batches back to the operator. See discovery.RunOnNode for the +// dispatch/drain loop. func (h *Handler) DiscoverOnFleetNode(ctx context.Context, req *connect.Request[pb.DiscoverOnFleetNodeRequest], stream *connect.ServerStream[pb.DiscoverOnFleetNodeResponse]) error { info, err := middleware.RequirePermission(ctx, authz.PermFleetnodeManage, authz.ResourceContext{}) if err != nil { @@ -195,254 +178,12 @@ func (h *Handler) DiscoverOnFleetNode(ctx context.Context, req *connect.Request[ return fleeterror.NewFailedPreconditionError("fleet node is not CONFIRMED") } - normalized, err := normalizeDiscoverRequest(discoverReq) - if err != nil { - return err - } - - commandID := id.GenerateID() - payload, err := proto.Marshal(normalized) - if err != nil { - return fleeterror.NewInternalErrorf("marshal discover payload: %v", err) - } - - ctx, cancel := context.WithTimeout(ctx, DiscoverCommandTimeout) - defer cancel() - - session, err := h.registry.Send(ctx, fleetNodeID, &gatewaypb.ControlCommand{ - CommandId: commandID, - Payload: payload, - }, buildReportScope(normalized)) - if err != nil { - if errors.Is(err, control.ErrNoActiveStream) { - return fleeterror.NewFailedPreconditionError("fleet node has no active control stream") - } - return err - } - defer session.Close() - - // handleEvent forwards a batch or resolves the command on an ack. terminal - // reports whether the command is finished; err is set only on failure. - handleEvent := func(ev control.CommandEvent) (terminal bool, err error) { - switch { - case ev.Batch != nil: - if sendErr := stream.Send(&pb.DiscoverOnFleetNodeResponse{Response: ev.Batch}); sendErr != nil { - return true, fleeterror.NewInternalErrorf("send batch to operator: %v", sendErr) - } - return false, nil - case ev.Ack != nil: - // PARTIAL carries succeeded=false but its reports already streamed; - // treat it as a usable (incomplete) result, not a failure. - if ev.Ack.GetCode() == gatewaypb.AckCode_ACK_CODE_PARTIAL { - slog.Warn("fleet node discovery completed partially", - "fleet_node_id", fleetNodeID, "detail", ev.Ack.GetErrorMessage()) - return true, nil - } - // Require the structured OK code, not just the boolean, so an - // inconsistent ack (succeeded=true with a non-OK/unset code) can't - // pass a failed scan off as success. - if ev.Ack.GetCode() != gatewaypb.AckCode_ACK_CODE_OK || !ev.Ack.GetSucceeded() { - return true, discoverAckFailure(ev.Ack) - } - return true, nil - default: - return false, nil - } - } - - events := session.Events() - for { - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return connect.NewError(connect.CodeDeadlineExceeded, fmt.Errorf("discovery command timed out after %s", DiscoverCommandTimeout)) - } - return fleeterror.NewInternalErrorf("operator stream cancelled: %v", ctx.Err()) - case ev := <-events: - if terminal, err := handleEvent(ev); terminal { - return err - } - case <-session.Done(): - // Stream died before an ack. Drain buffered events first (a final - // ack or last batch) so select randomness doesn't drop them. - for { - select { - case ev := <-events: - if terminal, err := handleEvent(ev); terminal { - return err - } - default: - return fleeterror.NewFailedPreconditionError("fleet node control stream closed before command completed") - } - } - } - } -} - -// discoverAckFailure maps a non-OK ack to an operator-facing error, even when -// error_message is empty. The structured AckCode drives the gRPC code so the -// operator can tell a retryable condition (BUSY) and a capability gap -// (AGENT_INCAPABLE) apart from a malformed request (BAD_REQUEST); anything else -// is an opaque Internal failure. -func discoverAckFailure(ack *gatewaypb.ControlAck) error { - reason := ack.GetErrorMessage() - if reason == "" { - reason = "code " + ack.GetCode().String() - } - // if/else (not switch) so the exhaustive linter doesn't demand a case per - // AckCode; everything outside these three is an opaque Internal failure. - code := ack.GetCode() - if code == gatewaypb.AckCode_ACK_CODE_BAD_REQUEST { - return fleeterror.NewInvalidArgumentErrorf("fleet node rejected discovery command: %s", reason) - } - if code == gatewaypb.AckCode_ACK_CODE_BUSY { - return fleeterror.NewPlainError( - fmt.Sprintf("fleet node is busy with another command; retry shortly: %s", reason), - connect.CodeResourceExhausted, - ) - } - if code == gatewaypb.AckCode_ACK_CODE_AGENT_INCAPABLE { - return fleeterror.NewFailedPreconditionErrorf("fleet node cannot service this discovery request; try another node: %s", reason) - } - return fleeterror.NewInternalErrorf("fleet node reported discovery failure: %s", reason) -} - -func normalizeDiscoverRequest(in *pairingpb.DiscoverRequest) (*pairingpb.DiscoverRequest, error) { - switch m := in.GetMode().(type) { - case *pairingpb.DiscoverRequest_IpList: - if m.IpList == nil || len(m.IpList.GetIpAddresses()) == 0 { - return nil, fleeterror.NewInvalidArgumentError("ip_list.ip_addresses must not be empty") - } - if err := checkScanLimits(m.IpList.GetIpAddresses(), m.IpList.GetPorts()); err != nil { - return nil, err - } - // Every entry must be a valid IP or hostname, and IP literals must be - // private. A malformed token (e.g. "bad/entry") is unresolvable for the - // agent yet trips the scope matcher's hostname fallback, widening the - // command to port-only scope. A public literal scans fine but every report - // is rejected by validateReport (private-only), surfacing as a late - // REPORT_FAILED. Hostnames resolve agent-side to an IP the server can't - // check here, so they pass through. - for _, e := range m.IpList.GetIpAddresses() { - addr, perr := netip.ParseAddr(e) - if perr != nil { - if !nmaptarget.IsHostname(e) { - return nil, fleeterror.NewInvalidArgumentErrorf("ip_list entry %q is not a valid IP address or hostname", e) - } - continue - } - if !addr.Unmap().IsPrivate() { - return nil, fleeterror.NewInvalidArgumentErrorf("ip_list entry %q is not a private (RFC1918/RFC4193) address", e) - } - } - return in, nil - case *pairingpb.DiscoverRequest_IpRange: - ips, err := expandIPv4Range(m.IpRange.GetStartIp(), m.IpRange.GetEndIp()) - if err != nil { - return nil, err - } - if err := checkScanLimits(ips, m.IpRange.GetPorts()); err != nil { - return nil, err - } - return &pairingpb.DiscoverRequest{ - Mode: &pairingpb.DiscoverRequest_IpList{ - IpList: &pairingpb.IPListModeRequest{ - IpAddresses: ips, - Ports: m.IpRange.GetPorts(), - }, - }, - }, nil - case *pairingpb.DiscoverRequest_Nmap: - target := m.Nmap.GetTarget() - // Validate against the shared grammar (incl. the /22 CIDR cap), then - // reject IPv6 CIDR — both rejections the agent makes — so an unsupported - // target fails fast here instead of as a late agent BAD_REQUEST ack. - if err := nmaptarget.Validate(target); err != nil { - return nil, fleeterror.NewInvalidArgumentError(err.Error()) - } - if prefix, perr := netip.ParsePrefix(target); perr == nil && prefix.Addr().Is6() { - return nil, fleeterror.NewInvalidArgumentError("nmap IPv6 CIDR is not supported; use ip_list for IPv6 devices") - } - // A public target scans fine but every report comes back non-private and - // is rejected by validateReport, so fail fast. Hostnames resolve agent-side - // and pass through (the report validator still guards what they return). - if !nmapTargetIsPrivate(target) { - return nil, fleeterror.NewInvalidArgumentError("nmap target must be within a private (RFC1918/RFC4193) range") - } - if err := checkScanLimits(nil, m.Nmap.GetPorts()); err != nil { - return nil, err - } - return in, nil - case *pairingpb.DiscoverRequest_Mdns: - return nil, fleeterror.NewInvalidArgumentError("mdns discovery is not supported on fleet nodes") - default: - return nil, fleeterror.NewInvalidArgumentError("discover request mode is required") - } -} - -// checkScanLimits enforces the agent's per-command caps (via discoverylimits) -// and rejects malformed ports before dispatch, so an over-cap or invalid request -// fails fast with a validation error instead of a late agent BAD_REQUEST ack. -// The proto caps are the wire ceiling; these are the real limits. -func checkScanLimits(ipAddresses, ports []string) error { - if len(ipAddresses) > discoverylimits.MaxScanTargets { - return fleeterror.NewInvalidArgumentErrorf("too many targets: %d exceeds the limit of %d", len(ipAddresses), discoverylimits.MaxScanTargets) - } - if len(ports) > discoverylimits.MaxPortsPerIP { - return fleeterror.NewInvalidArgumentErrorf("too many ports: %d exceeds the limit of %d", len(ports), discoverylimits.MaxPortsPerIP) - } - // Each port must be a bare decimal in 1-65535, matching the agent's - // resolveAndValidatePorts; otherwise a token like "80/tcp" or "70000" - // dispatches and returns as a late agent BAD_REQUEST ack. - for _, p := range ports { - if n, err := strconv.Atoi(p); err != nil || n < 1 || n > 65535 { - return fleeterror.NewInvalidArgumentErrorf("invalid port %q: must be a decimal in 1-65535", p) + return h.discovery.RunOnNode(ctx, fleetNodeID, discoverReq, func(batch *pairingpb.DiscoverResponse) error { + if sendErr := stream.Send(&pb.DiscoverOnFleetNodeResponse{Response: batch}); sendErr != nil { + return fleeterror.NewInternalErrorf("send batch to operator: %v", sendErr) } - } - return nil -} - -func expandIPv4Range(startStr, endStr string) ([]string, error) { - startAddr, err := netutil.ParseIPv4(startStr) - if err != nil { - return nil, fleeterror.NewInvalidArgumentErrorf("invalid start_ip: %v", err) - } - endAddr, err := netutil.ParseIPv4(endStr) - if err != nil { - return nil, fleeterror.NewInvalidArgumentErrorf("invalid end_ip: %v", err) - } - // Both ends must be private. The MaxScanTargets cap below keeps the range far - // smaller than the gap between RFC1918 blocks, so private endpoints imply a - // fully private range. A public range scans fine but every report is rejected - // by validateReport, surfacing as a late REPORT_FAILED. - if !startAddr.IsPrivate() || !endAddr.IsPrivate() { - return nil, fleeterror.NewInvalidArgumentError("ip range must be within a private (RFC1918) range") - } - start, end := netutil.IPv4ToUint32(startAddr), netutil.IPv4ToUint32(endAddr) - if end < start { - return nil, fleeterror.NewInvalidArgumentError("end_ip must be >= start_ip") - } - // Skip the network (.0) and gateway (.1) start addresses, matching the agent - // and cloud pairing. Otherwise expanding to an IP list would scan .0/.1 as - // literal targets — gateways answer on many ports and look like miners. - start = netutil.AdjustIPv4RangeStart(start) - if end < start { - return nil, fleeterror.NewInvalidArgumentError("ip range covers only network/gateway addresses") - } - // uint64 math so a range ending at 255.255.255.255 can't wrap (in uint32, - // end-start+1 would overflow to 0, bypassing the cap and never terminating). - size := uint64(end) - uint64(start) + 1 - if size > discoverylimits.MaxScanTargets { - return nil, fleeterror.NewInvalidArgumentErrorf("ip range exceeds %d addresses", discoverylimits.MaxScanTargets) - } - out := make([]string, 0, size) - for v := start; ; v++ { - out = append(out, netutil.Uint32ToIPv4(v)) - if v == end { - break - } - } - return out, nil + return nil + }) } // AWAITING_CONFIRMATION lives only on pending_enrollment, so a PENDING fleet diff --git a/server/internal/handlers/fleetnode/admin/handler_discover_test.go b/server/internal/handlers/fleetnode/admin/handler_discover_test.go index 825365035..98bcf6323 100644 --- a/server/internal/handlers/fleetnode/admin/handler_discover_test.go +++ b/server/internal/handlers/fleetnode/admin/handler_discover_test.go @@ -19,8 +19,8 @@ import ( gatewaypb "github.com/block/proto-fleet/server/generated/grpc/fleetnodegateway/v1" pairingpb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" "github.com/block/proto-fleet/server/internal/domain/authz" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" "github.com/block/proto-fleet/server/internal/domain/session" - "github.com/block/proto-fleet/server/internal/handlers/fleetnode/admin" "github.com/block/proto-fleet/server/internal/handlers/interceptors" "github.com/block/proto-fleet/server/internal/handlers/middleware" ) @@ -557,9 +557,9 @@ func TestDiscoverOnFleetNode_TimesOutWhenAgentNeverResponds(t *testing.T) { // Arrange: register an agent stream but never publish batch or ack. // Override DiscoverCommandTimeout to a short window so the test // terminates quickly. - prev := admin.DiscoverCommandTimeout - admin.DiscoverCommandTimeout = 200 * time.Millisecond - t.Cleanup(func() { admin.DiscoverCommandTimeout = prev }) + prev := discovery.DiscoverCommandTimeout + discovery.DiscoverCommandTimeout = 200 * time.Millisecond + t.Cleanup(func() { discovery.DiscoverCommandTimeout = prev }) h := newPairingHarness(t) fleetNodeID := h.createFleetNode(t, "admin-discover-timeout") diff --git a/server/internal/handlers/fleetnode/admin/handler_pairing_test.go b/server/internal/handlers/fleetnode/admin/handler_pairing_test.go index 4a3462b58..c6d4602b5 100644 --- a/server/internal/handlers/fleetnode/admin/handler_pairing_test.go +++ b/server/internal/handlers/fleetnode/admin/handler_pairing_test.go @@ -19,6 +19,7 @@ import ( "github.com/block/proto-fleet/server/internal/domain/authz" "github.com/block/proto-fleet/server/internal/domain/fleeterror" "github.com/block/proto-fleet/server/internal/domain/fleetnode/control" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" "github.com/block/proto-fleet/server/internal/domain/fleetnode/enrollment" "github.com/block/proto-fleet/server/internal/domain/fleetnode/pairing" "github.com/block/proto-fleet/server/internal/domain/session" @@ -57,8 +58,9 @@ func newPairingHarness(t *testing.T) *pairingHarness { pairingSvc := pairing.NewService(pairingStore, enrollmentStore, transactor) registry := control.NewRegistry() + discoverySvc := discovery.NewService(registry, enrollmentSvc) return &pairingHarness{ - handler: admin.NewHandler(enrollmentSvc, pairingSvc, registry), + handler: admin.NewHandler(enrollmentSvc, pairingSvc, discoverySvc), db: db, orgID: 1, enrollment: enrollmentSvc, diff --git a/server/internal/handlers/pairing/forwarder.go b/server/internal/handlers/pairing/forwarder.go new file mode 100644 index 000000000..706063fda --- /dev/null +++ b/server/internal/handlers/pairing/forwarder.go @@ -0,0 +1,69 @@ +package pairing + +import ( + "sync" + + pb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" + "github.com/block/proto-fleet/server/internal/domain/pairing" +) + +// dedupForwarder serializes concurrent Discover sources (the cloud scan and each +// fleet node) onto one server stream — Connect streams are not safe for +// concurrent Send — and dedupes devices across sources by pairing.DeviceDedupKey. +// The first Send error is recorded and onErr is invoked once so the caller can +// cancel the remaining sources. +type dedupForwarder struct { + mu sync.Mutex + seen map[string]struct{} + send func(*pb.DiscoverResponse) error + onErr func() + err error +} + +func newDedupForwarder(send func(*pb.DiscoverResponse) error, onErr func()) *dedupForwarder { + return &dedupForwarder{seen: make(map[string]struct{}), send: send, onErr: onErr} +} + +// forward dedupes resp's devices across sources and forwards it. A batch reduced +// entirely to duplicates (with no error payload) is dropped. Once any Send has +// failed, forward returns that error without sending again. +func (f *dedupForwarder) forward(resp *pb.DiscoverResponse) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.err != nil { + return f.err + } + out := resp + if len(resp.GetDevices()) > 0 { + deduped := make([]*pb.Device, 0, len(resp.GetDevices())) + for _, d := range resp.GetDevices() { + key := pairing.DeviceDedupKey(d) + if _, dup := f.seen[key]; dup { + continue + } + f.seen[key] = struct{}{} + deduped = append(deduped, d) + } + if len(deduped) == 0 && resp.GetError() == "" { + return nil // whole batch was duplicates; nothing to forward + } + if len(deduped) < len(resp.GetDevices()) { + out = &pb.DiscoverResponse{Devices: deduped, Error: resp.GetError()} + } + } + if err := f.send(out); err != nil { + f.err = err + if f.onErr != nil { + f.onErr() + } + return err + } + return nil +} + +// failure returns the first Send error, if any. +func (f *dedupForwarder) failure() error { + f.mu.Lock() + defer f.mu.Unlock() + return f.err +} diff --git a/server/internal/handlers/pairing/forwarder_test.go b/server/internal/handlers/pairing/forwarder_test.go new file mode 100644 index 000000000..fd836ab2b --- /dev/null +++ b/server/internal/handlers/pairing/forwarder_test.go @@ -0,0 +1,104 @@ +package pairing + +import ( + "errors" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + pb "github.com/block/proto-fleet/server/generated/grpc/pairing/v1" +) + +func dev(id, ip, port string) *pb.Device { + return &pb.Device{DeviceIdentifier: id, IpAddress: ip, Port: port} +} + +func TestDedupForwarder_DedupesAcrossSources(t *testing.T) { + // Arrange + var sent []*pb.DiscoverResponse + fwd := newDedupForwarder(func(r *pb.DiscoverResponse) error { sent = append(sent, r); return nil }, nil) + + // Act: two sources report a shared device (mac:a) plus distinct ones. + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80"), dev("mac:b", "10.0.0.2", "80")}})) + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80"), dev("mac:c", "10.0.0.3", "80")}})) + + // Assert: mac:a forwarded once; the second batch is reduced to just mac:c. + require.Len(t, sent, 2) + assert.Len(t, sent[0].GetDevices(), 2) + require.Len(t, sent[1].GetDevices(), 1) + assert.Equal(t, "mac:c", sent[1].GetDevices()[0].GetDeviceIdentifier()) +} + +func TestDedupForwarder_IPPortFallbackKey(t *testing.T) { + // Arrange: devices without an identifier dedupe by ip:port. + var sent []*pb.DiscoverResponse + fwd := newDedupForwarder(func(r *pb.DiscoverResponse) error { sent = append(sent, r); return nil }, nil) + + // Act + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("", "10.0.0.5", "4028")}})) + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("", "10.0.0.5", "4028")}})) + + // Assert: the second (all-duplicate) batch is dropped. + require.Len(t, sent, 1) +} + +func TestDedupForwarder_DropsAllDuplicateBatchButKeepsErrorResponse(t *testing.T) { + // Arrange + var sent []*pb.DiscoverResponse + fwd := newDedupForwarder(func(r *pb.DiscoverResponse) error { sent = append(sent, r); return nil }, nil) + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80")}})) + + // Act: a fully-duplicate batch is dropped; an error-only response is forwarded. + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80")}})) + require.NoError(t, fwd.forward(&pb.DiscoverResponse{Error: "scan failed"})) + + // Assert + require.Len(t, sent, 2) + assert.Equal(t, "scan failed", sent[1].GetError()) +} + +func TestDedupForwarder_SendErrorRecordedAndCancels(t *testing.T) { + // Arrange + sendErr := errors.New("stream gone") + var cancelled bool + fwd := newDedupForwarder(func(*pb.DiscoverResponse) error { return sendErr }, func() { cancelled = true }) + + // Act + err1 := fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:a", "10.0.0.1", "80")}}) + err2 := fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev("mac:b", "10.0.0.2", "80")}}) + + // Assert: the first failure records the error and cancels; the second short-circuits. + require.ErrorIs(t, err1, sendErr) + require.ErrorIs(t, err2, sendErr) + assert.True(t, cancelled) + require.ErrorIs(t, fwd.failure(), sendErr) +} + +func TestDedupForwarder_ConcurrentForwardIsSerialized(t *testing.T) { + // Arrange: many goroutines forward distinct devices; -race verifies safety. + var mu sync.Mutex + count := 0 + fwd := newDedupForwarder(func(r *pb.DiscoverResponse) error { + mu.Lock() + count += len(r.GetDevices()) + mu.Unlock() + return nil + }, nil) + var wg sync.WaitGroup + + // Act + for i := range 50 { + wg.Add(1) + go func(i int) { + defer wg.Done() + _ = fwd.forward(&pb.DiscoverResponse{Devices: []*pb.Device{dev(fmt.Sprintf("mac:%d", i), "10.0.0.1", "80")}}) + }(i) + } + wg.Wait() + + // Assert: all 50 distinct devices forwarded exactly once. + assert.Equal(t, 50, count) +} diff --git a/server/internal/handlers/pairing/handler.go b/server/internal/handlers/pairing/handler.go index b0c0024d6..0d1698e96 100644 --- a/server/internal/handlers/pairing/handler.go +++ b/server/internal/handlers/pairing/handler.go @@ -2,10 +2,14 @@ package pairing import ( "context" + "errors" "log/slog" + "sync" "github.com/block/proto-fleet/server/internal/domain/authz" "github.com/block/proto-fleet/server/internal/domain/fleeterror" + "github.com/block/proto-fleet/server/internal/domain/fleetnode/discovery" + "github.com/block/proto-fleet/server/internal/domain/nmaptarget" "github.com/block/proto-fleet/server/internal/handlers/middleware" "connectrpc.com/connect" @@ -17,59 +21,133 @@ import ( // Handler handles the Connect-RPC endpoints type Handler struct { pairingSvc *pairing.Service + // discovery fans the "Scan your network" nmap action out to connected fleet + // nodes; nil disables fan-out (cloud-only discovery). + discovery *discovery.Service } var _ pairingv1connect.PairingServiceHandler = &Handler{} // NewHandler creates a new instance of Handler -func NewHandler(pairingSvc *pairing.Service) *Handler { +func NewHandler(pairingSvc *pairing.Service, discoverySvc *discovery.Service) *Handler { return &Handler{ pairingSvc: pairingSvc, + discovery: discoverySvc, } } -// Discover implements pairingv1connect.DeviceDiscoveryServiceHandler. +// Discover implements pairingv1connect.PairingServiceHandler. An nmap "Scan your +// network" request also fans out to every CONFIRMED + connected fleet node and +// merges their LAN-local results into this stream; other modes are cloud-only. func (h *Handler) Discover(ctx context.Context, r *connect.Request[pb.DiscoverRequest], s *connect.ServerStream[pb.DiscoverResponse]) error { - if _, err := middleware.RequirePermission(ctx, authz.PermMinerPair, authz.ResourceContext{}); err != nil { + info, err := middleware.RequirePermission(ctx, authz.PermMinerPair, authz.ResourceContext{}) + if err != nil { return err } slog.Debug("Discover: handling discover request", "payload", r.Msg) + + // A send failure (operator disconnected) cancels every source. + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Serialize the concurrent sources (cloud scan + each node) onto the one + // stream and dedupe devices across them; a Send failure cancels the rest. + fwd := newDedupForwarder(s.Send, cancel) + var resultChan <-chan *pb.DiscoverResponse - var err error + var isLocalSubnetNmap bool switch r.Msg.Mode.(type) { case *pb.DiscoverRequest_IpList: - resultChan, err = h.pairingSvc.DiscoverWithIPList(ctx, r.Msg.GetIpList()) + resultChan, err = h.pairingSvc.DiscoverWithIPList(streamCtx, r.Msg.GetIpList()) case *pb.DiscoverRequest_IpRange: - resultChan, err = h.pairingSvc.DiscoverWithIPRange(ctx, r.Msg.GetIpRange()) + resultChan, err = h.pairingSvc.DiscoverWithIPRange(streamCtx, r.Msg.GetIpRange()) case *pb.DiscoverRequest_Nmap: - resultChan, err = h.pairingSvc.DiscoverWithNmap(ctx, r.Msg.GetNmap()) + resultChan, isLocalSubnetNmap, err = h.pairingSvc.DiscoverWithNmap(streamCtx, r.Msg.GetNmap()) case *pb.DiscoverRequest_Mdns: - resultChan, err = h.pairingSvc.DiscoverWithMDNS(ctx, r.Msg.GetMdns()) + resultChan, err = h.pairingSvc.DiscoverWithMDNS(streamCtx, r.Msg.GetMdns()) default: return fleeterror.NewInternalError("unsupported mode") } - if err != nil { return err } - for { - select { - case result, ok := <-resultChan: - if !ok { - return nil - } - res := &pb.DiscoverResponse{ - Devices: result.Devices, + var wg sync.WaitGroup + + // Cloud discovery source. + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case result, ok := <-resultChan: + if !ok { + return + } + if err := fwd.forward(result); err != nil { + return + } + case <-streamCtx.Done(): + return } - if err := s.Send(res); err != nil { - // nolint:wrapcheck - return err + } + }() + + // Fan out only for the automatic "Scan your network" action (nmap target == + // the cloud's own local subnet), never a manual/explicit target, and only for + // callers who also hold fleetnode:manage — the same permission the single-node + // DiscoverOnFleetNode path requires. Without it, discovery stays cloud-only so + // the weaker miner:pair grant can't drive discovery commands on fleet nodes. + if isLocalSubnetNmap && h.discovery != nil && callerCanManageFleetNodes(ctx) { + nodeIDs, listErr := h.discovery.ConfirmedConnectedNodeIDs(streamCtx, info.OrganizationID) + if listErr != nil { + // Fan-out is best-effort; a lookup failure must never break the + // cloud scan. With zero connected nodes this is the same path. + slog.Warn("skipping fleet node discovery fan-out", "error", listErr) + } else { + autoReq := &pb.DiscoverRequest{Mode: &pb.DiscoverRequest_Nmap{Nmap: &pb.NmapModeRequest{ + Target: nmaptarget.LocalSubnetTarget, + Ports: r.Msg.GetNmap().GetPorts(), + }}} + for _, nodeID := range nodeIDs { + wg.Add(1) + go func(nodeID int64) { + defer wg.Done() + // Each node is bounded by RunOnNode's per-node timeout. + runErr := h.discovery.RunOnNode(streamCtx, nodeID, autoReq, fwd.forward) + // One node failing must not fail the scan, and is expected on + // operator disconnect — stay quiet once streamCtx is done. + if runErr != nil && streamCtx.Err() == nil { + slog.Warn("fleet node discovery failed during cloud fan-out", + "fleet_node_id", nodeID, "error", runErr) + } + }(nodeID) } - case <-ctx.Done(): - return fleeterror.NewCanceledError() } } + + wg.Wait() + if err := fwd.failure(); err != nil { + return err + } + // A client cancel/deadline drains the sources without a Send error; report it + // rather than success. (A fan-out-budget expiry is not a client error.) + if ctxErr := ctx.Err(); ctxErr != nil { + if errors.Is(ctxErr, context.DeadlineExceeded) { + return connect.NewError(connect.CodeDeadlineExceeded, ctxErr) + } + return fleeterror.NewCanceledError() + } + return nil +} + +// callerCanManageFleetNodes reports whether the request holds fleetnode:manage. +// It reuses the canonical permission path (so the synthesized-actor and +// fail-closed semantics match) but treats absence as a soft signal to skip +// fan-out rather than an error to return. +func callerCanManageFleetNodes(ctx context.Context) bool { + _, err := middleware.RequirePermission(ctx, authz.PermFleetnodeManage, authz.ResourceContext{}) + return err == nil } // Pair implements pairingv1connect.PairingServiceHandler. diff --git a/server/internal/handlers/pairing/handler_internal_test.go b/server/internal/handlers/pairing/handler_internal_test.go new file mode 100644 index 000000000..71f1e131b --- /dev/null +++ b/server/internal/handlers/pairing/handler_internal_test.go @@ -0,0 +1,74 @@ +package pairing + +import ( + "context" + "testing" + + "connectrpc.com/authn" + "github.com/stretchr/testify/assert" + + "github.com/block/proto-fleet/server/internal/domain/authz" + "github.com/block/proto-fleet/server/internal/domain/session" + "github.com/block/proto-fleet/server/internal/handlers/middleware" +) + +// ctxWithPerms builds the request context the auth interceptor would produce: +// session info plus the caller's effective org-scoped permissions. +func ctxWithPerms(perms ...string) context.Context { + info := &session.Info{ + AuthMethod: session.AuthMethodSession, + SessionID: "sess-1", + UserID: 1, + OrganizationID: 1, + ExternalUserID: "user-1", + Username: "alice", + } + ctx := authn.SetInfo(context.Background(), info) + return middleware.WithEffectivePermissions(ctx, authz.NewEffectivePermissions( + []authz.Assignment{{AssignmentID: 1, ScopeType: authz.ScopeOrg, Permissions: perms}}, + )) +} + +func TestCallerCanManageFleetNodes(t *testing.T) { + tests := []struct { + name string + perms []string + want bool + }{ + { + // The fan-out regression: miner:pair alone (no fleetnode:manage) must + // NOT unlock fleet-node discovery commands. + name: "miner:pair only does not grant fleet-node management", + perms: []string{authz.PermMinerPair}, + want: false, + }, + { + name: "fleetnode:manage grants it", + perms: []string{authz.PermMinerPair, authz.PermFleetnodeManage}, + want: true, + }, + { + name: "fleetnode:read alone does not grant it", + perms: []string{authz.PermMinerPair, authz.PermFleetnodeRead}, + want: false, + }, + { + name: "no permissions", + perms: nil, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Arrange + ctx := ctxWithPerms(tc.perms...) + + // Act + got := callerCanManageFleetNodes(ctx) + + // Assert + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/server/internal/testutil/infrastructure_provider.go b/server/internal/testutil/infrastructure_provider.go index 2587af397..7269a9741 100644 --- a/server/internal/testutil/infrastructure_provider.go +++ b/server/internal/testutil/infrastructure_provider.go @@ -54,7 +54,8 @@ func NewInfrastructureProvider(t *testing.T, serviceProvider *ServiceProvider, a authHandler := auth.NewHandler(serviceProvider.AuthService) mux.Handle(authv1connect.NewAuthServiceHandler(authHandler, interceptorsOption)) - pairingHandler := pairing.NewHandler(serviceProvider.PairingService) + // nil discovery service: no fleet node fan-out in this test harness. + pairingHandler := pairing.NewHandler(serviceProvider.PairingService, nil) mux.Handle(pairingv1connect.NewPairingServiceHandler(pairingHandler, interceptorsOption)) onboardingHandler := onboarding.NewHandler(serviceProvider.AuthService, serviceProvider.OnboardingService)