Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions internal/client/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,30 @@ const (
ExtMetadata = "ut_metadata"
)

// BEP 3 wire constants. The handshake is *exactly* 68 bytes: 1 length byte +
// 19 magic bytes + 8 reserved + 20 info_hash + 20 peer_id. Anything else is
// off-spec and rejected.
const (
protoMagic = "BitTorrent protocol"
pstrLen = 19 // len(protoMagic)
handshakeLen = 1 + pstrLen + 8 + 20 + 20
)

// peerState is the LangSec-style typed state of a PeerConn. Methods that
// operate on a connection require a minimum state and refuse to run otherwise.
type peerState int

const (
stateInit peerState = iota // freshly TCP-connected, no handshake yet
stateHandshook // BEP 3 handshake complete
stateExtended // BEP 10 extended handshake complete
)

// PeerConn wraps a TCP connection to a peer and tracks PWP state.
type PeerConn struct {
conn net.Conn
addr string
conn net.Conn
addr string
state peerState

AmChoking bool
AmInterested bool
Expand Down Expand Up @@ -50,9 +70,17 @@ func Connect(ctx context.Context, addr string) (*PeerConn, error) {
// Handshake performs the standard BEP 3 handshake.
// Sends: <pstrlen><pstr><reserved><info_hash><peer_id>
//
// LangSec recognition: the peer's reply must be exactly 68 bytes,
// pstrlen must be exactly 19, and pstr must be exactly "BitTorrent protocol".
// Anything else is off-spec and the connection is dropped before any state
// (PeerExtensions, MetadataSize, choke/interest flags) is touched.
//
// NOTE: Even for hybrid v1+v2 torrents, the BitTorrent wire protocol
// handshake ALWAYS uses the 20-byte SHA-1 info hash (v1) per BEP 3.
func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string) error {
if p.state != stateInit {
return fmt.Errorf("handshake called in state %d (expected init)", p.state)
}
deadline := time.Now().Add(10 * time.Second)
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
deadline = d
Expand All @@ -67,11 +95,10 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string
return fmt.Errorf("peer id must be 20 bytes, got %d", len(peerID))
}

pstr := "BitTorrent protocol"
buf := make([]byte, 1+len(pstr)+8+20+20)
buf[0] = byte(len(pstr))
buf := make([]byte, handshakeLen)
buf[0] = byte(pstrLen)
curr := 1
curr += copy(buf[curr:], pstr)
curr += copy(buf[curr:], protoMagic)

// Reserved bytes (8 bytes)
// We set bit 43 (byte 5, bit 0x10) to signal BEP 10 Extension Protocol support.
Expand All @@ -86,28 +113,30 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string
return fmt.Errorf("write handshake: %w", err)
}

// Read peer's handshake
resBuf := make([]byte, 1)
// Read the peer's full 68-byte handshake. LangSec: read it whole, then
// validate the entire structure before letting anything downstream touch it.
resBuf := make([]byte, handshakeLen)
if _, err := io.ReadFull(p.conn, resBuf); err != nil {
return fmt.Errorf("read pstrlen: %w", err)
}
pstrlen := int(resBuf[0])
if pstrlen == 0 {
return fmt.Errorf("invalid pstrlen 0")
return fmt.Errorf("read handshake: %w", err)
}

