diff --git a/.plzconfig b/.plzconfig index b423e69f..dff2c9b5 100644 --- a/.plzconfig +++ b/.plzconfig @@ -2,7 +2,7 @@ version = >=17.0.0 [build] -path = /usr/local/go/bin:/usr/local/bin:/usr/bin:/bin +path = /usr/local/go/bin:/usr/local/bin:/usr/bin:/bin:/sbin:/usr/sbin [buildconfig] local-host = 127.0.0.1 diff --git a/elan/rpc/BUILD b/elan/rpc/BUILD index 1bd9895a..32a16616 100644 --- a/elan/rpc/BUILD +++ b/elan/rpc/BUILD @@ -57,6 +57,15 @@ go_test( ], ) +go_test( + name = "inflight_test", + srcs = ["inflight_test.go"], + deps = [ + ":rpc", + "///third_party/go/github.com_stretchr_testify//require", + ], +) + genrule( name = "test_data", cmd = [ diff --git a/elan/rpc/inflight.go b/elan/rpc/inflight.go new file mode 100644 index 00000000..23101f8a --- /dev/null +++ b/elan/rpc/inflight.go @@ -0,0 +1,55 @@ +package rpc + +import ( + "context" + "sync" + "time" +) + +// defaultWriteTimeout is the maximum time a reader will wait for an in-progress +// write to complete before proceeding anyway. +const defaultWriteTimeout = 10 * time.Minute + +// inflightWrites tracks blob writes that are currently in progress so that +// concurrent readers can block until the write completes rather than getting +// a spurious NotFound. +type inflightWrites struct { + mu struct { + sync.Mutex + blobs map[string]context.Context + } +} + +func newInflightWrites() *inflightWrites { + w := &inflightWrites{} + w.mu.blobs = make(map[string]context.Context) + return w +} + +// startWrite registers a blob write in progress. The returned cancel function +// must be called when the write completes (success or failure) — use defer. +// The write context inherits from the caller so parent cancellation propagates. +// A default timeout of defaultWriteTimeout is applied as a safety net. +func (w *inflightWrites) startWrite(ctx context.Context, hash string) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(ctx, defaultWriteTimeout) + w.mu.Lock() + w.mu.blobs[hash] = ctx + w.mu.Unlock() + return ctx, func() { + w.mu.Lock() + delete(w.mu.blobs, hash) + w.mu.Unlock() + cancel() + } +} + +// waitForWrite blocks until any in-progress write for the given hash completes. +// If no write is in progress, it returns immediately. +func (w *inflightWrites) waitForWrite(hash string) { + w.mu.Lock() + ctx, ok := w.mu.blobs[hash] + w.mu.Unlock() + if ok { + <-ctx.Done() + } +} diff --git a/elan/rpc/inflight_test.go b/elan/rpc/inflight_test.go new file mode 100644 index 00000000..a32d2eb7 --- /dev/null +++ b/elan/rpc/inflight_test.go @@ -0,0 +1,129 @@ +package rpc + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWaitForWrite_NoInflight(t *testing.T) { + w := newInflightWrites() + // Should return immediately when nothing is in flight. + doneCtx, doneFunc := context.WithCancel(t.Context()) + go func() { + w.waitForWrite("abc123") + doneFunc() + }() + select { + case <-doneCtx.Done(): + case <-time.After(120 * time.Second): + t.Fatal("waitForWrite blocked when no write was in progress") + } +} + +func TestWaitForWrite_BlocksUntilDone(t *testing.T) { + w := newInflightWrites() + _, finish := w.startWrite(t.Context(), "abc123") + + var order []string + var mu sync.Mutex + record := func(s string) { + mu.Lock() + order = append(order, s) + mu.Unlock() + } + + doneCtx, doneFunc := context.WithCancel(t.Context()) + go func() { + w.waitForWrite("abc123") + record("read") + doneFunc() + }() + + // Give the reader goroutine time to block. + time.Sleep(50 * time.Millisecond) + record("write") + finish() + + select { + case <-doneCtx.Done(): + case <-t.Context().Done(): + t.Fatal("reader never unblocked") + } + + mu.Lock() + defer mu.Unlock() + require.Equal(t, []string{"write", "read"}, order) +} + +func TestWaitForWrite_DifferentDigests(t *testing.T) { + w := newInflightWrites() + _, finish := w.startWrite(t.Context(), "abc123") + defer finish() + + // A different digest should not block. + doneCtx, doneFunc := context.WithCancel(t.Context()) + go func() { + w.waitForWrite("def456") + doneFunc() + }() + select { + case <-doneCtx.Done(): + case <-t.Context().Done(): + t.Fatal("waitForWrite blocked on a different digest") + } +} + +func TestWaitForWrite_ParentCancellation(t *testing.T) { + w := newInflightWrites() + ctx, cancel := context.WithCancel(t.Context()) + _, finish := w.startWrite(ctx, "abc123") + defer finish() + + doneCtx, doneFunc := context.WithCancel(t.Context()) + go func() { + w.waitForWrite("abc123") + doneFunc() + }() + + // Cancel the parent context — reader should unblock. + cancel() + select { + case <-doneCtx.Done(): + case <-t.Context().Done(): + t.Fatal("reader did not unblock after parent context cancellation") + } +} + +func TestWaitForWrite_MultipleReaders(t *testing.T) { + w := newInflightWrites() + _, finish := w.startWrite(t.Context(), "abc123") + + const numReaders = 10 + doneCtx, doneFunc := context.WithCancel(t.Context()) + var wg sync.WaitGroup + wg.Add(numReaders) + for range numReaders { + go func() { + w.waitForWrite("abc123") + wg.Done() + }() + } + + // All readers should be blocked. Finish the write. + time.Sleep(50 * time.Millisecond) + finish() + + go func() { + wg.Wait() + doneFunc() + }() + select { + case <-doneCtx.Done(): + case <-t.Context().Done(): + t.Fatal("not all readers unblocked") + } +} diff --git a/elan/rpc/rpc.go b/elan/rpc/rpc.go index 3f5afe71..b48b7ac6 100644 --- a/elan/rpc/rpc.go +++ b/elan/rpc/rpc.go @@ -179,6 +179,7 @@ func createServer(storage string, parallelism int, maxDirCacheSize, maxKnownBlob decompressor: dec, readRedis: readRedis, largeBlobSize: largeBlobSize, + inflight: newInflightWrites(), } } @@ -222,6 +223,7 @@ type server struct { decompressor *zstd.Decoder readRedis *redis.Client largeBlobSize int64 + inflight *inflightWrites } func (s *server) GetCapabilities(ctx context.Context, req *pb.GetCapabilitiesRequest) (*pb.ServerCapabilities, error) { @@ -549,6 +551,7 @@ func (s *server) Read(req *bs.ReadRequest, srv bs.ByteStream_ReadServer) error { } func (s *server) readCompressed(ctx context.Context, prefix string, digest *pb.Digest, compressed bool, offset, limit int64) (io.ReadCloser, bool, error) { + s.inflight.waitForWrite(digest.Hash) if prefix != "cas" { if compressed { return nil, false, fmt.Errorf("Attempted to do a compressed read for non-CAS prefix %s", prefix) // This is a programming error and shouldn't happen. @@ -687,6 +690,7 @@ func (s *server) readAllBlobBatched(ctx context.Context, prefix string, digest * } func (s *server) readAllBlobCompressed(ctx context.Context, digest *pb.Digest, key string, batched, compressed bool) ([]byte, error) { + s.inflight.waitForWrite(digest.Hash) if digest.SizeBytes > s.largeBlobSize { s.limiter <- struct{}{} defer func() { <-s.limiter }() @@ -724,6 +728,8 @@ func (s *server) compressedKey(prefix string, digest *pb.Digest, compressed bool } func (s *server) writeBlob(ctx context.Context, prefix string, digest *pb.Digest, r io.Reader, compressed bool) error { + _, done := s.inflight.startWrite(ctx, digest.Hash) + defer done() key := s.compressedKey(prefix, digest, compressed) if s.isEmpty(digest) || s.blobExists(ctx, prefix, digest, compressed, true) { // Read and discard entire content; there is no need to update. @@ -774,6 +780,8 @@ func (s *server) writeBlob(ctx context.Context, prefix string, digest *pb.Digest } func (s *server) writeAll(ctx context.Context, digest *pb.Digest, data []byte, compressed bool) error { + _, done := s.inflight.startWrite(ctx, digest.Hash) + defer done() if digest.SizeBytes > s.largeBlobSize { s.limiter <- struct{}{} defer func() { <-s.limiter }() diff --git a/go.mod b/go.mod index a5adf5a4..13660866 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/thought-machine/please-servers -go 1.22 +go 1.24 require ( cloud.google.com/go/profiler v0.4.0 diff --git a/third_party/go/BUILD b/third_party/go/BUILD index 96cc0e21..d503d813 100644 --- a/third_party/go/BUILD +++ b/third_party/go/BUILD @@ -4,7 +4,7 @@ package(default_visibility = ["PUBLIC"]) go_toolchain( name = "toolchain", - version = "1.23.2", + version = "1.24.2", install_std = False, )