diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index b705d804914..f4e521d442e 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,126 +1,126 @@ name: Go on: - push: - paths-ignore: - - packaging/** - branches: - - 'master' - pull_request: - paths-ignore: - - packaging/** - branches: - - '**' + push: + paths-ignore: + - packaging/** + branches: + - "master" + pull_request: + paths-ignore: + - packaging/** + branches: + - "**" jobs: - test: - name: Test - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Setup Go - uses: actions/setup-go@v5 - with: - cache: true - go-version-file: go.mod - - name: Build - run: make build - - name: Test with race detector (Ubuntu and MacOS) - if: matrix.os != 'windows-latest' - run: make test-ci-race - - name: Test without race detector (Windows) - if: matrix.os == 'windows-latest' - run: make test-ci - test-flaky: - name: Test (flaky) - runs-on: ubuntu-latest - continue-on-error: ${{ github.ref == 'refs/heads/master' }} - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Setup Go - uses: actions/setup-go@v5 - with: - cache: true - go-version-file: go.mod - - name: Run flaky test - run: make test-ci-flaky - continue-on-error: ${{ github.ref == 'refs/heads/master' }} - lint: - name: Lint - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Setup Go - uses: actions/setup-go@v5 - with: - cache: false - go-version-file: go.mod - - name: Commit linting - if: github.ref != 'refs/heads/master' - uses: wagoid/commitlint-github-action@v5 - - name: GolangCI-Lint - uses: golangci/golangci-lint-action@v6 - with: - skip-cache: false - version: v1.64.5 - - name: Whitespace check - run: make check-whitespace - - name: go mod tidy check - uses: katexochen/go-tidy-check@v2 - coverage: - name: Coverage Report - if: github.ref == 'refs/heads/master' - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Setup Go - uses: actions/setup-go@v5 - with: - cache: false - go-version-file: go.mod - - name: Cache Go Modules - uses: actions/cache@v4 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-coverage-${{ hashFiles('**/go.sum') }} - - name: Test with code coverage - run: make cover=1 test-ci - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v5 - with: - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true - files: ./cover.out - trigger-beekeeper: - name: Trigger Beekeeper - runs-on: ubuntu-latest - needs: [test, lint, coverage] - if: github.ref == 'refs/heads/master' - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Setup Go - uses: actions/setup-go@v5 - with: - cache: false - go-version-file: go.mod - - name: Trigger Beekeeper - uses: peter-evans/repository-dispatch@v2 - with: - token: ${{ secrets.GHA_PAT_BASIC }} - repository: ${{ github.repository }} - event-type: trigger-beekeeper - client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}"}' + test: + name: Test + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + cache: true + go-version-file: go.mod + - name: Build + run: make build + - name: Test with race detector (Ubuntu and MacOS) + if: matrix.os != 'windows-latest' + run: make test-ci-race + - name: Test without race detector (Windows) + if: matrix.os == 'windows-latest' + run: make test-ci + test-flaky: + name: Test (flaky) + runs-on: ubuntu-latest + continue-on-error: ${{ github.ref == 'refs/heads/master' }} + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + cache: true + go-version-file: go.mod + - name: Run flaky test + run: make test-ci-flaky + continue-on-error: ${{ github.ref == 'refs/heads/master' }} + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup Go + uses: actions/setup-go@v5 + with: + cache: false + go-version-file: go.mod + - name: Commit linting + if: github.ref != 'refs/heads/master' + uses: wagoid/commitlint-github-action@v5 + - name: GolangCI-Lint + uses: golangci/golangci-lint-action@v7 + with: + skip-cache: false + version: v2.1.6 + - name: Whitespace check + run: make check-whitespace + - name: go mod tidy check + uses: katexochen/go-tidy-check@v2 + coverage: + name: Coverage Report + if: github.ref == 'refs/heads/master' + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + cache: false + go-version-file: go.mod + - name: Cache Go Modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-coverage-${{ hashFiles('**/go.sum') }} + - name: Test with code coverage + run: make cover=1 test-ci + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true + files: ./cover.out + trigger-beekeeper: + name: Trigger Beekeeper + runs-on: ubuntu-latest + needs: [test, lint, coverage] + if: github.ref == 'refs/heads/master' + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup Go + uses: actions/setup-go@v5 + with: + cache: false + go-version-file: go.mod + - name: Trigger Beekeeper + uses: peter-evans/repository-dispatch@v2 + with: + token: ${{ secrets.GHA_PAT_BASIC }} + repository: ${{ github.repository }} + event-type: trigger-beekeeper + client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}"}' diff --git a/.golangci.yml b/.golangci.yml index 4f07d6fa978..ea1d78ce9af 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,5 +1,4 @@ -run: - timeout: 10m +version: "2" linters: enable: - asciicheck @@ -13,10 +12,8 @@ linters: - forbidigo - gochecknoinits - goconst - - gofmt - goheader - goprintffuncname - - gosimple - govet - importas - ineffassign @@ -29,47 +26,76 @@ linters: - promlinter - staticcheck - thelper - - typecheck - unconvert - unused - # - depguard disable temporary until this issue is resolved: https://github.com/golangci/golangci-lint/issues/3906 - -linters-settings: - govet: - enable-all: true - disable: - - fieldalignment ## temporally disabled - - shadow ## temporally disabled - goheader: - values: - regexp: - date: "20[1-2][0-9]" - template: |- - Copyright {{date}} The Swarm Authors. All rights reserved. - Use of this source code is governed by a BSD-style - license that can be found in the LICENSE file. - paralleltest: - # Ignore missing calls to `t.Parallel()` and only report incorrect uses of `t.Parallel()`. - ignore-missing: true -issues: - exclude-rules: - - linters: - - goheader - text: "go-ethereum Authors" ## disable check for other authors - - path: _test\.go - linters: - - goconst ## temporally disable goconst in test - - linters: - - forbidigo - path: cmd/bee/cmd - text: "use of `fmt.Print" ## allow fmt.Print in cmd directory - - linters: - - dogsled - path: pkg/api/(.+)_test\.go # temporally disable dogsled in api test files - - linters: - - dogsled - path: pkg/pushsync/(.+)_test\.go # temporally disable dogsled in pushsync test files - # temporally disable paralleltest in following packages - - linters: - - paralleltest - path: pkg/log + settings: + goheader: + values: + regexp: + date: 20[1-2][0-9] + template: |- + Copyright {{date}} The Swarm Authors. All rights reserved. + Use of this source code is governed by a BSD-style + license that can be found in the LICENSE file. + govet: + disable: + - fieldalignment + - shadow + enable-all: true + paralleltest: + ignore-missing: true + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - goheader + text: go-ethereum Authors + - linters: + - goconst + path: _test\.go + - linters: + - forbidigo + path: cmd/bee/cmd + text: use of `fmt.Print + - linters: + - dogsled + path: pkg/api/(.+)_test\.go + - linters: + - dogsled + path: pkg/pushsync/(.+)_test\.go + - linters: + - paralleltest + path: pkg/log + - linters: + - staticcheck + text: "QF1008:" + - linters: + - staticcheck + text: "QF1001:" + - linters: + - staticcheck + text: "QF1002:" + - linters: + - staticcheck + text: "QF1003:" + - linters: + - staticcheck + text: "QF1006:" + paths: + - third_party$ + - builtin$ + - examples$ +formatters: + enable: + - gofmt + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/Makefile b/Makefile index 01568f1136b..13e8146e2d8 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ GO ?= go GOBIN ?= $$($(GO) env GOPATH)/bin GOLANGCI_LINT ?= $(GOBIN)/golangci-lint -GOLANGCI_LINT_VERSION ?= v1.64.5 +GOLANGCI_LINT_VERSION ?= v2.1.6 GOGOPROTOBUF ?= protoc-gen-gogofaster GOGOPROTOBUF_VERSION ?= v1.3.1 BEEKEEPER_INSTALL_DIR ?= $(GOBIN) diff --git a/cmd/bee/cmd/cmd.go b/cmd/bee/cmd/cmd.go index e10e015ce81..07bae9334bd 100644 --- a/cmd/bee/cmd/cmd.go +++ b/cmd/bee/cmd/cmd.go @@ -43,6 +43,7 @@ const ( optionNameTracingHost = "tracing-host" optionNameTracingPort = "tracing-port" optionNameTracingServiceName = "tracing-service-name" + optionNameEnableTraceHeaders = "enable-trace-headers" optionNameVerbosity = "verbosity" optionNamePaymentThreshold = "payment-threshold" optionNamePaymentTolerance = "payment-tolerance-percent" @@ -253,6 +254,7 @@ func (c *command) setAllFlags(cmd *cobra.Command) { cmd.Flags().String(optionNameTracingHost, "", "host to send tracing data") cmd.Flags().String(optionNameTracingPort, "", "port to send tracing data") cmd.Flags().String(optionNameTracingServiceName, "bee", "service name identifier for tracing") + cmd.Flags().Bool(optionNameEnableTraceHeaders, false, "enable trace headers capability for p2p streams") cmd.Flags().String(optionNameVerbosity, "info", "log verbosity level 0=silent, 1=error, 2=warn, 3=info, 4=debug, 5=trace") cmd.Flags().String(optionWelcomeMessage, "", "send a welcome message string during handshakes") cmd.Flags().String(optionNamePaymentThreshold, "13500000", "threshold in BZZ where you expect to get paid from your peers") diff --git a/cmd/bee/cmd/split.go b/cmd/bee/cmd/split.go index ebe798e878d..112b32ed2fa 100644 --- a/cmd/bee/cmd/split.go +++ b/cmd/bee/cmd/split.go @@ -116,13 +116,13 @@ func splitRefs(cmd *cobra.Command) { } logger.Debug("write root", "hash", rootRef) - _, err = writer.WriteString(fmt.Sprintf("%s\n", rootRef)) + _, err = fmt.Fprintf(writer, "%s\n", rootRef) if err != nil { return fmt.Errorf("write root hash: %w", err) } for _, ref := range refs { logger.Debug("write chunk", "hash", ref) - _, err = writer.WriteString(fmt.Sprintf("%s\n", ref)) + _, err = fmt.Fprintf(writer, "%s\n", ref) if err != nil { return fmt.Errorf("write chunk address: %w", err) } diff --git a/cmd/bee/cmd/start.go b/cmd/bee/cmd/start.go index 1b1baac9503..67cf0606988 100644 --- a/cmd/bee/cmd/start.go +++ b/cmd/bee/cmd/start.go @@ -297,6 +297,7 @@ func buildBeeNode(ctx context.Context, c *command, cmd *cobra.Command, logger lo DBOpenFilesLimit: c.config.GetUint64(optionNameDBOpenFilesLimit), DBWriteBufferSize: c.config.GetUint64(optionNameDBWriteBufferSize), EnableStorageIncentives: c.config.GetBool(optionNameStorageIncentivesEnable), + EnableTraceHeaders: c.config.GetBool(optionNameEnableTraceHeaders), EnableWS: c.config.GetBool(optionNameP2PWSEnable), FullNodeMode: fullNode, Logger: logger, diff --git a/pkg/file/splitter/internal/job.go b/pkg/file/splitter/internal/job.go index db34e5764c1..f33c134bcc0 100644 --- a/pkg/file/splitter/internal/job.go +++ b/pkg/file/splitter/internal/job.go @@ -198,7 +198,7 @@ func (s *SimpleSplitterJob) hashUnfinished() error { // F F // F F F // -// F F F F S +// # F F F F S // // The result will be: // @@ -206,7 +206,7 @@ func (s *SimpleSplitterJob) hashUnfinished() error { // F F // F F F // -// F F F F +// # F F F F // // After which the SS will be hashed to obtain the final root hash func (s *SimpleSplitterJob) moveDanglingChunk() error { diff --git a/pkg/node/node.go b/pkg/node/node.go index 9893bfb361e..2266d74730a 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -142,6 +142,7 @@ type Options struct { DBOpenFilesLimit uint64 DBWriteBufferSize uint64 EnableStorageIncentives bool + EnableTraceHeaders bool EnableWS bool FullNodeMode bool Logger log.Logger @@ -638,14 +639,15 @@ func NewBee( } p2ps, err := libp2p.New(ctx, signer, networkID, swarmAddress, addr, addressbook, stateStore, lightNodes, logger, tracer, libp2p.Options{ - PrivateKey: libp2pPrivateKey, - NATAddr: o.NATAddr, - EnableWS: o.EnableWS, - WelcomeMessage: o.WelcomeMessage, - FullNode: o.FullNodeMode, - Nonce: nonce, - ValidateOverlay: chainEnabled, - Registry: registry, + PrivateKey: libp2pPrivateKey, + NATAddr: o.NATAddr, + EnableWS: o.EnableWS, + EnableTraceHeaders: o.EnableTraceHeaders, + WelcomeMessage: o.WelcomeMessage, + FullNode: o.FullNodeMode, + Nonce: nonce, + ValidateOverlay: chainEnabled, + Registry: registry, }) if err != nil { return nil, fmt.Errorf("p2p service: %w", err) diff --git a/pkg/p2p/libp2p/connections_test.go b/pkg/p2p/libp2p/connections_test.go index c85b016d801..a989b7cdf6e 100644 --- a/pkg/p2p/libp2p/connections_test.go +++ b/pkg/p2p/libp2p/connections_test.go @@ -9,7 +9,6 @@ import ( "context" "errors" "io" - "math/rand" "reflect" "strings" "sync" @@ -158,8 +157,6 @@ func TestLightPeerLimit(t *testing.T) { func TestStreamsMaxIncomingLimit(t *testing.T) { t.Parallel() - maxIncomingStreams := 5000 - ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -182,7 +179,6 @@ func TestStreamsMaxIncomingLimit(t *testing.T) { Handler: func(ctx context.Context, p p2p.Peer, s p2p.Stream) error { streamsMu.Lock() defer streamsMu.Unlock() - streams = append(streams, s) return nil }, @@ -191,22 +187,15 @@ func TestStreamsMaxIncomingLimit(t *testing.T) { } t.Cleanup(func() { + streamsMu.Lock() for _, s := range streams { if err := s.Reset(); err != nil { t.Error(err) } } + streamsMu.Unlock() }) - testProtocolClient := func() error { - _, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName) - if err != nil { - return err - } - // do not close or rest the stream in defer in order to keep the stream active - return nil - } - if err := s1.AddProtocol(testProtocolSpec); err != nil { t.Fatal(err) } @@ -218,62 +207,72 @@ func TestStreamsMaxIncomingLimit(t *testing.T) { expectPeers(t, s2, overlay1) expectPeersEventually(t, s1, overlay2) - overflowStreamCount := maxIncomingStreams / 4 + // Test resource manager limits by creating streams and activating them with data + // The limit should be enforced when streams become active (data is written) + testLimit := 5100 // Slightly more than IncomingStreamCountLimit (5000) + createdStreams := make([]p2p.Stream, 0) + activeStreamCount := 0 + writeErrors := 0 - // create streams over the limit + // First, create stream placeholders + for i := 0; i < testLimit; i++ { + stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName) + if err != nil { + t.Logf("Stream creation failed at %d: %v", i, err) + break + } + createdStreams = append(createdStreams, stream) + } + + t.Logf("Created %d stream placeholders", len(createdStreams)) - for i := 0; i < maxIncomingStreams+overflowStreamCount; i++ { - err := testProtocolClient() - if i < maxIncomingStreams { - if err != nil { - t.Errorf("test protocol client %v: %v", i, err) + // Now activate streams by writing data - this should hit the resource limit + for i, stream := range createdStreams { + _, writeErr := stream.Write([]byte("activate")) + if writeErr != nil { + writeErrors++ + if writeErrors <= 10 { + t.Logf("Stream write failed at %d: %v", i, writeErr) } } else { - if err == nil { - t.Errorf("test protocol client %v got nil error", i) - } + activeStreamCount++ } - } - if len(streams) != maxIncomingStreams { - t.Errorf("got %v streams, want %v", len(streams), maxIncomingStreams) + // Small delay to avoid overwhelming + if i%100 == 0 && i > 0 { + time.Sleep(time.Millisecond) + } } - closeStreamCount := len(streams) / 2 + // Allow time for handlers to be called + time.Sleep(200 * time.Millisecond) - // close random streams to validate new streams creation + streamsMu.Lock() + handlerCallCount := len(streams) + streamsMu.Unlock() - random := rand.New(rand.NewSource(time.Now().UnixNano())) - for i := 0; i < closeStreamCount; i++ { - n := random.Intn(len(streams)) - if err := streams[n].Reset(); err != nil { - t.Error(err) - continue - } - streams = append(streams[:n], streams[n+1:]...) - } + t.Logf("Results: %d placeholders created, %d successfully activated, %d write errors, %d handlers called", + len(createdStreams), activeStreamCount, writeErrors, handlerCallCount) - if maxIncomingStreams-len(streams) != closeStreamCount { - t.Errorf("got %v closed streams, want %v", maxIncomingStreams-len(streams), closeStreamCount) + // Verify the resource manager enforced the 5000 stream limit + expectedLimit := 5000 + if handlerCallCount > expectedLimit { + t.Errorf("Handler count %d exceeded expected limit %d - resource manager not working", handlerCallCount, expectedLimit) } - // create new streams + if writeErrors == 0 && activeStreamCount > expectedLimit { + t.Errorf("Expected some stream activations to fail due to resource limits, but all %d succeeded", activeStreamCount) + } - for i := 0; i < closeStreamCount+overflowStreamCount; i++ { - err := testProtocolClient() - if i < closeStreamCount { - if err != nil { - t.Errorf("test protocol client %v: %v", i, err) - } - } else { - if err == nil { - t.Errorf("test protocol client %v got nil error", i) - } - } + if handlerCallCount != expectedLimit { + t.Errorf("Expected exactly %d handlers to be called (the limit), got %d", expectedLimit, handlerCallCount) } - if len(streams) != maxIncomingStreams { - t.Errorf("got %v streams, want %v", len(streams), maxIncomingStreams) + // Clean up all streams + for _, stream := range createdStreams { + if err := stream.Reset(); err != nil { + t.Logf("Failed to reset stream: %v", err) + } } expectPeers(t, s2, overlay1) @@ -1054,7 +1053,6 @@ func TestWithBlocklistStreams(t *testing.T) { expectPeers(t, s2, overlay1) expectPeersEventually(t, s1, overlay2) - s, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName) expectStreamReset(t, s, err) diff --git a/pkg/p2p/libp2p/headers_test.go b/pkg/p2p/libp2p/headers_test.go index 5805dafa0d4..077d38f9ef7 100644 --- a/pkg/p2p/libp2p/headers_test.go +++ b/pkg/p2p/libp2p/headers_test.go @@ -15,7 +15,8 @@ import ( "github.com/ethersphere/bee/v2/pkg/swarm" ) -func TestHeaders(t *testing.T) { +// TestHeaders_BothSupportTrace tests header exchange when both peers support trace headers +func TestHeaders_BothSupportTrace(t *testing.T) { t.Parallel() headers := p2p.Headers{ @@ -27,13 +28,19 @@ func TestHeaders(t *testing.T) { defer cancel() s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ - FullNode: true, + FullNode: true, + EnableTraceHeaders: true, }}) - s2, overlay2 := newService(t, 1, libp2pServiceOpts{}) + s2, overlay2 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + EnableTraceHeaders: true, + }}) var gotHeaders p2p.Headers handled := make(chan struct{}) + testMessage := []byte("test-message") + receivedMessage := make(chan []byte, 1) + if err := s1.AddProtocol(newTestProtocol(func(ctx context.Context, p p2p.Peer, stream p2p.Stream) error { if ctx == nil { t.Fatal("missing context") @@ -42,6 +49,14 @@ func TestHeaders(t *testing.T) { t.Fatalf("got peer %v, want %v", p.Address, overlay2) } gotHeaders = stream.Headers() + + // Read test message from stream + buf := make([]byte, len(testMessage)) + _, err := stream.Read(buf) + if err != nil { + t.Errorf("failed to read from stream: %v", err) + } + receivedMessage <- buf close(handled) return nil })); err != nil { @@ -54,37 +69,64 @@ func TestHeaders(t *testing.T) { t.Fatal(err) } + // Wait for handshake to complete and capabilities to be registered + expectPeers(t, s2, overlay1) + expectPeersEventually(t, s1, overlay2) + stream, err := s2.NewStream(ctx, overlay1, headers, testProtocolName, testProtocolVersion, testStreamName) if err != nil { t.Fatal(err) } defer stream.Close() + // Send test message to verify stream is working + _, err = stream.Write(testMessage) + if err != nil { + t.Errorf("failed to write to stream: %v", err) + } + select { case <-handled: case <-time.After(30 * time.Second): t.Fatal("timeout waiting for handler") } + // Verify message was received + select { + case msg := <-receivedMessage: + if string(msg) != string(testMessage) { + t.Errorf("got message %s, want %s", string(msg), string(testMessage)) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for message") + } + if fmt.Sprint(gotHeaders) != fmt.Sprint(headers) { t.Errorf("got headers %+v, want %+v", gotHeaders, headers) } } -func TestHeaders_empty(t *testing.T) { +// TestHeaders_BothSupportTrace_EmptyHeaders tests empty header exchange when both peers support trace headers +func TestHeaders_BothSupportTrace_EmptyHeaders(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) defer cancel() s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ - FullNode: true, + FullNode: true, + EnableTraceHeaders: true, }}) - s2, overlay2 := newService(t, 1, libp2pServiceOpts{}) + s2, overlay2 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + EnableTraceHeaders: true, + }}) var gotHeaders p2p.Headers handled := make(chan struct{}) + testMessage := []byte("test-message-empty-headers") + receivedMessage := make(chan []byte, 1) + if err := s1.AddProtocol(newTestProtocol(func(ctx context.Context, p p2p.Peer, stream p2p.Stream) error { if ctx == nil { t.Fatal("missing context") @@ -93,6 +135,14 @@ func TestHeaders_empty(t *testing.T) { t.Fatalf("got peer %v, want %v", p.Address, overlay2) } gotHeaders = stream.Headers() + + // Read test message from stream + buf := make([]byte, len(testMessage)) + _, err := stream.Read(buf) + if err != nil { + t.Errorf("failed to read from stream: %v", err) + } + receivedMessage <- buf close(handled) return nil })); err != nil { @@ -105,23 +155,326 @@ func TestHeaders_empty(t *testing.T) { t.Fatal(err) } + // Wait for handshake to complete and capabilities to be registered + expectPeers(t, s2, overlay1) + expectPeersEventually(t, s1, overlay2) + stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName) if err != nil { t.Fatal(err) } defer stream.Close() + // Send test message to verify stream is working + _, err = stream.Write(testMessage) + if err != nil { + t.Errorf("failed to write to stream: %v", err) + } + select { case <-handled: case <-time.After(30 * time.Second): t.Fatal("timeout waiting for handler") } + // Verify message was received + select { + case msg := <-receivedMessage: + if string(msg) != string(testMessage) { + t.Errorf("got message %s, want %s", string(msg), string(testMessage)) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for message") + } + if len(gotHeaders) != 0 { t.Errorf("got headers %+v, want none", gotHeaders) } } +// TestHeaders_LocalSupportRemoteNoSupport tests no header exchange when local supports but remote doesn't +func TestHeaders_LocalSupportRemoteNoSupport(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + FullNode: true, + EnableTraceHeaders: true, // Local supports + }}) + + s2, overlay2 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + EnableTraceHeaders: false, // Remote doesn't support + }}) + + var gotHeaders p2p.Headers + handled := make(chan struct{}) + testMessage := []byte("test-message-mixed-caps-1") + receivedMessage := make(chan []byte, 1) + + if err := s1.AddProtocol(newTestProtocol(func(ctx context.Context, p p2p.Peer, stream p2p.Stream) error { + if ctx == nil { + t.Fatal("missing context") + } + if !p.Address.Equal(overlay2) { + t.Fatalf("got peer %v, want %v", p.Address, overlay2) + } + gotHeaders = stream.Headers() + + // Read test message from stream + buf := make([]byte, len(testMessage)) + _, err := stream.Read(buf) + if err != nil { + t.Errorf("failed to read from stream: %v", err) + } + receivedMessage <- buf + close(handled) + return nil + })); err != nil { + t.Fatal(err) + } + + addr := serviceUnderlayAddress(t, s1) + + if _, err := s2.Connect(ctx, addr); err != nil { + t.Fatal(err) + } + + // Wait for handshake to complete and capabilities to be registered + expectPeers(t, s2, overlay1) + expectPeersEventually(t, s1, overlay2) + + // Try to send headers but they should be ignored due to capability mismatch + providedHeaders := p2p.Headers{ + "ignored-header": []byte("ignored-value"), + } + + stream, err := s2.NewStream(ctx, overlay1, providedHeaders, testProtocolName, testProtocolVersion, testStreamName) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + // Send test message to verify stream is working + _, err = stream.Write(testMessage) + if err != nil { + t.Errorf("failed to write to stream: %v", err) + } + + select { + case <-handled: + case <-time.After(30 * time.Second): + t.Fatal("timeout waiting for handler") + } + + // Verify message was received + select { + case msg := <-receivedMessage: + if string(msg) != string(testMessage) { + t.Errorf("got message %s, want %s", string(msg), string(testMessage)) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for message") + } + + // No headers should be exchanged due to capability mismatch + if len(gotHeaders) != 0 { + t.Errorf("expected no headers due to capability mismatch, got %+v", gotHeaders) + } + + // Verify that response headers are also empty + responseHeaders := stream.ResponseHeaders() + if len(responseHeaders) != 0 { + t.Errorf("expected no response headers when capabilities don't match, got %+v", responseHeaders) + } +} + +// TestHeaders_LocalNoSupportRemoteSupport tests no header exchange when local doesn't support but remote does +func TestHeaders_LocalNoSupportRemoteSupport(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + FullNode: true, + EnableTraceHeaders: false, // Local doesn't support + }}) + + s2, overlay2 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + EnableTraceHeaders: true, // Remote supports + }}) + + var gotHeaders p2p.Headers + handled := make(chan struct{}) + testMessage := []byte("test-message-mixed-caps-2") + receivedMessage := make(chan []byte, 1) + + if err := s1.AddProtocol(newTestProtocol(func(ctx context.Context, p p2p.Peer, stream p2p.Stream) error { + if ctx == nil { + t.Fatal("missing context") + } + if !p.Address.Equal(overlay2) { + t.Fatalf("got peer %v, want %v", p.Address, overlay2) + } + gotHeaders = stream.Headers() + + // Read test message from stream + buf := make([]byte, len(testMessage)) + _, err := stream.Read(buf) + if err != nil { + t.Errorf("failed to read from stream: %v", err) + } + receivedMessage <- buf + close(handled) + return nil + })); err != nil { + t.Fatal(err) + } + + addr := serviceUnderlayAddress(t, s1) + + if _, err := s2.Connect(ctx, addr); err != nil { + t.Fatal(err) + } + + // Wait for handshake to complete and capabilities to be registered + expectPeers(t, s2, overlay1) + expectPeersEventually(t, s1, overlay2) + + // Try to send headers but they should be ignored due to capability mismatch + providedHeaders := p2p.Headers{ + "another-ignored-header": []byte("another-ignored-value"), + } + + stream, err := s2.NewStream(ctx, overlay1, providedHeaders, testProtocolName, testProtocolVersion, testStreamName) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + // Send test message to verify stream is working + _, err = stream.Write(testMessage) + if err != nil { + t.Errorf("failed to write to stream: %v", err) + } + + select { + case <-handled: + case <-time.After(30 * time.Second): + t.Fatal("timeout waiting for handler") + } + + // Verify message was received + select { + case msg := <-receivedMessage: + if string(msg) != string(testMessage) { + t.Errorf("got message %s, want %s", string(msg), string(testMessage)) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for message") + } + + // No headers should be exchanged due to capability mismatch + if len(gotHeaders) != 0 { + t.Errorf("expected no headers due to capability mismatch, got %+v", gotHeaders) + } +} + +// TestHeaders_BothNoSupport tests no header exchange when both peers don't support trace headers +func TestHeaders_BothNoSupport(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + FullNode: true, + EnableTraceHeaders: false, // Local doesn't support + }}) + + s2, overlay2 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + EnableTraceHeaders: false, // Remote doesn't support + }}) + + var gotHeaders p2p.Headers + handled := make(chan struct{}) + testMessage := []byte("test-message-both-no-support") + receivedMessage := make(chan []byte, 1) + + if err := s1.AddProtocol(newTestProtocol(func(ctx context.Context, p p2p.Peer, stream p2p.Stream) error { + if ctx == nil { + t.Fatal("missing context") + } + if !p.Address.Equal(overlay2) { + t.Fatalf("got peer %v, want %v", p.Address, overlay2) + } + gotHeaders = stream.Headers() + + // Read test message from stream + buf := make([]byte, len(testMessage)) + _, err := stream.Read(buf) + if err != nil { + t.Errorf("failed to read from stream: %v", err) + } + receivedMessage <- buf + close(handled) + return nil + })); err != nil { + t.Fatal(err) + } + + addr := serviceUnderlayAddress(t, s1) + + if _, err := s2.Connect(ctx, addr); err != nil { + t.Fatal(err) + } + + // Wait for handshake to complete and capabilities to be registered + expectPeers(t, s2, overlay1) + expectPeersEventually(t, s1, overlay2) + + // Headers provided but should be completely ignored + providedHeaders := p2p.Headers{ + "completely-ignored": []byte("completely-ignored-value"), + } + + stream, err := s2.NewStream(ctx, overlay1, providedHeaders, testProtocolName, testProtocolVersion, testStreamName) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + // Send test message to verify stream is working + _, err = stream.Write(testMessage) + if err != nil { + t.Errorf("failed to write to stream: %v", err) + } + + select { + case <-handled: + case <-time.After(30 * time.Second): + t.Fatal("timeout waiting for handler") + } + + // Verify message was received + select { + case msg := <-receivedMessage: + if string(msg) != string(testMessage) { + t.Errorf("got message %s, want %s", string(msg), string(testMessage)) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for message") + } + + // No headers should be exchanged since both don't support trace headers + if len(gotHeaders) != 0 { + t.Errorf("expected no headers when both peers don't support trace headers, got %+v", gotHeaders) + } +} + +// TestHeadler tests header exchange with Headler function when both peers support trace headers func TestHeadler(t *testing.T) { t.Parallel() @@ -138,13 +491,19 @@ func TestHeadler(t *testing.T) { defer cancel() s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ - FullNode: true, + FullNode: true, + EnableTraceHeaders: true, }}) - s2, _ := newService(t, 1, libp2pServiceOpts{}) + s2, overlay2 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + EnableTraceHeaders: true, + }}) var gotReceivedHeaders p2p.Headers handled := make(chan struct{}) + testMessage := []byte("test-message-headler") + receivedMessage := make(chan []byte, 1) + if err := s1.AddProtocol(p2p.ProtocolSpec{ Name: testProtocolName, Version: testProtocolVersion, @@ -152,6 +511,13 @@ func TestHeadler(t *testing.T) { { Name: testStreamName, Handler: func(_ context.Context, _ p2p.Peer, stream p2p.Stream) error { + // Read test message from stream + buf := make([]byte, len(testMessage)) + _, err := stream.Read(buf) + if err != nil { + t.Errorf("failed to read from stream: %v", err) + } + receivedMessage <- buf return nil }, Headler: func(headers p2p.Headers, address swarm.Address) p2p.Headers { @@ -171,18 +537,38 @@ func TestHeadler(t *testing.T) { t.Fatal(err) } + // Wait for handshake to complete and capabilities to be registered + expectPeers(t, s2, overlay1) + expectPeersEventually(t, s1, overlay2) + stream, err := s2.NewStream(ctx, overlay1, receivedHeaders, testProtocolName, testProtocolVersion, testStreamName) if err != nil { t.Fatal(err) } defer stream.Close() + // Send test message to verify stream is working + _, err = stream.Write(testMessage) + if err != nil { + t.Errorf("failed to write to stream: %v", err) + } + select { case <-handled: case <-time.After(30 * time.Second): t.Fatal("timeout waiting for handler") } + // Verify message was received + select { + case msg := <-receivedMessage: + if string(msg) != string(testMessage) { + t.Errorf("got message %s, want %s", string(msg), string(testMessage)) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for message") + } + if fmt.Sprint(gotReceivedHeaders) != fmt.Sprint(receivedHeaders) { t.Errorf("got received headers %+v, want %+v", gotReceivedHeaders, receivedHeaders) } @@ -192,3 +578,119 @@ func TestHeadler(t *testing.T) { t.Errorf("got sent headers %+v, want %+v", gotSentHeaders, sentHeaders) } } + +// TestHeadler_NoTraceCapability tests that Headler is not called when capabilities don't match +func TestHeadler_NoTraceCapability(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + FullNode: true, + EnableTraceHeaders: true, // Local supports + }}) + + s2, overlay2 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + EnableTraceHeaders: false, // Remote doesn't support + }}) + + headlerCalled := make(chan struct{}, 1) + handlerCalled := make(chan struct{}) + testMessage := []byte("test-message-headler-no-caps") + receivedMessage := make(chan []byte, 1) + + if err := s1.AddProtocol(p2p.ProtocolSpec{ + Name: testProtocolName, + Version: testProtocolVersion, + StreamSpecs: []p2p.StreamSpec{ + { + Name: testStreamName, + Handler: func(_ context.Context, _ p2p.Peer, stream p2p.Stream) error { + defer close(handlerCalled) + // Read test message from stream + buf := make([]byte, len(testMessage)) + _, err := stream.Read(buf) + if err != nil { + t.Errorf("failed to read from stream: %v", err) + } + receivedMessage <- buf + + // Verify no headers were received + headers := stream.Headers() + if len(headers) != 0 { + t.Errorf("expected no headers due to capability mismatch, got %+v", headers) + } + return nil + }, + Headler: func(headers p2p.Headers, address swarm.Address) p2p.Headers { + select { + case headlerCalled <- struct{}{}: + default: + } + t.Error("Headler should not be called when capabilities don't match") + return nil + }, + }, + }, + }); err != nil { + t.Fatal(err) + } + + addr := serviceUnderlayAddress(t, s1) + + if _, err := s2.Connect(ctx, addr); err != nil { + t.Fatal(err) + } + + // Wait for handshake to complete and capabilities to be registered + expectPeers(t, s2, overlay1) + expectPeersEventually(t, s1, overlay2) + + // Provide headers but they should be ignored + providedHeaders := p2p.Headers{ + "should-be-ignored": []byte("ignored-value"), + } + + stream, err := s2.NewStream(ctx, overlay1, providedHeaders, testProtocolName, testProtocolVersion, testStreamName) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + // Send test message to verify stream is working + _, err = stream.Write(testMessage) + if err != nil { + t.Errorf("failed to write to stream: %v", err) + } + + select { + case <-handlerCalled: + case <-time.After(30 * time.Second): + t.Fatal("timeout waiting for handler") + } + + // Verify message was received + select { + case msg := <-receivedMessage: + if string(msg) != string(testMessage) { + t.Errorf("got message %s, want %s", string(msg), string(testMessage)) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for message") + } + + // Verify Headler was NOT called + select { + case <-headlerCalled: + t.Error("Headler should not have been called when capabilities don't match") + case <-time.After(1 * time.Second): + // Expected - Headler should not be called + } + + // Verify stream has no headers + headers := stream.Headers() + if len(headers) != 0 { + t.Errorf("expected no headers when capabilities don't match, got %+v", headers) + } +} diff --git a/pkg/p2p/libp2p/internal/handshake/handshake.go b/pkg/p2p/libp2p/internal/handshake/handshake.go index a1d7cb2d587..bbe90c5456a 100644 --- a/pkg/p2p/libp2p/internal/handshake/handshake.go +++ b/pkg/p2p/libp2p/internal/handshake/handshake.go @@ -30,7 +30,7 @@ const ( // ProtocolName is the text of the name of the handshake protocol. ProtocolName = "handshake" // ProtocolVersion is the current handshake protocol version. - ProtocolVersion = "13.0.0" + ProtocolVersion = "14.0.0" // StreamName is the name of the stream used for handshake purposes. StreamName = "handshake" // MaxWelcomeMessageLength is maximum number of characters allowed in the welcome message. @@ -66,6 +66,7 @@ type Service struct { advertisableAddresser AdvertisableAddressResolver overlay swarm.Address fullNode bool + traceHeaders bool nonce []byte networkID uint64 validateOverlay bool @@ -78,12 +79,12 @@ type Service struct { // Info contains the information received from the handshake. type Info struct { - BzzAddress *bzz.Address - FullNode bool + BzzAddress *bzz.Address + Capabilities *pb.Capabilities } func (i *Info) LightString() string { - if !i.FullNode { + if !i.Capabilities.FullNode { return " (light)" } @@ -91,7 +92,7 @@ func (i *Info) LightString() string { } // New creates a new handshake Service. -func New(signer crypto.Signer, advertisableAddresser AdvertisableAddressResolver, overlay swarm.Address, networkID uint64, fullNode bool, nonce []byte, welcomeMessage string, validateOverlay bool, ownPeerID libp2ppeer.ID, logger log.Logger) (*Service, error) { +func New(signer crypto.Signer, advertisableAddresser AdvertisableAddressResolver, overlay swarm.Address, networkID uint64, fullNode bool, traceHeaders bool, nonce []byte, welcomeMessage string, validateOverlay bool, ownPeerID libp2ppeer.ID, logger log.Logger) (*Service, error) { if len(welcomeMessage) > MaxWelcomeMessageLength { return nil, ErrWelcomeMessageLength } @@ -102,6 +103,7 @@ func New(signer crypto.Signer, advertisableAddresser AdvertisableAddressResolver overlay: overlay, networkID: networkID, fullNode: fullNode, + traceHeaders: traceHeaders, validateOverlay: validateOverlay, nonce: nonce, libp2pID: ownPeerID, @@ -135,18 +137,27 @@ func (s *Service) Handshake(ctx context.Context, stream p2p.Stream, peerMultiadd return nil, err } - if err := w.WriteMsgWithContext(ctx, &pb.Syn{ - ObservedUnderlay: fullRemoteMABytes, + if err := w.WriteMsgWithContext(ctx, &pb.Handshake{ + Payload: &pb.Handshake_Syn{ + Syn: &pb.HandshakeSyn{ + ObservedUnderlay: fullRemoteMABytes, + }, + }, }); err != nil { return nil, fmt.Errorf("write syn message: %w", err) } - var resp pb.SynAck + var resp pb.Handshake if err := r.ReadMsgWithContext(ctx, &resp); err != nil { return nil, fmt.Errorf("read synack message: %w", err) } - observedUnderlay, err := ma.NewMultiaddrBytes(resp.Syn.ObservedUnderlay) + synAck := resp.GetSynAck() + if synAck == nil { + return nil, fmt.Errorf("expected syn_ack message") + } + + observedUnderlay, err := ma.NewMultiaddrBytes(synAck.ObservedUnderlay) if err != nil { return nil, ErrInvalidSyn } @@ -176,27 +187,34 @@ func (s *Service) Handshake(ctx context.Context, stream p2p.Stream, peerMultiadd return nil, err } - if resp.Ack.NetworkID != s.networkID { + if synAck.NetworkID != s.networkID { return nil, ErrNetworkIDIncompatible } - remoteBzzAddress, err := s.parseCheckAck(resp.Ack) + remoteBzzAddress, err := s.parseCheckSynAck(synAck) if err != nil { return nil, err } // Synced read: welcomeMessage := s.GetWelcomeMessage() - msg := &pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: advertisableUnderlayBytes, - Overlay: bzzAddress.Overlay.Bytes(), - Signature: bzzAddress.Signature, + msg := &pb.Handshake{ + Payload: &pb.Handshake_Ack{ + Ack: &pb.HandshakeAck{ + Address: &pb.BzzAddress{ + Underlay: advertisableUnderlayBytes, + Overlay: bzzAddress.Overlay.Bytes(), + Signature: bzzAddress.Signature, + }, + NetworkID: s.networkID, + Capabilities: &pb.Capabilities{ + FullNode: s.fullNode, + TraceHeaders: s.traceHeaders, + }, + Nonce: s.nonce, + WelcomeMessage: welcomeMessage, + }, }, - NetworkID: s.networkID, - FullNode: s.fullNode, - Nonce: s.nonce, - WelcomeMessage: welcomeMessage, } if err := w.WriteMsgWithContext(ctx, msg); err != nil { @@ -204,13 +222,13 @@ func (s *Service) Handshake(ctx context.Context, stream p2p.Stream, peerMultiadd } loggerV1.Debug("handshake finished for peer (outbound)", "peer_address", remoteBzzAddress.Overlay) - if len(resp.Ack.WelcomeMessage) > 0 { - s.logger.Debug("greeting message from peer", "peer_address", remoteBzzAddress.Overlay, "message", resp.Ack.WelcomeMessage) + if len(synAck.WelcomeMessage) > 0 { + s.logger.Debug("greeting message from peer", "peer_address", remoteBzzAddress.Overlay, "message", synAck.WelcomeMessage) } return &Info{ - BzzAddress: remoteBzzAddress, - FullNode: resp.Ack.FullNode, + BzzAddress: remoteBzzAddress, + Capabilities: synAck.Capabilities, }, nil } @@ -232,13 +250,18 @@ func (s *Service) Handle(ctx context.Context, stream p2p.Stream, remoteMultiaddr return nil, err } - var syn pb.Syn - if err := r.ReadMsgWithContext(ctx, &syn); err != nil { + var synMsg pb.Handshake + if err := r.ReadMsgWithContext(ctx, &synMsg); err != nil { s.metrics.SynRxFailed.Inc() return nil, fmt.Errorf("read syn message: %w", err) } s.metrics.SynRx.Inc() + syn := synMsg.GetSyn() + if syn == nil { + return nil, fmt.Errorf("expected syn message") + } + observedUnderlay, err := ma.NewMultiaddrBytes(syn.ObservedUnderlay) if err != nil { return nil, ErrInvalidSyn @@ -261,20 +284,23 @@ func (s *Service) Handle(ctx context.Context, stream p2p.Stream, remoteMultiaddr welcomeMessage := s.GetWelcomeMessage() - if err := w.WriteMsgWithContext(ctx, &pb.SynAck{ - Syn: &pb.Syn{ - ObservedUnderlay: fullRemoteMABytes, - }, - Ack: &pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: advertisableUnderlayBytes, - Overlay: bzzAddress.Overlay.Bytes(), - Signature: bzzAddress.Signature, + if err := w.WriteMsgWithContext(ctx, &pb.Handshake{ + Payload: &pb.Handshake_SynAck{ + SynAck: &pb.HandshakeSynAck{ + ObservedUnderlay: fullRemoteMABytes, + Address: &pb.BzzAddress{ + Underlay: advertisableUnderlayBytes, + Overlay: bzzAddress.Overlay.Bytes(), + Signature: bzzAddress.Signature, + }, + NetworkID: s.networkID, + Capabilities: &pb.Capabilities{ + FullNode: s.fullNode, + TraceHeaders: s.traceHeaders, + }, + Nonce: s.nonce, + WelcomeMessage: welcomeMessage, }, - NetworkID: s.networkID, - FullNode: s.fullNode, - Nonce: s.nonce, - WelcomeMessage: welcomeMessage, }, }); err != nil { s.metrics.SynAckTxFailed.Inc() @@ -282,13 +308,18 @@ func (s *Service) Handle(ctx context.Context, stream p2p.Stream, remoteMultiaddr } s.metrics.SynAckTx.Inc() - var ack pb.Ack - if err := r.ReadMsgWithContext(ctx, &ack); err != nil { + var ackMsg pb.Handshake + if err := r.ReadMsgWithContext(ctx, &ackMsg); err != nil { s.metrics.AckRxFailed.Inc() return nil, fmt.Errorf("read ack message: %w", err) } s.metrics.AckRx.Inc() + ack := ackMsg.GetAck() + if ack == nil { + return nil, fmt.Errorf("expected ack message") + } + if ack.NetworkID != s.networkID { return nil, ErrNetworkIDIncompatible } @@ -296,12 +327,12 @@ func (s *Service) Handle(ctx context.Context, stream p2p.Stream, remoteMultiaddr overlay := swarm.NewAddress(ack.Address.Overlay) if s.picker != nil { - if !s.picker.Pick(p2p.Peer{Address: overlay, FullNode: ack.FullNode}) { + if !s.picker.Pick(p2p.Peer{Address: overlay, FullNode: ack.Capabilities.FullNode}) { return nil, ErrPicker } } - remoteBzzAddress, err := s.parseCheckAck(&ack) + remoteBzzAddress, err := s.parseCheckAck(ack) if err != nil { return nil, err } @@ -312,8 +343,8 @@ func (s *Service) Handle(ctx context.Context, stream p2p.Stream, remoteMultiaddr } return &Info{ - BzzAddress: remoteBzzAddress, - FullNode: ack.FullNode, + BzzAddress: remoteBzzAddress, + Capabilities: ack.Capabilities, }, nil } @@ -331,11 +362,16 @@ func (s *Service) GetWelcomeMessage() string { return s.welcomeMessage.Load().(string) } +// SupportsTraceHeaders returns whether this node supports trace headers. +func (s *Service) SupportsTraceHeaders() bool { + return s.traceHeaders +} + func buildFullMA(addr ma.Multiaddr, peerID libp2ppeer.ID) (ma.Multiaddr, error) { return ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", addr.String(), peerID.String())) } -func (s *Service) parseCheckAck(ack *pb.Ack) (*bzz.Address, error) { +func (s *Service) parseCheckAck(ack *pb.HandshakeAck) (*bzz.Address, error) { bzzAddress, err := bzz.ParseAddress(ack.Address.Underlay, ack.Address.Overlay, ack.Address.Signature, ack.Nonce, s.validateOverlay, s.networkID) if err != nil { return nil, ErrInvalidAck @@ -343,3 +379,12 @@ func (s *Service) parseCheckAck(ack *pb.Ack) (*bzz.Address, error) { return bzzAddress, nil } + +func (s *Service) parseCheckSynAck(synAck *pb.HandshakeSynAck) (*bzz.Address, error) { + bzzAddress, err := bzz.ParseAddress(synAck.Address.Underlay, synAck.Address.Overlay, synAck.Address.Signature, synAck.Nonce, s.validateOverlay, s.networkID) + if err != nil { + return nil, ErrInvalidAck + } + + return bzzAddress, nil +} diff --git a/pkg/p2p/libp2p/internal/handshake/handshake_test.go b/pkg/p2p/libp2p/internal/handshake/handshake_test.go index e0afcc56d70..9802368c28e 100644 --- a/pkg/p2p/libp2p/internal/handshake/handshake_test.go +++ b/pkg/p2p/libp2p/internal/handshake/handshake_test.go @@ -20,6 +20,7 @@ import ( "github.com/ethersphere/bee/v2/pkg/p2p/libp2p/internal/handshake/mock" "github.com/ethersphere/bee/v2/pkg/p2p/libp2p/internal/handshake/pb" "github.com/ethersphere/bee/v2/pkg/p2p/protobuf" + "github.com/ethersphere/bee/v2/pkg/swarm" libp2ppeer "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" @@ -91,16 +92,22 @@ func TestHandshake(t *testing.T) { node1Info := handshake.Info{ BzzAddress: node1BzzAddress, - FullNode: true, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, } node2Info := handshake.Info{ BzzAddress: node2BzzAddress, - FullNode: true, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, } aaddresser := &AdvertisableAddresserMock{} - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nonce, testWelcomeMessage, true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nonce, testWelcomeMessage, true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -112,20 +119,23 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) w, r := protobuf.NewWriterAndReader(stream2) - if err := w.WriteMsg(&pb.SynAck{ - Syn: &pb.Syn{ - ObservedUnderlay: node1maBinary, - }, - Ack: &pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node2BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_SynAck{ + SynAck: &pb.HandshakeSynAck{ + ObservedUnderlay: node1maBinary, + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node2BzzAddress.Signature, + }, + NetworkID: networkID, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, + Nonce: nonce, + WelcomeMessage: testWelcomeMessage, }, - NetworkID: networkID, - FullNode: true, - Nonce: nonce, - WelcomeMessage: testWelcomeMessage, }, }); err != nil { t.Fatal(err) @@ -141,20 +151,26 @@ func TestHandshake(t *testing.T) { testInfo(t, *res, node2Info) - var syn pb.Syn - if err := r.ReadMsg(&syn); err != nil { + var synMsg pb.Handshake + if err := r.ReadMsg(&synMsg); err != nil { t.Fatal(err) } - if !bytes.Equal(syn.ObservedUnderlay, node2maBinary) { - t.Fatal("bad syn") + syn := synMsg.GetSyn() + if syn == nil || !bytes.Equal(syn.ObservedUnderlay, node2maBinary) { + t.Fatalf("got bad syn") } - var ack pb.Ack - if err := r.ReadMsg(&ack); err != nil { + var ackMsg pb.Handshake + if err := r.ReadMsg(&ackMsg); err != nil { t.Fatal(err) } + ack := ackMsg.GetAck() + if ack == nil { + t.Fatal("expected ack message") + } + if !bytes.Equal(ack.Address.Overlay, node1BzzAddress.Overlay.Bytes()) { t.Fatal("bad ack - overlay") } @@ -167,7 +183,7 @@ func TestHandshake(t *testing.T) { if ack.NetworkID != networkID { t.Fatal("bad ack - networkID") } - if ack.FullNode != true { + if ack.Capabilities.FullNode != true { t.Fatal("bad ack - full node") } @@ -177,7 +193,7 @@ func TestHandshake(t *testing.T) { }) t.Run("Handshake - picker error", func(t *testing.T) { - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nonce, "", true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nonce, "", true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -190,21 +206,32 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.Syn{ - ObservedUnderlay: node1maBinary, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Syn{ + Syn: &pb.HandshakeSyn{ + ObservedUnderlay: node1maBinary, + }, + }, }); err != nil { t.Fatal(err) } - if err := w.WriteMsg(&pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node2BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Ack{ + Ack: &pb.HandshakeAck{ + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node2BzzAddress.Signature, + }, + NetworkID: networkID, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, + Nonce: nonce, + }, }, - NetworkID: networkID, - Nonce: nonce, - FullNode: true, }); err != nil { t.Fatal(err) } @@ -220,7 +247,7 @@ func TestHandshake(t *testing.T) { const LongMessage = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi consectetur urna ut lorem sollicitudin posuere. Donec sagittis laoreet sapien." expectedErr := handshake.ErrWelcomeMessageLength - _, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, LongMessage, true, node1AddrInfo.ID, logger) + _, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nil, LongMessage, true, node1AddrInfo.ID, logger) if err == nil || err.Error() != expectedErr.Error() { t.Fatal("expected:", expectedErr, "got:", err) } @@ -289,22 +316,24 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.SynAck{ - Syn: &pb.Syn{ - ObservedUnderlay: node1maBinary, - }, - Ack: &pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node2BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_SynAck{ + SynAck: &pb.HandshakeSynAck{ + ObservedUnderlay: node1maBinary, + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node2BzzAddress.Signature, + }, + NetworkID: networkID, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, + Nonce: nonce, }, - Nonce: nonce, - NetworkID: networkID, - FullNode: true, }, - }, - ); err != nil { + }); err != nil { t.Fatal(err) } @@ -325,18 +354,21 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.SynAck{ - Syn: &pb.Syn{ - ObservedUnderlay: node1maBinary, - }, - Ack: &pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node2BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_SynAck{ + SynAck: &pb.HandshakeSynAck{ + ObservedUnderlay: node1maBinary, + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node2BzzAddress.Signature, + }, + NetworkID: 5, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, }, - NetworkID: 5, - FullNode: true, }, }); err != nil { t.Fatal(err) @@ -359,23 +391,26 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.SynAck{ - Syn: &pb.Syn{ - ObservedUnderlay: node1maBinary, - }, - Ack: &pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node1BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_SynAck{ + SynAck: &pb.HandshakeSynAck{ + ObservedUnderlay: node1maBinary, + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node1BzzAddress.Signature, + }, + NetworkID: networkID, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, }, - NetworkID: networkID, - FullNode: true, }, }); err != nil { t.Fatal(err) } - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nonce, testWelcomeMessage, true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nonce, testWelcomeMessage, true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -402,18 +437,21 @@ func TestHandshake(t *testing.T) { }() w, _ := protobuf.NewWriterAndReader(stream2) - if err := w.WriteMsg(&pb.SynAck{ - Syn: &pb.Syn{ - ObservedUnderlay: node1maBinary, - }, - Ack: &pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node2BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_SynAck{ + SynAck: &pb.HandshakeSynAck{ + ObservedUnderlay: node1maBinary, + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node2BzzAddress.Signature, + }, + NetworkID: networkID, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, }, - NetworkID: networkID, - FullNode: true, }, }); err != nil { t.Fatal(err) @@ -432,7 +470,7 @@ func TestHandshake(t *testing.T) { }) t.Run("Handle - OK", func(t *testing.T) { - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nonce, "", true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nonce, "", true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -442,21 +480,32 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.Syn{ - ObservedUnderlay: node1maBinary, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Syn{ + Syn: &pb.HandshakeSyn{ + ObservedUnderlay: node1maBinary, + }, + }, }); err != nil { t.Fatal(err) } - if err := w.WriteMsg(&pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node2BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Ack{ + Ack: &pb.HandshakeAck{ + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node2BzzAddress.Signature, + }, + NetworkID: networkID, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, + Nonce: nonce, + }, }, - NetworkID: networkID, - Nonce: nonce, - FullNode: true, }); err != nil { t.Fatal(err) } @@ -469,28 +518,29 @@ func TestHandshake(t *testing.T) { testInfo(t, *res, node2Info) _, r := protobuf.NewWriterAndReader(stream2) - var got pb.SynAck - if err := r.ReadMsg(&got); err != nil { + var gotMsg pb.Handshake + if err := r.ReadMsg(&gotMsg); err != nil { t.Fatal(err) } - if !bytes.Equal(got.Syn.ObservedUnderlay, node2maBinary) { - t.Fatalf("got bad syn") + got := gotMsg.GetSynAck() + if got == nil || !bytes.Equal(got.ObservedUnderlay, node2maBinary) { + t.Fatalf("got bad syn_ack") } - bzzAddress, err := bzz.ParseAddress(got.Ack.Address.Underlay, got.Ack.Address.Overlay, got.Ack.Address.Signature, got.Ack.Nonce, true, got.Ack.NetworkID) + bzzAddress, err := bzz.ParseAddress(got.Address.Underlay, got.Address.Overlay, got.Address.Signature, got.Nonce, true, got.NetworkID) if err != nil { t.Fatal(err) } testInfo(t, node1Info, handshake.Info{ - BzzAddress: bzzAddress, - FullNode: got.Ack.FullNode, + BzzAddress: bzzAddress, + Capabilities: got.Capabilities, }) }) t.Run("Handle - read error ", func(t *testing.T) { - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nil, "", true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -509,7 +559,7 @@ func TestHandshake(t *testing.T) { }) t.Run("Handle - write error ", func(t *testing.T) { - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nil, "", true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -519,8 +569,12 @@ func TestHandshake(t *testing.T) { stream := mock.NewStream(&buffer, &buffer) stream.SetWriteErr(testErr, 1) w := protobuf.NewWriter(stream) - if err := w.WriteMsg(&pb.Syn{ - ObservedUnderlay: node1maBinary, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Syn{ + Syn: &pb.HandshakeSyn{ + ObservedUnderlay: node1maBinary, + }, + }, }); err != nil { t.Fatal(err) } @@ -536,7 +590,7 @@ func TestHandshake(t *testing.T) { }) t.Run("Handle - ack read error ", func(t *testing.T) { - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nil, "", true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -548,8 +602,12 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) stream1.SetReadErr(testErr, 1) w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.Syn{ - ObservedUnderlay: node1maBinary, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Syn{ + Syn: &pb.HandshakeSyn{ + ObservedUnderlay: node1maBinary, + }, + }, }); err != nil { t.Fatal(err) } @@ -565,7 +623,7 @@ func TestHandshake(t *testing.T) { }) t.Run("Handle - networkID mismatch ", func(t *testing.T) { - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nil, "", true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -575,20 +633,31 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.Syn{ - ObservedUnderlay: node1maBinary, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Syn{ + Syn: &pb.HandshakeSyn{ + ObservedUnderlay: node1maBinary, + }, + }, }); err != nil { t.Fatal(err) } - if err := w.WriteMsg(&pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node2BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Ack{ + Ack: &pb.HandshakeAck{ + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node2BzzAddress.Signature, + }, + NetworkID: 5, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, + }, }, - NetworkID: 5, - FullNode: true, }); err != nil { t.Fatal(err) } @@ -604,7 +673,7 @@ func TestHandshake(t *testing.T) { }) t.Run("Handle - invalid ack", func(t *testing.T) { - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nil, "", true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -614,20 +683,31 @@ func TestHandshake(t *testing.T) { stream2 := mock.NewStream(&buffer2, &buffer1) w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.Syn{ - ObservedUnderlay: node1maBinary, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Syn{ + Syn: &pb.HandshakeSyn{ + ObservedUnderlay: node1maBinary, + }, + }, }); err != nil { t.Fatal(err) } - if err := w.WriteMsg(&pb.Ack{ - Address: &pb.BzzAddress{ - Underlay: node2maBinary, - Overlay: node2BzzAddress.Overlay.Bytes(), - Signature: node1BzzAddress.Signature, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Ack{ + Ack: &pb.HandshakeAck{ + Address: &pb.BzzAddress{ + Underlay: node2maBinary, + Overlay: node2BzzAddress.Overlay.Bytes(), + Signature: node1BzzAddress.Signature, + }, + NetworkID: networkID, + Capabilities: &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + }, + }, }, - NetworkID: networkID, - FullNode: true, }); err != nil { t.Fatal(err) } @@ -639,7 +719,7 @@ func TestHandshake(t *testing.T) { }) t.Run("Handle - advertisable error", func(t *testing.T) { - handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, nil, "", true, node1AddrInfo.ID, logger) + handshakeService, err := handshake.New(signer1, aaddresser, node1Info.BzzAddress.Overlay, networkID, true, true, nil, "", true, node1AddrInfo.ID, logger) if err != nil { t.Fatal(err) } @@ -655,8 +735,12 @@ func TestHandshake(t *testing.T) { }() w := protobuf.NewWriter(stream2) - if err := w.WriteMsg(&pb.Syn{ - ObservedUnderlay: node1maBinary, + if err := w.WriteMsg(&pb.Handshake{ + Payload: &pb.Handshake_Syn{ + Syn: &pb.HandshakeSyn{ + ObservedUnderlay: node1maBinary, + }, + }, }); err != nil { t.Fatal(err) } @@ -687,11 +771,81 @@ func (p *picker) Pick(peer p2p.Peer) bool { // testInfo validates if two Info instances are equal. func testInfo(t *testing.T, got, want handshake.Info) { t.Helper() - if !got.BzzAddress.Equal(want.BzzAddress) || got.FullNode != want.FullNode { + if !got.BzzAddress.Equal(want.BzzAddress) || + got.Capabilities.FullNode != want.Capabilities.FullNode || + got.Capabilities.TraceHeaders != want.Capabilities.TraceHeaders { t.Fatalf("got info %+v, want %+v", got, want) } } +func TestInfo_CapabilitiesReuse(t *testing.T) { + // Test that Info struct reuses the protobuf Capabilities object + // instead of creating separate boolean fields, reducing allocations + + capabilities := &pb.Capabilities{ + FullNode: true, + TraceHeaders: true, + } + + // Create a valid BZZ address using existing test helper + privateKey, err := crypto.GenerateSecp256k1Key() + if err != nil { + t.Fatal(err) + } + signer := crypto.NewDefaultSigner(privateKey) + + overlay := swarm.RandAddress(t) + underlay, _ := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/1234") + nonce := make([]byte, 32) + + bzzAddr, err := bzz.NewAddress(signer, underlay, overlay, 1, nonce) + if err != nil { + t.Fatal(err) + } + + info := handshake.Info{ + BzzAddress: bzzAddr, + Capabilities: capabilities, + } + + // Verify that the same protobuf object is being used + if info.Capabilities != capabilities { + t.Fatal("Info.Capabilities should reuse the same protobuf object") + } + + // Verify that capabilities can be accessed correctly + if !info.Capabilities.FullNode { + t.Fatal("Expected FullNode to be true") + } + + if !info.Capabilities.TraceHeaders { + t.Fatal("Expected TraceHeaders to be true") + } + + // Verify LightString method works + lightStr := info.LightString() + if lightStr != "" { + t.Fatalf("Expected empty string for full node, got %q", lightStr) + } + + // Test light node + lightCapabilities := &pb.Capabilities{ + FullNode: false, + TraceHeaders: false, + } + + lightInfo := handshake.Info{ + BzzAddress: bzzAddr, + Capabilities: lightCapabilities, + } + + lightStr = lightInfo.LightString() + expected := " (light)" + if lightStr != expected { + t.Fatalf("Expected %q for light node, got %q", expected, lightStr) + } +} + type AdvertisableAddresserMock struct { advertisableAddress ma.Multiaddr err error diff --git a/pkg/p2p/libp2p/internal/handshake/pb/handshake.pb.go b/pkg/p2p/libp2p/internal/handshake/pb/handshake.pb.go index 0ae7f00ebf6..da851c98d8a 100644 --- a/pkg/p2p/libp2p/internal/handshake/pb/handshake.pb.go +++ b/pkg/p2p/libp2p/internal/handshake/pb/handshake.pb.go @@ -22,22 +22,27 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package -type Syn struct { - ObservedUnderlay []byte `protobuf:"bytes,1,opt,name=ObservedUnderlay,proto3" json:"ObservedUnderlay,omitempty"` +type Handshake struct { + // Types that are valid to be assigned to Payload: + // + // *Handshake_Syn + // *Handshake_SynAck + // *Handshake_Ack + Payload isHandshake_Payload `protobuf_oneof:"payload"` } -func (m *Syn) Reset() { *m = Syn{} } -func (m *Syn) String() string { return proto.CompactTextString(m) } -func (*Syn) ProtoMessage() {} -func (*Syn) Descriptor() ([]byte, []int) { +func (m *Handshake) Reset() { *m = Handshake{} } +func (m *Handshake) String() string { return proto.CompactTextString(m) } +func (*Handshake) ProtoMessage() {} +func (*Handshake) Descriptor() ([]byte, []int) { return fileDescriptor_a77305914d5d202f, []int{0} } -func (m *Syn) XXX_Unmarshal(b []byte) error { +func (m *Handshake) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Syn) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *Handshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Syn.Marshal(b, m, deterministic) + return xxx_messageInfo_Handshake.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -47,45 +52,140 @@ func (m *Syn) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Syn) XXX_Merge(src proto.Message) { - xxx_messageInfo_Syn.Merge(m, src) +func (m *Handshake) XXX_Merge(src proto.Message) { + xxx_messageInfo_Handshake.Merge(m, src) } -func (m *Syn) XXX_Size() int { +func (m *Handshake) XXX_Size() int { return m.Size() } -func (m *Syn) XXX_DiscardUnknown() { - xxx_messageInfo_Syn.DiscardUnknown(m) +func (m *Handshake) XXX_DiscardUnknown() { + xxx_messageInfo_Handshake.DiscardUnknown(m) +} + +var xxx_messageInfo_Handshake proto.InternalMessageInfo + +type isHandshake_Payload interface { + isHandshake_Payload() + MarshalTo([]byte) (int, error) + Size() int +} + +type Handshake_Syn struct { + Syn *HandshakeSyn `protobuf:"bytes,1,opt,name=syn,proto3,oneof" json:"syn,omitempty"` +} +type Handshake_SynAck struct { + SynAck *HandshakeSynAck `protobuf:"bytes,2,opt,name=syn_ack,json=synAck,proto3,oneof" json:"syn_ack,omitempty"` +} +type Handshake_Ack struct { + Ack *HandshakeAck `protobuf:"bytes,3,opt,name=ack,proto3,oneof" json:"ack,omitempty"` } -var xxx_messageInfo_Syn proto.InternalMessageInfo +func (*Handshake_Syn) isHandshake_Payload() {} +func (*Handshake_SynAck) isHandshake_Payload() {} +func (*Handshake_Ack) isHandshake_Payload() {} -func (m *Syn) GetObservedUnderlay() []byte { +func (m *Handshake) GetPayload() isHandshake_Payload { if m != nil { - return m.ObservedUnderlay + return m.Payload + } + return nil +} + +func (m *Handshake) GetSyn() *HandshakeSyn { + if x, ok := m.GetPayload().(*Handshake_Syn); ok { + return x.Syn + } + return nil +} + +func (m *Handshake) GetSynAck() *HandshakeSynAck { + if x, ok := m.GetPayload().(*Handshake_SynAck); ok { + return x.SynAck + } + return nil +} + +func (m *Handshake) GetAck() *HandshakeAck { + if x, ok := m.GetPayload().(*Handshake_Ack); ok { + return x.Ack } return nil } -type Ack struct { - Address *BzzAddress `protobuf:"bytes,1,opt,name=Address,proto3" json:"Address,omitempty"` - NetworkID uint64 `protobuf:"varint,2,opt,name=NetworkID,proto3" json:"NetworkID,omitempty"` - FullNode bool `protobuf:"varint,3,opt,name=FullNode,proto3" json:"FullNode,omitempty"` - Nonce []byte `protobuf:"bytes,4,opt,name=Nonce,proto3" json:"Nonce,omitempty"` - WelcomeMessage string `protobuf:"bytes,99,opt,name=WelcomeMessage,proto3" json:"WelcomeMessage,omitempty"` +// XXX_OneofWrappers is for the internal use of the proto package. +func (*Handshake) XXX_OneofWrappers() []interface{} { + return []interface{}{ + (*Handshake_Syn)(nil), + (*Handshake_SynAck)(nil), + (*Handshake_Ack)(nil), + } +} + +type HandshakeSyn struct { + ObservedUnderlay []byte `protobuf:"bytes,1,opt,name=ObservedUnderlay,proto3" json:"ObservedUnderlay,omitempty"` } -func (m *Ack) Reset() { *m = Ack{} } -func (m *Ack) String() string { return proto.CompactTextString(m) } -func (*Ack) ProtoMessage() {} -func (*Ack) Descriptor() ([]byte, []int) { +func (m *HandshakeSyn) Reset() { *m = HandshakeSyn{} } +func (m *HandshakeSyn) String() string { return proto.CompactTextString(m) } +func (*HandshakeSyn) ProtoMessage() {} +func (*HandshakeSyn) Descriptor() ([]byte, []int) { return fileDescriptor_a77305914d5d202f, []int{1} } -func (m *Ack) XXX_Unmarshal(b []byte) error { +func (m *HandshakeSyn) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *HandshakeSyn) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_HandshakeSyn.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *HandshakeSyn) XXX_Merge(src proto.Message) { + xxx_messageInfo_HandshakeSyn.Merge(m, src) +} +func (m *HandshakeSyn) XXX_Size() int { + return m.Size() +} +func (m *HandshakeSyn) XXX_DiscardUnknown() { + xxx_messageInfo_HandshakeSyn.DiscardUnknown(m) +} + +var xxx_messageInfo_HandshakeSyn proto.InternalMessageInfo + +func (m *HandshakeSyn) GetObservedUnderlay() []byte { + if m != nil { + return m.ObservedUnderlay + } + return nil +} + +type HandshakeSynAck struct { + ObservedUnderlay []byte `protobuf:"bytes,1,opt,name=ObservedUnderlay,proto3" json:"ObservedUnderlay,omitempty"` + Address *BzzAddress `protobuf:"bytes,2,opt,name=Address,proto3" json:"Address,omitempty"` + NetworkID uint64 `protobuf:"varint,3,opt,name=NetworkID,proto3" json:"NetworkID,omitempty"` + Capabilities *Capabilities `protobuf:"bytes,4,opt,name=capabilities,proto3" json:"capabilities,omitempty"` + Nonce []byte `protobuf:"bytes,5,opt,name=Nonce,proto3" json:"Nonce,omitempty"` + WelcomeMessage string `protobuf:"bytes,99,opt,name=WelcomeMessage,proto3" json:"WelcomeMessage,omitempty"` +} + +func (m *HandshakeSynAck) Reset() { *m = HandshakeSynAck{} } +func (m *HandshakeSynAck) String() string { return proto.CompactTextString(m) } +func (*HandshakeSynAck) ProtoMessage() {} +func (*HandshakeSynAck) Descriptor() ([]byte, []int) { + return fileDescriptor_a77305914d5d202f, []int{2} +} +func (m *HandshakeSynAck) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Ack) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *HandshakeSynAck) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Ack.Marshal(b, m, deterministic) + return xxx_messageInfo_HandshakeSynAck.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -95,70 +195,80 @@ func (m *Ack) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Ack) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ack.Merge(m, src) +func (m *HandshakeSynAck) XXX_Merge(src proto.Message) { + xxx_messageInfo_HandshakeSynAck.Merge(m, src) } -func (m *Ack) XXX_Size() int { +func (m *HandshakeSynAck) XXX_Size() int { return m.Size() } -func (m *Ack) XXX_DiscardUnknown() { - xxx_messageInfo_Ack.DiscardUnknown(m) +func (m *HandshakeSynAck) XXX_DiscardUnknown() { + xxx_messageInfo_HandshakeSynAck.DiscardUnknown(m) } -var xxx_messageInfo_Ack proto.InternalMessageInfo +var xxx_messageInfo_HandshakeSynAck proto.InternalMessageInfo + +func (m *HandshakeSynAck) GetObservedUnderlay() []byte { + if m != nil { + return m.ObservedUnderlay + } + return nil +} -func (m *Ack) GetAddress() *BzzAddress { +func (m *HandshakeSynAck) GetAddress() *BzzAddress { if m != nil { return m.Address } return nil } -func (m *Ack) GetNetworkID() uint64 { +func (m *HandshakeSynAck) GetNetworkID() uint64 { if m != nil { return m.NetworkID } return 0 } -func (m *Ack) GetFullNode() bool { +func (m *HandshakeSynAck) GetCapabilities() *Capabilities { if m != nil { - return m.FullNode + return m.Capabilities } - return false + return nil } -func (m *Ack) GetNonce() []byte { +func (m *HandshakeSynAck) GetNonce() []byte { if m != nil { return m.Nonce } return nil } -func (m *Ack) GetWelcomeMessage() string { +func (m *HandshakeSynAck) GetWelcomeMessage() string { if m != nil { return m.WelcomeMessage } return "" } -type SynAck struct { - Syn *Syn `protobuf:"bytes,1,opt,name=Syn,proto3" json:"Syn,omitempty"` - Ack *Ack `protobuf:"bytes,2,opt,name=Ack,proto3" json:"Ack,omitempty"` +type HandshakeAck struct { + Address *BzzAddress `protobuf:"bytes,1,opt,name=Address,proto3" json:"Address,omitempty"` + NetworkID uint64 `protobuf:"varint,2,opt,name=NetworkID,proto3" json:"NetworkID,omitempty"` + Capabilities *Capabilities `protobuf:"bytes,3,opt,name=capabilities,proto3" json:"capabilities,omitempty"` + Nonce []byte `protobuf:"bytes,4,opt,name=Nonce,proto3" json:"Nonce,omitempty"` + WelcomeMessage string `protobuf:"bytes,99,opt,name=WelcomeMessage,proto3" json:"WelcomeMessage,omitempty"` } -func (m *SynAck) Reset() { *m = SynAck{} } -func (m *SynAck) String() string { return proto.CompactTextString(m) } -func (*SynAck) ProtoMessage() {} -func (*SynAck) Descriptor() ([]byte, []int) { - return fileDescriptor_a77305914d5d202f, []int{2} +func (m *HandshakeAck) Reset() { *m = HandshakeAck{} } +func (m *HandshakeAck) String() string { return proto.CompactTextString(m) } +func (*HandshakeAck) ProtoMessage() {} +func (*HandshakeAck) Descriptor() ([]byte, []int) { + return fileDescriptor_a77305914d5d202f, []int{3} } -func (m *SynAck) XXX_Unmarshal(b []byte) error { +func (m *HandshakeAck) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *SynAck) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *HandshakeAck) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_SynAck.Marshal(b, m, deterministic) + return xxx_messageInfo_HandshakeAck.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -168,32 +278,105 @@ func (m *SynAck) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *SynAck) XXX_Merge(src proto.Message) { - xxx_messageInfo_SynAck.Merge(m, src) +func (m *HandshakeAck) XXX_Merge(src proto.Message) { + xxx_messageInfo_HandshakeAck.Merge(m, src) } -func (m *SynAck) XXX_Size() int { +func (m *HandshakeAck) XXX_Size() int { return m.Size() } -func (m *SynAck) XXX_DiscardUnknown() { - xxx_messageInfo_SynAck.DiscardUnknown(m) +func (m *HandshakeAck) XXX_DiscardUnknown() { + xxx_messageInfo_HandshakeAck.DiscardUnknown(m) } -var xxx_messageInfo_SynAck proto.InternalMessageInfo +var xxx_messageInfo_HandshakeAck proto.InternalMessageInfo -func (m *SynAck) GetSyn() *Syn { +func (m *HandshakeAck) GetAddress() *BzzAddress { if m != nil { - return m.Syn + return m.Address + } + return nil +} + +func (m *HandshakeAck) GetNetworkID() uint64 { + if m != nil { + return m.NetworkID + } + return 0 +} + +func (m *HandshakeAck) GetCapabilities() *Capabilities { + if m != nil { + return m.Capabilities } return nil } -func (m *SynAck) GetAck() *Ack { +func (m *HandshakeAck) GetNonce() []byte { if m != nil { - return m.Ack + return m.Nonce } return nil } +func (m *HandshakeAck) GetWelcomeMessage() string { + if m != nil { + return m.WelcomeMessage + } + return "" +} + +type Capabilities struct { + FullNode bool `protobuf:"varint,1,opt,name=full_node,json=fullNode,proto3" json:"full_node,omitempty"` + TraceHeaders bool `protobuf:"varint,2,opt,name=trace_headers,json=traceHeaders,proto3" json:"trace_headers,omitempty"` +} + +func (m *Capabilities) Reset() { *m = Capabilities{} } +func (m *Capabilities) String() string { return proto.CompactTextString(m) } +func (*Capabilities) ProtoMessage() {} +func (*Capabilities) Descriptor() ([]byte, []int) { + return fileDescriptor_a77305914d5d202f, []int{4} +} +func (m *Capabilities) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Capabilities) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Capabilities.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Capabilities) XXX_Merge(src proto.Message) { + xxx_messageInfo_Capabilities.Merge(m, src) +} +func (m *Capabilities) XXX_Size() int { + return m.Size() +} +func (m *Capabilities) XXX_DiscardUnknown() { + xxx_messageInfo_Capabilities.DiscardUnknown(m) +} + +var xxx_messageInfo_Capabilities proto.InternalMessageInfo + +func (m *Capabilities) GetFullNode() bool { + if m != nil { + return m.FullNode + } + return false +} + +func (m *Capabilities) GetTraceHeaders() bool { + if m != nil { + return m.TraceHeaders + } + return false +} + type BzzAddress struct { Underlay []byte `protobuf:"bytes,1,opt,name=Underlay,proto3" json:"Underlay,omitempty"` Signature []byte `protobuf:"bytes,2,opt,name=Signature,proto3" json:"Signature,omitempty"` @@ -204,7 +387,7 @@ func (m *BzzAddress) Reset() { *m = BzzAddress{} } func (m *BzzAddress) String() string { return proto.CompactTextString(m) } func (*BzzAddress) ProtoMessage() {} func (*BzzAddress) Descriptor() ([]byte, []int) { - return fileDescriptor_a77305914d5d202f, []int{3} + return fileDescriptor_a77305914d5d202f, []int{5} } func (m *BzzAddress) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -255,39 +438,49 @@ func (m *BzzAddress) GetOverlay() []byte { } func init() { - proto.RegisterType((*Syn)(nil), "handshake.Syn") - proto.RegisterType((*Ack)(nil), "handshake.Ack") - proto.RegisterType((*SynAck)(nil), "handshake.SynAck") + proto.RegisterType((*Handshake)(nil), "handshake.Handshake") + proto.RegisterType((*HandshakeSyn)(nil), "handshake.HandshakeSyn") + proto.RegisterType((*HandshakeSynAck)(nil), "handshake.HandshakeSynAck") + proto.RegisterType((*HandshakeAck)(nil), "handshake.HandshakeAck") + proto.RegisterType((*Capabilities)(nil), "handshake.Capabilities") proto.RegisterType((*BzzAddress)(nil), "handshake.BzzAddress") } func init() { proto.RegisterFile("handshake.proto", fileDescriptor_a77305914d5d202f) } var fileDescriptor_a77305914d5d202f = []byte{ - // 318 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x64, 0x91, 0xcd, 0x4a, 0xc3, 0x40, - 0x14, 0x85, 0x3b, 0x4d, 0xed, 0xcf, 0xb5, 0x54, 0x19, 0x14, 0x82, 0x94, 0x10, 0xb2, 0x90, 0xe0, - 0xa2, 0xa2, 0x3e, 0x41, 0x8b, 0x08, 0x82, 0xb6, 0x30, 0x41, 0x04, 0x57, 0xa6, 0x99, 0x4b, 0x2b, - 0x89, 0x33, 0x65, 0xa6, 0xad, 0xa4, 0x4f, 0xe1, 0x93, 0xf8, 0x1c, 0x2e, 0xbb, 0x74, 0x29, 0xed, - 0x8b, 0x48, 0xa6, 0x3f, 0xd1, 0xba, 0x3c, 0xe7, 0x9e, 0x99, 0xf9, 0xce, 0x1d, 0x38, 0x18, 0x86, - 0x82, 0xeb, 0x61, 0x18, 0x63, 0x6b, 0xa4, 0xe4, 0x58, 0xd2, 0xda, 0xd6, 0xf0, 0x2e, 0xc0, 0x0a, - 0x52, 0x41, 0xcf, 0xe0, 0xb0, 0xd7, 0xd7, 0xa8, 0xa6, 0xc8, 0x1f, 0x04, 0x47, 0x95, 0x84, 0xa9, - 0x4d, 0x5c, 0xe2, 0xd7, 0xd9, 0x3f, 0xdf, 0xfb, 0x20, 0x60, 0xb5, 0xa3, 0x98, 0x9e, 0x43, 0xa5, - 0xcd, 0xb9, 0x42, 0xad, 0x4d, 0x74, 0xff, 0xf2, 0xb8, 0x95, 0x3f, 0xd4, 0x99, 0xcd, 0xd6, 0x43, - 0xb6, 0x49, 0xd1, 0x26, 0xd4, 0xba, 0x38, 0x7e, 0x93, 0x2a, 0xbe, 0xbd, 0xb6, 0x8b, 0x2e, 0xf1, - 0x4b, 0x2c, 0x37, 0xe8, 0x09, 0x54, 0x6f, 0x26, 0x49, 0xd2, 0x95, 0x1c, 0x6d, 0xcb, 0x25, 0x7e, - 0x95, 0x6d, 0x35, 0x3d, 0x82, 0xbd, 0xae, 0x14, 0x11, 0xda, 0x25, 0xc3, 0xb4, 0x12, 0xf4, 0x14, - 0x1a, 0x8f, 0x98, 0x44, 0xf2, 0x15, 0xef, 0x51, 0xeb, 0x70, 0x80, 0x76, 0xe4, 0x12, 0xbf, 0xc6, - 0x76, 0x5c, 0xef, 0x0e, 0xca, 0x41, 0x2a, 0x32, 0x64, 0xd7, 0xb4, 0x5d, 0xe3, 0x36, 0x7e, 0xe1, - 0x06, 0xa9, 0x60, 0x66, 0x11, 0xae, 0xe9, 0x66, 0xe8, 0xfe, 0x26, 0xda, 0x51, 0xcc, 0xb2, 0x91, - 0xf7, 0x0c, 0x90, 0x97, 0xcb, 0xa8, 0x77, 0x16, 0xb6, 0xd5, 0x59, 0xdf, 0xe0, 0x65, 0x20, 0xc2, - 0xf1, 0x44, 0xa1, 0xb9, 0xb1, 0xce, 0x72, 0x83, 0xda, 0x50, 0xe9, 0x4d, 0x57, 0x07, 0x2d, 0x33, - 0xdb, 0xc8, 0x4e, 0xf3, 0x73, 0xe1, 0x90, 0xf9, 0xc2, 0x21, 0xdf, 0x0b, 0x87, 0xbc, 0x2f, 0x9d, - 0xc2, 0x7c, 0xe9, 0x14, 0xbe, 0x96, 0x4e, 0xe1, 0xa9, 0x38, 0xea, 0xf7, 0xcb, 0xe6, 0x0f, 0xaf, - 0x7e, 0x02, 0x00, 0x00, 0xff, 0xff, 0x87, 0xa7, 0x49, 0x00, 0xd6, 0x01, 0x00, 0x00, -} - -func (m *Syn) Marshal() (dAtA []byte, err error) { + // 438 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x53, 0xc1, 0x6e, 0xd3, 0x40, + 0x10, 0xf5, 0x26, 0x69, 0x13, 0x0f, 0x86, 0xa2, 0x15, 0x08, 0xab, 0x54, 0x56, 0x65, 0x24, 0x54, + 0x81, 0x54, 0x24, 0x10, 0x17, 0x38, 0x35, 0x70, 0x08, 0x07, 0x52, 0xb4, 0x15, 0x42, 0xe2, 0x12, + 0xd6, 0xde, 0xa1, 0xb1, 0x6c, 0x76, 0xad, 0x5d, 0xb7, 0xc8, 0xfd, 0x01, 0xae, 0xfc, 0x05, 0xbf, + 0xc2, 0xb1, 0xc7, 0x1e, 0x51, 0xf2, 0x23, 0xc8, 0xeb, 0x26, 0x36, 0xa9, 0x2a, 0xd4, 0xdc, 0x3c, + 0x6f, 0xde, 0x9b, 0x99, 0xf7, 0x6c, 0xc3, 0xd6, 0x94, 0x4b, 0x61, 0xa6, 0x3c, 0xc5, 0xfd, 0x5c, + 0xab, 0x42, 0x51, 0x77, 0x09, 0x84, 0xbf, 0x08, 0xb8, 0xa3, 0x45, 0x45, 0x9f, 0x42, 0xd7, 0x94, + 0xd2, 0x27, 0xbb, 0x64, 0xef, 0xd6, 0xf3, 0x07, 0xfb, 0x8d, 0x6e, 0x49, 0x39, 0x2a, 0xe5, 0xc8, + 0x61, 0x15, 0x8b, 0xbe, 0x84, 0xbe, 0x29, 0xe5, 0x84, 0xc7, 0xa9, 0xdf, 0xb1, 0x82, 0xed, 0x6b, + 0x04, 0x07, 0x71, 0x3a, 0x72, 0xd8, 0xa6, 0xb1, 0x4f, 0xd5, 0x8e, 0x4a, 0xd2, 0xbd, 0x7e, 0x47, + 0xcd, 0xaf, 0x58, 0x43, 0x17, 0xfa, 0x39, 0x2f, 0x33, 0xc5, 0x45, 0xf8, 0x0a, 0xbc, 0xf6, 0x50, + 0xfa, 0x04, 0xee, 0x1e, 0x46, 0x06, 0xf5, 0x29, 0x8a, 0x8f, 0x52, 0xa0, 0xce, 0x78, 0x69, 0x0f, + 0xf7, 0xd8, 0x15, 0x3c, 0xfc, 0xd1, 0x81, 0xad, 0x95, 0x8b, 0x6e, 0xa2, 0xa7, 0xcf, 0xa0, 0x7f, + 0x20, 0x84, 0x46, 0x63, 0x2e, 0xad, 0xde, 0x6f, 0xdd, 0x3d, 0x3c, 0x3b, 0xbb, 0x6c, 0xb2, 0x05, + 0x8b, 0xee, 0x80, 0x3b, 0xc6, 0xe2, 0xbb, 0xd2, 0xe9, 0xbb, 0xb7, 0xd6, 0x6a, 0x8f, 0x35, 0x00, + 0x7d, 0x0d, 0x5e, 0xcc, 0x73, 0x1e, 0x25, 0x59, 0x52, 0x24, 0x68, 0xfc, 0xde, 0x95, 0x2c, 0xde, + 0xb4, 0xda, 0xec, 0x1f, 0x32, 0xbd, 0x07, 0x1b, 0x63, 0x25, 0x63, 0xf4, 0x37, 0xec, 0xb1, 0x75, + 0x41, 0x1f, 0xc3, 0x9d, 0x4f, 0x98, 0xc5, 0xea, 0x1b, 0xbe, 0x47, 0x63, 0xf8, 0x31, 0xfa, 0xf1, + 0x2e, 0xd9, 0x73, 0xd9, 0x0a, 0x1a, 0x5e, 0x90, 0x56, 0x8c, 0x55, 0x0c, 0x2d, 0x6b, 0xe4, 0xe6, + 0xd6, 0x3a, 0xff, 0xb3, 0xd6, 0x5d, 0xcb, 0x5a, 0x6f, 0x1d, 0x6b, 0x1f, 0xc0, 0x6b, 0xcf, 0xa6, + 0x0f, 0xc1, 0xfd, 0x7a, 0x92, 0x65, 0x13, 0xa9, 0x04, 0x5a, 0x6f, 0x03, 0x36, 0xa8, 0x80, 0xb1, + 0x12, 0x48, 0x1f, 0xc1, 0xed, 0x42, 0xf3, 0x18, 0x27, 0x53, 0xe4, 0x02, 0x75, 0xfd, 0x5e, 0x07, + 0xcc, 0xb3, 0xe0, 0xa8, 0xc6, 0xc2, 0x2f, 0x00, 0x4d, 0x02, 0x74, 0x1b, 0x06, 0x2b, 0x1f, 0xca, + 0xb2, 0xae, 0x42, 0x39, 0x4a, 0x8e, 0x25, 0x2f, 0x4e, 0x34, 0xda, 0x51, 0x1e, 0x6b, 0x00, 0xea, + 0x43, 0xff, 0xf0, 0xb4, 0x16, 0x76, 0x6d, 0x6f, 0x51, 0x0e, 0x77, 0x7e, 0xcf, 0x02, 0x72, 0x3e, + 0x0b, 0xc8, 0x9f, 0x59, 0x40, 0x7e, 0xce, 0x03, 0xe7, 0x7c, 0x1e, 0x38, 0x17, 0xf3, 0xc0, 0xf9, + 0xdc, 0xc9, 0xa3, 0x68, 0xd3, 0xfe, 0xae, 0x2f, 0xfe, 0x06, 0x00, 0x00, 0xff, 0xff, 0xb7, 0x0a, + 0x0e, 0x1c, 0xc1, 0x03, 0x00, 0x00, +} + +func (m *Handshake) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -297,12 +490,107 @@ func (m *Syn) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Syn) MarshalTo(dAtA []byte) (int, error) { +func (m *Handshake) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Syn) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *Handshake) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Payload != nil { + { + size := m.Payload.Size() + i -= size + if _, err := m.Payload.MarshalTo(dAtA[i:]); err != nil { + return 0, err + } + } + } + return len(dAtA) - i, nil +} + +func (m *Handshake_Syn) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Handshake_Syn) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + if m.Syn != nil { + { + size, err := m.Syn.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintHandshake(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} +func (m *Handshake_SynAck) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Handshake_SynAck) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + if m.SynAck != nil { + { + size, err := m.SynAck.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintHandshake(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + } + return len(dAtA) - i, nil +} +func (m *Handshake_Ack) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Handshake_Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + if m.Ack != nil { + { + size, err := m.Ack.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintHandshake(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x1a + } + return len(dAtA) - i, nil +} +func (m *HandshakeSyn) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *HandshakeSyn) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *HandshakeSyn) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -317,7 +605,7 @@ func (m *Syn) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } -func (m *Ack) Marshal() (dAtA []byte, err error) { +func (m *HandshakeSynAck) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -327,12 +615,12 @@ func (m *Ack) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ack) MarshalTo(dAtA []byte) (int, error) { +func (m *HandshakeSynAck) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *HandshakeSynAck) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -351,22 +639,24 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) { copy(dAtA[i:], m.Nonce) i = encodeVarintHandshake(dAtA, i, uint64(len(m.Nonce))) i-- - dAtA[i] = 0x22 + dAtA[i] = 0x2a } - if m.FullNode { - i-- - if m.FullNode { - dAtA[i] = 1 - } else { - dAtA[i] = 0 + if m.Capabilities != nil { + { + size, err := m.Capabilities.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintHandshake(dAtA, i, uint64(size)) } i-- - dAtA[i] = 0x18 + dAtA[i] = 0x22 } if m.NetworkID != 0 { i = encodeVarintHandshake(dAtA, i, uint64(m.NetworkID)) i-- - dAtA[i] = 0x10 + dAtA[i] = 0x18 } if m.Address != nil { { @@ -378,12 +668,19 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) { i = encodeVarintHandshake(dAtA, i, uint64(size)) } i-- + dAtA[i] = 0x12 + } + if len(m.ObservedUnderlay) > 0 { + i -= len(m.ObservedUnderlay) + copy(dAtA[i:], m.ObservedUnderlay) + i = encodeVarintHandshake(dAtA, i, uint64(len(m.ObservedUnderlay))) + i-- dAtA[i] = 0xa } return len(dAtA) - i, nil } -func (m *SynAck) Marshal() (dAtA []byte, err error) { +func (m *HandshakeAck) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -393,19 +690,35 @@ func (m *SynAck) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *SynAck) MarshalTo(dAtA []byte) (int, error) { +func (m *HandshakeAck) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *SynAck) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *HandshakeAck) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l - if m.Ack != nil { + if len(m.WelcomeMessage) > 0 { + i -= len(m.WelcomeMessage) + copy(dAtA[i:], m.WelcomeMessage) + i = encodeVarintHandshake(dAtA, i, uint64(len(m.WelcomeMessage))) + i-- + dAtA[i] = 0x6 + i-- + dAtA[i] = 0x9a + } + if len(m.Nonce) > 0 { + i -= len(m.Nonce) + copy(dAtA[i:], m.Nonce) + i = encodeVarintHandshake(dAtA, i, uint64(len(m.Nonce))) + i-- + dAtA[i] = 0x22 + } + if m.Capabilities != nil { { - size, err := m.Ack.MarshalToSizedBuffer(dAtA[:i]) + size, err := m.Capabilities.MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -413,11 +726,16 @@ func (m *SynAck) MarshalToSizedBuffer(dAtA []byte) (int, error) { i = encodeVarintHandshake(dAtA, i, uint64(size)) } i-- - dAtA[i] = 0x12 + dAtA[i] = 0x1a } - if m.Syn != nil { + if m.NetworkID != 0 { + i = encodeVarintHandshake(dAtA, i, uint64(m.NetworkID)) + i-- + dAtA[i] = 0x10 + } + if m.Address != nil { { - size, err := m.Syn.MarshalToSizedBuffer(dAtA[:i]) + size, err := m.Address.MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -430,6 +748,49 @@ func (m *SynAck) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *Capabilities) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Capabilities) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Capabilities) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.TraceHeaders { + i-- + if m.TraceHeaders { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x10 + } + if m.FullNode { + i-- + if m.FullNode { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + func (m *BzzAddress) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) @@ -485,7 +846,55 @@ func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int { dAtA[offset] = uint8(v) return base } -func (m *Syn) Size() (n int) { +func (m *Handshake) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Payload != nil { + n += m.Payload.Size() + } + return n +} + +func (m *Handshake_Syn) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Syn != nil { + l = m.Syn.Size() + n += 1 + l + sovHandshake(uint64(l)) + } + return n +} +func (m *Handshake_SynAck) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.SynAck != nil { + l = m.SynAck.Size() + n += 1 + l + sovHandshake(uint64(l)) + } + return n +} +func (m *Handshake_Ack) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Ack != nil { + l = m.Ack.Size() + n += 1 + l + sovHandshake(uint64(l)) + } + return n +} +func (m *HandshakeSyn) Size() (n int) { if m == nil { return 0 } @@ -498,12 +907,16 @@ func (m *Syn) Size() (n int) { return n } -func (m *Ack) Size() (n int) { +func (m *HandshakeSynAck) Size() (n int) { if m == nil { return 0 } var l int _ = l + l = len(m.ObservedUnderlay) + if l > 0 { + n += 1 + l + sovHandshake(uint64(l)) + } if m.Address != nil { l = m.Address.Size() n += 1 + l + sovHandshake(uint64(l)) @@ -511,8 +924,9 @@ func (m *Ack) Size() (n int) { if m.NetworkID != 0 { n += 1 + sovHandshake(uint64(m.NetworkID)) } - if m.FullNode { - n += 2 + if m.Capabilities != nil { + l = m.Capabilities.Size() + n += 1 + l + sovHandshake(uint64(l)) } l = len(m.Nonce) if l > 0 { @@ -525,20 +939,46 @@ func (m *Ack) Size() (n int) { return n } -func (m *SynAck) Size() (n int) { +func (m *HandshakeAck) Size() (n int) { if m == nil { return 0 } var l int _ = l - if m.Syn != nil { - l = m.Syn.Size() + if m.Address != nil { + l = m.Address.Size() n += 1 + l + sovHandshake(uint64(l)) } - if m.Ack != nil { - l = m.Ack.Size() + if m.NetworkID != 0 { + n += 1 + sovHandshake(uint64(m.NetworkID)) + } + if m.Capabilities != nil { + l = m.Capabilities.Size() + n += 1 + l + sovHandshake(uint64(l)) + } + l = len(m.Nonce) + if l > 0 { n += 1 + l + sovHandshake(uint64(l)) } + l = len(m.WelcomeMessage) + if l > 0 { + n += 2 + l + sovHandshake(uint64(l)) + } + return n +} + +func (m *Capabilities) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.FullNode { + n += 2 + } + if m.TraceHeaders { + n += 2 + } return n } @@ -560,16 +1000,174 @@ func (m *BzzAddress) Size() (n int) { if l > 0 { n += 1 + l + sovHandshake(uint64(l)) } - return n -} + return n +} + +func sovHandshake(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozHandshake(x uint64) (n int) { + return sovHandshake(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *Handshake) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Handshake: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Handshake: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Syn", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthHandshake + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthHandshake + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + v := &HandshakeSyn{} + if err := v.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + m.Payload = &Handshake_Syn{v} + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SynAck", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthHandshake + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthHandshake + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + v := &HandshakeSynAck{} + if err := v.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + m.Payload = &Handshake_SynAck{v} + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Ack", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthHandshake + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthHandshake + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + v := &HandshakeAck{} + if err := v.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + m.Payload = &Handshake_Ack{v} + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipHandshake(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthHandshake + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthHandshake + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } -func sovHandshake(x uint64) (n int) { - return (math_bits.Len64(x|1) + 6) / 7 -} -func sozHandshake(x uint64) (n int) { - return sovHandshake(uint64((x << 1) ^ uint64((int64(x) >> 63)))) + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil } -func (m *Syn) Unmarshal(dAtA []byte) error { +func (m *HandshakeSyn) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -592,10 +1190,10 @@ func (m *Syn) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Syn: wiretype end group for non-group") + return fmt.Errorf("proto: HandshakeSyn: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Syn: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: HandshakeSyn: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -656,7 +1254,7 @@ func (m *Syn) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ack) Unmarshal(dAtA []byte) error { +func (m *HandshakeSynAck) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -679,13 +1277,47 @@ func (m *Ack) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ack: wiretype end group for non-group") + return fmt.Errorf("proto: HandshakeSynAck: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ack: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: HandshakeSynAck: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ObservedUnderlay", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthHandshake + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthHandshake + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ObservedUnderlay = append(m.ObservedUnderlay[:0], dAtA[iNdEx:postIndex]...) + if m.ObservedUnderlay == nil { + m.ObservedUnderlay = []byte{} + } + iNdEx = postIndex + case 2: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Address", wireType) } @@ -721,7 +1353,7 @@ func (m *Ack) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex - case 2: + case 3: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field NetworkID", wireType) } @@ -740,11 +1372,11 @@ func (m *Ack) Unmarshal(dAtA []byte) error { break } } - case 3: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field FullNode", wireType) + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Capabilities", wireType) } - var v int + var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowHandshake @@ -754,13 +1386,29 @@ func (m *Ack) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - v |= int(b&0x7F) << shift + msglen |= int(b&0x7F) << shift if b < 0x80 { break } } - m.FullNode = bool(v != 0) - case 4: + if msglen < 0 { + return ErrInvalidLengthHandshake + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthHandshake + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Capabilities == nil { + m.Capabilities = &Capabilities{} + } + if err := m.Capabilities.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 5: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Nonce", wireType) } @@ -850,7 +1498,7 @@ func (m *Ack) Unmarshal(dAtA []byte) error { } return nil } -func (m *SynAck) Unmarshal(dAtA []byte) error { +func (m *HandshakeAck) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -873,15 +1521,15 @@ func (m *SynAck) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: SynAck: wiretype end group for non-group") + return fmt.Errorf("proto: HandshakeAck: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: SynAck: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: HandshakeAck: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Syn", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Address", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -908,16 +1556,35 @@ func (m *SynAck) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - if m.Syn == nil { - m.Syn = &Syn{} + if m.Address == nil { + m.Address = &BzzAddress{} } - if err := m.Syn.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + if err := m.Address.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field NetworkID", wireType) + } + m.NetworkID = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.NetworkID |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 3: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ack", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Capabilities", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -944,13 +1611,172 @@ func (m *SynAck) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - if m.Ack == nil { - m.Ack = &Ack{} + if m.Capabilities == nil { + m.Capabilities = &Capabilities{} } - if err := m.Ack.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + if err := m.Capabilities.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Nonce", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthHandshake + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthHandshake + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Nonce = append(m.Nonce[:0], dAtA[iNdEx:postIndex]...) + if m.Nonce == nil { + m.Nonce = []byte{} + } + iNdEx = postIndex + case 99: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field WelcomeMessage", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthHandshake + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthHandshake + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.WelcomeMessage = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipHandshake(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthHandshake + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthHandshake + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Capabilities) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Capabilities: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Capabilities: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field FullNode", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.FullNode = bool(v != 0) + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field TraceHeaders", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.TraceHeaders = bool(v != 0) default: iNdEx = preIndex skippy, err := skipHandshake(dAtA[iNdEx:]) diff --git a/pkg/p2p/libp2p/internal/handshake/pb/handshake.proto b/pkg/p2p/libp2p/internal/handshake/pb/handshake.proto index a3811a895b8..5fbaf4387e2 100644 --- a/pkg/p2p/libp2p/internal/handshake/pb/handshake.proto +++ b/pkg/p2p/libp2p/internal/handshake/pb/handshake.proto @@ -8,25 +8,42 @@ package handshake; option go_package = "pb"; -message Syn { +message Handshake { + oneof payload { + HandshakeSyn syn = 1; + HandshakeSynAck syn_ack = 2; + HandshakeAck ack = 3; + } +} + +message HandshakeSyn { + bytes ObservedUnderlay = 1; +} + +message HandshakeSynAck { bytes ObservedUnderlay = 1; + BzzAddress Address = 2; + uint64 NetworkID = 3; + Capabilities capabilities = 4; + bytes Nonce = 5; + string WelcomeMessage = 99; } -message Ack { +message HandshakeAck { BzzAddress Address = 1; uint64 NetworkID = 2; - bool FullNode = 3; + Capabilities capabilities = 3; bytes Nonce = 4; - string WelcomeMessage = 99; + string WelcomeMessage = 99; } -message SynAck { - Syn Syn = 1; - Ack Ack = 2; +message Capabilities { + bool full_node = 1; + bool trace_headers = 2; } message BzzAddress { bytes Underlay = 1; bytes Signature = 2; bytes Overlay = 3; -} +} \ No newline at end of file diff --git a/pkg/p2p/libp2p/libp2p.go b/pkg/p2p/libp2p/libp2p.go index 4d4c3773f92..4be731017a0 100644 --- a/pkg/p2p/libp2p/libp2p.go +++ b/pkg/p2p/libp2p/libp2p.go @@ -44,6 +44,7 @@ import ( basichost "github.com/libp2p/go-libp2p/p2p/host/basic" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" + lp2pswarm "github.com/libp2p/go-libp2p/p2p/net/swarm" libp2pping "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/transport/tcp" @@ -55,7 +56,6 @@ import ( ocprom "contrib.go.opencensus.io/exporter/prometheus" m2 "github.com/ethersphere/bee/v2/pkg/metrics" - rcmgrObs "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/prometheus/client_golang/prometheus" ) @@ -120,17 +120,18 @@ type lightnodes interface { } type Options struct { - PrivateKey *ecdsa.PrivateKey - NATAddr string - EnableWS bool - FullNode bool - LightNodeLimit int - WelcomeMessage string - Nonce []byte - ValidateOverlay bool - hostFactory func(...libp2p.Option) (host.Host, error) - HeadersRWTimeout time.Duration - Registry *prometheus.Registry + PrivateKey *ecdsa.PrivateKey + NATAddr string + EnableWS bool + FullNode bool + LightNodeLimit int + WelcomeMessage string + Nonce []byte + ValidateOverlay bool + hostFactory func(...libp2p.Option) (host.Host, error) + HeadersRWTimeout time.Duration + Registry *prometheus.Registry + EnableTraceHeaders bool } func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay swarm.Address, addr string, ab addressbook.Putter, storer storage.StateStorer, lightNodes *lightnode.Container, logger log.Logger, tracer *tracing.Tracer, o Options) (*Service, error) { @@ -175,7 +176,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay } if o.Registry != nil { - rcmgrObs.MustRegisterWith(o.Registry) + rcmgr.MustRegisterWith(o.Registry) } _, err = ocprom.NewExporter(ocprom.Options{ @@ -201,7 +202,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay // The resource manager expects a limiter, se we create one from our limits. limiter := rcmgr.NewFixedLimiter(limits) - str, err := rcmgrObs.NewStatsTraceReporter() + str, err := rcmgr.NewStatsTraceReporter() if err != nil { return nil, err } @@ -304,7 +305,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay advertisableAddresser = natAddrResolver } - handshakeService, err := handshake.New(signer, advertisableAddresser, overlay, networkID, o.FullNode, o.Nonce, o.WelcomeMessage, o.ValidateOverlay, h.ID(), logger) + handshakeService, err := handshake.New(signer, advertisableAddresser, overlay, networkID, o.FullNode, o.EnableTraceHeaders, o.Nonce, o.WelcomeMessage, o.ValidateOverlay, h.ID(), logger) if err != nil { return nil, fmt.Errorf("handshake service: %w", err) } @@ -437,7 +438,7 @@ func (s *Service) handleIncoming(stream network.Stream) { return } - if exists := s.peers.addIfNotExists(stream.Conn(), overlay, i.FullNode); exists { + if exists := s.peers.addIfNotExists(stream.Conn(), overlay, i.Capabilities); exists { s.logger.Debug("stream handler: peer already exists", "peer_address", overlay) if err = handshakeStream.FullClose(); err != nil { s.logger.Debug("stream handler: could not close stream", "peer_address", overlay, "error", err) @@ -454,7 +455,7 @@ func (s *Service) handleIncoming(stream network.Stream) { return } - if i.FullNode { + if i.Capabilities.FullNode { err = s.addressbook.Put(i.BzzAddress.Overlay, *i.BzzAddress) if err != nil { s.logger.Debug("stream handler: addressbook put error", "peer_id", peerID, "error", err) @@ -464,7 +465,7 @@ func (s *Service) handleIncoming(stream network.Stream) { } } - peer := p2p.Peer{Address: overlay, FullNode: i.FullNode, EthereumAddress: i.BzzAddress.EthereumAddress} + peer := p2p.Peer{Address: overlay, FullNode: i.Capabilities.FullNode, EthereumAddress: i.BzzAddress.EthereumAddress} s.protocolsmu.RLock() for _, tn := range s.protocols { @@ -480,10 +481,10 @@ func (s *Service) handleIncoming(stream network.Stream) { s.protocolsmu.RUnlock() if s.notifier != nil { - if !i.FullNode { + if !i.Capabilities.FullNode { s.lightNodes.Connected(s.ctx, peer) // light node announces explicitly - if err := s.notifier.Announce(s.ctx, peer.Address, i.FullNode); err != nil { + if err := s.notifier.Announce(s.ctx, peer.Address, i.Capabilities.FullNode); err != nil { s.logger.Debug("stream handler: notifier.Announce failed", "peer", peer.Address, "error", err) } @@ -524,7 +525,7 @@ func (s *Service) handleIncoming(stream network.Stream) { if err := s.notifier.AnnounceTo(s.ctx, addressee, peer, fullnode); err != nil { s.logger.Debug("stream handler: notifier.AnnounceTo failed", "addressee", addressee, "peer", peer, "error", err) } - }(addr, peer.Address, i.FullNode) + }(addr, peer.Address, i.Capabilities.FullNode) return false, false, nil }) } @@ -577,38 +578,52 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) { return } - stream := newStream(streamlibp2p, s.metrics) - - // exchange headers - headersStartTime := time.Now() - ctx, cancel := context.WithTimeout(s.ctx, s.HeadersRWTimeout) - defer cancel() - if err := handleHeaders(ctx, ss.Headler, stream, overlay); err != nil { - s.logger.Debug("handle protocol: handle headers failed", "protocol", p.Name, "version", p.Version, "stream", ss.Name, "peer", overlay, "error", err) - _ = stream.Reset() - return + // Check if BOTH local and remote peer support trace headers + // At this point, handshake has completed and capabilities are known + peerCaps := s.peers.capabilities(overlay) + localSupportsTrace := s.handshakeService.SupportsTraceHeaders() + remoteSupportsTrace := peerCaps != nil && peerCaps.TraceHeaders + bothSupportTrace := localSupportsTrace && remoteSupportsTrace + var stream *stream + var err error + + if bothSupportTrace { + // Both support trace headers - exchange them with timeout + headersStartTime := time.Now() + ctx, cancel := context.WithTimeout(s.ctx, s.HeadersRWTimeout) + defer cancel() + + stream, err = newStreamWithHeaders(streamlibp2p, s.metrics, ctx, ss.Headler, overlay, make(p2p.Headers)) + if err != nil { + s.logger.Debug("handle protocol: handle headers failed", "protocol", p.Name, "version", p.Version, "stream", ss.Name, "peer", overlay, "error", err) + _ = streamlibp2p.Reset() + return + } + s.metrics.HeadersExchangeDuration.Observe(time.Since(headersStartTime).Seconds()) + } else { + // NO header exchange - nothing on wire + stream = newStream(streamlibp2p, s.metrics) } - s.metrics.HeadersExchangeDuration.Observe(time.Since(headersStartTime).Seconds()) - ctx, cancel = context.WithCancel(s.ctx) + ctxStream, cancelStream := context.WithCancel(s.ctx) - s.peers.addStream(peerID, streamlibp2p, cancel) + s.peers.addStream(peerID, streamlibp2p, cancelStream) defer s.peers.removeStream(peerID, streamlibp2p) // tracing: get span tracing context and add it to the context // silently ignore if the peer is not providing tracing - ctx, err := s.tracer.WithContextFromHeaders(ctx, stream.Headers()) + ctxStream, err = s.tracer.WithContextFromHeaders(ctxStream, stream.Headers()) if err != nil && !errors.Is(err, tracing.ErrContextNotFound) { s.logger.Debug("handle protocol: get tracing context failed", "protocol", p.Name, "version", p.Version, "stream", ss.Name, "peer", overlay, "error", err) _ = stream.Reset() return } - logger := tracing.NewLoggerWithTraceID(ctx, s.logger) + logger := tracing.NewLoggerWithTraceID(ctxStream, s.logger) loggerV1 := logger.V(1).Build() s.metrics.HandledStreamCount.Inc() - if err := ss.Handler(ctx, p2p.Peer{Address: overlay, FullNode: full}, stream); err != nil { + if err := ss.Handler(ctxStream, p2p.Peer{Address: overlay, FullNode: full}, stream); err != nil { var de *p2p.DisconnectError if errors.As(err, &de) { loggerV1.Debug("libp2p handler: disconnecting due to disconnect error", "protocol", p.Name, "address", overlay) @@ -758,7 +773,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz. return nil, fmt.Errorf("handshake: %w", err) } - if !i.FullNode { + if !i.Capabilities.FullNode { _ = handshakeStream.Reset() _ = s.host.Network().ClosePeer(info.ID) return nil, p2p.ErrDialLightNode @@ -782,7 +797,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz. return nil, p2p.ErrPeerBlocklisted } - if exists := s.peers.addIfNotExists(stream.Conn(), overlay, i.FullNode); exists { + if exists := s.peers.addIfNotExists(stream.Conn(), overlay, i.Capabilities); exists { if err := handshakeStream.FullClose(); err != nil { _ = s.Disconnect(overlay, "failed closing handshake stream after connect") return nil, fmt.Errorf("peer exists, full close: %w", err) @@ -796,7 +811,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz. return nil, fmt.Errorf("connect full close %w", err) } - if i.FullNode { + if i.Capabilities.FullNode { err = s.addressbook.Put(overlay, *i.BzzAddress) if err != nil { _ = s.Disconnect(overlay, "failed storing peer in addressbook") @@ -807,7 +822,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz. s.protocolsmu.RLock() for _, tn := range s.protocols { if tn.ConnectOut != nil { - if err := tn.ConnectOut(ctx, p2p.Peer{Address: overlay, FullNode: i.FullNode, EthereumAddress: i.BzzAddress.EthereumAddress}); err != nil { + if err := tn.ConnectOut(ctx, p2p.Peer{Address: overlay, FullNode: i.Capabilities.FullNode, EthereumAddress: i.BzzAddress.EthereumAddress}); err != nil { s.logger.Debug("connectOut: failed to connect", "protocol", tn.Name, "version", tn.Version, "peer", overlay, "error", err) _ = s.Disconnect(overlay, "failed to process outbound connection notifier") s.protocolsmu.RUnlock() @@ -939,21 +954,27 @@ func (s *Service) NewStream(ctx context.Context, overlay swarm.Address, headers stream := newStream(streamlibp2p, s.metrics) - // tracing: add span context header - if headers == nil { - headers = make(p2p.Headers) - } - if err := s.tracer.AddContextHeader(ctx, headers); err != nil && !errors.Is(err, tracing.ErrContextNotFound) { - _ = stream.Reset() - return nil, fmt.Errorf("new stream add context header fail: %w", err) - } + peerCaps := s.peers.capabilities(overlay) + localSupportsTrace := s.handshakeService.SupportsTraceHeaders() + remoteSupportsTrace := peerCaps != nil && peerCaps.TraceHeaders + bothSupportTrace := localSupportsTrace && remoteSupportsTrace + // Only exchange headers if BOTH local and remote peer support trace headers + if bothSupportTrace { + // Both support trace headers - exchange them with timeout + if headers == nil { + headers = make(p2p.Headers) + } + if err := s.tracer.AddContextHeader(ctx, headers); err != nil && !errors.Is(err, tracing.ErrContextNotFound) { + _ = stream.Reset() + return nil, fmt.Errorf("new stream add context header fail: %w", err) + } - // exchange headers - ctx, cancel := context.WithTimeout(ctx, s.HeadersRWTimeout) - defer cancel() - if err := sendHeaders(ctx, headers, stream); err != nil { - _ = stream.Reset() - return nil, fmt.Errorf("send headers: %w", err) + ctxTimeout, cancel := context.WithTimeout(ctx, s.HeadersRWTimeout) + defer cancel() + if err := sendHeaders(ctxTimeout, headers, stream); err != nil { + _ = stream.Reset() + return nil, fmt.Errorf("send headers: %w", err) + } } return stream, nil diff --git a/pkg/p2p/libp2p/peer.go b/pkg/p2p/libp2p/peer.go index 5aa0f75713a..72ff668082e 100644 --- a/pkg/p2p/libp2p/peer.go +++ b/pkg/p2p/libp2p/peer.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/ethersphere/bee/v2/pkg/p2p" + "github.com/ethersphere/bee/v2/pkg/p2p/libp2p/internal/handshake/pb" "github.com/ethersphere/bee/v2/pkg/swarm" "github.com/libp2p/go-libp2p/core/network" libp2ppeer "github.com/libp2p/go-libp2p/core/peer" @@ -18,12 +19,12 @@ import ( ) type peerRegistry struct { - underlays map[string]libp2ppeer.ID // map overlay address to underlay peer id - overlays map[libp2ppeer.ID]swarm.Address // map underlay peer id to overlay address - full map[libp2ppeer.ID]bool // map to track whether a node is full or light node (true=full) - connections map[libp2ppeer.ID]map[network.Conn]struct{} // list of connections for safe removal on Disconnect notification - streams map[libp2ppeer.ID]map[network.Stream]context.CancelFunc - mu sync.RWMutex + underlays map[string]libp2ppeer.ID // map overlay address to underlay peer id + overlays map[libp2ppeer.ID]swarm.Address // map underlay peer id to overlay address + peerCapabilities map[libp2ppeer.ID]*pb.Capabilities // map to track peer capabilities + connections map[libp2ppeer.ID]map[network.Conn]struct{} // list of connections for safe removal on Disconnect notification + streams map[libp2ppeer.ID]map[network.Stream]context.CancelFunc + mu sync.RWMutex //nolint:misspell disconnecter disconnecter // peerRegistry notifies libp2p on peer disconnection @@ -36,11 +37,11 @@ type disconnecter interface { func newPeerRegistry() *peerRegistry { return &peerRegistry{ - underlays: make(map[string]libp2ppeer.ID), - overlays: make(map[libp2ppeer.ID]swarm.Address), - full: make(map[libp2ppeer.ID]bool), - connections: make(map[libp2ppeer.ID]map[network.Conn]struct{}), - streams: make(map[libp2ppeer.ID]map[network.Stream]context.CancelFunc), + underlays: make(map[string]libp2ppeer.ID), + overlays: make(map[libp2ppeer.ID]swarm.Address), + peerCapabilities: make(map[libp2ppeer.ID]*pb.Capabilities), + connections: make(map[libp2ppeer.ID]map[network.Conn]struct{}), + streams: make(map[libp2ppeer.ID]map[network.Stream]context.CancelFunc), Notifiee: new(network.NoopNotifiee), } @@ -80,7 +81,7 @@ func (r *peerRegistry) Disconnected(_ network.Network, c network.Conn) { cancel() } delete(r.streams, peerID) - delete(r.full, peerID) + delete(r.peerCapabilities, peerID) r.mu.Unlock() r.disconnecter.disconnected(overlay) @@ -119,9 +120,14 @@ func (r *peerRegistry) peers() []p2p.Peer { r.mu.RLock() peers := make([]p2p.Peer, 0, len(r.overlays)) for p, a := range r.overlays { + caps := r.peerCapabilities[p] + fullNode := false + if caps != nil { + fullNode = caps.FullNode + } peers = append(peers, p2p.Peer{ Address: a, - FullNode: r.full[p], + FullNode: fullNode, }) } r.mu.RUnlock() @@ -131,7 +137,7 @@ func (r *peerRegistry) peers() []p2p.Peer { return peers } -func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address, full bool) (exists bool) { +func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address, capabilities *pb.Capabilities) (exists bool) { peerID := c.RemotePeer() r.mu.Lock() defer r.mu.Unlock() @@ -150,7 +156,7 @@ func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address, ful r.streams[peerID] = make(map[network.Stream]context.CancelFunc) r.underlays[overlay.ByteString()] = peerID r.overlays[peerID] = overlay - r.full[peerID] = full + r.peerCapabilities[peerID] = capabilities return false } @@ -171,9 +177,12 @@ func (r *peerRegistry) overlay(peerID libp2ppeer.ID) (swarm.Address, bool) { func (r *peerRegistry) fullnode(peerID libp2ppeer.ID) (bool, bool) { r.mu.RLock() - full, found := r.full[peerID] + caps, found := r.peerCapabilities[peerID] r.mu.RUnlock() - return full, found + if !found || caps == nil { + return false, found + } + return caps.FullNode, found } func (r *peerRegistry) isConnected(peerID libp2ppeer.ID, remoteAddr ma.Multiaddr) (swarm.Address, bool) { @@ -215,13 +224,28 @@ func (r *peerRegistry) remove(overlay swarm.Address) (found, full bool, peerID l cancel() } delete(r.streams, peerID) - full = r.full[peerID] - delete(r.full, peerID) + caps := r.peerCapabilities[peerID] + if caps != nil { + full = caps.FullNode + } + delete(r.peerCapabilities, peerID) r.mu.Unlock() return found, full, peerID } +func (r *peerRegistry) capabilities(overlay swarm.Address) *pb.Capabilities { + peerID, exists := r.peerID(overlay) + if !exists { + return nil + } + + r.mu.RLock() + caps := r.peerCapabilities[peerID] + r.mu.RUnlock() + return caps +} + func (r *peerRegistry) setDisconnecter(d disconnecter) { r.disconnecter = d } diff --git a/pkg/p2p/libp2p/protocols_test.go b/pkg/p2p/libp2p/protocols_test.go index 4563083726a..4a61eb0cc4a 100644 --- a/pkg/p2p/libp2p/protocols_test.go +++ b/pkg/p2p/libp2p/protocols_test.go @@ -317,7 +317,14 @@ func TestDisconnectError(t *testing.T) { // error is not checked as opening a new stream should cause disconnect from s1 which is async and can make errors in newStream function // it is important to validate that disconnect will happen after NewStream() - _, _ = s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName) + stream, _ := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName) + + // Write data to trigger the protocol handler + if stream != nil { + _, _ = stream.Write([]byte("test")) + stream.Close() + } + expectPeersEventually(t, s1) } diff --git a/pkg/p2p/libp2p/stream.go b/pkg/p2p/libp2p/stream.go index 93e7ce6181a..b3465bfe7b1 100644 --- a/pkg/p2p/libp2p/stream.go +++ b/pkg/p2p/libp2p/stream.go @@ -5,11 +5,13 @@ package libp2p import ( + "context" "errors" "io" "time" "github.com/ethersphere/bee/v2/pkg/p2p" + "github.com/ethersphere/bee/v2/pkg/swarm" "github.com/libp2p/go-libp2p/core/network" ) @@ -21,14 +23,29 @@ var _ p2p.Stream = (*stream)(nil) type stream struct { network.Stream - headers map[string][]byte - responseHeaders map[string][]byte + headers p2p.Headers + responseHeaders p2p.Headers metrics metrics } func newStream(s network.Stream, metrics metrics) *stream { return &stream{Stream: s, metrics: metrics} } + +func newStreamWithHeaders(s network.Stream, metrics metrics, ctx context.Context, headler p2p.HeadlerFunc, peerAddress swarm.Address, headers p2p.Headers) (*stream, error) { + stream := &stream{ + Stream: s, + metrics: metrics, + headers: make(p2p.Headers), + responseHeaders: make(p2p.Headers), + } + + if err := handleHeaders(ctx, headler, stream, peerAddress); err != nil { + return nil, err + } + + return stream, nil +} func (s *stream) Headers() p2p.Headers { return s.headers } diff --git a/pkg/p2p/libp2p/tracing_test.go b/pkg/p2p/libp2p/tracing_test.go index 7e02f8e2d7c..481dd9143a9 100644 --- a/pkg/p2p/libp2p/tracing_test.go +++ b/pkg/p2p/libp2p/tracing_test.go @@ -37,10 +37,13 @@ func TestTracing(t *testing.T) { defer closer2.Close() s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ - FullNode: true, + FullNode: true, + EnableTraceHeaders: true, }}) - s2, _ := newService(t, 1, libp2pServiceOpts{}) + s2, _ := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{ + EnableTraceHeaders: true, + }}) var handledTracingSpan string handled := make(chan struct{}) diff --git a/pkg/shed/vector_uint64_test.go b/pkg/shed/vector_uint64_test.go index 2d602a034c0..4eeaf1cdaf3 100644 --- a/pkg/shed/vector_uint64_test.go +++ b/pkg/shed/vector_uint64_test.go @@ -56,7 +56,7 @@ func TestUint64Vector(t *testing.T) { } for _, index := range []uint64{0, 1, 2, 5, 100} { - var want uint64 = 42 + index + var want = 42 + index err = bins.Put(index, want) if err != nil { t.Fatal(err) @@ -70,7 +70,7 @@ func TestUint64Vector(t *testing.T) { } t.Run("overwrite", func(t *testing.T) { - var want uint64 = 84 + index + var want = 84 + index err = bins.Put(index, want) if err != nil { t.Fatal(err) @@ -97,7 +97,7 @@ func TestUint64Vector(t *testing.T) { for _, index := range []uint64{0, 1, 2, 3, 5, 10} { batch := new(leveldb.Batch) - var want uint64 = 43 + index + var want = 43 + index bins.PutInBatch(batch, index, want) err = db.WriteBatch(batch) if err != nil { @@ -113,7 +113,7 @@ func TestUint64Vector(t *testing.T) { t.Run("overwrite", func(t *testing.T) { batch := new(leveldb.Batch) - var want uint64 = 85 + index + var want = 85 + index bins.PutInBatch(batch, index, want) err = db.WriteBatch(batch) if err != nil { diff --git a/pkg/transaction/transaction.go b/pkg/transaction/transaction.go index fc7a4904510..4b85351f038 100644 --- a/pkg/transaction/transaction.go +++ b/pkg/transaction/transaction.go @@ -371,7 +371,7 @@ func (t *transactionService) nextNonce(ctx context.Context) (uint64, error) { // PendingNonceAt returns the nonce we should use, but we will // compare this to our pending tx list, therefore the -1. - var maxNonce uint64 = onchainNonce - 1 + var maxNonce = onchainNonce - 1 for _, txHash := range pendingTxs { trx, _, err := t.backend.TransactionByHash(ctx, txHash) @@ -419,7 +419,7 @@ func (t *transactionService) WatchSentTransaction(txHash common.Hash) (<-chan ty } func (t *transactionService) PendingTransactions() ([]common.Hash, error) { - var txHashes []common.Hash = make([]common.Hash, 0) + var txHashes = make([]common.Hash, 0) err := t.store.Iterate(pendingTransactionPrefix, func(key, value []byte) (stop bool, err error) { txHash := common.HexToHash(strings.TrimPrefix(string(key), pendingTransactionPrefix)) txHashes = append(txHashes, txHash)