resBuf = make([]byte, pstrlen+8+20+20)
if _, err := io.ReadFull(p.conn, resBuf); err != nil {
return fmt.Errorf("read handshake payload: %w", err)
if int(resBuf[0]) != pstrLen {
return fmt.Errorf("invalid pstrlen %d (must be %d per BEP 3)", resBuf[0], pstrLen)
}
if !bytes.Equal(resBuf[1:1+pstrLen], []byte(protoMagic)) {
return fmt.Errorf("invalid protocol magic: got %q", resBuf[1:1+pstrLen])
}

resInfoHash := resBuf[pstrlen+8 : pstrlen+8+20]
resInfoHash := resBuf[1+pstrLen+8 : 1+pstrLen+8+20]
if !bytes.Equal(resInfoHash, infoHash) {
return fmt.Errorf("info hash mismatch: expected %x, got %x", infoHash, resInfoHash)
}

// All recognition passed — promote state.
p.state = stateHandshook

// Check if peer supports BEP 10
peerReserved := resBuf[pstrlen : pstrlen+8]
peerReserved := resBuf[1+pstrLen : 1+pstrLen+8]
supportsBEP10 := (peerReserved[5] & 0x10) != 0

if supportsBEP10 {
Expand All @@ -120,6 +149,7 @@ func (p *PeerConn) Handshake(ctx context.Context, infoHash []byte, peerID string
if err := p.readExtendedHandshake(); err != nil {
return fmt.Errorf("read extended handshake: %w", err)
}
p.state = stateExtended
}

return nil
Expand Down Expand Up @@ -158,7 +188,11 @@ func (p *PeerConn) readExtendedHandshake() error {
}

// RequestMetadata sends a BEP 9 metadata request for the given piece.
// Requires that the BEP 10 extended handshake has completed.
func (p *PeerConn) RequestMetadata(piece int) error {
if p.state < stateExtended {
return fmt.Errorf("RequestMetadata called in state %d (expected extended)", p.state)
}
extID, ok := p.PeerExtensions[ExtMetadata]
if !ok {
return fmt.Errorf("peer does not support %s", ExtMetadata)
Expand Down Expand Up @@ -202,15 +236,24 @@ func (p *PeerConn) sendExtendedHandshake() error {
})
}

// ReadMessage reads the next message from the peer.
// ReadMessage reads the next message from the peer. Requires the BEP 3
// handshake to have completed — until then, raw bytes on the wire don't
// frame as PWP messages.
func (p *PeerConn) ReadMessage() (*Message, error) {
if p.state < stateHandshook {
return nil, fmt.Errorf("ReadMessage called in state %d (expected handshook)", p.state)
}
// Set a reasonable read timeout to avoid hanging forever
p.conn.SetReadDeadline(time.Now().Add(2 * time.Minute))
return ReadMessage(p.conn)
}

// WriteMessage writes a message to the peer.
// WriteMessage writes a message to the peer. Requires the BEP 3 handshake
// to have completed.
func (p *PeerConn) WriteMessage(m *Message) error {
if p.state < stateHandshook {
return fmt.Errorf("WriteMessage called in state %d (expected handshook)", p.state)
}
p.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return WriteMessage(p.conn, m)
}
Expand Down
105 changes: 105 additions & 0 deletions internal/client/peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,108 @@ func TestHandshakeMismatch(t *testing.T) {
t.Errorf("unexpected error: %v", err)
}
}

// fakePeer accepts one connection, reads the client's handshake, and replies
// with a custom 68-byte handshake reply. Returns the listener's addr.
func fakePeer(t *testing.T, reply []byte) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() { ln.Close() })
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
// Drain the client's 68-byte handshake.
io.ReadFull(conn, make([]byte, 68))
conn.Write(reply)
}()
return ln.Addr().String()
}

func TestHandshakeRejectsBadPstrlen(t *testing.T) {
infoHash := make([]byte, 20)
copy(infoHash, "infohash123456789012")

// Reply claims pstrlen=18 instead of 19 — handshake reads exactly 68
// bytes (1+19+8+20+20), so we need to construct a 68-byte buffer with
// a bogus first byte. The parser should reject on the length check.
reply := make([]byte, 68)
reply[0] = 18 // wrong
copy(reply[1:], "BitTorrent protocol") // would be valid magic if length matched
copy(reply[1+19+8:], infoHash)

addr := fakePeer(t, reply)
p, err := Connect(context.Background(), addr)
if err != nil {
t.Fatalf("connect: %v", err)
}
defer p.Close()

err = p.Handshake(context.Background(), infoHash, "-WL0001-123456789012")
if err == nil {
t.Fatal("expected pstrlen rejection")
}
if !bytes.Contains([]byte(err.Error()), []byte("pstrlen")) {
t.Errorf("expected pstrlen error, got %v", err)
}
}

func TestHandshakeRejectsBadMagic(t *testing.T) {
infoHash := make([]byte, 20)
copy(infoHash, "infohash123456789012")

// pstrlen=19 (valid) but magic is wrong.
reply := make([]byte, 68)
reply[0] = 19
copy(reply[1:], "WrongTorrent magic!") // 19 bytes, wrong content
copy(reply[1+19+8:], infoHash)

addr := fakePeer(t, reply)
p, err := Connect(context.Background(), addr)
if err != nil {
t.Fatalf("connect: %v", err)
}
defer p.Close()

err = p.Handshake(context.Background(), infoHash, "-WL0001-123456789012")
if err == nil {
t.Fatal("expected magic-string rejection")
}
if !bytes.Contains([]byte(err.Error()), []byte("protocol magic")) {
t.Errorf("expected protocol magic error, got %v", err)
}
}

func TestReadMessageBeforeHandshakeRejected(t *testing.T) {
// Set up a connection but never handshake.
server, client := net.Pipe()
defer server.Close()
defer client.Close()

p := &PeerConn{conn: client, state: stateInit}

if _, err := p.ReadMessage(); err == nil {
t.Error("ReadMessage in stateInit should fail")
}
if err := p.WriteMessage(&Message{ID: 0}); err == nil {
t.Error("WriteMessage in stateInit should fail")
}
}

func TestRequestMetadataBeforeExtendedRejected(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()

// Handshook but not Extended.
p := &PeerConn{conn: client, state: stateHandshook}

if err := p.RequestMetadata(0); err == nil {
t.Error("RequestMetadata before extended handshake should fail")
}
}
Loading