diff --git a/packages/orchestrator/internal/sandbox/block/chunk.go b/packages/orchestrator/internal/sandbox/block/chunk.go index fff60dc792..d04c96ad63 100644 --- a/packages/orchestrator/internal/sandbox/block/chunk.go +++ b/packages/orchestrator/internal/sandbox/block/chunk.go @@ -17,7 +17,16 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) -type Chunker struct { +// Chunker is the interface satisfied by both FullFetchChunker and StreamingChunker. +type Chunker interface { + Slice(ctx context.Context, off, length int64) ([]byte, error) + ReadAt(ctx context.Context, b []byte, off int64) (int, error) + WriteTo(ctx context.Context, w io.Writer) (int64, error) + Close() error + FileSize() (int64, error) +} + +type FullFetchChunker struct { base storage.SeekableReader cache *Cache metrics metrics.Metrics @@ -28,18 +37,18 @@ type Chunker struct { fetchers *utils.WaitMap } -func NewChunker( +func NewFullFetchChunker( size, blockSize int64, base storage.SeekableReader, cachePath string, metrics metrics.Metrics, -) (*Chunker, error) { +) (*FullFetchChunker, error) { cache, err := NewCache(size, blockSize, cachePath, false) if err != nil { return nil, fmt.Errorf("failed to create file cache: %w", err) } - chunker := &Chunker{ + chunker := &FullFetchChunker{ size: size, base: base, cache: cache, @@ -50,7 +59,7 @@ func NewChunker( return chunker, nil } -func (c *Chunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { +func (c *FullFetchChunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { slice, err := c.Slice(ctx, off, int64(len(b))) if err != nil { return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", off, off+int64(len(b)), err) @@ -59,7 +68,7 @@ func (c *Chunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) return copy(b, slice), nil } -func (c *Chunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { +func (c *FullFetchChunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { for i := int64(0); i < c.size; i += storage.MemoryChunkSize { chunk := make([]byte, storage.MemoryChunkSize) @@ -77,7 +86,7 @@ func (c *Chunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { return c.size, nil } -func (c *Chunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { +func (c *FullFetchChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { timer := c.metrics.SlicesTimerFactory.Begin() b, err := c.cache.Slice(off, length) @@ -121,7 +130,7 @@ func (c *Chunker) Slice(ctx context.Context, off, length int64) ([]byte, error) } // fetchToCache ensures that the data at the given offset and length is available in the cache. -func (c *Chunker) fetchToCache(ctx context.Context, off, length int64) error { +func (c *FullFetchChunker) fetchToCache(ctx context.Context, off, length int64) error { var eg errgroup.Group chunks := header.BlocksOffsets(length, storage.MemoryChunkSize) @@ -194,11 +203,11 @@ func (c *Chunker) fetchToCache(ctx context.Context, off, length int64) error { return nil } -func (c *Chunker) Close() error { +func (c *FullFetchChunker) Close() error { return c.cache.Close() } -func (c *Chunker) FileSize() (int64, error) { +func (c *FullFetchChunker) FileSize() (int64, error) { return c.cache.FileSize() } diff --git a/packages/orchestrator/internal/sandbox/block/streaming_chunk.go b/packages/orchestrator/internal/sandbox/block/streaming_chunk.go new file mode 100644 index 0000000000..b45a773c35 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/streaming_chunk.go @@ -0,0 +1,507 @@ +package block + +import ( + "cmp" + "context" + "errors" + "fmt" + "io" + "slices" + "sync" + "sync/atomic" + "time" + + "go.opentelemetry.io/otel/attribute" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +const ( + // defaultFetchTimeout is the maximum time a single 4MB chunk fetch may take. + // Acts as a safety net: if the upstream hangs, the goroutine won't live forever. + defaultFetchTimeout = 60 * time.Second +) + +type rangeWaiter struct { + // endByte is the byte offset (relative to chunkOff) at which this waiter's + // entire requested range is cached. Equal to the end of the last block + // overlapping the requested range. Always a multiple of blockSize. + endByte int64 + ch chan error // buffered cap 1 +} + +const ( + fetchStateRunning = iota + fetchStateDone + fetchStateErrored +) + +type fetchSession struct { + mu sync.Mutex + chunker *StreamingChunker + chunkOff int64 + chunkLen int64 + waiters []*rangeWaiter // sorted by endByte ascending + state int + fetchErr error + + // bytesReady is the byte count (from chunkOff) up to which all blocks are + // fully written to mmap and marked cached. Always a multiple of blockSize + // during progressive reads. Used to cheaply determine which sorted waiters + // are satisfied without calling isCached. + bytesReady int64 +} + +// registerAndWait adds a waiter for the given range and blocks until the range +// is cached or the context is cancelled. Returns nil if the range was already +// cached before registering. +func (s *fetchSession) registerAndWait(ctx context.Context, off, length int64) error { + // endByte is the byte offset (relative to chunkOff) past which all blocks + // covering [off, off+length) are fully cached. + blockSize := s.chunker.blockSize + lastBlockIdx := (off + length - 1 - s.chunkOff) / blockSize + endByte := (lastBlockIdx + 1) * blockSize + + // Fast path: already cached (handles pre-existing cache from prior sessions). + // No lock needed — atomic load + sync.Map lookup are both thread-safe. + if cache := s.chunker.cache.Load(); cache != nil && cache.isCached(off, length) { + return nil + } + + s.mu.Lock() + + // Session already done — all data that will ever be fetched is in cache. + // Unlock first: once state is Done no goroutine mutates the dirty map for + // this chunk, so isCached is safe to call without the session lock. + if s.state == fetchStateDone { + s.mu.Unlock() + if cache := s.chunker.cache.Load(); cache != nil && cache.isCached(off, length) { + return nil + } + + return fmt.Errorf("fetch completed but range %d-%d not cached", off, off+length) + } + + // Session errored — partial data may still be usable. + if s.state == fetchStateErrored { + fetchErr := s.fetchErr + s.mu.Unlock() + if cache := s.chunker.cache.Load(); cache != nil && cache.isCached(off, length) { + return nil + } + + return fmt.Errorf("fetch failed: %w", fetchErr) + } + + w := &rangeWaiter{ + endByte: endByte, + ch: make(chan error, 1), + } + + // Insert in sorted order so notifyWaiters can iterate front-to-back. + idx, _ := slices.BinarySearchFunc(s.waiters, endByte, func(w *rangeWaiter, target int64) int { + return cmp.Compare(w.endByte, target) + }) + s.waiters = slices.Insert(s.waiters, idx, w) + + s.mu.Unlock() + + select { + case err := <-w.ch: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +// notifyWaiters notifies waiters whose ranges are satisfied. +// +// Because waiters are sorted by endByte and the fetch fills the chunk +// sequentially, we only need to walk from the front until we hit a waiter +// whose endByte exceeds bytesReady — all subsequent waiters are unsatisfied. +// +// In terminal states (done/errored) all remaining waiters are notified. +// Must be called with s.mu held. +func (s *fetchSession) notifyWaiters(sendErr error) { + // Terminal: notify every remaining waiter. + if s.state != fetchStateRunning { + for _, w := range s.waiters { + if sendErr != nil && w.endByte > s.bytesReady { + w.ch <- sendErr + } else { + w.ch <- nil + } + } + s.waiters = nil + + return + } + + // Progress: pop satisfied waiters from the sorted front. + i := 0 + for i < len(s.waiters) && s.waiters[i].endByte <= s.bytesReady { + s.waiters[i].ch <- nil + i++ + } + s.waiters = s.waiters[i:] +} + +type StreamingChunker struct { + upstream storage.StreamingReader + cache atomic.Pointer[Cache] // nil until ensureInitialized succeeds + metrics metrics.Metrics + fetchTimeout time.Duration + + size atomic.Int64 // 0 until ensureInitialized succeeds + blockSize int64 + + fetchMu sync.Mutex + fetchMap map[int64]*fetchSession + + initOnce sync.Once + initErr error + // Fields used only by ensureInitialized (immutable after construction). + cachePath string + sizeFunc func(context.Context) (int64, error) +} + +// NewStreamingChunker creates a streaming chunker that defers cache creation +// until the first range read discovers the object size. The sizeFunc should be +// the storage object's Size method, which returns the cached value after the +// first OpenRangeReader call populates it. +func NewStreamingChunker( + blockSize int64, + upstream storage.StreamingReader, + sizeFunc func(context.Context) (int64, error), + cachePath string, + metrics metrics.Metrics, +) *StreamingChunker { + return &StreamingChunker{ + blockSize: blockSize, + upstream: upstream, + metrics: metrics, + fetchTimeout: defaultFetchTimeout, + fetchMap: make(map[int64]*fetchSession), + cachePath: cachePath, + sizeFunc: sizeFunc, + } +} + +// ensureInitialized creates the mmap-backed cache on first call. +// The caller must have already triggered a range read so that sizeFunc +// returns the cached value without a network call. +// Safe to call from multiple goroutines; sync.Once serializes. +func (c *StreamingChunker) ensureInitialized(ctx context.Context) error { + c.initOnce.Do(func() { + size, err := c.sizeFunc(ctx) + if err != nil { + c.initErr = fmt.Errorf("failed to get object size: %w", err) + + return + } + + cache, err := NewCache(size, c.blockSize, c.cachePath, false) + if err != nil { + c.initErr = fmt.Errorf("failed to create file cache: %w", err) + + return + } + + // Store size before cache: any goroutine that sees cache != nil + // is guaranteed to see the size (atomic sequential consistency). + c.size.Store(size) + c.cache.Store(cache) + }) + + return c.initErr +} + +func (c *StreamingChunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { + slice, err := c.Slice(ctx, off, int64(len(b))) + if err != nil { + return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", off, off+int64(len(b)), err) + } + + return copy(b, slice), nil +} + +func (c *StreamingChunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { + chunk := make([]byte, storage.MemoryChunkSize) + size := c.size.Load() + + for i := int64(0); i < size; i += storage.MemoryChunkSize { + n, err := c.ReadAt(ctx, chunk, i) + if err != nil { + return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", i, i+storage.MemoryChunkSize, err) + } + + _, err = w.Write(chunk[:n]) + if err != nil { + return 0, fmt.Errorf("failed to write chunk %d to writer: %w", i, err) + } + } + + return size, nil +} + +func (c *StreamingChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { + timer := c.metrics.SlicesTimerFactory.Begin() + + // Fast path: already cached. Skip if cache hasn't been created yet (lazy init). + if cache := c.cache.Load(); cache != nil { + b, err := cache.Slice(off, length) + if err == nil { + timer.Success(ctx, length, + attribute.String(pullType, pullTypeLocal)) + + return b, nil + } + + if !errors.As(err, &BytesNotAvailableError{}) { + timer.Failure(ctx, length, + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalRead)) + + return nil, fmt.Errorf("failed read from cache at offset %d: %w", off, err) + } + } + + // Compute which 4MB chunks overlap with the requested range + firstChunkOff := (off / storage.MemoryChunkSize) * storage.MemoryChunkSize + lastChunkOff := ((off + length - 1) / storage.MemoryChunkSize) * storage.MemoryChunkSize + + var eg errgroup.Group + + for fetchOff := firstChunkOff; fetchOff <= lastChunkOff; fetchOff += storage.MemoryChunkSize { + eg.Go(func() error { + // Clip request to this chunk's boundaries. + chunkEnd := fetchOff + storage.MemoryChunkSize + clippedOff := max(off, fetchOff) + clippedEnd := min(off+length, chunkEnd) + // Clip to known size if initialized; before init, size is + // unknown so we let the fetch discover it. + if s := c.size.Load(); s > 0 { + clippedEnd = min(clippedEnd, s) + } + clippedLen := clippedEnd - clippedOff + + if clippedLen <= 0 { + return nil + } + + session := c.getOrCreateSession(ctx, fetchOff) + + return session.registerAndWait(ctx, clippedOff, clippedLen) + }) + } + + if err := eg.Wait(); err != nil { + timer.Failure(ctx, length, + attribute.String(pullType, pullTypeRemote), + attribute.String(failureReason, failureTypeCacheFetch)) + + return nil, fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, err) + } + + b, cacheErr := c.cache.Load().Slice(off, length) + if cacheErr != nil { + timer.Failure(ctx, length, + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalReadAgain)) + + return nil, fmt.Errorf("failed to read from cache after ensuring data at %d-%d: %w", off, off+length, cacheErr) + } + + timer.Success(ctx, length, + attribute.String(pullType, pullTypeRemote)) + + return b, nil +} + +func (c *StreamingChunker) getOrCreateSession(ctx context.Context, fetchOff int64) *fetchSession { + chunkLen := int64(storage.MemoryChunkSize) + + // Before init, use the full chunk size as default; + // runFetch will correct it after ensureInitialized. + if s := c.size.Load(); s > 0 { + chunkLen = min(chunkLen, s-fetchOff) + } + + s := &fetchSession{ + chunker: c, + chunkOff: fetchOff, + chunkLen: chunkLen, + state: fetchStateRunning, + } + + c.fetchMu.Lock() + if existing, ok := c.fetchMap[fetchOff]; ok { + c.fetchMu.Unlock() + + return existing + } + c.fetchMap[fetchOff] = s + c.fetchMu.Unlock() + + // Detach from the caller's cancel signal so the shared fetch goroutine + // continues even if the first caller's context is cancelled. Trace/value + // context is preserved for metrics. + go c.runFetch(context.WithoutCancel(ctx), s) + + return s +} + +func (c *StreamingChunker) runFetch(ctx context.Context, s *fetchSession) { + ctx, cancel := context.WithTimeout(ctx, c.fetchTimeout) + defer cancel() + + defer func() { + c.fetchMu.Lock() + delete(c.fetchMap, s.chunkOff) + c.fetchMu.Unlock() + }() + + // Panic recovery: ensure waiters are always notified even if the fetch + // goroutine panics (e.g. nil pointer in upstream reader, mmap fault). + // Without this, waiters would block forever on their channels. + defer func() { + if r := recover(); r != nil { + err := fmt.Errorf("fetch panicked: %v", r) + s.mu.Lock() + if s.state == fetchStateRunning { + s.state = fetchStateErrored + s.fetchErr = err + s.notifyWaiters(err) + } + s.mu.Unlock() + } + }() + + // Open range reader first — for lazy init, this triggers size discovery + // on the storage object before we need the cache. + reader, err := c.upstream.OpenRangeReader(ctx, s.chunkOff, s.chunkLen) + if err != nil { + err = fmt.Errorf("failed to open range reader at %d: %w", s.chunkOff, err) + s.mu.Lock() + s.state = fetchStateErrored + s.fetchErr = err + s.notifyWaiters(err) + s.mu.Unlock() + + return + } + defer reader.Close() + + // For lazy init: now that OpenRangeReader has cached the object size, + // create the mmap-backed cache. + if err := c.ensureInitialized(ctx); err != nil { + s.mu.Lock() + s.state = fetchStateErrored + s.fetchErr = err + s.notifyWaiters(err) + s.mu.Unlock() + + return + } + + // Correct chunkLen now that we know the real file size. + // Only the runFetch goroutine writes s.chunkLen; no lock needed. + size := c.size.Load() + if s.chunkLen > size-s.chunkOff { + s.chunkLen = size - s.chunkOff + } + + mmapSlice, releaseLock, err := c.cache.Load().addressBytes(s.chunkOff, s.chunkLen) + if err != nil { + s.mu.Lock() + s.state = fetchStateErrored + s.fetchErr = err + s.notifyWaiters(err) + s.mu.Unlock() + + return + } + defer releaseLock() + + fetchTimer := c.metrics.RemoteReadsTimerFactory.Begin() + + err = c.progressiveRead(ctx, s, mmapSlice, reader) + if err != nil { + fetchTimer.Failure(ctx, s.chunkLen, + attribute.String(failureReason, failureTypeRemoteRead)) + + s.mu.Lock() + s.state = fetchStateErrored + s.fetchErr = err + s.notifyWaiters(err) + s.mu.Unlock() + + return + } + + fetchTimer.Success(ctx, s.chunkLen) + + s.mu.Lock() + s.state = fetchStateDone + s.notifyWaiters(nil) + s.mu.Unlock() +} + +func (c *StreamingChunker) progressiveRead(_ context.Context, s *fetchSession, mmapSlice []byte, reader io.Reader) error { + blockSize := c.blockSize + var totalRead int64 + var prevCompleted int64 + + for totalRead < s.chunkLen { + // Cap each Read to blockSize so the HTTP/GCS client returns after each + // block rather than buffering the entire remaining range. + readEnd := min(totalRead+blockSize, s.chunkLen) + n, readErr := reader.Read(mmapSlice[totalRead:readEnd]) + totalRead += int64(n) + + completedBlocks := totalRead / blockSize + if completedBlocks > prevCompleted { + newBytes := (completedBlocks - prevCompleted) * blockSize + c.cache.Load().setIsCached(s.chunkOff+prevCompleted*blockSize, newBytes) + prevCompleted = completedBlocks + + s.mu.Lock() + s.bytesReady = completedBlocks * blockSize + s.notifyWaiters(nil) + s.mu.Unlock() + } + + if errors.Is(readErr, io.EOF) { + // Mark final partial block if any + if totalRead > prevCompleted*blockSize { + c.cache.Load().setIsCached(s.chunkOff+prevCompleted*blockSize, totalRead-prevCompleted*blockSize) + } + // Remaining waiters are notified in runFetch via the Done state. + break + } + + if readErr != nil { + return fmt.Errorf("failed reading at offset %d after %d bytes: %w", s.chunkOff, totalRead, readErr) + } + } + + return nil +} + +func (c *StreamingChunker) Close() error { + if cache := c.cache.Load(); cache != nil { + return cache.Close() + } + + return nil +} + +func (c *StreamingChunker) FileSize() (int64, error) { + if cache := c.cache.Load(); cache != nil { + return cache.FileSize() + } + + return 0, nil +} diff --git a/packages/orchestrator/internal/sandbox/block/streaming_chunk_test.go b/packages/orchestrator/internal/sandbox/block/streaming_chunk_test.go new file mode 100644 index 0000000000..66808b9294 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/streaming_chunk_test.go @@ -0,0 +1,775 @@ +package block + +import ( + "context" + "crypto/rand" + "fmt" + "io" + mathrand "math/rand/v2" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +const ( + testBlockSize = header.PageSize // 4KB +) + +// slowUpstream simulates GCS: implements both SeekableReader and StreamingReader. +// OpenRangeReader returns a reader that yields blockSize bytes per Read() call +// with a configurable delay between calls. +type slowUpstream struct { + data []byte + blockSize int64 + delay time.Duration +} + +var ( + _ storage.SeekableReader = (*slowUpstream)(nil) + _ storage.StreamingReader = (*slowUpstream)(nil) +) + +func (s *slowUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { + end := min(off+int64(len(buffer)), int64(len(s.data))) + n := copy(buffer, s.data[off:end]) + + return n, nil +} + +func (s *slowUpstream) Size(_ context.Context) (int64, error) { + return int64(len(s.data)), nil +} + +func (s *slowUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { + end := min(off+length, int64(len(s.data))) + + return &slowReader{ + data: s.data[off:end], + blockSize: int(s.blockSize), + delay: s.delay, + }, nil +} + +type slowReader struct { + data []byte + pos int + blockSize int + delay time.Duration +} + +func (r *slowReader) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + + if r.delay > 0 { + time.Sleep(r.delay) + } + + end := min(r.pos+r.blockSize, len(r.data)) + + n := copy(p, r.data[r.pos:end]) + r.pos += n + + if r.pos >= len(r.data) { + return n, io.EOF + } + + return n, nil +} + +func (r *slowReader) Close() error { + return nil +} + +// fastUpstream simulates NFS: same interfaces but no delay. +type fastUpstream = slowUpstream + +// errorAfterNUpstream fails after reading n bytes. +type errorAfterNUpstream struct { + data []byte + failAfter int64 + blockSize int64 +} + +var _ storage.StreamingReader = (*errorAfterNUpstream)(nil) + +func (u *errorAfterNUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { + end := min(off+length, int64(len(u.data))) + + return &errorAfterNReader{ + data: u.data[off:end], + blockSize: int(u.blockSize), + failAfter: int(u.failAfter - off), + }, nil +} + +type errorAfterNReader struct { + data []byte + pos int + blockSize int + failAfter int +} + +func (r *errorAfterNReader) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + + if r.pos >= r.failAfter { + return 0, fmt.Errorf("simulated upstream error") + } + + end := min(r.pos+r.blockSize, len(r.data)) + + n := copy(p, r.data[r.pos:end]) + r.pos += n + + if r.pos >= len(r.data) { + return n, io.EOF + } + + return n, nil +} + +func (r *errorAfterNReader) Close() error { + return nil +} + +func newTestMetrics(t *testing.T) metrics.Metrics { + t.Helper() + + m, err := metrics.NewMetrics(noop.NewMeterProvider()) + require.NoError(t, err) + + return m +} + +func makeTestData(t *testing.T, size int) []byte { + t.Helper() + + data := make([]byte, size) + _, err := rand.Read(data) + require.NoError(t, err) + + return data +} + +func TestStreamingChunker_BasicSlice(t *testing.T) { + t.Parallel() + + data := makeTestData(t, storage.MemoryChunkSize) + upstream := &fastUpstream{data: data, blockSize: testBlockSize} + + chunker := NewStreamingChunker( + testBlockSize, + upstream, upstream.Size, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + // Read first page + slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) +} + +func TestStreamingChunker_CacheHit(t *testing.T) { + t.Parallel() + + data := makeTestData(t, storage.MemoryChunkSize) + readCount := atomic.Int64{} + + upstream := &countingUpstream{ + inner: &fastUpstream{data: data, blockSize: testBlockSize}, + readCount: &readCount, + } + + chunker := NewStreamingChunker( + testBlockSize, + upstream, upstream.Size, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + // First read: triggers fetch + _, err := chunker.Slice(t.Context(), 0, testBlockSize) + require.NoError(t, err) + + // Wait for the full chunk to be fetched + time.Sleep(50 * time.Millisecond) + + firstCount := readCount.Load() + require.Positive(t, firstCount) + + // Second read: should hit cache + slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) + + // No additional reads should have happened + assert.Equal(t, firstCount, readCount.Load()) +} + +type countingUpstream struct { + inner *fastUpstream + readCount *atomic.Int64 +} + +var ( + _ storage.SeekableReader = (*countingUpstream)(nil) + _ storage.StreamingReader = (*countingUpstream)(nil) +) + +func (c *countingUpstream) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { + c.readCount.Add(1) + + return c.inner.ReadAt(ctx, buffer, off) +} + +func (c *countingUpstream) Size(ctx context.Context) (int64, error) { + return c.inner.Size(ctx) +} + +func (c *countingUpstream) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { + c.readCount.Add(1) + + return c.inner.OpenRangeReader(ctx, off, length) +} + +func TestStreamingChunker_ConcurrentSameChunk(t *testing.T) { + t.Parallel() + + data := makeTestData(t, storage.MemoryChunkSize) + // Use a slow upstream so requests will overlap + upstream := &slowUpstream{ + data: data, + blockSize: testBlockSize, + delay: 50 * time.Microsecond, + } + + chunker := NewStreamingChunker( + testBlockSize, + upstream, upstream.Size, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + numGoroutines := 10 + offsets := make([]int64, numGoroutines) + for i := range numGoroutines { + offsets[i] = int64(i) * testBlockSize + } + + results := make([][]byte, numGoroutines) + + var eg errgroup.Group + + for i := range numGoroutines { + eg.Go(func() error { + slice, err := chunker.Slice(t.Context(), offsets[i], testBlockSize) + if err != nil { + return fmt.Errorf("goroutine %d failed: %w", i, err) + } + results[i] = make([]byte, len(slice)) + copy(results[i], slice) + + return nil + }) + } + + require.NoError(t, eg.Wait()) + + for i := range numGoroutines { + require.Equal(t, data[offsets[i]:offsets[i]+testBlockSize], results[i], + "goroutine %d got wrong data", i) + } +} + +func TestStreamingChunker_EarlyReturn(t *testing.T) { + t.Parallel() + + data := makeTestData(t, storage.MemoryChunkSize) + upstream := &slowUpstream{ + data: data, + blockSize: testBlockSize, + delay: 100 * time.Microsecond, + } + + chunker := NewStreamingChunker( + testBlockSize, + upstream, upstream.Size, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + // Time how long it takes to get the first block + start := time.Now() + _, err := chunker.Slice(t.Context(), 0, testBlockSize) + earlyLatency := time.Since(start) + require.NoError(t, err) + + // Time how long it takes to get the last block (on a fresh chunker) + chunker2 := NewStreamingChunker( + testBlockSize, + upstream, upstream.Size, + t.TempDir()+"/cache2", + newTestMetrics(t), + ) + defer chunker2.Close() + + lastOff := int64(len(data)) - testBlockSize + start = time.Now() + _, err = chunker2.Slice(t.Context(), lastOff, testBlockSize) + lateLatency := time.Since(start) + require.NoError(t, err) + + // The early slice should return significantly faster + t.Logf("early latency: %v, late latency: %v", earlyLatency, lateLatency) + assert.Less(t, earlyLatency, lateLatency, + "first-block latency should be less than last-block latency") +} + +func TestStreamingChunker_ErrorKeepsPartialData(t *testing.T) { + t.Parallel() + + chunkSize := storage.MemoryChunkSize + data := makeTestData(t, chunkSize) + failAfter := int64(chunkSize / 2) // Fail at 2MB + + upstream := &errorAfterNUpstream{ + data: data, + failAfter: failAfter, + blockSize: testBlockSize, + } + + size := int64(len(data)) + chunker := NewStreamingChunker( + testBlockSize, + upstream, func(_ context.Context) (int64, error) { return size, nil }, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + // Request the last page — this should fail because upstream dies at 2MB + lastOff := int64(chunkSize) - testBlockSize + _, err := chunker.Slice(t.Context(), lastOff, testBlockSize) + require.Error(t, err) + + // But first page (within first 2MB) should still be cached and servable + slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) +} + +func TestStreamingChunker_ContextCancellation(t *testing.T) { + t.Parallel() + + data := makeTestData(t, storage.MemoryChunkSize) + upstream := &slowUpstream{ + data: data, + blockSize: testBlockSize, + delay: 1 * time.Millisecond, + } + + chunker := NewStreamingChunker( + testBlockSize, + upstream, upstream.Size, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + // Request with a context that we'll cancel quickly + ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond) + defer cancel() + + lastOff := int64(storage.MemoryChunkSize) - testBlockSize + _, err := chunker.Slice(ctx, lastOff, testBlockSize) + // This should fail with context cancellation + require.Error(t, err) + + // But another caller with a valid context should still get the data + // because the fetch goroutine uses background context + time.Sleep(200 * time.Millisecond) // Wait for fetch to complete + slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) +} + +func TestStreamingChunker_LastBlockPartial(t *testing.T) { + t.Parallel() + + // File size not aligned to blockSize + size := storage.MemoryChunkSize - 100 + data := makeTestData(t, size) + upstream := &fastUpstream{data: data, blockSize: testBlockSize} + + chunker := NewStreamingChunker( + testBlockSize, + upstream, upstream.Size, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + // Read the last partial block + lastBlockOff := (int64(size) / testBlockSize) * testBlockSize + remaining := int64(size) - lastBlockOff + + slice, err := chunker.Slice(t.Context(), lastBlockOff, remaining) + require.NoError(t, err) + require.Equal(t, data[lastBlockOff:], slice) +} + +func TestStreamingChunker_MultiChunkSlice(t *testing.T) { + t.Parallel() + + // Two 4MB chunks + size := storage.MemoryChunkSize * 2 + data := makeTestData(t, size) + upstream := &fastUpstream{data: data, blockSize: testBlockSize} + + chunker := NewStreamingChunker( + testBlockSize, + upstream, upstream.Size, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + // Request spanning two chunks: last page of chunk 0 + first page of chunk 1 + off := int64(storage.MemoryChunkSize) - testBlockSize + length := testBlockSize * 2 + + slice, err := chunker.Slice(t.Context(), off, int64(length)) + require.NoError(t, err) + require.Equal(t, data[off:off+int64(length)], slice) +} + +// panicUpstream panics during Read after delivering a configurable number of bytes. +type panicUpstream struct { + data []byte + blockSize int64 + panicAfter int64 // byte offset at which to panic (0 = panic immediately) +} + +var _ storage.StreamingReader = (*panicUpstream)(nil) + +func (u *panicUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { + end := min(off+length, int64(len(u.data))) + + return &panicReader{ + data: u.data[off:end], + blockSize: int(u.blockSize), + panicAfter: int(u.panicAfter - off), + }, nil +} + +type panicReader struct { + data []byte + pos int + blockSize int + panicAfter int +} + +func (r *panicReader) Read(p []byte) (int, error) { + if r.pos >= r.panicAfter { + panic("simulated upstream panic") + } + + if r.pos >= len(r.data) { + return 0, io.EOF + } + + end := min(r.pos+r.blockSize, len(r.data)) + n := copy(p, r.data[r.pos:end]) + r.pos += n + + return n, nil +} + +func (r *panicReader) Close() error { + return nil +} + +func TestStreamingChunker_PanicRecovery(t *testing.T) { + t.Parallel() + + data := makeTestData(t, storage.MemoryChunkSize) + panicAt := int64(storage.MemoryChunkSize / 2) // Panic at 2MB + + upstream := &panicUpstream{ + data: data, + blockSize: testBlockSize, + panicAfter: panicAt, + } + + size := int64(len(data)) + chunker := NewStreamingChunker( + testBlockSize, + upstream, func(_ context.Context) (int64, error) { return size, nil }, + t.TempDir()+"/cache", + newTestMetrics(t), + ) + defer chunker.Close() + + // Request data past the panic point — should get an error, not hang or crash + lastOff := int64(storage.MemoryChunkSize) - testBlockSize + _, err := chunker.Slice(t.Context(), lastOff, testBlockSize) + require.Error(t, err) + assert.Contains(t, err.Error(), "panicked") + + // Data before the panic point should still be cached + slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) +} + +// --- Benchmarks --- +// +// Uses a bandwidth-limited upstream with real time.Sleep to simulate GCS and +// NFS backends. Measures actual wall-clock latency per caller. +// +// Backend parameters (tuned to match observed production latencies): +// GCS: 20ms TTFB + 100 MB/s → 4MB chunk ≈ 62ms (observed ~60ms) +// NFS: 1ms TTFB + 500 MB/s → 4MB chunk ≈ 9ms (observed ~9-10ms) +// +// All sub-benchmarks share a pre-generated offset sequence so results are +// directly comparable across chunker types and backends. +// +// Recommended invocation (~1 minute): +// go test -bench BenchmarkRandomAccess -benchtime 150x -count=3 -run '^$' ./... + +func newBenchmarkMetrics(b *testing.B) metrics.Metrics { + b.Helper() + + m, err := metrics.NewMetrics(noop.NewMeterProvider()) + require.NoError(b, err) + + return m +} + +// realisticUpstream simulates a storage backend with configurable time-to-first-byte +// and bandwidth. ReadAt blocks for the full transfer duration (bulk fetch model). +// OpenRangeReader returns a bandwidth-limited progressive reader. +type realisticUpstream struct { + data []byte + blockSize int64 + ttfb time.Duration + bytesPerSec float64 +} + +var ( + _ storage.SeekableReader = (*realisticUpstream)(nil) + _ storage.StreamingReader = (*realisticUpstream)(nil) +) + +func (u *realisticUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { + transferTime := time.Duration(float64(len(buffer)) / u.bytesPerSec * float64(time.Second)) + time.Sleep(u.ttfb + transferTime) + + end := min(off+int64(len(buffer)), int64(len(u.data))) + n := copy(buffer, u.data[off:end]) + + return n, nil +} + +func (u *realisticUpstream) Size(_ context.Context) (int64, error) { + return int64(len(u.data)), nil +} + +func (u *realisticUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { + end := min(off+length, int64(len(u.data))) + + return &bandwidthReader{ + data: u.data[off:end], + blockSize: int(u.blockSize), + ttfb: u.ttfb, + bytesPerSec: u.bytesPerSec, + }, nil +} + +// bandwidthReader delivers data at a steady rate after an initial TTFB delay. +// Uses cumulative timing (time since first byte) so OS scheduling jitter does +// not compound across blocks. +type bandwidthReader struct { + data []byte + pos int + blockSize int + ttfb time.Duration + bytesPerSec float64 + startTime time.Time + started bool +} + +func (r *bandwidthReader) Read(p []byte) (int, error) { + if !r.started { + r.started = true + time.Sleep(r.ttfb) + r.startTime = time.Now() + } + + if r.pos >= len(r.data) { + return 0, io.EOF + } + + end := min(r.pos+r.blockSize, len(r.data)) + n := copy(p, r.data[r.pos:end]) + r.pos += n + + // Enforce bandwidth: sleep until this many bytes should have arrived. + expectedArrival := r.startTime.Add(time.Duration(float64(r.pos) / r.bytesPerSec * float64(time.Second))) + if wait := time.Until(expectedArrival); wait > 0 { + time.Sleep(wait) + } + + if r.pos >= len(r.data) { + return n, io.EOF + } + + return n, nil +} + +func (r *bandwidthReader) Close() error { + return nil +} + +type benchChunker interface { + Slice(ctx context.Context, off, length int64) ([]byte, error) + Close() error +} + +func BenchmarkRandomAccess(b *testing.B) { + size := int64(storage.MemoryChunkSize) + data := make([]byte, size) + + backends := []struct { + name string + upstream *realisticUpstream + }{ + { + name: "GCS", + upstream: &realisticUpstream{ + data: data, + blockSize: testBlockSize, + ttfb: 20 * time.Millisecond, + bytesPerSec: 100e6, // 100 MB/s — full 4MB chunk ≈ 62ms (observed ~60ms) + }, + }, + { + name: "NFS", + upstream: &realisticUpstream{ + data: data, + blockSize: testBlockSize, + ttfb: 1 * time.Millisecond, + bytesPerSec: 500e6, // 500 MB/s — full 4MB chunk ≈ 9ms (observed ~9-10ms) + }, + }, + } + + chunkerTypes := []struct { + name string + newChunker func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker + }{ + { + name: "StreamingChunker", + newChunker: func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker { + b.Helper() + + return NewStreamingChunker(testBlockSize, upstream, upstream.Size, b.TempDir()+"/cache", m) + }, + }, + { + name: "FullFetchChunker", + newChunker: func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker { + b.Helper() + c, err := NewFullFetchChunker(size, testBlockSize, upstream, b.TempDir()+"/cache", m) + require.NoError(b, err) + + return c + }, + }, + } + + // Realistic concurrency: UFFD faults are limited by vCPU count (typically + // 1-2 for Firecracker VMs) and NBD requests are largely sequential. + const numCallers = 3 + + // Pre-generate a fixed sequence of random offsets so all sub-benchmarks + // use identical access patterns, making results directly comparable. + const maxIters = 500 + numBlocks := size / testBlockSize + rng := mathrand.New(mathrand.NewPCG(42, 0)) + + allOffsets := make([][]int64, maxIters) + for i := range allOffsets { + offsets := make([]int64, numCallers) + for j := range offsets { + offsets[j] = rng.Int64N(numBlocks) * testBlockSize + } + allOffsets[i] = offsets + } + + for _, backend := range backends { + for _, ct := range chunkerTypes { + b.Run(backend.name+"/"+ct.name, func(b *testing.B) { + m := newBenchmarkMetrics(b) + + b.ReportMetric(0, "ns/op") + + var sumAvg, sumMax float64 + + for i := range b.N { + offsets := allOffsets[i%maxIters] + + chunker := ct.newChunker(b, m, backend.upstream) + + latencies := make([]time.Duration, numCallers) + + var eg errgroup.Group + for ci, off := range offsets { + eg.Go(func() error { + start := time.Now() + _, err := chunker.Slice(context.Background(), off, testBlockSize) + latencies[ci] = time.Since(start) + + return err + }) + } + require.NoError(b, eg.Wait()) + + var totalLatency time.Duration + var maxLatency time.Duration + for _, l := range latencies { + totalLatency += l + maxLatency = max(maxLatency, l) + } + + avgUs := float64(totalLatency.Microseconds()) / float64(numCallers) + sumAvg += avgUs + sumMax = max(sumMax, float64(maxLatency.Microseconds())) + + chunker.Close() + } + + b.ReportMetric(sumAvg/float64(b.N), "avg-us/caller") + b.ReportMetric(sumMax, "worst-us/caller") + }) + } + } +} diff --git a/packages/orchestrator/internal/sandbox/build/build.go b/packages/orchestrator/internal/sandbox/build/build.go index b032c940c6..a87cbe52f6 100644 --- a/packages/orchestrator/internal/sandbox/build/build.go +++ b/packages/orchestrator/internal/sandbox/build/build.go @@ -123,6 +123,7 @@ func (b *File) getBuild(ctx context.Context, buildID *uuid.UUID) (Diff, error) { int64(b.header.Metadata.BlockSize), b.metrics, b.persistence, + b.store.flags, ) if err != nil { return nil, fmt.Errorf("failed to create storage diff: %w", err) diff --git a/packages/orchestrator/internal/sandbox/build/storage_diff.go b/packages/orchestrator/internal/sandbox/build/storage_diff.go index 8a05ec165d..53413eddde 100644 --- a/packages/orchestrator/internal/sandbox/build/storage_diff.go +++ b/packages/orchestrator/internal/sandbox/build/storage_diff.go @@ -7,6 +7,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -16,15 +17,16 @@ func storagePath(buildId string, diffType DiffType) string { } type StorageDiff struct { - chunker *utils.SetOnce[*block.Chunker] + chunker *utils.SetOnce[block.Chunker] cachePath string cacheKey DiffStoreKey storagePath string storageObjectType storage.SeekableObjectType - blockSize int64 - metrics blockmetrics.Metrics - persistence storage.StorageProvider + blockSize int64 + metrics blockmetrics.Metrics + persistence storage.StorageProvider + featureFlags *featureflags.Client } var _ Diff = (*StorageDiff)(nil) @@ -44,6 +46,7 @@ func newStorageDiff( blockSize int64, metrics blockmetrics.Metrics, persistence storage.StorageProvider, + featureFlags *featureflags.Client, ) (*StorageDiff, error) { storagePath := storagePath(buildId, diffType) storageObjectType, ok := storageObjectType(diffType) @@ -57,10 +60,11 @@ func newStorageDiff( storagePath: storagePath, storageObjectType: storageObjectType, cachePath: cachePath, - chunker: utils.NewSetOnce[*block.Chunker](), + chunker: utils.NewSetOnce[block.Chunker](), blockSize: blockSize, metrics: metrics, persistence: persistence, + featureFlags: featureFlags, cacheKey: GetDiffStoreKey(buildId, diffType), }, nil } @@ -86,23 +90,31 @@ func (b *StorageDiff) Init(ctx context.Context) error { return err } - size, err := obj.Size(ctx) - if err != nil { - errMsg := fmt.Errorf("failed to get object size: %w", err) - b.chunker.SetError(errMsg) - - return errMsg - } - - chunker, err := block.NewChunker(size, b.blockSize, obj, b.cachePath, b.metrics) - if err != nil { - errMsg := fmt.Errorf("failed to create chunker: %w", err) - b.chunker.SetError(errMsg) - - return errMsg + var c block.Chunker + if b.featureFlags != nil && b.featureFlags.BoolFlag(ctx, featureflags.UseStreamingChunkerFlag) { + // Lazy init: the object size is discovered from the first + // OpenRangeReader response (free from Content-Range), eliminating + // the need for a separate Size()/Attrs()/HeadObject call. + c = block.NewStreamingChunker(b.blockSize, obj, obj.Size, b.cachePath, b.metrics) + } else { + size, err := obj.Size(ctx) + if err != nil { + errMsg := fmt.Errorf("failed to get object size: %w", err) + b.chunker.SetError(errMsg) + + return errMsg + } + + c, err = block.NewFullFetchChunker(size, b.blockSize, obj, b.cachePath, b.metrics) + if err != nil { + errMsg := fmt.Errorf("failed to create chunker: %w", err) + b.chunker.SetError(errMsg) + + return errMsg + } } - return b.chunker.SetValue(chunker) + return b.chunker.SetValue(c) } func (b *StorageDiff) Close() error { diff --git a/packages/shared/pkg/feature-flags/flags.go b/packages/shared/pkg/feature-flags/flags.go index c29a2bd595..12cc293acf 100644 --- a/packages/shared/pkg/feature-flags/flags.go +++ b/packages/shared/pkg/feature-flags/flags.go @@ -92,6 +92,7 @@ var ( BestOfKTooManyStartingFlag = newBoolFlag("best-of-k-too-many-starting", false) EdgeProvidedSandboxMetricsFlag = newBoolFlag("edge-provided-sandbox-metrics", false) CreateStorageCacheSpansFlag = newBoolFlag("create-storage-cache-spans", env.IsDevelopment()) + UseStreamingChunkerFlag = newBoolFlag("use-streaming-chunker", false) SandboxAutoResumeFlag = newBoolFlag("sandbox-auto-resume", env.IsDevelopment()) PersistentVolumesFlag = newBoolFlag("can-use-persistent-volumes", env.IsDevelopment()) ) diff --git a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go b/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go index 6bfc8f60ad..3931f6b349 100644 --- a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go +++ b/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go @@ -6,6 +6,7 @@ package storagemocks import ( "context" + "io" mock "github.com/stretchr/testify/mock" ) @@ -37,6 +38,80 @@ func (_m *MockSeekable) EXPECT() *MockSeekable_Expecter { return &MockSeekable_Expecter{mock: &_m.Mock} } +// OpenRangeReader provides a mock function for the type MockSeekable +func (_mock *MockSeekable) OpenRangeReader(ctx context.Context, off int64, length int64) (io.ReadCloser, error) { + ret := _mock.Called(ctx, off, length) + + if len(ret) == 0 { + panic("no return value specified for OpenRangeReader") + } + + var r0 io.ReadCloser + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) (io.ReadCloser, error)); ok { + return returnFunc(ctx, off, length) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) io.ReadCloser); ok { + r0 = returnFunc(ctx, off, length) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok { + r1 = returnFunc(ctx, off, length) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockSeekable_OpenRangeReader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OpenRangeReader' +type MockSeekable_OpenRangeReader_Call struct { + *mock.Call +} + +// OpenRangeReader is a helper method to define mock.On call +// - ctx context.Context +// - off int64 +// - length int64 +func (_e *MockSeekable_Expecter) OpenRangeReader(ctx interface{}, off interface{}, length interface{}) *MockSeekable_OpenRangeReader_Call { + return &MockSeekable_OpenRangeReader_Call{Call: _e.mock.On("OpenRangeReader", ctx, off, length)} +} + +func (_c *MockSeekable_OpenRangeReader_Call) Run(run func(ctx context.Context, off int64, length int64)) *MockSeekable_OpenRangeReader_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 int64 + if args[1] != nil { + arg1 = args[1].(int64) + } + var arg2 int64 + if args[2] != nil { + arg2 = args[2].(int64) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockSeekable_OpenRangeReader_Call) Return(readCloser io.ReadCloser, err error) *MockSeekable_OpenRangeReader_Call { + _c.Call.Return(readCloser, err) + return _c +} + +func (_c *MockSeekable_OpenRangeReader_Call) RunAndReturn(run func(ctx context.Context, off int64, length int64) (io.ReadCloser, error)) *MockSeekable_OpenRangeReader_Call { + _c.Call.Return(run) + return _c +} + // ReadAt provides a mock function for the type MockSeekable func (_mock *MockSeekable) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { ret := _mock.Called(ctx, buffer, off) diff --git a/packages/shared/pkg/storage/storage.go b/packages/shared/pkg/storage/storage.go index 2446539f4f..12f5ed95ed 100644 --- a/packages/shared/pkg/storage/storage.go +++ b/packages/shared/pkg/storage/storage.go @@ -79,6 +79,11 @@ type SeekableReader interface { Size(ctx context.Context) (int64, error) } +// StreamingReader supports progressive reads via a streaming range reader. +type StreamingReader interface { + OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) +} + type SeekableWriter interface { // Store entire file StoreFile(ctx context.Context, path string) error @@ -87,6 +92,7 @@ type SeekableWriter interface { type Seekable interface { SeekableReader SeekableWriter + StreamingReader } func GetTemplateStorageProvider(ctx context.Context, limiter *limit.Limiter) (StorageProvider, error) { diff --git a/packages/shared/pkg/storage/storage_aws.go b/packages/shared/pkg/storage/storage_aws.go index dd1555d936..5ba1aaf4a9 100644 --- a/packages/shared/pkg/storage/storage_aws.go +++ b/packages/shared/pkg/storage/storage_aws.go @@ -7,7 +7,9 @@ import ( "fmt" "io" "os" + "strconv" "strings" + "sync/atomic" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -38,6 +40,10 @@ type awsObject struct { client *s3.Client path string bucketName string + + // discoveredSize caches the total object size learned from range-read + // responses (Content-Range header), avoiding a separate HeadObject call. + discoveredSize atomic.Int64 } var ( @@ -211,6 +217,31 @@ func (o *awsObject) Put(ctx context.Context, data []byte) error { return nil } +func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { + readRange := aws.String(fmt.Sprintf("bytes=%d-%d", off, off+length-1)) + resp, err := o.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(o.bucketName), + Key: aws.String(o.path), + Range: readRange, + }) + if err != nil { + var nsk *types.NoSuchKey + if errors.As(err, &nsk) { + return nil, ErrObjectNotExist + } + + return nil, fmt.Errorf("failed to create S3 range reader for %q: %w", o.path, err) + } + + if resp.ContentRange != nil { + if total := parseContentRangeTotal(*resp.ContentRange); total > 0 { + o.discoveredSize.Store(total) + } + } + + return resp.Body, nil +} + func (o *awsObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, err error) { ctx, cancel := context.WithTimeout(ctx, awsReadTimeout) defer cancel() @@ -232,6 +263,12 @@ func (o *awsObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, defer resp.Body.Close() + if resp.ContentRange != nil { + if total := parseContentRangeTotal(*resp.ContentRange); total > 0 { + o.discoveredSize.Store(total) + } + } + // When the object is smaller than requested range there will be unexpected EOF, // but backend expects to return EOF in this case. n, err = io.ReadFull(resp.Body, buff) @@ -243,6 +280,10 @@ func (o *awsObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, } func (o *awsObject) Size(ctx context.Context) (int64, error) { + if s := o.discoveredSize.Load(); s > 0 { + return s, nil + } + ctx, cancel := context.WithTimeout(ctx, awsOperationTimeout) defer cancel() @@ -259,6 +300,22 @@ func (o *awsObject) Size(ctx context.Context) (int64, error) { return *resp.ContentLength, nil } +// parseContentRangeTotal extracts the total size from a Content-Range header +// value like "bytes 0-99/12345". Returns 0 if the format is unexpected. +func parseContentRangeTotal(cr string) int64 { + idx := strings.LastIndex(cr, "/") + if idx < 0 || idx+1 >= len(cr) { + return 0 + } + + total, err := strconv.ParseInt(cr[idx+1:], 10, 64) + if err != nil { + return 0 + } + + return total +} + func (o *awsObject) Exists(ctx context.Context) (bool, error) { _, err := o.Size(ctx) diff --git a/packages/shared/pkg/storage/storage_aws_test.go b/packages/shared/pkg/storage/storage_aws_test.go new file mode 100644 index 0000000000..10476053bb --- /dev/null +++ b/packages/shared/pkg/storage/storage_aws_test.go @@ -0,0 +1,34 @@ +package storage + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseContentRangeTotal(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected int64 + }{ + {"standard range", "bytes 0-99/12345", 12345}, + {"large object", "bytes 0-4194303/1073741824", 1073741824}, + {"mid-range request", "bytes 4194304-8388607/1073741824", 1073741824}, + {"single byte", "bytes 0-0/1", 1}, + {"no slash", "bytes 0-99", 0}, + {"empty string", "", 0}, + {"unknown total", "bytes 0-99/*", 0}, + {"trailing slash", "bytes 0-99/", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := parseContentRangeTotal(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/packages/shared/pkg/storage/storage_cache_metrics.go b/packages/shared/pkg/storage/storage_cache_metrics.go index e93c4739e8..037bc7ed06 100644 --- a/packages/shared/pkg/storage/storage_cache_metrics.go +++ b/packages/shared/pkg/storage/storage_cache_metrics.go @@ -32,6 +32,8 @@ const ( cacheOpReadAt cacheOp = "read_at" cacheOpSize cacheOp = "size" + cacheOpOpenRangeReader cacheOp = "open_range_reader" + cacheOpWrite cacheOp = "write" cacheOpWriteFromFileSystem cacheOp = "write_from_filesystem" ) diff --git a/packages/shared/pkg/storage/storage_cache_seekable.go b/packages/shared/pkg/storage/storage_cache_seekable.go index bdd6f0e117..f7388a8946 100644 --- a/packages/shared/pkg/storage/storage_cache_seekable.go +++ b/packages/shared/pkg/storage/storage_cache_seekable.go @@ -1,6 +1,7 @@ package storage import ( + "bytes" "context" "errors" "fmt" @@ -60,7 +61,10 @@ type cachedSeekable struct { wg sync.WaitGroup } -var _ Seekable = (*cachedSeekable)(nil) +var ( + _ Seekable = (*cachedSeekable)(nil) + _ StreamingReader = (*cachedSeekable)(nil) +) func (c *cachedSeekable) ReadAt(ctx context.Context, buff []byte, offset int64) (n int, err error) { ctx, span := c.tracer.Start(ctx, "read object at offset", trace.WithAttributes( @@ -122,6 +126,84 @@ func (c *cachedSeekable) ReadAt(ctx context.Context, buff []byte, offset int64) return readCount, err } +func (c *cachedSeekable) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { + // Try NFS cache file first + chunkPath := c.makeChunkFilename(off) + + fp, err := os.Open(chunkPath) + if err == nil { + recordCacheRead(ctx, true, length, cacheTypeSeekable, cacheOpOpenRangeReader) + + return &fsRangeReadCloser{ + Reader: io.NewSectionReader(fp, 0, length), + file: fp, + }, nil + } + + if !os.IsNotExist(err) { + recordCacheReadError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) + } + + // Cache miss: delegate to the inner backend (Seekable embeds StreamingReader). + inner, err := c.inner.OpenRangeReader(ctx, off, length) + if err != nil { + return nil, fmt.Errorf("failed to open inner range reader: %w", err) + } + + recordCacheRead(ctx, false, length, cacheTypeSeekable, cacheOpOpenRangeReader) + + // Wrap in a write-through reader that caches data on Close + return &cacheWriteThroughReader{ + inner: inner, + buf: bytes.NewBuffer(make([]byte, 0, length)), + cache: c, + ctx: ctx, + off: off, + chunkPath: chunkPath, + }, nil +} + +// cacheWriteThroughReader wraps an inner reader, buffering all data read through it. +// On Close, it asynchronously writes the buffered data to the NFS cache. +type cacheWriteThroughReader struct { + inner io.ReadCloser + buf *bytes.Buffer + cache *cachedSeekable + ctx context.Context //nolint:containedctx // needed for async cache write-back in Close + off int64 + chunkPath string +} + +func (r *cacheWriteThroughReader) Read(p []byte) (int, error) { + n, err := r.inner.Read(p) + if n > 0 { + r.buf.Write(p[:n]) + } + + return n, err +} + +func (r *cacheWriteThroughReader) Close() error { + closeErr := r.inner.Close() + + if r.buf.Len() > 0 { + data := make([]byte, r.buf.Len()) + copy(data, r.buf.Bytes()) + + r.cache.goCtx(r.ctx, func(ctx context.Context) { + ctx, span := r.cache.tracer.Start(ctx, "write range reader chunk back to cache") + defer span.End() + + if err := r.cache.writeChunkToCache(ctx, r.off, r.chunkPath, data); err != nil { + recordError(span, err) + recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) + } + }) + } + + return closeErr +} + func (c *cachedSeekable) Size(ctx context.Context) (n int64, e error) { ctx, span := c.tracer.Start(ctx, "get size of object") defer func() { diff --git a/packages/shared/pkg/storage/storage_fs.go b/packages/shared/pkg/storage/storage_fs.go index c47692fd1d..c02ef84948 100644 --- a/packages/shared/pkg/storage/storage_fs.go +++ b/packages/shared/pkg/storage/storage_fs.go @@ -22,10 +22,21 @@ type fsObject struct { } var ( - _ Seekable = (*fsObject)(nil) - _ Blob = (*fsObject)(nil) + _ Seekable = (*fsObject)(nil) + _ Blob = (*fsObject)(nil) + _ StreamingReader = (*fsObject)(nil) ) +type fsRangeReadCloser struct { + io.Reader + + file *os.File +} + +func (r *fsRangeReadCloser) Close() error { + return r.file.Close() +} + func newFileSystemStorage(basePath string) *fsStorage { return &fsStorage{ basePath: basePath, @@ -117,6 +128,18 @@ func (o *fsObject) StoreFile(_ context.Context, path string) error { return nil } +func (o *fsObject) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { + f, err := o.getHandle(true) + if err != nil { + return nil, err + } + + return &fsRangeReadCloser{ + Reader: io.NewSectionReader(f, off, length), + file: f, + }, nil +} + func (o *fsObject) ReadAt(_ context.Context, buff []byte, off int64) (n int, err error) { handle, err := o.getHandle(true) if err != nil { diff --git a/packages/shared/pkg/storage/storage_google.go b/packages/shared/pkg/storage/storage_google.go index ed4e631428..6d7ed6505b 100644 --- a/packages/shared/pkg/storage/storage_google.go +++ b/packages/shared/pkg/storage/storage_google.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "os" + "sync/atomic" "time" "cloud.google.com/go/storage" @@ -75,11 +76,16 @@ type gcpObject struct { handle *storage.ObjectHandle limiter *limit.Limiter + + // discoveredSize caches the total object size learned from range-read + // responses (Content-Range header), avoiding a separate Attrs() call. + discoveredSize atomic.Int64 } var ( - _ Seekable = (*gcpObject)(nil) - _ Blob = (*gcpObject)(nil) + _ Seekable = (*gcpObject)(nil) + _ Blob = (*gcpObject)(nil) + _ StreamingReader = (*gcpObject)(nil) ) func NewGCP(ctx context.Context, bucketName string, limiter *limit.Limiter) (StorageProvider, error) { @@ -208,6 +214,10 @@ func (o *gcpObject) Exists(ctx context.Context) (bool, error) { } func (o *gcpObject) Size(ctx context.Context) (int64, error) { + if s := o.discoveredSize.Load(); s > 0 { + return s, nil + } + ctx, cancel := context.WithTimeout(ctx, googleOperationTimeout) defer cancel() @@ -224,6 +234,37 @@ func (o *gcpObject) Size(ctx context.Context) (int64, error) { return attrs.Size, nil } +func (o *gcpObject) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { + ctx, cancel := context.WithTimeout(ctx, googleReadTimeout) + + reader, err := o.handle.NewRangeReader(ctx, off, length) + if err != nil { + cancel() + + return nil, fmt.Errorf("failed to create GCS range reader for %q at %d+%d: %w", o.path, off, length, err) + } + + if s := reader.Attrs.Size; s > 0 { + o.discoveredSize.Store(s) + } + + return &cancelOnCloseReader{ReadCloser: reader, cancel: cancel}, nil +} + +// cancelOnCloseReader wraps a ReadCloser and calls a CancelFunc on Close, +// ensuring the context used to create the reader is cleaned up. +type cancelOnCloseReader struct { + io.ReadCloser + + cancel context.CancelFunc +} + +func (r *cancelOnCloseReader) Close() error { + defer r.cancel() + + return r.ReadCloser.Close() +} + func (o *gcpObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, err error) { timer := googleReadTimerFactory.Begin(attribute.String(gcsOperationAttr, gcsOperationAttrReadAt)) @@ -240,6 +281,10 @@ func (o *gcpObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, defer reader.Close() + if s := reader.Attrs.Size; s > 0 { + o.discoveredSize.Store(s) + } + for reader.Remain() > 0 { nr, err := reader.Read(buff[n:]) n += nr