Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .plzconfig
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
version = >=17.0.0

[build]
path = /usr/local/go/bin:/usr/local/bin:/usr/bin:/bin
path = /usr/local/go/bin:/usr/local/bin:/usr/bin:/bin:/sbin:/usr/sbin

[buildconfig]
local-host = 127.0.0.1
Expand Down
9 changes: 9 additions & 0 deletions elan/rpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ go_test(
],
)

go_test(
name = "inflight_test",
srcs = ["inflight_test.go"],
deps = [
":rpc",
"///third_party/go/github.com_stretchr_testify//require",
],
)

genrule(
name = "test_data",
cmd = [
Expand Down
55 changes: 55 additions & 0 deletions elan/rpc/inflight.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package rpc

import (
"context"
"sync"
"time"
)

// defaultWriteTimeout is the maximum time a reader will wait for an in-progress
// write to complete before proceeding anyway.
const defaultWriteTimeout = 10 * time.Minute

// inflightWrites tracks blob writes that are currently in progress so that
// concurrent readers can block until the write completes rather than getting
// a spurious NotFound.
type inflightWrites struct {
mu struct {
sync.Mutex
blobs map[string]context.Context
}
}

func newInflightWrites() *inflightWrites {
w := &inflightWrites{}
w.mu.blobs = make(map[string]context.Context)
return w
}

// startWrite registers a blob write in progress. The returned cancel function
// must be called when the write completes (success or failure) — use defer.
// The write context inherits from the caller so parent cancellation propagates.
// A default timeout of defaultWriteTimeout is applied as a safety net.
func (w *inflightWrites) startWrite(ctx context.Context, hash string) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(ctx, defaultWriteTimeout)
w.mu.Lock()
w.mu.blobs[hash] = ctx
w.mu.Unlock()
return ctx, func() {
w.mu.Lock()
delete(w.mu.blobs, hash)
w.mu.Unlock()
cancel()
}
}

// waitForWrite blocks until any in-progress write for the given hash completes.
// If no write is in progress, it returns immediately.
func (w *inflightWrites) waitForWrite(hash string) {
w.mu.Lock()
ctx, ok := w.mu.blobs[hash]
w.mu.Unlock()
if ok {
<-ctx.Done()
}
}
129 changes: 129 additions & 0 deletions elan/rpc/inflight_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package rpc

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestWaitForWrite_NoInflight(t *testing.T) {
w := newInflightWrites()
// Should return immediately when nothing is in flight.
doneCtx, doneFunc := context.WithCancel(t.Context())
go func() {
w.waitForWrite("abc123")
doneFunc()
}()
select {
case <-doneCtx.Done():
case <-time.After(120 * time.Second):
t.Fatal("waitForWrite blocked when no write was in progress")
}
}

func TestWaitForWrite_BlocksUntilDone(t *testing.T) {
w := newInflightWrites()
_, finish := w.startWrite(t.Context(), "abc123")

var order []string
var mu sync.Mutex
record := func(s string) {
mu.Lock()
order = append(order, s)
mu.Unlock()
}

doneCtx, doneFunc := context.WithCancel(t.Context())
go func() {
w.waitForWrite("abc123")
record("read")
doneFunc()
}()

// Give the reader goroutine time to block.
time.Sleep(50 * time.Millisecond)
record("write")
finish()

select {
case <-doneCtx.Done():
case <-t.Context().Done():
t.Fatal("reader never unblocked")
}

mu.Lock()
defer mu.Unlock()
require.Equal(t, []string{"write", "read"}, order)
}

func TestWaitForWrite_DifferentDigests(t *testing.T) {
w := newInflightWrites()
_, finish := w.startWrite(t.Context(), "abc123")
defer finish()

// A different digest should not block.
doneCtx, doneFunc := context.WithCancel(t.Context())
go func() {
w.waitForWrite("def456")
doneFunc()
}()
select {
case <-doneCtx.Done():
case <-t.Context().Done():
t.Fatal("waitForWrite blocked on a different digest")
}
}

func TestWaitForWrite_ParentCancellation(t *testing.T) {
w := newInflightWrites()
ctx, cancel := context.WithCancel(t.Context())
_, finish := w.startWrite(ctx, "abc123")
defer finish()

doneCtx, doneFunc := context.WithCancel(t.Context())
go func() {
w.waitForWrite("abc123")
doneFunc()
}()

// Cancel the parent context — reader should unblock.
cancel()
select {
case <-doneCtx.Done():
case <-t.Context().Done():
t.Fatal("reader did not unblock after parent context cancellation")
}
}

func TestWaitForWrite_MultipleReaders(t *testing.T) {
w := newInflightWrites()
_, finish := w.startWrite(t.Context(), "abc123")

const numReaders = 10
doneCtx, doneFunc := context.WithCancel(t.Context())
var wg sync.WaitGroup
wg.Add(numReaders)
for range numReaders {
go func() {
w.waitForWrite("abc123")
wg.Done()
}()
}

// All readers should be blocked. Finish the write.
time.Sleep(50 * time.Millisecond)
finish()

go func() {
wg.Wait()
doneFunc()
}()
select {
case <-doneCtx.Done():
case <-t.Context().Done():
t.Fatal("not all readers unblocked")
}
}
8 changes: 8 additions & 0 deletions elan/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ func createServer(storage string, parallelism int, maxDirCacheSize, maxKnownBlob
decompressor: dec,
readRedis: readRedis,
largeBlobSize: largeBlobSize,
inflight: newInflightWrites(),
}
}

