Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions forwardproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,9 +631,7 @@ func serveHijack(w http.ResponseWriter, targetConn net.Conn) error {
return dualStream(targetConn, clientConn, clientConn)
}

// Copies data target->clientReader and clientWriter->target, and flushes as needed
// Returns when clientWriter-> target stream is done.
// Caddy should finish writing target -> clientReader.
// Copies data clientReader->target and target->clientWriter, and flushes as needed.
func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Writer) error {
stream := func(w io.Writer, r io.Reader) error {
// copy bytes from r to w
Expand All @@ -648,14 +646,35 @@ func dualStream(target net.Conn, clientReader io.ReadCloser, clientWriter io.Wri
}
return _err
}
go stream(target, clientReader) //nolint: errcheck
return stream(clientWriter, target)

var closeOnce sync.Once
closeTarget := func() {
closeOnce.Do(func() {
_ = target.Close()
})
}
limitTargetRead := func(err error) {
if err != nil || tunnelHalfCloseTimeout <= 0 {
closeTarget()
return
}
_ = target.SetReadDeadline(time.Now().Add(tunnelHalfCloseTimeout))
}

go func() {
limitTargetRead(stream(target, clientReader))
}()
err := stream(clientWriter, target)
closeTarget()
return err
}

type closeWriter interface {
CloseWrite() error
}

var tunnelHalfCloseTimeout = 30 * time.Second

// flushingIoCopy is analogous to buffering io.Copy(), but also attempts to flush on each iteration.
// If dst does not implement http.Flusher(e.g. net.TCPConn), it will do a simple io.CopyBuffer().
// Reasoning: http2ResponseWriter will not flush on its own, so we have to do it manually.
Expand Down
50 changes: 50 additions & 0 deletions forwardproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ import (
"golang.org/x/net/http2"
)

type errorReadCloser struct {
err error
}

func (r errorReadCloser) Read([]byte) (int, error) {
return 0, r.err
}

func (r errorReadCloser) Close() error {
return nil
}

func dial(proxyAddr, httpProxyVer string, useTLS bool) (net.Conn, error) {
// always dial localhost for testing purposes
if useTLS {
Expand Down Expand Up @@ -369,6 +381,44 @@ func TestCONNECTViaUpstream(t *testing.T) {
}
}

func TestDualStreamClosesTargetWhenClientReaderErrors(t *testing.T) {
target, peer := net.Pipe()
defer peer.Close()

done := make(chan error, 1)
go func() {
done <- dualStream(target, errorReadCloser{err: io.ErrUnexpectedEOF}, io.Discard)
}()

select {
case <-done:
case <-time.After(time.Second):
t.Fatal("dualStream did not return after client reader error")
}
}

func TestDualStreamLimitsTargetReadAfterClientReaderEOF(t *testing.T) {
origHalfCloseTimeout := tunnelHalfCloseTimeout
tunnelHalfCloseTimeout = 10 * time.Millisecond
defer func() {
tunnelHalfCloseTimeout = origHalfCloseTimeout
}()

target, peer := net.Pipe()
defer peer.Close()

done := make(chan error, 1)
go func() {
done <- dualStream(target, errorReadCloser{err: io.EOF}, io.Discard)
}()

select {
case <-done:
case <-time.After(time.Second):
t.Fatal("dualStream did not return after target read deadline")
}
}

func TestGETViaUpstream(t *testing.T) {
const useTLS = true
for range make([]byte, 5) { // do several times to test http2 connection reuse
Expand Down