diff --git a/internal/client/peer.go b/internal/client/peer.go index 9ecfb58..008710c 100644 --- a/internal/client/peer.go +++ b/internal/client/peer.go @@ -16,10 +16,30 @@ const ( ExtMetadata = "ut_metadata" ) +// BEP 3 wire constants. The handshake is *exactly* 68 bytes: 1 length byte + +// 19 magic bytes + 8 reserved + 20 info_hash + 20 peer_id. Anything else is +// off-spec and rejected. +const ( + protoMagic = "BitTorrent protocol" + pstrLen = 19 // len(protoMagic) + handshakeLen = 1 + pstrLen + 8 + 20 + 20 +) + +// peerState is the LangSec-style typed state of a PeerConn. Methods that +// operate on a connection require a minimum state and refuse to run otherwise. +type peerState int + +const ( + stateInit peerState = iota // freshly TCP-connected, no handshake yet + stateHandshook // BEP 3 handshake complete + stateExtended // BEP 10 extended handshake complete +) + // PeerConn wraps a TCP connection to a peer and tracks PWP state. type PeerConn struct { - conn net.Conn - addr string + conn net.Conn + addr string + state peerState AmChoking bool AmInterested bool @@ -50,9 +70,17 @@ func Connect(ctx context.Context, addr string) (*PeerConn, error) { // Handshake performs the standard BEP 3 handshake. // Sends: // +// LangSec recognition: the peer's reply must be exactly 68 bytes, +// pstrlen must be exactly 19, and pstr must be exactly "BitTorrent protocol". +// Anything else is off-spec and the connection is dropped before any state +// (PeerExtensions, MetadataSize, choke/interest flags) is touched. +// // NOTE: Even for hybrid v1+v2 torrents, the BitTorrent wire protocol // handshake ALWAYS uses the 20-byte SHA-1 info hash (v1) per BEP 3. func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string) error { + if p.state != stateInit { + return fmt.Errorf("handshake called in state %d (expected init)", p.state) + } deadline := time.Now().Add(10 * time.Second) if d, ok := ctx.Deadline(); ok && d.Before(deadline) { deadline = d @@ -67,11 +95,10 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string return fmt.Errorf("peer id must be 20 bytes, got %d", len(peerID)) } - pstr := "BitTorrent protocol" - buf := make([]byte, 1+len(pstr)+8+20+20) - buf[0] = byte(len(pstr)) + buf := make([]byte, handshakeLen) + buf[0] = byte(pstrLen) curr := 1 - curr += copy(buf[curr:], pstr) + curr += copy(buf[curr:], protoMagic) // Reserved bytes (8 bytes) // We set bit 43 (byte 5, bit 0x10) to signal BEP 10 Extension Protocol support. @@ -86,28 +113,30 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string return fmt.Errorf("write handshake: %w", err) } - // Read peer's handshake - resBuf := make([]byte, 1) + // Read the peer's full 68-byte handshake. LangSec: read it whole, then + // validate the entire structure before letting anything downstream touch it. + resBuf := make([]byte, handshakeLen) if _, err := io.ReadFull(p.conn, resBuf); err != nil { - return fmt.Errorf("read pstrlen: %w", err) - } - pstrlen := int(resBuf[0]) - if pstrlen == 0 { - return fmt.Errorf("invalid pstrlen 0") + return fmt.Errorf("read handshake: %w", err) } - resBuf = make([]byte, pstrlen+8+20+20) - if _, err := io.ReadFull(p.conn, resBuf); err != nil { - return fmt.Errorf("read handshake payload: %w", err) + if int(resBuf[0]) != pstrLen { + return fmt.Errorf("invalid pstrlen %d (must be %d per BEP 3)", resBuf[0], pstrLen) + } + if !bytes.Equal(resBuf[1:1+pstrLen], []byte(protoMagic)) { + return fmt.Errorf("invalid protocol magic: got %q", resBuf[1:1+pstrLen]) } - resInfoHash := resBuf[pstrlen+8 : pstrlen+8+20] + resInfoHash := resBuf[1+pstrLen+8 : 1+pstrLen+8+20] if !bytes.Equal(resInfoHash, infoHash) { return fmt.Errorf("info hash mismatch: expected %x, got %x", infoHash, resInfoHash) } + // All recognition passed — promote state. + p.state = stateHandshook + // Check if peer supports BEP 10 - peerReserved := resBuf[pstrlen : pstrlen+8] + peerReserved := resBuf[1+pstrLen : 1+pstrLen+8] supportsBEP10 := (peerReserved[5] & 0x10) != 0 if supportsBEP10 { @@ -120,6 +149,7 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string if err := p.readExtendedHandshake(); err != nil { return fmt.Errorf("read extended handshake: %w", err) } + p.state = stateExtended } return nil @@ -158,7 +188,11 @@ func (p *PeerConn) readExtendedHandshake() error { } // RequestMetadata sends a BEP 9 metadata request for the given piece. +// Requires that the BEP 10 extended handshake has completed. func (p *PeerConn) RequestMetadata(piece int) error { + if p.state < stateExtended { + return fmt.Errorf("RequestMetadata called in state %d (expected extended)", p.state) + } extID, ok := p.PeerExtensions[ExtMetadata] if !ok { return fmt.Errorf("peer does not support %s", ExtMetadata) @@ -202,15 +236,24 @@ func (p *PeerConn) sendExtendedHandshake() error { }) } -// ReadMessage reads the next message from the peer. +// ReadMessage reads the next message from the peer. Requires the BEP 3 +// handshake to have completed — until then, raw bytes on the wire don't +// frame as PWP messages. func (p *PeerConn) ReadMessage() (*Message, error) { + if p.state < stateHandshook { + return nil, fmt.Errorf("ReadMessage called in state %d (expected handshook)", p.state) + } // Set a reasonable read timeout to avoid hanging forever p.conn.SetReadDeadline(time.Now().Add(2 * time.Minute)) return ReadMessage(p.conn) } -// WriteMessage writes a message to the peer. +// WriteMessage writes a message to the peer. Requires the BEP 3 handshake +// to have completed. func (p *PeerConn) WriteMessage(m *Message) error { + if p.state < stateHandshook { + return fmt.Errorf("WriteMessage called in state %d (expected handshook)", p.state) + } p.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) return WriteMessage(p.conn, m) } diff --git a/internal/client/peer_test.go b/internal/client/peer_test.go index 899ac59..13eabc2 100644 --- a/internal/client/peer_test.go +++ b/internal/client/peer_test.go @@ -131,3 +131,108 @@ func TestHandshakeMismatch(t *testing.T) { t.Errorf("unexpected error: %v", err) } } + +// fakePeer accepts one connection, reads the client's handshake, and replies +// with a custom 68-byte handshake reply. Returns the listener's addr. +func fakePeer(t *testing.T, reply []byte) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { ln.Close() }) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + // Drain the client's 68-byte handshake. + io.ReadFull(conn, make([]byte, 68)) + conn.Write(reply) + }() + return ln.Addr().String() +} + +func TestHandshakeRejectsBadPstrlen(t *testing.T) { + infoHash := make([]byte, 20) + copy(infoHash, "infohash123456789012") + + // Reply claims pstrlen=18 instead of 19 — handshake reads exactly 68 + // bytes (1+19+8+20+20), so we need to construct a 68-byte buffer with + // a bogus first byte. The parser should reject on the length check. + reply := make([]byte, 68) + reply[0] = 18 // wrong + copy(reply[1:], "BitTorrent protocol") // would be valid magic if length matched + copy(reply[1+19+8:], infoHash) + + addr := fakePeer(t, reply) + p, err := Connect(context.Background(), addr) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer p.Close() + + err = p.Handshake(context.Background(), infoHash, "-WL0001-123456789012") + if err == nil { + t.Fatal("expected pstrlen rejection") + } + if !bytes.Contains([]byte(err.Error()), []byte("pstrlen")) { + t.Errorf("expected pstrlen error, got %v", err) + } +} + +func TestHandshakeRejectsBadMagic(t *testing.T) { + infoHash := make([]byte, 20) + copy(infoHash, "infohash123456789012") + + // pstrlen=19 (valid) but magic is wrong. + reply := make([]byte, 68) + reply[0] = 19 + copy(reply[1:], "WrongTorrent magic!") // 19 bytes, wrong content + copy(reply[1+19+8:], infoHash) + + addr := fakePeer(t, reply) + p, err := Connect(context.Background(), addr) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer p.Close() + + err = p.Handshake(context.Background(), infoHash, "-WL0001-123456789012") + if err == nil { + t.Fatal("expected magic-string rejection") + } + if !bytes.Contains([]byte(err.Error()), []byte("protocol magic")) { + t.Errorf("expected protocol magic error, got %v", err) + } +} + +func TestReadMessageBeforeHandshakeRejected(t *testing.T) { + // Set up a connection but never handshake. + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + p := &PeerConn{conn: client, state: stateInit} + + if _, err := p.ReadMessage(); err == nil { + t.Error("ReadMessage in stateInit should fail") + } + if err := p.WriteMessage(&Message{ID: 0}); err == nil { + t.Error("WriteMessage in stateInit should fail") + } +} + +func TestRequestMetadataBeforeExtendedRejected(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + // Handshook but not Extended. + p := &PeerConn{conn: client, state: stateHandshook} + + if err := p.RequestMetadata(0); err == nil { + t.Error("RequestMetadata before extended handshake should fail") + } +}