Expand Down Expand Up @@ -222,6 +223,7 @@ type server struct {
decompressor *zstd.Decoder
readRedis *redis.Client
largeBlobSize int64
inflight *inflightWrites
}

func (s *server) GetCapabilities(ctx context.Context, req *pb.GetCapabilitiesRequest) (*pb.ServerCapabilities, error) {
Expand Down Expand Up @@ -549,6 +551,7 @@ func (s *server) Read(req *bs.ReadRequest, srv bs.ByteStream_ReadServer) error {
}

func (s *server) readCompressed(ctx context.Context, prefix string, digest *pb.Digest, compressed bool, offset, limit int64) (io.ReadCloser, bool, error) {
s.inflight.waitForWrite(digest.Hash)
if prefix != "cas" {
if compressed {
return nil, false, fmt.Errorf("Attempted to do a compressed read for non-CAS prefix %s", prefix) // This is a programming error and shouldn't happen.
Expand Down Expand Up @@ -687,6 +690,7 @@ func (s *server) readAllBlobBatched(ctx context.Context, prefix string, digest *
}

func (s *server) readAllBlobCompressed(ctx context.Context, digest *pb.Digest, key string, batched, compressed bool) ([]byte, error) {
s.inflight.waitForWrite(digest.Hash)
if digest.SizeBytes > s.largeBlobSize {
s.limiter <- struct{}{}
defer func() { <-s.limiter }()
Expand Down Expand Up @@ -724,6 +728,8 @@ func (s *server) compressedKey(prefix string, digest *pb.Digest, compressed bool
}

func (s *server) writeBlob(ctx context.Context, prefix string, digest *pb.Digest, r io.Reader, compressed bool) error {
_, done := s.inflight.startWrite(ctx, digest.Hash)
defer done()
key := s.compressedKey(prefix, digest, compressed)
if s.isEmpty(digest) || s.blobExists(ctx, prefix, digest, compressed, true) {
// Read and discard entire content; there is no need to update.
Expand Down Expand Up @@ -774,6 +780,8 @@ func (s *server) writeBlob(ctx context.Context, prefix string, digest *pb.Digest
}

func (s *server) writeAll(ctx context.Context, digest *pb.Digest, data []byte, compressed bool) error {
_, done := s.inflight.startWrite(ctx, digest.Hash)
defer done()
if digest.SizeBytes > s.largeBlobSize {
s.limiter <- struct{}{}
defer func() { <-s.limiter }()
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/thought-machine/please-servers

go 1.22
go 1.24

require (
cloud.google.com/go/profiler v0.4.0
Expand Down
2 changes: 1 addition & 1 deletion third_party/go/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package(default_visibility = ["PUBLIC"])

go_toolchain(
name = "toolchain",
version = "1.23.2",
version = "1.24.2",
install_std = False,
)

Expand Down