diff --git a/.gitignore b/.gitignore index 58e1bff2e1..85a81ee327 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ go.work.sum .vscode/mise-tools /packages/fc-kernels /packages/fc-versions +/compress-build +/inspect-build diff --git a/.mockery.yaml b/.mockery.yaml index 4b175b33c2..c80d238c16 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -7,32 +7,46 @@ packages: filename: mocks.go pkgname: filesystemconnectmocks + github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block: + interfaces: + flagsClient: + config: + dir: packages/orchestrator/internal/sandbox/block + filename: mock_flagsclient_test.go + pkgname: block + inpackage: true + structname: MockFlagsClient + github.com/e2b-dev/infra/packages/shared/pkg/storage: interfaces: featureFlagsClient: config: - dir: packages/shared/pkg/storage/mocks - filename: mockfeatureflagsclient.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_featureflagsclient_test.go + pkgname: storage + inpackage: true structname: MockFeatureFlagsClient Blob: config: - dir: packages/shared/pkg/storage/mocks - filename: mockobjectprovider.go - pkgname: storagemocks - Seekable: + dir: packages/shared/pkg/storage + filename: mock_blob_test.go + pkgname: storage + inpackage: true + FramedFile: config: - dir: packages/shared/pkg/storage/mocks - filename: mockseekableobjectprovider.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_framedfile_test.go + pkgname: storage + inpackage: true io: interfaces: Reader: config: - dir: packages/shared/pkg/storage/mocks - filename: mockioreader.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_ioreader_test.go + pkgname: storage + inpackage: true github.com/e2b-dev/infra/packages/shared/pkg/utils: interfaces: diff --git a/CLAUDE.md b/CLAUDE.md index 613b69e811..16e299e2ed 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -286,6 +286,11 @@ make connect-orchestrator - Access: `https://nomad.` - Token: GCP Secrets Manager -### Logs +### Nomad Logs +- Use `nomad alloc logs -job ` to fetch service logs (e.g., `nomad alloc logs -job orchestrator-dev`) +- Use `-stderr` flag for stderr output: `nomad alloc logs -job orchestrator-dev -stderr` +- Use `-tail` for live tailing: `nomad alloc logs -job orchestrator-dev -tail` +- The orchestrator job in dev is called `orchestrator-dev` +- Integration test failures should be diagnosed by checking these logs first - Local: Docker logs in `make local-infra` - Production: Grafana Loki or Nomad UI diff --git a/Makefile b/Makefile index 0dd4518187..4dfe36723f 100644 --- a/Makefile +++ b/Makefile @@ -171,6 +171,22 @@ test: test-integration: $(MAKE) -C tests/integration test +.PHONY: test-integration/sandboxes +test-integration/sandboxes: + $(MAKE) -C tests/integration test/api/sandboxes + +.PHONY: test-integration/templates +test-integration/templates: + $(MAKE) -C tests/integration test/api/templates + +.PHONY: test-integration/envd +test-integration/envd: + $(MAKE) -C tests/integration test/envd + +.PHONY: reset-test-env +reset-test-env: + scripts/reset-test-env.sh + .PHONY: connect-orchestrator connect-orchestrator: $(MAKE) -C tests/integration connect-orchestrator diff --git a/docs/compression-architecture.md b/docs/compression-architecture.md new file mode 100644 index 0000000000..c6a9c827db --- /dev/null +++ b/docs/compression-architecture.md @@ -0,0 +1,349 @@ +# Template Compression: Architecture & Status + +- [A. Architecture](#a-architecture) + - [Storage Format](#storage-format) · [Storage interface](#storage-interface) · [Feature Flags](#feature-flags) · [Template Loading](#template-loading) · [Read Path](#read-path-nbd--uffd--prefetch) +- [B. Biggest Changes](#b-biggest-changes) +- [C. Read Path Diagram](#c-read-path-diagram) +- [D. Remaining Work](#d-remaining-work) + - [From This Branch](#from-this-branch) · [From lev-zstd-compression](#from-lev-zstd-compression-unported) +- [E. Write Paths](#e-write-paths) + - [Inline Build / Pause](#inline-build--pause) · [Background Compression](#background-compression-compress-build-cli) +- [F. Failure Modes](#f-failure-modes) +- [G. Cost & Benefit](#g-cost--benefit) + - [Storage](#storage) · [CPU](#cpu) · [Memory](#memory) · [Net](#net) +- [H. Grafana Metrics](#h-grafana-metrics) + - [Chunker](#chunker-meter-internalsandboxblockmetrics) · [NFS Cache](#nfs-cache-meter-sharedpkgstorage) · [GCS Backend](#gcs-backend-meter-sharedpkgstorage) · [Key Queries](#key-queries) +- [I. Rollout Strategy](#i-rollout-strategy) + +## A. Architecture + +Templates are stored in GCS as build artifacts. Each build produces two data files (memfile, rootfs) plus a header and metadata. Each data file can have an uncompressed variant (`{buildId}/memfile`) and a compressed variant (`{buildId}/v4.memfile.lz4`), with corresponding v3 and v4 headers. + +### Storage Format + +- Data is broken into **frames**, each independently decompressible (LZ4 or Zstd). +- Frames are aligned to `FrameAlignmentSize` (= `MemoryChunkSize` = 4 MiB) in uncompressed space, with a minimum of 1 MB compressed and a maximum of 32 MB uncompressed (configurable). +- The **v4 header** embeds a `FrameTable` per mapping: `CompressionType + StartAt + []FrameSize`. The header itself is always LZ4-block-compressed, regardless of data compression type. +- The `FrameTable` is subset per mapping so each mapping carries only the frames it references. + +### Storage interface + +The most relevant change is `FramedFile` (returned by `OpenFramedFile`) replaces the old `Seekable` (returned by `OpenSeekable`). Where `Seekable` had separate `ReadAt`, `OpenRangeReader`, and `StoreFile` methods, `FramedFile` unifies reads into a single `GetFrame(ctx, offsetU, frameTable, decompress, buf, readSize, onRead)` that handles both compressed and uncompressed data, plus `Size` and `StoreFile` (with optional compression via `FramedUploadOptions`). For compressed data, raw compressed frames are cached individually on NFS by `(path, frameStart, frameSize)` key. + +### Feature Flags + +Two LaunchDarkly JSON flags control compression, with per-team/cluster/template targeting: + +**`chunker-config`** (read path): + +```json +// (restart required for existing chunkers) +{ + "useCompressedAssets": false, // load v4 headers, use compressed read path if available + "minReadBatchSizeKB": 16 // floor for read batch size in KB +} +``` + +**`compress-config`** (write path): + +```json +{ + "compressBuilds": false, // enable compressed dual-write uploads + "compressionType": "zstd", // "lz4" or "zstd" + "level": 2, // compression level (0=fast, higher=better ratio) + "frameTargetMB": 2, // target compressed frame size in MiB + "frameMaxUncompressedMB": 16, // cap on uncompressed bytes per frame (= 4 × MemoryChunkSize) + "uploadPartTargetMB": 50, // target GCS multipart upload part size in MiB + "encoderConcurrency": 1, // goroutines per zstd encoder + "decoderConcurrency": 1 // goroutines per pooled zstd decoder +} +``` + +### Template Loading + +When an orchestrator loads a template from storage (cache miss): + +1. **Header probe**: if `useCompressedAssets`, probes for v4 and v3 headers in parallel, preferring v4. Falls back to v3 if v4 is missing. +2. **Asset probe**: for each build referenced in header mappings, probes for 3 data variants in parallel (uncompressed, `.lz4`, `.zstd`). Missing variants are silently skipped. +3. **Chunker creation**: one `Chunker` per `(buildId, fileType)`. The chunker's `AssetInfo` records which variants exist. + +### Read Path (NBD / UFFD / Prefetch) + +All three consumer types share the same path at read time: + +``` +GetBlock(offset, length, ft) // was Slice() + → header.GetShiftedMapping(offset) // in-memory → BuildMap with FrameTable + → DiffStore.Get(buildId) // TTL cache hit → cached Chunker + → Chunker.GetBlock(offset, length, ft) + → mmap cache hit? return reference + → miss: regionLock dedup → fetchSession → GetFrame → NFS cache → GCS + → decompressed bytes written into mmap, waiters notified +``` + +- Prefetch reads 4 MiB, UFFD reads 4 KB or 2 MB (hugepage), NBD reads 4 KB. +- Frames are aligned to `MemoryChunkSize` (4 MiB), so no `GetBlock` call ever crosses a frame boundary. +- If the v4 header was loaded, each mapping carries a subset `FrameTable`; this `ft` is threaded through to `GetBlock`, routing to compressed or uncompressed fetch, no header fetch is needed. + +--- + +## B. Biggest Changes + +- **Unified Chunker**: collapsed `FullFetchChunker`, `StreamingChunker`, and the `Chunker` interface back into a single concrete `Chunker` struct backed by slot-based `regionLock` for fetch deduplication; a single code path handles both compressed and uncompressed data via `GetFrame`. + +- **Asset probing at init**: `StorageDiff.Init` now probes for all 3 data variants (uncompressed, lz4, zstd) in parallel via `probeAssets`, constructing an `AssetInfo` that the Chunker uses to route reads. This replaces the previous `OpenSeekable` single-object path. + +- **Upload API on TemplateBuild**: moved the upload lifecycle from `Snapshot` to `TemplateBuild`, which now owns path extraction, `PendingFrameTables` accumulation, and V4 header serialization. `UploadAll` is synchronous (no internal goroutine); multi-layer builds use `UploadExceptV4Headers` + `UploadV4Header` with explicit coordination via `UploadTracker`. + +- **NFS cache for compressed frames**: `GetFrame` on the NFS cache layer stores and retrieves individual compressed frames by `(path, frameStart, frameSize)`, with progressive decompression into mmap. Uncompressed reads use the same `GetFrame` codepath with `ft=nil`. + +- **FrameTable validation and testing**: added `validateGetFrameParams` at the `GetFrame` entry point (alignment checks for compressed, bounds checks for uncompressed), fixed `FrameTable.Range` bug (was not initializing from `StartAt`), and added comprehensive `FrameTable` unit tests. + +--- + +## C. Read Path Diagram + +```mermaid +flowchart TD + subgraph Consumers + NBD["NBD (4 KB)"] + UFFD["UFFD (4 KB / 2 MB)"] + PF["Prefetch (4 MiB)"] + end + + NBD & UFFD & PF --> GM["header.GetShiftedMapping(offset)"] + GM -->|"BuildMap + FrameTable"| DS["DiffStore.Get(buildId)"] + DS -->|"cached Chunker"| GB["Chunker.GetBlock(offset, length, ft)"] + + GB --> MC{"mmap cache hit?"} + MC -->|"hit"| REF["return []byte (reference to mmap)"] + MC -->|"miss"| RL["regionLock (dedup / wait)"] + + RL --> ROUTE{"matching compressed asset exists?"} + + ROUTE -->|"compressed"| GFC["GetFrame (ft, decompress=true)"] + ROUTE -->|"uncompressed"| GFU["GetFrame (ft=nil, decompress=false)"] + + GFC --> NFS{"NFS cache hit?"} + GFU --> NFS + + NFS -->|"hit"| WRITE["write to mmap + notify waiters"] + NFS -->|"miss"| GCS["GCS range read (C-space or U-space)"] + + GCS --> DEC{"compressed?"} + DEC -->|"yes"| DECOMP["pooled zstd/lz4 decoder"] + DEC -->|"no"| STORE_NFS + + DECOMP --> STORE_NFS["store frame in NFS cache"] + STORE_NFS --> WRITE + WRITE --> REF +``` + +
+ASCII version + +``` + NBD (4KB) UFFD (4KB/2MB) Prefetch (4MiB) + \ | / + `---------.---'--------.-----' + v v + header.GetShiftedMapping(offset) + | + v + DiffStore.Get(buildId) ──> cached Chunker + | + v + Chunker.GetBlock(offset, length, ft) + | + .------+------. + v v + [mmap hit] [mmap miss] + return ref | + regionLock (dedup/wait) + | + .--------+--------. + v v + ft != nil? ft == nil + compressed uncompressed + asset exists? + | | + v v + GetFrame GetFrame + (decompress=T) (decompress=F) + | | + '--------+-------' + | + NFS cache hit? ──yes──> write to mmap + | + notify waiters + no | + | v + GCS range read return []byte ref + (C-space / U-space) + | + compressed? ──no──> store in NFS + | | + yes v + | write to mmap + zstd/lz4 decode + notify waiters + | | + store in NFS v + | return []byte ref + v + write to mmap + + notify waiters + | + v + return []byte ref +``` + +
+ +--- + +## D. Remaining Work + +### From This Branch + +1. **Per-artifact compression config**: memfile and rootfs have different runtime requirements. The `compress-config` flag should support separate codec, level, and frame size settings per artifact type rather than applying a single config to both. + +2. **Verify `getFrame` timer lifecycle**: audit that `Success()`/`Failure()` is always called on every code path in the storage cache's `getFrameCompressed` and `getFrameUncompressed`. + +3. **Feature flag to disable progressive `GetBlock` reading**: add a flag that bypasses progressive reading/returning in `GetBlock` and falls back to the original whole-block fetch behavior. Useful as a fault-tolerance lever if progressive reads cause issues in production. + +4. **NFS write-through for compressed uploads**: during `StoreFile` with compression, tee out uncompressed chunk data to NFS cache via a callback, so uncompressed `GetFrame` reads can hit cache immediately after upload without a cold GCS fetch. + +### Compression Modes & Write-Path Timing + +5. **Compressed-only write mode**: add a `compress-config` flag (e.g. `"skipUncompressed": true`) that skips the uncompressed upload entirely and writes only compressed data + v4 header. Code: `TemplateBuild.UploadAll` / `UploadExceptV4Headers` currently always uploads uncompressed; gate that behind the flag. Read path: `probeAssets` already handles missing uncompressed variants, so this should work as-is. Saves the dual-write bandwidth and storage cost, but makes rollback to uncompressed reads impossible for those builds. + +6. **Purity enforcement (no mixed compressed/uncompressed stacks)**: add a `chunker-config` flag (e.g. `"requirePureCompression": true`) that, at template load time, validates that if the top-layer build has compressed assets then every ancestor build in the header's mappings also has compressed assets (and vice versa). Fail sandbox creation if the check fails rather than silently mixing. This interacts with the write path: when `requirePureCompression` is enabled and a new layer is built on top of an uncompressed parent, the build must either (a) refuse to compress, (b) refuse to start, or (c) trigger background compression of the parent chain first. Today's `probeAssets` per-build routing lets mixed stacks work; purity enforcement would intentionally break that flexibility for correctness guarantees. + +7. **Sync vs async layer compression**: today compression is either inline (during `TemplateBuild.Upload*`, blocking the build) or fully async (background `compress-build` CLI, after the fact). Middle ground to explore: + - **Compress before upload submission**: the snapshot data is already in memory/mmap after Firecracker pause. Compress frames in-process before kicking off the GCS upload, so the upload only sends compressed data (pairs with #5). Tradeoff: adds compression latency to the critical path before the sandbox can be resumed on another server. + - **Compress shortly after build completes**: fire an async compression job (in-process goroutine or separate task) that runs after the uncompressed upload finishes. The sandbox is resumable immediately from uncompressed data, and compressed data appears later. But: if another build references this layer before compression finishes, the child gets an uncompressed parent — violating purity (#6). And if the sandbox is resumed from the uncompressed image on a different server while compression is in-flight, we have a race on the GCS objects. + - **Implications for purity**: strict purity enforcement (#6) effectively forces synchronous compression of the entire ancestor chain before a compressed child can be built. Async compression is only safe when purity is not enforced, or when there's a coordination mechanism (e.g. a "compression pending" state that blocks child builds until the parent is compressed). + +### From `lev-zstd-compression` (Unported) + +8. **Storage Provider/Backend layer separation**: decompose `StorageProvider` into distinct Provider (high-level: `FrameGetter`, `FileStorer`, `Blobber`) and Backend (low-level: `Basic`, `RangeGetter`, `MultipartUploaderFactory`) layers. Prerequisite for clean instrumentation wrapping. + +9. **OTEL instrumentation middleware** (`instrumented_provider.go`, `instrumented_backend.go`): full span and metrics wrapping at both layers. ~400 lines. + +10. **Test coverage** (~4300 lines total): chunker matrix tests (`chunk_test.go` — concurrent access, decompression stats, cross-chunker coverage), compression round-trip tests (`compress_test.go`), NFS cache with compressed data (`storage_cache_seekable_test.go`), template build upload tests (`template_build_test.go`). + +--- + +## E. Write Paths + +### Inline Build / Pause + +Triggered by `sbx.Pause()` or initial template build. The orchestrator creates a `Snapshot` (FC memory + rootfs diffs, headers, snapfile, metadata), then constructs a `TemplateBuild` which owns the upload lifecycle: + +- **Single-layer** (initial build, simple pause): `TemplateBuild.UploadAll(ctx)` — synchronous, creates its own `PendingFrameTables` internally. Uploads uncompressed data + compressed data (if `compressBuilds` FF enabled) + uncompressed headers + snapfile + metadata concurrently in an errgroup. V4 headers are finalized and uploaded after all data uploads complete (they depend on `FrameTable` results). + +- **Multi-layer** (layered build): `TemplateBuild.UploadExceptV4Headers(ctx)` uploads all data, then returns `hasCompressed`. The caller coordinates with `UploadTracker` to wait for ancestor layers, then calls `TemplateBuild.UploadV4Header(ctx)` which reads accumulated `PendingFrameTables` from all layers and serializes the final v4 header. + +### Background Compression (`compress-build` CLI) + +A standalone CLI tool for compressing existing uncompressed builds after the fact: + +``` +compress-build -build [-storage gs://bucket] [-compression lz4|zstd] [-recursive] +``` + +- Reads the uncompressed data from GCS, compresses into frames, writes compressed data + v4 header back. +- `--recursive` walks header mappings to discover and compress dependency builds first (parent templates), avoiding nil-FrameTable gaps in derived templates. +- Supports `--dry-run`, `-template ` (resolves via E2B API), configurable frame size and compression level. +- Idempotent: skips builds that already have compressed artifacts. + +--- + +## F. Failure Modes + +**Corrupted compressed frame in GCS or NFS**: no automatic fallback to uncompressed today. The read fails, `GetBlock` returns an error, and the sandbox page-faults. Unresolved: should the Chunker retry with the uncompressed variant when decompression fails and `HasUncompressed` is true? + +**Half-compressed builds** (some layers have v4 header + compressed data, ancestors don't): handled by design. `probeAssets` finds whichever variants exist per build; each Chunker routes independently. A v4 header with a nil FrameTable for an ancestor mapping falls through to uncompressed fetch for that mapping. + +**NFS unavailable**: compressed frames that miss NFS go straight to GCS (existing behavior). Uncompressed reads also use NFS caching with read-through and async write-back. No circuit breaker — repeated NFS timeouts will add latency to every miss until the cache recovers. + +**Upload path complexity**: dual-write (uncompressed + compressed), `PendingFrameTables` accumulation, and V4 header serialization add failure surface to the build hot path. Multi-layer builds add `UploadTracker` coordination between layers. A compression failure during upload could fail the entire build. Back-out: set `compressBuilds: false` in `compress-config` — this disables compressed writes entirely; uncompressed uploads continue as before and the read path already handles missing compressed variants. No cleanup of already-written compressed data needed (it becomes inert). + +### Unresolved + +- Should Chunker fall back to uncompressed on a corrupt V4 header or a decompression error? + +--- + +## G. Cost & Benefit + +### Storage + +Sampled from `gs://e2b-staging-lev-fc-templates/` (262 builds, zstd level 2): + +| Artifact | Builds sampled | Avg uncompressed | Avg compressed | Ratio | +|----------|---------------|-----------------|---------------|-------| +| memfile | 191 (both variants) | 140 MiB | 35 MiB | **4.0x** | +| rootfs | 153 (compressed-only) | unknown | varies | est. 2-10x (diff layers are tiny, full builds ~2x) | + +During dual-write, GCS storage increases ~25% for memfile. After dropping uncompressed, net savings are **~75% for memfile**. Rootfs savings depend on the mix of diff vs full builds. + +### CPU + +New per-orchestrator CPU cost: decompressing every GCS-fetched frame. At ~35 MiB compressed per cold memfile load and zstd level 2 decode throughput of ~1-2 GB/s, each cold load burns ~20-40 ms of CPU. Scales with cold template load rate, not sandbox count. Encode cost is write-path only (build/pause), bounded by upload concurrency. + +### Memory + +The main cost: **mmap regions are allocated at uncompressed size** but frames are fetched whole. A 4 KB NBD read triggers a full frame fetch (4-16 MiB uncompressed), filling mmap with data the sandbox may never touch. This inflates RSS and can pressure the orchestrator fleet into scaling. Mitigations: tune `frameMaxUncompressedMB` down, or drop unrequested bytes from the mmap after the requesting read completes. + +### Net + +Smaller GCS reads (4x fewer bytes) and smaller NFS cache entries reduce network bandwidth. Upload path doubles bandwidth during dual-write. + +--- + +## H. Grafana Metrics + +Each `TimerFactory` metric emits three series with the same name but different units: a duration histogram (ms), a bytes counter (By), and an ops counter. All three carry the same attributes listed below plus an automatic `result` = `success` | `failure`. + +### Chunker (meter: `internal.sandbox.block.metrics`) + +| Metric | What it measures | Attributes | +|--------|-----------------|------------| +| `orchestrator.blocks.slices` | End-to-end `GetBlock` latency (mmap hit or remote fetch) | `compressed` (bool), `pull-type` (`local` · `remote`), `failure-reason`\* | +| `orchestrator.blocks.chunks.fetch` | Remote storage fetch (GCS range read + optional decompress) | `compressed` (bool), `failure-reason`\* | +| `orchestrator.blocks.chunks.store` | Writing fetched data into local mmap cache | — | + +\* `failure-reason` values: `local-read`, `local-read-again`, `remote-read`, `cache-fetch`, `session_create` + +### NFS Cache (meter: `shared.pkg.storage`) + +| Metric | What it measures | Attributes | +|--------|-----------------|------------| +| `orchestrator.storage.slab.nfs.read` | NFS cache read (frame or size lookup) | `operation` (`GetFrame` · `Size`) | +| `orchestrator.storage.slab.nfs.write` | NFS cache write (store frame after GCS fetch) | — | +| `orchestrator.storage.cache.ops` | NFS cache operation count | `cache_type` (`blob` · `framed_file`), `op_type`\*, `cache_hit` (bool) | +| `orchestrator.storage.cache.bytes` | NFS cache bytes transferred | `cache_type`, `op_type`\*, `cache_hit` (bool) | +| `orchestrator.storage.cache.errors` | NFS cache errors (excluding expected `ErrNotExist`) | `cache_type`, `op_type`\*, `error_type` (`read` · `write` · `write-lock`) | + +\* `op_type` values: `get_frame`, `write_to`, `size`, `put`, `store_file` + +### GCS Backend (meter: `shared.pkg.storage`) + +| Metric | What it measures | Attributes | +|--------|-----------------|------------| +| `orchestrator.storage.gcs.read` | GCS read operations | `operation` (`Size` · `WriteTo` · `GetFrame`) | +| `orchestrator.storage.gcs.write` | GCS write operations | `operation` (`Write` · `WriteFromFileSystem` · `WriteFromFileSystemOneShot`) | + +### Key Queries + +- **Compressed vs uncompressed latency**: `orchestrator.blocks.slices` grouped by `compressed`, filtered to `result=success` +- **Cache hit rate**: `orchestrator.blocks.slices` where `pull-type=local` vs `pull-type=remote` +- **NFS effectiveness**: `orchestrator.storage.cache.ops` where `op_type=get_frame`, ratio of `cache_hit=true` to total +- **GCS fetch volume**: `orchestrator.storage.gcs.read` where `operation=GetFrame`, bytes counter +- **Decompression overhead**: `orchestrator.blocks.chunks.fetch` where `compressed=true`, compare duration histogram to `compressed=false` + +--- + +## I. Rollout Strategy + +_TBD_ diff --git a/packages/orchestrator/benchmark_test.go b/packages/orchestrator/benchmark_test.go index 0f0935e7da..6cee110964 100644 --- a/packages/orchestrator/benchmark_test.go +++ b/packages/orchestrator/benchmark_test.go @@ -1,8 +1,15 @@ // run with something like: // -// sudo `which go` test -benchtime=15s -bench=. -v // sudo modprobe nbd -// echo 1024 | sudo tee /proc/sys/vm/nr_hugepages +// sudo `which go` test ./packages/orchestrator/ -bench=BenchmarkBaseImage -v -timeout=60m +// +// Single mode: +// +// sudo `which go` test ./packages/orchestrator/ -bench=BenchmarkBaseImage/zstd-2 -v +// +// More iterations: +// +// sudo `which go` test ./packages/orchestrator/ -bench=BenchmarkBaseImage -benchtime=5x -v -timeout=60m package main import ( @@ -14,6 +21,7 @@ import ( "testing" "time" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" @@ -32,7 +40,6 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/template/build" buildconfig "github.com/e2b-dev/infra/packages/orchestrator/internal/template/build/config" "github.com/e2b-dev/infra/packages/orchestrator/internal/template/build/metrics" - "github.com/e2b-dev/infra/packages/orchestrator/internal/template/metadata" artifactsregistry "github.com/e2b-dev/infra/packages/shared/pkg/artifacts-registry" "github.com/e2b-dev/infra/packages/shared/pkg/dockerhub" featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" @@ -45,21 +52,34 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) -var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator") +type benchMode struct { + name string + buildID string + compressionType string // "lz4" or "zstd"; "" = uncompressed + level int +} + +func (m benchMode) compressed() bool { return m.compressionType != "" } -func BenchmarkBaseImageLaunch(b *testing.B) { +var benchModes = []benchMode{ + {"uncompressed", "ba6aae36-0000-0000-0000-000000000000", "", 0}, + {"lz4", "ba6aae36-0000-0000-0000-000000000001", "lz4", 0}, + {"zstd-0", "ba6aae36-0000-0000-0000-000000000002", "zstd", 0}, + {"zstd-1", "ba6aae36-0000-0000-0000-000000000003", "zstd", 1}, + {"zstd-2", "ba6aae36-0000-0000-0000-000000000004", "zstd", 2}, + {"zstd-3", "ba6aae36-0000-0000-0000-000000000005", "zstd", 3}, +} + +func BenchmarkBaseImage(b *testing.B) { if os.Geteuid() != 0 { b.Skip("skipping benchmark because not running as root") } - // test configuration const ( - testType = onlyStart baseImage = "e2bdev/base" kernelVersion = "vmlinux-6.1.158" fcVersion = featureflags.DefaultFirecrackerVersion templateID = "fcb33d09-3141-42c4-8d3b-c2df411681db" - buildID = "ba6aae36-74f7-487a-b6f7-74fd7c94e479" useHugePages = false templateVersion = "v2.0.0" ) @@ -92,7 +112,7 @@ func BenchmarkBaseImageLaunch(b *testing.B) { }) require.NoError(b, err) - resource, err := telemetry.GetResource(b.Context(), "node-id", "BenchmarkBaseImageLaunch", "service-commit", "service-version", "service-instance-id") + resource, err := telemetry.GetResource(b.Context(), "node-id", "BenchmarkBaseImage", "service-commit", "service-version", "service-instance-id") require.NoError(b, err) tracerProvider := telemetry.NewTracerProvider(spanExporter, resource) otel.SetTracerProvider(tracerProvider) @@ -105,11 +125,12 @@ func BenchmarkBaseImageLaunch(b *testing.B) { downloadKernel(b, linuxKernelFilename, linuxKernelURL) // hacks, these should go away + templateStoragePath := abs(filepath.Join(persistenceDir, "templates")) b.Setenv("ARTIFACTS_REGISTRY_PROVIDER", "Local") b.Setenv("FIRECRACKER_VERSIONS_DIR", abs(filepath.Join("..", "fc-versions", "builds"))) b.Setenv("HOST_ENVD_PATH", abs(filepath.Join("..", "envd", "bin", "envd"))) b.Setenv("HOST_KERNELS_DIR", abs(kernelsDir)) - b.Setenv("LOCAL_TEMPLATE_STORAGE_BASE_PATH", abs(filepath.Join(persistenceDir, "templates"))) + b.Setenv("LOCAL_TEMPLATE_STORAGE_BASE_PATH", templateStoragePath) b.Setenv("ORCHESTRATOR_BASE_PATH", tempDir) b.Setenv("SANDBOX_DIR", abs(sandboxDir)) b.Setenv("SNAPSHOT_CACHE_DIR", abs(filepath.Join(tempDir, "snapshot-cache"))) @@ -130,7 +151,6 @@ func BenchmarkBaseImageLaunch(b *testing.B) { require.NoError(b, err) sbxlogger.SetSandboxLoggerInternal(l) - // sbxlogger.SetSandboxLoggerExternal(logger) slotStorage, err := network.NewStorageLocal(b.Context(), config.NetworkConfig) require.NoError(b, err) @@ -175,9 +195,7 @@ func BenchmarkBaseImageLaunch(b *testing.B) { require.NoError(b, err) c, err := cfg.Parse() - if err != nil { - b.Fatalf("error parsing config: %v", err) - } + require.NoError(b, err) templateCache, err := template.NewCache(c, featureFlags, persistence, blockMetrics) require.NoError(b, err) @@ -279,50 +297,93 @@ func BenchmarkBaseImageLaunch(b *testing.B) { buildMetrics, ) - buildPath := filepath.Join(os.Getenv("LOCAL_TEMPLATE_STORAGE_BASE_PATH"), buildID, "rootfs.ext4") - if _, err := os.Stat(buildPath); os.IsNotExist(err) { - // build template - force := true - templateConfig := buildconfig.TemplateConfig{ - Version: templateVersion, - TemplateID: templateID, - FromImage: baseImage, - Force: &force, - VCpuCount: sandboxConfig.Vcpu, - MemoryMB: sandboxConfig.RamMB, - StartCmd: "echo 'start cmd debug' && sleep .1 && echo 'done starting command debug'", - DiskSizeMB: sandboxConfig.TotalDiskSizeMB, - HugePages: sandboxConfig.HugePages, - KernelVersion: kernelVersion, - FirecrackerVersion: fcVersion, - } - - metadata := storage.TemplateFiles{ - BuildID: buildID, - } - _, err = builder.Build(b.Context(), metadata, templateConfig, l.Detach(b.Context()).Core()) - require.NoError(b, err) - } - - // retrieve template - tmpl, err := templateCache.GetTemplate( - b.Context(), - buildID, - false, - false, - ) - require.NoError(b, err) - - tc := testContainer{ - sandboxFactory: sandboxFactory, - testType: testType, - tmpl: tmpl, - sandboxConfig: sandboxConfig, - runtime: runtime, + force := true + templateConfig := buildconfig.TemplateConfig{ + Version: templateVersion, + TemplateID: templateID, + FromImage: baseImage, + Force: &force, + VCpuCount: sandboxConfig.Vcpu, + MemoryMB: sandboxConfig.RamMB, + StartCmd: "echo 'start cmd debug' && sleep .1 && echo 'done starting command debug'", + DiskSizeMB: sandboxConfig.TotalDiskSizeMB, + HugePages: sandboxConfig.HugePages, + KernelVersion: kernelVersion, + FirecrackerVersion: fcVersion, } - for b.Loop() { - tc.testOneItem(b, buildID, kernelVersion, fcVersion) + for _, mode := range benchModes { + b.Run(mode.name, func(b *testing.B) { + // Set flags for this mode + featureflags.OverrideJSONFlag(featureflags.CompressConfigFlag, ldvalue.FromJSONMarshal(map[string]any{ + "compressBuilds": mode.compressed(), + "compressionType": mode.compressionType, + "level": mode.level, + "frameTargetMB": 2, + "uploadPartTargetMB": 50, + "frameMaxUncompressedMB": 16, + "encoderConcurrency": 1, + "decoderConcurrency": 1, + })) + featureflags.OverrideJSONFlag(featureflags.ChunkerConfigFlag, ldvalue.FromJSONMarshal(map[string]any{ + "useCompressedAssets": mode.compressed(), + "minReadBatchSizeKB": 16, + })) + + b.Logf("mode=%s buildID=%s compressed=%v type=%s level=%d", + mode.name, mode.buildID, mode.compressed(), mode.compressionType, mode.level) + + // Build (exactly once, timed for reporting). + // Skipped if template already exists on disk. + // To force rebuild: rm -rf /root/.cache/e2b-orchestrator-benchmark/templates/ + buildStart := time.Now() + buildPath := filepath.Join(templateStoragePath, mode.buildID, "rootfs.ext4") + if _, err := os.Stat(buildPath); os.IsNotExist(err) { + metadata := storage.TemplateFiles{BuildID: mode.buildID} + _, err = builder.Build(b.Context(), metadata, templateConfig, l.Detach(b.Context()).Core()) + require.NoError(b, err) + } + buildDuration := time.Since(buildStart) + + // Cold start benchmark. + // Each iteration gets a fresh template with empty block caches. + // InvalidateAll() evicts the cached template; GetTemplate() creates + // a new storageTemplate with fresh chunkers (no mmap data cached). + // Template headers reload from local FS (cheap, OS page cache). + // The timed ResumeSandbox() then triggers real block fetches on + // every page fault — a true cold start. + b.ResetTimer() + b.StopTimer() + for range b.N { + // Setup (untimed): fresh template with empty block cache + templateCache.InvalidateAll() + tmpl, err := templateCache.GetTemplate(b.Context(), mode.buildID, false, false) + require.NoError(b, err) + + _, err = tmpl.Metadata() + require.NoError(b, err) + + // Timed: cold start sandbox launch + b.StartTimer() + sbx, err := sandboxFactory.ResumeSandbox( + b.Context(), + tmpl, + sandboxConfig, + runtime, + time.Now(), + time.Now().Add(time.Second*15), + nil, + ) + b.StopTimer() + require.NoError(b, err) + + // Cleanup (untimed) + err = sbx.Close(b.Context()) + require.NoError(b, err) + } + + b.ReportMetric(buildDuration.Seconds(), "build-s") + }) } } @@ -335,76 +396,6 @@ func getPersistenceDir() string { return filepath.Join(os.TempDir(), "e2b-orchestrator-benchmark") } -type testCycle string - -const ( - onlyStart testCycle = "only-start" - startAndPause testCycle = "start-and-pause" - startPauseResume testCycle = "start-pause-resume" -) - -type testContainer struct { - testType testCycle - sandboxFactory *sandbox.Factory - tmpl template.Template - sandboxConfig sandbox.Config - runtime sandbox.RuntimeMetadata -} - -func (tc *testContainer) testOneItem(b *testing.B, buildID, kernelVersion, fcVersion string) { - b.Helper() - - ctx, span := tracer.Start(b.Context(), "testOneItem") - defer span.End() - - sbx, err := tc.sandboxFactory.ResumeSandbox( - ctx, - tc.tmpl, - tc.sandboxConfig, - tc.runtime, - time.Now(), - time.Now().Add(time.Second*15), - nil, - ) - require.NoError(b, err) - - if tc.testType == onlyStart { - b.StopTimer() - err = sbx.Close(ctx) - require.NoError(b, err) - b.StartTimer() - - return - } - - meta, err := sbx.Template.Metadata() - require.NoError(b, err) - - templateMetadata := meta.SameVersionTemplate(metadata.TemplateMetadata{ - BuildID: buildID, - KernelVersion: kernelVersion, - FirecrackerVersion: fcVersion, - }) - snap, err := sbx.Pause(ctx, templateMetadata) - require.NoError(b, err) - require.NotNil(b, snap) - - if tc.testType == startAndPause { - b.StopTimer() - err = sbx.Close(ctx) - require.NoError(b, err) - b.StartTimer() - } - - // resume sandbox - sbx, err = tc.sandboxFactory.ResumeSandbox(ctx, tc.tmpl, tc.sandboxConfig, tc.runtime, time.Now(), time.Now().Add(time.Second*15), nil) - require.NoError(b, err) - - // close sandbox - err = sbx.Close(ctx) - require.NoError(b, err) -} - func downloadKernel(b *testing.B, filename, url string) { b.Helper() diff --git a/packages/orchestrator/cmd/benchmark-compress/main.go b/packages/orchestrator/cmd/benchmark-compress/main.go new file mode 100644 index 0000000000..6d66041da3 --- /dev/null +++ b/packages/orchestrator/cmd/benchmark-compress/main.go @@ -0,0 +1,567 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "slices" + "strings" + "sync" + "time" + + "github.com/klauspost/compress/zstd" + lz4 "github.com/pierrec/lz4/v4" + + "github.com/e2b-dev/infra/packages/orchestrator/cmd/internal/cmdutil" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +// bufferPartUploader implements storage.PartUploader for in-memory writes. +// Parts are collected by index and assembled in order on Complete, since +// CompressStream uploads parts concurrently and they may arrive out of order. +type bufferPartUploader struct { + mu sync.Mutex + parts map[int][]byte + buf bytes.Buffer +} + +func (b *bufferPartUploader) Start(_ context.Context) error { + b.parts = make(map[int][]byte) + + return nil +} + +func (b *bufferPartUploader) UploadPart(_ context.Context, partIndex int, data ...[]byte) error { + var combined bytes.Buffer + for _, d := range data { + combined.Write(d) + } + b.mu.Lock() + b.parts[partIndex] = combined.Bytes() + b.mu.Unlock() + + return nil +} + +func (b *bufferPartUploader) Complete(_ context.Context) error { + // Assemble parts in order + keys := make([]int, 0, len(b.parts)) + for k := range b.parts { + keys = append(keys, k) + } + slices.Sort(keys) + for _, k := range keys { + b.buf.Write(b.parts[k]) + } + b.parts = nil + + return nil +} + +type benchResult struct { + codec string + level int + rawEncTime time.Duration + frmEncTime time.Duration + rawDecTime time.Duration + frmDecTime time.Duration + rawSize int64 + frmSize int64 + origSize int64 + numFrames int +} + +func main() { + build := flag.String("build", "", "build ID") + template := flag.String("template", "", "template ID or alias (requires E2B_API_KEY)") + storagePath := flag.String("storage", ".local-build", "storage: local path or gs://bucket") + doMemfile := flag.Bool("memfile", false, "benchmark memfile only") + doRootfs := flag.Bool("rootfs", false, "benchmark rootfs only") + iterations := flag.Int("iterations", 1, "number of iterations for timing (results averaged)") + + flag.Parse() + + cmdutil.SuppressNoisyLogsKeepStdLog() + + // Resolve build ID + if *template != "" && *build != "" { + log.Fatal("specify either -build or -template, not both") + } + if *template != "" { + resolvedBuild, err := resolveTemplateID(*template) + if err != nil { + log.Fatalf("failed to resolve template: %s", err) + } + *build = resolvedBuild + fmt.Printf("Resolved template %q to build %s\n", *template, *build) + } + + if *build == "" { + fmt.Fprintf(os.Stderr, "Usage: benchmark-compress (-build | -template ) [flags]\n\n") + fmt.Fprintf(os.Stderr, "Benchmarks raw vs framed compression to measure framing overhead.\n\n") + flag.PrintDefaults() + os.Exit(1) + } + + // Determine which artifacts to benchmark + type artifact struct { + name string + file string + } + var artifacts []artifact + if !*doMemfile && !*doRootfs { + // Default: both + artifacts = []artifact{ + {"memfile", storage.MemfileName}, + {"rootfs", storage.RootfsName}, + } + } else { + if *doMemfile { + artifacts = append(artifacts, artifact{"memfile", storage.MemfileName}) + } + if *doRootfs { + artifacts = append(artifacts, artifact{"rootfs", storage.RootfsName}) + } + } + + ctx := context.Background() + + for _, a := range artifacts { + data, err := loadArtifact(ctx, *storagePath, *build, a.file) + if err != nil { + log.Fatalf("failed to load %s: %s", a.name, err) + } + + printHeader(a.name, int64(len(data))) + benchmarkArtifact(data, *iterations, func(r benchResult) { + printRow(r) + }) + fmt.Println() + } +} + +func loadArtifact(ctx context.Context, storagePath, buildID, file string) ([]byte, error) { + reader, dataSize, source, err := cmdutil.OpenDataFile(ctx, storagePath, buildID, file) + if err != nil { + return nil, fmt.Errorf("open %s: %w", file, err) + } + defer reader.Close() + + fmt.Printf("Loading %s from %s (%d bytes, %.1f MiB)...\n", + file, source, dataSize, float64(dataSize)/1024/1024) + + data := make([]byte, dataSize) + _, err = io.ReadFull(io.NewSectionReader(reader, 0, dataSize), data) + if err != nil { + return nil, fmt.Errorf("read %s: %w", file, err) + } + + return data, nil +} + +func benchmarkArtifact(data []byte, iterations int, emit func(benchResult)) { + type codecConfig struct { + name string + ct storage.CompressionType + levels []int + } + codecs := []codecConfig{ + {"lz4", storage.CompressionLZ4, []int{0, 1}}, + {"zstd", storage.CompressionZstd, []int{ + int(zstd.SpeedFastest), // 1 + int(zstd.SpeedDefault), // 2 + int(zstd.SpeedBetterCompression), // 3 + int(zstd.SpeedBestCompression), // 4 + }}, + } + + for _, codec := range codecs { + for _, level := range codec.levels { + r := benchResult{ + codec: codec.name, + level: level, + origSize: int64(len(data)), + } + + var rawCompressed, framedCompressed []byte + var ft *storage.FrameTable + + for range iterations { + rc, rawDur := rawEncode(data, codec.ct, level) + fc, fft, frmDur := framedEncode(data, codec.ct, level) + + r.rawEncTime += rawDur + r.frmEncTime += frmDur + + rawCompressed = rc + framedCompressed = fc + ft = fft + } + + r.rawEncTime /= time.Duration(iterations) + r.frmEncTime /= time.Duration(iterations) + r.rawSize = int64(len(rawCompressed)) + r.frmSize = int64(len(framedCompressed)) + + if ft != nil { + r.numFrames = len(ft.Frames) + } + + // Pre-allocate a shared output buffer for decode benchmarks + // so both paths pay the same allocation cost (zero). + decBuf := make([]byte, len(data)) + + for range iterations { + rawDecDur := rawDecode(rawCompressed, codec.ct, decBuf) + frmDecDur := framedDecode(framedCompressed, ft, codec.ct, decBuf) + + r.rawDecTime += rawDecDur + r.frmDecTime += frmDecDur + } + + r.rawDecTime /= time.Duration(iterations) + r.frmDecTime /= time.Duration(iterations) + + emit(r) + } + } +} + +func rawEncode(data []byte, ct storage.CompressionType, level int) ([]byte, time.Duration) { + var buf bytes.Buffer + buf.Grow(len(data)) + + start := time.Now() + + switch ct { + case storage.CompressionLZ4: + w := lz4.NewWriter(&buf) + opts := []lz4.Option{lz4.ConcurrencyOption(1)} + if level > 0 { + opts = append(opts, lz4.CompressionLevelOption(lz4.CompressionLevel(1<<(8+level)))) + } + _ = w.Apply(opts...) + _, _ = w.Write(data) + _ = w.Close() + + case storage.CompressionZstd: + // Match the framed encoder: CompressStream passes TargetFrameSize as + // windowSize to newZstdEncoder, so we must use the same window here + // for an apples-to-apples comparison. + w, err := zstd.NewWriter(&buf, + zstd.WithEncoderLevel(zstd.EncoderLevel(level)), + zstd.WithEncoderConcurrency(1), + zstd.WithWindowSize(2*1024*1024)) + if err != nil { + log.Fatalf("zstd raw encoder (level %d): %s", level, err) + } + _, _ = w.Write(data) + _ = w.Close() + } + + elapsed := time.Since(start) + + return buf.Bytes(), elapsed +} + +func framedEncode(data []byte, ct storage.CompressionType, level int) ([]byte, *storage.FrameTable, time.Duration) { + uploader := &bufferPartUploader{} + + opts := &storage.FramedUploadOptions{ + CompressionType: ct, + Level: level, + CompressionConcurrency: 1, + TargetFrameSize: 2 * 1024 * 1024, // 2 MiB + MaxUncompressedFrameSize: storage.DefaultMaxFrameUncompressedSize, + TargetPartSize: 50 * 1024 * 1024, + } + + ctx := context.Background() + reader := bytes.NewReader(data) + + start := time.Now() + ft, err := storage.CompressStream(ctx, reader, opts, uploader) + elapsed := time.Since(start) + + if err != nil { + log.Fatalf("framed encode failed: %s", err) + } + + return uploader.buf.Bytes(), ft, elapsed +} + +func rawDecode(compressed []byte, ct storage.CompressionType, buf []byte) time.Duration { + start := time.Now() + + switch ct { + case storage.CompressionLZ4: + r := lz4.NewReader(bytes.NewReader(compressed)) + _, _ = io.ReadFull(r, buf) + + case storage.CompressionZstd: + r, _ := zstd.NewReader(bytes.NewReader(compressed), zstd.WithDecoderConcurrency(1)) + _, _ = io.ReadFull(r, buf) + r.Close() + } + + return time.Since(start) +} + +func framedDecode(compressed []byte, ft *storage.FrameTable, ct storage.CompressionType, buf []byte) time.Duration { + if ft == nil || len(ft.Frames) == 0 { + return 0 + } + + start := time.Now() + + var cOffset int64 + var uOffset int + for _, frame := range ft.Frames { + frameData := compressed[cOffset : cOffset+int64(frame.C)] + frameBuf := buf[uOffset : uOffset+int(frame.U)] + decompressFrameInto(ct, frameData, frameBuf) + cOffset += int64(frame.C) + uOffset += int(frame.U) + } + + return time.Since(start) +} + +// decompressFrameInto decompresses into a pre-allocated buffer to avoid +// per-frame allocation. Uses single-threaded decoders to match rawDecode. +func decompressFrameInto(ct storage.CompressionType, compressed, buf []byte) { + switch ct { + case storage.CompressionLZ4: + r := lz4.NewReader(bytes.NewReader(compressed)) + _, err := io.ReadFull(r, buf) + if err != nil { + log.Fatalf("framed lz4 decode failed: %s", err) + } + + case storage.CompressionZstd: + r, err := zstd.NewReader(bytes.NewReader(compressed), zstd.WithDecoderConcurrency(1)) + if err != nil { + log.Fatalf("framed zstd decoder create failed: %s", err) + } + _, err = io.ReadFull(r, buf) + if err != nil { + log.Fatalf("framed zstd decode failed: %s", err) + } + r.Close() + } +} + +// ANSI colors. +const ( + colorReset = "\033[0m" + colorGreen = "\033[32m" + colorYellow = "\033[33m" + colorRed = "\033[91m" +) + +func overheadColor(pct float64) string { + switch { + case pct < 5: + return colorGreen + case pct < 15: + return colorYellow + default: + return colorRed + } +} + +// pad right-pads s with spaces to exactly width visible characters. +func pad(s string, width int) string { + if len(s) >= width { + return s + } + + return s + strings.Repeat(" ", width-len(s)) +} + +// rpad right-aligns s within width visible characters. +func rpad(s string, width int) string { + if len(s) >= width { + return s + } + + return strings.Repeat(" ", width-len(s)) + s +} + +// colorWrap wraps text with ANSI color, pre-padded to width so alignment is correct. +func colorWrap(color, text string, width int) string { + padded := pad(text, width) + + return color + padded + colorReset +} + +func fmtSpeed(dataSize int64, d time.Duration) string { + if d == 0 { + return rpad("N/A", 9) + } + mbps := float64(dataSize) / d.Seconds() / (1024 * 1024) + + return rpad(fmt.Sprintf("%.0f MB/s", mbps), 9) +} + +func fmtOverhead(raw, framed time.Duration) string { + if raw == 0 { + return pad("N/A", 7) + } + pct := float64(framed-raw) / float64(raw) * 100 + text := fmt.Sprintf("%+.1f%%", pct) + + return colorWrap(overheadColor(pct), text, 7) +} + +func fmtSizeOH(rawSize, frmSize int64) string { + if rawSize == 0 { + return pad("N/A", 7) + } + pct := float64(frmSize-rawSize) / float64(rawSize) * 100 + text := fmt.Sprintf("%+.1f%%", pct) + + return colorWrap(overheadColor(pct), text, 7) +} + +func fmtMiB(b int64) string { + return rpad(fmt.Sprintf("%.1f MiB", float64(b)/1024/1024), 9) +} + +func printHeader(artifact string, origSize int64) { + fmt.Printf("\n=== %s (%.1f MiB) ===\n\n", artifact, float64(origSize)/1024/1024) + + hdr := fmt.Sprintf("%-4s %3s %9s %9s %-7s %9s %9s %-7s %9s %9s %-7s %-5s %6s %8s", + "Codec", "Lvl", + "Raw Enc", "Frm Enc", "Enc OH", + "Raw Dec", "Frm Dec", "Dec OH", + "Raw Size", "Frm Size", "Size OH", + "Ratio", "Frames", "Dec/Frm") + sep := fmt.Sprintf("%-4s %3s %9s %9s %-7s %9s %9s %-7s %9s %9s %-7s %-5s %6s %8s", + "----", "---", + "---------", "---------", "-------", + "---------", "---------", "-------", + "---------", "---------", "-------", + "-----", "------", "--------") + fmt.Println(hdr) + fmt.Println(sep) +} + +func printRow(r benchResult) { + ratio := float64(r.origSize) / float64(r.frmSize) + ratioColor := cmdutil.RatioColor(ratio) + ratioText := fmt.Sprintf("%.1fx", ratio) + if ratio >= 100 { + ratioText = fmt.Sprintf("%.0fx", ratio) + } + + var decPerFrame string + if r.numFrames > 0 { + usPerFrame := r.frmDecTime.Microseconds() / int64(r.numFrames) + decPerFrame = rpad(fmt.Sprintf("%d us", usPerFrame), 8) + } else { + decPerFrame = rpad("N/A", 8) + } + + fmt.Printf("%-4s %3d %s %s %s %s %s %s %s %s %s %s %6d %s\n", + r.codec, + r.level, + fmtSpeed(r.origSize, r.rawEncTime), + fmtSpeed(r.origSize, r.frmEncTime), + fmtOverhead(r.rawEncTime, r.frmEncTime), + fmtSpeed(r.origSize, r.rawDecTime), + fmtSpeed(r.origSize, r.frmDecTime), + fmtOverhead(r.rawDecTime, r.frmDecTime), + fmtMiB(r.rawSize), + fmtMiB(r.frmSize), + fmtSizeOH(r.rawSize, r.frmSize), + colorWrap(ratioColor, ratioText, 5), + r.numFrames, + decPerFrame, + ) +} + +// --- Template resolution (copied from compress-build) --- + +type templateInfo struct { + TemplateID string `json:"templateID"` + BuildID string `json:"buildID"` + Aliases []string `json:"aliases"` + Names []string `json:"names"` +} + +func resolveTemplateID(input string) (string, error) { + apiKey := os.Getenv("E2B_API_KEY") + if apiKey == "" { + return "", fmt.Errorf("E2B_API_KEY environment variable required for -template flag") + } + + apiURL := "https://api.e2b.dev/templates" + if domain := os.Getenv("E2B_DOMAIN"); domain != "" { + apiURL = fmt.Sprintf("https://api.%s/templates", domain) + } + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("X-API-Key", apiKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to fetch templates: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return "", fmt.Errorf("API returned %d: %s", resp.StatusCode, string(body)) + } + + var templates []templateInfo + if err := json.NewDecoder(resp.Body).Decode(&templates); err != nil { + return "", fmt.Errorf("failed to parse API response: %w", err) + } + + var match *templateInfo + var availableAliases []string + + for i := range templates { + t := &templates[i] + availableAliases = append(availableAliases, t.Aliases...) + + if t.TemplateID == input { + match = t + + break + } + if slices.Contains(t.Aliases, input) { + match = t + + break + } + if slices.Contains(t.Names, input) { + match = t + + break + } + } + + if match == nil { + return "", fmt.Errorf("template %q not found. Available aliases: %s", input, strings.Join(availableAliases, ", ")) + } + + if match.BuildID == "" || match.BuildID == cmdutil.NilUUID { + return "", fmt.Errorf("template %q has no successful build", input) + } + + return match.BuildID, nil +} diff --git a/packages/orchestrator/cmd/compress-build/main.go b/packages/orchestrator/cmd/compress-build/main.go new file mode 100644 index 0000000000..a35b2e8bd7 --- /dev/null +++ b/packages/orchestrator/cmd/compress-build/main.go @@ -0,0 +1,665 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "slices" + "strconv" + "strings" + "time" + + "github.com/e2b-dev/infra/packages/orchestrator/cmd/internal/cmdutil" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// filePartWriter implements storage.PartUploader for local file writes. +type filePartWriter struct { + path string + f *os.File +} + +func (w *filePartWriter) Start(_ context.Context) error { + dir := filepath.Dir(w.path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + f, err := os.Create(w.path) + if err != nil { + return err + } + w.f = f + + return nil +} + +func (w *filePartWriter) UploadPart(_ context.Context, _ int, data ...[]byte) error { + for _, d := range data { + if _, err := w.f.Write(d); err != nil { + return err + } + } + + return nil +} + +func (w *filePartWriter) Complete(_ context.Context) error { + return w.f.Close() +} + +// compressConfig holds the flags for a compression run. +type compressConfig struct { + storagePath string + compType storage.CompressionType + level int + frameSize int + maxFrameU int + dryRun bool + recursive bool + verbose bool +} + +func main() { + build := flag.String("build", "", "build ID") + template := flag.String("template", "", "template ID or alias (requires E2B_API_KEY)") + storagePath := flag.String("storage", ".local-build", "storage: local path or gs://bucket") + compression := flag.String("compression", "lz4", "compression type: lz4 or zstd") + level := flag.Int("level", storage.DefaultCompressionOptions.Level, "compression level (0=default)") + frameSize := flag.Int("frame-size", storage.DefaultCompressionOptions.TargetFrameSize, "target compressed frame size in bytes") + maxFrameU := flag.Int("max-frame-u", storage.DefaultMaxFrameUncompressedSize, "max uncompressed bytes per frame") + dryRun := flag.Bool("dry-run", false, "show what would be done without making changes") + recursive := flag.Bool("recursive", false, "recursively compress dependencies (referenced builds)") + verbose := flag.Bool("v", false, "verbose: print per-frame info during compression") + + flag.Parse() + + // Resolve build ID from template if provided + if *template != "" && *build != "" { + log.Fatal("specify either -build or -template, not both") + } + if *template != "" { + resolvedBuild, err := resolveTemplateID(*template) + if err != nil { + log.Fatalf("failed to resolve template: %s", err) + } + *build = resolvedBuild + fmt.Printf("Resolved template %q to build %s\n", *template, *build) + } + + if *build == "" { + printUsage() + os.Exit(1) + } + + // Parse compression type + var compType storage.CompressionType + switch *compression { + case "lz4": + compType = storage.CompressionLZ4 + case "zstd": + compType = storage.CompressionZstd + default: + log.Fatalf("unsupported compression type: %s (use 'lz4' or 'zstd')", *compression) + } + + cfg := &compressConfig{ + storagePath: *storagePath, + compType: compType, + level: *level, + frameSize: *frameSize, + maxFrameU: *maxFrameU, + dryRun: *dryRun, + recursive: *recursive, + verbose: *verbose, + } + + ctx := context.Background() + + if err := compressBuild(ctx, cfg, *build, nil); err != nil { + log.Fatalf("failed to compress build %s: %s", *build, err) + } + + fmt.Printf("\nDone.\n") +} + +func printUsage() { + fmt.Fprintf(os.Stderr, "Usage: compress-build (-build | -template ) [-storage ] [-compression lz4|zstd] [-level N] [-frame-size N] [-dry-run] [-recursive]\n\n") + fmt.Fprintf(os.Stderr, "Compresses uncompressed build artifacts and creates v4 headers.\n\n") + fmt.Fprintf(os.Stderr, "The -template flag requires E2B_API_KEY environment variable.\n") + fmt.Fprintf(os.Stderr, "Set E2B_DOMAIN for non-production environments.\n\n") + fmt.Fprintf(os.Stderr, "Examples:\n") + fmt.Fprintf(os.Stderr, " compress-build -build abc123 # compress with default LZ4\n") + fmt.Fprintf(os.Stderr, " compress-build -build abc123 -compression zstd # compress with zstd\n") + fmt.Fprintf(os.Stderr, " compress-build -build abc123 -dry-run # show what would be done\n") + fmt.Fprintf(os.Stderr, " compress-build -build abc123 -storage gs://my-bucket # compress from GCS\n") + fmt.Fprintf(os.Stderr, " compress-build -build abc123 -recursive # compress build and all dependencies\n") + fmt.Fprintf(os.Stderr, " compress-build -template base -storage gs://bucket # compress by template alias\n") + fmt.Fprintf(os.Stderr, " compress-build -template gtjfpksmxd9ct81x1f8e # compress by template ID\n") +} + +// compressBuild compresses a single build and optionally its dependencies. +// visited tracks already-processed builds to avoid cycles. +func compressBuild(ctx context.Context, cfg *compressConfig, buildID string, visited map[string]bool) error { + if visited == nil { + visited = make(map[string]bool) + } + if visited[buildID] { + return nil + } + visited[buildID] = true + + artifacts := []struct { + name string + file string + }{ + {"memfile", storage.MemfileName}, + {"rootfs", storage.RootfsName}, + } + + // In recursive mode, first discover and compress dependencies. + if cfg.recursive { + deps, err := findDependencies(ctx, cfg.storagePath, buildID) + if err != nil { + fmt.Printf(" Warning: could not discover dependencies for %s: %s\n", buildID, err) + } else if len(deps) > 0 { + fmt.Printf("\nBuild %s has %d dependency build(s): %s\n", buildID, len(deps), strings.Join(deps, ", ")) + for _, depBuild := range deps { + // Check if the dependency already has compressed data. + alreadyCompressed := true + for _, a := range artifacts { + compressedFile := storage.V4DataName(a.file, cfg.compType) + info := cmdutil.ProbeFile(ctx, cfg.storagePath, depBuild, compressedFile) + if !info.Exists { + alreadyCompressed = false + + break + } + } + if alreadyCompressed { + fmt.Printf(" Dependency %s already compressed, skipping\n", depBuild) + + continue + } + + fmt.Printf("\n>>> Compressing dependency %s\n", depBuild) + if err := compressBuild(ctx, cfg, depBuild, visited); err != nil { + return fmt.Errorf("dependency %s: %w", depBuild, err) + } + } + } + } + + fmt.Printf("\n====== Build %s ======\n", buildID) + + for _, artifact := range artifacts { + if err := compressArtifact(ctx, cfg, buildID, artifact.name, artifact.file); err != nil { + return fmt.Errorf("failed to compress %s: %w", artifact.name, err) + } + } + + return nil +} + +// findDependencies reads headers for a build and returns unique build IDs +// referenced in mappings (excluding the build itself and nil UUIDs). +func findDependencies(ctx context.Context, storagePath, buildID string) ([]string, error) { + seen := make(map[string]bool) + + for _, file := range []string{storage.MemfileName, storage.RootfsName} { + headerFile := file + storage.HeaderSuffix + headerData, _, err := cmdutil.ReadFileIfExists(ctx, storagePath, buildID, headerFile) + if err != nil { + return nil, fmt.Errorf("read header %s: %w", headerFile, err) + } + if headerData == nil { + continue + } + + h, err := header.DeserializeBytes(headerData) + if err != nil { + return nil, fmt.Errorf("deserialize %s: %w", headerFile, err) + } + + for _, m := range h.Mapping { + bid := m.BuildId.String() + if bid != buildID && bid != cmdutil.NilUUID { + seen[bid] = true + } + } + } + + deps := make([]string, 0, len(seen)) + for bid := range seen { + deps = append(deps, bid) + } + + return deps, nil +} + +func compressArtifact(ctx context.Context, cfg *compressConfig, buildID, name, file string) error { + fmt.Printf("\n=== %s ===\n", name) + + // Read uncompressed header + headerFile := file + storage.HeaderSuffix + headerData, _, err := cmdutil.ReadFile(ctx, cfg.storagePath, buildID, headerFile) + if err != nil { + return fmt.Errorf("read header: %w", err) + } + + h, err := header.DeserializeBytes(headerData) + if err != nil { + return fmt.Errorf("deserialize header: %w", err) + } + fmt.Printf(" Header: version=%d, mappings=%d, size=%#x\n", + h.Metadata.Version, len(h.Mapping), h.Metadata.Size) + + // Check if compressed data already exists + compressedFile := storage.V4DataName(file, cfg.compType) + existing := cmdutil.ProbeFile(ctx, cfg.storagePath, buildID, compressedFile) + if existing.Exists { + fmt.Printf(" Compressed file already exists: %s (%#x), skipping\n", existing.Path, existing.Size) + + return nil + } + + // Check if v4 header already exists + compressedHeaderFile := storage.V4HeaderName(file) + existingHeader := cmdutil.ProbeFile(ctx, cfg.storagePath, buildID, compressedHeaderFile) + if existingHeader.Exists { + fmt.Printf(" Compressed header already exists: %s (%#x), skipping\n", existingHeader.Path, existingHeader.Size) + + return nil + } + + if cfg.dryRun { + fmt.Printf(" [dry-run] Would compress %s -> %s\n", file, compressedFile) + fmt.Printf(" [dry-run] Would create compressed header -> %s\n", compressedHeaderFile) + + return nil + } + + // Open data file for reading + reader, dataSize, dataSource, err := cmdutil.OpenDataFile(ctx, cfg.storagePath, buildID, file) + if err != nil { + return fmt.Errorf("open data file: %w", err) + } + defer reader.Close() + + fmt.Printf(" Data: %s (%#x, %.1f MiB)\n", dataSource, dataSize, float64(dataSize)/1024/1024) + + // Set up compression options + opts := &storage.FramedUploadOptions{ + CompressionType: cfg.compType, + Level: cfg.level, + TargetFrameSize: cfg.frameSize, + MaxUncompressedFrameSize: cfg.maxFrameU, + TargetPartSize: 50 * 1024 * 1024, + } + + if cfg.verbose { + frameIdx := 0 + lastFrameTime := time.Now() + opts.OnFrameReady = func(offset storage.FrameOffset, size storage.FrameSize, _ []byte) error { + now := time.Now() + elapsed := now.Sub(lastFrameTime) + mbps := float64(size.U) / elapsed.Seconds() / (1024 * 1024) + lastFrameTime = now + ratio := float64(size.U) / float64(size.C) + fmt.Printf(" frame[%d] U=%#x+%#x C=%#x+%#x ratio=%s %v %.0f MB/s\n", + frameIdx, offset.U, size.U, offset.C, size.C, + cmdutil.FormatRatio(ratio), elapsed.Round(time.Millisecond), mbps) + frameIdx++ + + return nil + } + } + + // Compress to a temp file, then upload if GCS + tmpDir, err := os.MkdirTemp("", "compress-build-*") + if err != nil { + return fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + + tmpCompressedPath := filepath.Join(tmpDir, compressedFile) + uploader := &filePartWriter{path: tmpCompressedPath} + + // Create an io.Reader from the DataReader (which supports ReadAt) + sectionReader := io.NewSectionReader(reader, 0, dataSize) + + fmt.Printf(" Compressing with %s (level=%d, frame-size=%#x, max-frame-u=%#x)...\n", + cfg.compType, cfg.level, cfg.frameSize, cfg.maxFrameU) + + // Compress + compressStart := time.Now() + frameTable, err := storage.CompressStream(ctx, sectionReader, opts, uploader) + if err != nil { + return fmt.Errorf("compress: %w", err) + } + compressElapsed := time.Since(compressStart) + + // Print compression stats + var totalU, totalC int64 + for _, f := range frameTable.Frames { + totalU += int64(f.U) + totalC += int64(f.C) + } + ratio := float64(totalU) / float64(totalC) + savings := 100.0 * (1.0 - float64(totalC)/float64(totalU)) + mbps := float64(totalU) / compressElapsed.Seconds() / (1024 * 1024) + fmt.Printf(" Compressed: %d frames, U=%#x C=%#x ratio=%s savings=%.1f%% in %v (%.0f MB/s)\n", + len(frameTable.Frames), totalU, totalC, cmdutil.FormatRatio(ratio), savings, + compressElapsed.Round(time.Millisecond), mbps) + + // Apply frame tables to header (current build's own data) + h.AddFrames(frameTable) + + // Propagate FrameTables from compressed dependencies into this header. + // Without this, mappings referencing parent builds would have nil FrameTable, + // forcing uncompressed chunkers for those layers even though compressed data exists. + propagateDependencyFrames(ctx, cfg.storagePath, h, file) + + h.Metadata.Version = header.MetadataVersionCompressed + + // Serialize as v4 + headerBytes, err := header.Serialize(h.Metadata, h.Mapping) + if err != nil { + return fmt.Errorf("serialize v4 header: %w", err) + } + + // LZ4-block-compress the header + compressedHeaderBytes, err := storage.CompressLZ4(headerBytes) + if err != nil { + return fmt.Errorf("LZ4-compress header: %w", err) + } + + // Write compressed header to temp + tmpHeaderPath := filepath.Join(tmpDir, compressedHeaderFile) + if err := os.WriteFile(tmpHeaderPath, compressedHeaderBytes, 0o644); err != nil { + return fmt.Errorf("write compressed header: %w", err) + } + + // Upload to destination + if cmdutil.IsGCSPath(cfg.storagePath) { + gcsBase := cmdutil.NormalizeGCSPath(cfg.storagePath) + "/" + buildID + "/" + + fmt.Printf(" Uploading compressed data to %s%s...\n", gcsBase, compressedFile) + if err := gcloudCopy(ctx, tmpCompressedPath, gcsBase+compressedFile, map[string]string{ + "uncompressed-size": strconv.FormatInt(dataSize, 10), + }); err != nil { + return fmt.Errorf("upload compressed data: %w", err) + } + + fmt.Printf(" Uploading compressed header to %s%s...\n", gcsBase, compressedHeaderFile) + if err := gcloudCopy(ctx, tmpHeaderPath, gcsBase+compressedHeaderFile, nil); err != nil { + return fmt.Errorf("upload compressed header: %w", err) + } + } else { + // Local storage: move from temp to final location + localBase := filepath.Join(cfg.storagePath, "templates", buildID) + if err := os.MkdirAll(localBase, 0o755); err != nil { + return fmt.Errorf("mkdir: %w", err) + } + + finalCompressed := filepath.Join(localBase, compressedFile) + if err := os.Rename(tmpCompressedPath, finalCompressed); err != nil { + return fmt.Errorf("move compressed data: %w", err) + } + fmt.Printf(" Output: %s\n", finalCompressed) + + // Write uncompressed-size sidecar for local storage + sidecarPath := finalCompressed + ".uncompressed-size" + if err := os.WriteFile(sidecarPath, []byte(strconv.FormatInt(dataSize, 10)), 0o644); err != nil { + return fmt.Errorf("write uncompressed-size sidecar: %w", err) + } + + finalHeader := filepath.Join(localBase, compressedHeaderFile) + if err := os.Rename(tmpHeaderPath, finalHeader); err != nil { + return fmt.Errorf("move compressed header: %w", err) + } + fmt.Printf(" Compressed header: %s\n", finalHeader) + } + + fmt.Printf(" Compressed header: %#x (uncompressed: %#x)\n", + len(compressedHeaderBytes), len(headerBytes)) + + return nil +} + +// propagateDependencyFrames reads compressed headers for dependency builds +// and injects their FrameTables into the current header's dependency mappings. +// +// When a derived template references base build data, the header mappings for +// those base builds initially have nil FrameTable. If the base build was +// previously compressed (has a v4 header), we read its FrameTable +// and apply it to the matching mappings in this header. This ensures the +// orchestrator creates compressed chunkers for ALL layers, not just the current build. +func propagateDependencyFrames(ctx context.Context, storagePath string, h *header.Header, artifactFile string) { + currentBuildID := h.Metadata.BuildId.String() + + // Collect unique dependency build IDs that have nil FrameTable. + depBuilds := make(map[string]bool) + for _, m := range h.Mapping { + bid := m.BuildId.String() + if bid == currentBuildID || bid == cmdutil.NilUUID { + continue + } + if m.FrameTable == nil { + depBuilds[bid] = true + } + } + + if len(depBuilds) == 0 { + return + } + + for depBuild := range depBuilds { + depH, _, err := cmdutil.ReadCompressedHeader(ctx, storagePath, depBuild, artifactFile) + if err != nil { + fmt.Printf(" Warning: could not read compressed header for dependency %s: %s\n", depBuild, err) + + continue + } + if depH == nil { + fmt.Printf(" Warning: no compressed header found for dependency %s (not compressed yet?)\n", depBuild) + + continue + } + + // Reconstruct the full FrameTable for the dependency by collecting + // all FrameTables from the dependency's own mappings and merging them. + fullFT := reconstructFullFrameTable(depH, depBuild) + if fullFT == nil { + fmt.Printf(" Warning: dependency %s compressed header has no FrameTable for its own data\n", depBuild) + + continue + } + + // Apply the full FrameTable to matching mappings in the current header. + applied := 0 + for _, m := range h.Mapping { + if m.BuildId.String() != depBuild || m.FrameTable != nil { + continue + } + if err := m.AddFrames(fullFT); err != nil { + fmt.Printf(" Warning: could not apply frames for dependency %s mapping at offset %#x: %s\n", + depBuild, m.Offset, err) + + continue + } + applied++ + } + if applied > 0 { + fmt.Printf(" Propagated %d FrameTable(s) from dependency %s (%d frames, %s)\n", + applied, depBuild, len(fullFT.Frames), fullFT.CompressionType) + } + } +} + +// reconstructFullFrameTable merges all per-mapping FrameTables for a given +// build ID from a header into a single FrameTable covering the entire data file. +func reconstructFullFrameTable(h *header.Header, buildID string) *storage.FrameTable { + var result *storage.FrameTable + + for _, m := range h.Mapping { + if m.BuildId.String() != buildID || m.FrameTable == nil { + continue + } + + ft := m.FrameTable + if result == nil { + // First FrameTable — start with a copy + result = &storage.FrameTable{ + CompressionType: ft.CompressionType, + StartAt: ft.StartAt, + Frames: make([]storage.FrameSize, len(ft.Frames)), + } + copy(result.Frames, ft.Frames) + + continue + } + + // Extend: calculate where the current result ends (uncompressed offset). + resultEndU := result.StartAt.U + for _, f := range result.Frames { + resultEndU += int64(f.U) + } + + // Append non-overlapping frames from ft. + ftCurrentU := ft.StartAt.U + for _, f := range ft.Frames { + frameEndU := ftCurrentU + int64(f.U) + if frameEndU <= resultEndU { + // Already covered + ftCurrentU = frameEndU + + continue + } + if ftCurrentU < resultEndU { + // Overlapping frame — same physical frame, skip it + ftCurrentU = frameEndU + + continue + } + // New frame beyond what we have + result.Frames = append(result.Frames, f) + ftCurrentU = frameEndU + } + } + + return result +} + +func gcloudCopy(ctx context.Context, localPath, gcsPath string, metadata map[string]string) error { + cmd := exec.CommandContext(ctx, "gcloud", "storage", "cp", "--verbosity", "error", localPath, gcsPath) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("gcloud storage cp failed: %w\n%s", err, string(output)) + } + + // Set custom metadata separately — gcloud storage cp --custom-metadata + // doesn't work with parallel composite uploads for large files. + if len(metadata) > 0 { + pairs := make([]string, 0, len(metadata)) + for k, v := range metadata { + pairs = append(pairs, k+"="+v) + } + updateCmd := exec.CommandContext(ctx, "gcloud", "storage", "objects", "update", + "--custom-metadata="+strings.Join(pairs, ","), gcsPath) + updateOutput, updateErr := updateCmd.CombinedOutput() + if updateErr != nil { + return fmt.Errorf("gcloud storage objects update failed: %w\n%s", updateErr, string(updateOutput)) + } + } + + return nil +} + +// templateInfo represents a template from the E2B API. +type templateInfo struct { + TemplateID string `json:"templateID"` + BuildID string `json:"buildID"` + Aliases []string `json:"aliases"` + Names []string `json:"names"` +} + +// resolveTemplateID fetches the build ID for a template from the E2B API. +func resolveTemplateID(input string) (string, error) { + apiKey := os.Getenv("E2B_API_KEY") + if apiKey == "" { + return "", fmt.Errorf("E2B_API_KEY environment variable required for -template flag") + } + + apiURL := "https://api.e2b.dev/templates" + if domain := os.Getenv("E2B_DOMAIN"); domain != "" { + apiURL = fmt.Sprintf("https://api.%s/templates", domain) + } + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("X-API-Key", apiKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to fetch templates: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return "", fmt.Errorf("API returned %d: %s", resp.StatusCode, string(body)) + } + + var templates []templateInfo + if err := json.NewDecoder(resp.Body).Decode(&templates); err != nil { + return "", fmt.Errorf("failed to parse API response: %w", err) + } + + var match *templateInfo + var availableAliases []string + + for i := range templates { + t := &templates[i] + availableAliases = append(availableAliases, t.Aliases...) + + if t.TemplateID == input { + match = t + + break + } + + if slices.Contains(t.Aliases, input) { + match = t + + break + } + + if slices.Contains(t.Names, input) { + match = t + + break + } + } + + if match == nil { + return "", fmt.Errorf("template %q not found. Available aliases: %s", input, strings.Join(availableAliases, ", ")) + } + + if match.BuildID == "" || match.BuildID == cmdutil.NilUUID { + return "", fmt.Errorf("template %q has no successful build", input) + } + + return match.BuildID, nil +} diff --git a/packages/orchestrator/cmd/copy-build/main.go b/packages/orchestrator/cmd/copy-build/main.go index f5d3d01e3a..d7a71db720 100644 --- a/packages/orchestrator/cmd/copy-build/main.go +++ b/packages/orchestrator/cmd/copy-build/main.go @@ -75,13 +75,13 @@ func NewDestinationFromPath(prefix, file string) (*Destination, error) { }, nil } -func NewHeaderFromObject(ctx context.Context, bucketName string, headerPath string, objectType storage.ObjectType) (*header.Header, error) { +func NewHeaderFromObject(ctx context.Context, bucketName string, headerPath string) (*header.Header, error) { b, err := storage.NewGCP(ctx, bucketName, nil) if err != nil { return nil, fmt.Errorf("failed to create GCS bucket storage provider: %w", err) } - obj, err := b.OpenBlob(ctx, headerPath, objectType) + obj, err := b.OpenBlob(ctx, headerPath) if err != nil { return nil, fmt.Errorf("failed to open object: %w", err) } @@ -219,7 +219,7 @@ func main() { if strings.HasPrefix(*from, "gs://") { bucketName, _ := strings.CutPrefix(*from, "gs://") - h, err := NewHeaderFromObject(ctx, bucketName, buildMemfileHeaderPath, storage.MemfileHeaderObjectType) + h, err := NewHeaderFromObject(ctx, bucketName, buildMemfileHeaderPath) if err != nil { log.Fatalf("failed to create header from object: %s", err) } @@ -245,7 +245,7 @@ func main() { var rootfsHeader *header.Header if strings.HasPrefix(*from, "gs://") { bucketName, _ := strings.CutPrefix(*from, "gs://") - h, err := NewHeaderFromObject(ctx, bucketName, buildRootfsHeaderPath, storage.RootFSHeaderObjectType) + h, err := NewHeaderFromObject(ctx, bucketName, buildRootfsHeaderPath) if err != nil { log.Fatalf("failed to create header from object: %s", err) } diff --git a/packages/orchestrator/cmd/create-build/main.go b/packages/orchestrator/cmd/create-build/main.go index 0332c56073..594ebf6d28 100644 --- a/packages/orchestrator/cmd/create-build/main.go +++ b/packages/orchestrator/cmd/create-build/main.go @@ -363,7 +363,7 @@ func printArtifactSizes(ctx context.Context, persistence storage.StorageProvider printLocalFileSizes(basePath, buildID) } else { // For remote storage, get sizes from storage provider - if memfile, err := persistence.OpenSeekable(ctx, files.StorageMemfilePath(), storage.MemfileObjectType); err == nil { + if memfile, err := persistence.OpenFramedFile(ctx, files.StorageMemfilePath()); err == nil { if size, err := memfile.Size(ctx); err == nil { fmt.Printf(" Memfile: %d MB\n", size>>20) } diff --git a/packages/orchestrator/cmd/inspect-build/main.go b/packages/orchestrator/cmd/inspect-build/main.go index 660a8c3af3..61479f1c46 100644 --- a/packages/orchestrator/cmd/inspect-build/main.go +++ b/packages/orchestrator/cmd/inspect-build/main.go @@ -3,6 +3,8 @@ package main import ( "bytes" "context" + "crypto/md5" + "encoding/hex" "encoding/json" "flag" "fmt" @@ -15,21 +17,28 @@ import ( "unsafe" "github.com/e2b-dev/infra/packages/orchestrator/cmd/internal/cmdutil" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) -const nilUUID = "00000000-0000-0000-0000-000000000000" - func main() { build := flag.String("build", "", "build ID") template := flag.String("template", "", "template ID or alias (requires E2B_API_KEY)") storagePath := flag.String("storage", ".local-build", "storage: local path or gs://bucket") memfile := flag.Bool("memfile", false, "inspect memfile artifact") rootfs := flag.Bool("rootfs", false, "inspect rootfs artifact") + compressed := flag.Bool("compressed", false, "read v4 compressed header (.v4.header)") + summary := flag.Bool("summary", false, "show only metadata + summary (skip per-mapping listing)") + listFiles := flag.Bool("list-files", false, "list all files for this build with existence and size info") data := flag.Bool("data", false, "inspect data blocks (default: header only)") start := flag.Int64("start", 0, "start block (only with -data)") end := flag.Int64("end", 0, "end block, 0 = all (only with -data)") + // Validation flags + validateAll := flag.Bool("validate-all", false, "validate both memfile and rootfs") + validateMemfile := flag.Bool("validate-memfile", false, "validate memfile data integrity") + validateRootfs := flag.Bool("validate-rootfs", false, "validate rootfs data integrity") + flag.Parse() // Resolve build ID from template if provided @@ -49,7 +58,40 @@ func main() { os.Exit(1) } - // Determine artifact type + ctx := context.Background() + + // Handle list-files mode + if *listFiles { + printFileList(ctx, *storagePath, *build) + os.Exit(0) + } + + // Handle validation mode + if *validateAll || *validateMemfile || *validateRootfs { + exitCode := 0 + + if *validateAll || *validateMemfile { + if err := validateArtifact(ctx, *storagePath, *build, "memfile"); err != nil { + fmt.Printf("memfile validation FAILED: %s\n", err) + exitCode = 1 + } else { + fmt.Printf("memfile validation PASSED\n") + } + } + + if *validateAll || *validateRootfs { + if err := validateArtifact(ctx, *storagePath, *build, "rootfs.ext4"); err != nil { + fmt.Printf("rootfs validation FAILED: %s\n", err) + exitCode = 1 + } else { + fmt.Printf("rootfs validation PASSED\n") + } + } + + os.Exit(exitCode) + } + + // Determine artifact type for inspection if !*memfile && !*rootfs { *memfile = true // default to memfile } @@ -64,22 +106,36 @@ func main() { artifactName = "rootfs.ext4" } - ctx := context.Background() + // Read header (compressed or default) + var h *header.Header + var headerSource string - // Read header - headerFile := artifactName + ".header" - headerData, headerSource, err := cmdutil.ReadFile(ctx, *storagePath, *build, headerFile) - if err != nil { - log.Fatalf("failed to read header: %s", err) - } + if *compressed { + var err error + h, headerSource, err = cmdutil.ReadCompressedHeader(ctx, *storagePath, *build, artifactName) + if err != nil { + log.Fatalf("failed to read compressed header: %s", err) + } + if h == nil { + log.Fatalf("compressed header not found for %s", artifactName) + } + headerSource += " [compressed header]" + } else { + headerFile := artifactName + storage.HeaderSuffix + headerData, source, err := cmdutil.ReadFile(ctx, *storagePath, *build, headerFile) + if err != nil { + log.Fatalf("failed to read header: %s", err) + } - h, err := header.DeserializeBytes(headerData) - if err != nil { - log.Fatalf("failed to deserialize header: %s", err) + h, err = header.DeserializeBytes(headerData) + if err != nil { + log.Fatalf("failed to deserialize header: %s", err) + } + headerSource = source } // Print header info - printHeader(h, headerSource) + printHeader(h, headerSource, *summary) // If -data flag, also inspect data blocks if *data { @@ -89,24 +145,31 @@ func main() { } func printUsage() { - fmt.Fprintf(os.Stderr, "Usage: inspect-build (-build | -template ) [-storage ] [-memfile|-rootfs] [-data [-start N] [-end N]]\n\n") + fmt.Fprintf(os.Stderr, "Usage: inspect-build (-build | -template ) [-storage ] [-memfile|-rootfs] [-compressed] [-summary] [-data [-start N] [-end N]]\n") + fmt.Fprintf(os.Stderr, " inspect-build (-build | -template ) [-storage ] -validate-all|-validate-memfile|-validate-rootfs\n") + fmt.Fprintf(os.Stderr, " inspect-build (-build | -template ) [-storage ] -list-files\n\n") fmt.Fprintf(os.Stderr, "The -template flag requires E2B_API_KEY environment variable.\n") fmt.Fprintf(os.Stderr, "Set E2B_DOMAIN for non-production environments.\n\n") fmt.Fprintf(os.Stderr, "Examples:\n") fmt.Fprintf(os.Stderr, " inspect-build -build abc123 # inspect memfile header\n") + fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -compressed # inspect compressed memfile header\n") + fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -summary # metadata + summaries only\n") + fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -list-files # list all build files\n") fmt.Fprintf(os.Stderr, " inspect-build -template base -storage gs://bucket # inspect by template alias\n") fmt.Fprintf(os.Stderr, " inspect-build -template gtjfpksmxd9ct81x1f8e # inspect by template ID\n") fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -rootfs # inspect rootfs header\n") fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -data # inspect memfile header + data\n") fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -rootfs -data -end 100 # inspect rootfs header + first 100 blocks\n") fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -storage gs://bucket # inspect from GCS\n") + fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -validate-all # validate both memfile and rootfs\n") + fmt.Fprintf(os.Stderr, " inspect-build -build abc123 -validate-memfile # validate memfile integrity\n") } -func printHeader(h *header.Header, source string) { +func printHeader(h *header.Header, source string, summaryOnly bool) { // Validate mappings err := header.ValidateMappings(h.Mapping, h.Metadata.Size, h.Metadata.BlockSize) if err != nil { - fmt.Printf("\n⚠️ WARNING: Mapping validation failed!\n%s\n\n", err) + fmt.Printf("\nWARNING: Mapping validation failed!\n%s\n\n", err) } fmt.Printf("\nMETADATA\n") @@ -116,23 +179,25 @@ func printHeader(h *header.Header, source string) { fmt.Printf("Generation %d\n", h.Metadata.Generation) fmt.Printf("Build ID %s\n", h.Metadata.BuildId) fmt.Printf("Base build ID %s\n", h.Metadata.BaseBuildId) - fmt.Printf("Size %d B (%d MiB)\n", h.Metadata.Size, h.Metadata.Size/1024/1024) - fmt.Printf("Block size %d B\n", h.Metadata.BlockSize) + fmt.Printf("Size %#x (%d MiB)\n", h.Metadata.Size, h.Metadata.Size/1024/1024) + fmt.Printf("Block size %#x\n", h.Metadata.BlockSize) fmt.Printf("Blocks %d\n", (h.Metadata.Size+h.Metadata.BlockSize-1)/h.Metadata.BlockSize) - totalSize := int64(unsafe.Sizeof(header.BuildMap{})) * int64(len(h.Mapping)) / 1024 - var sizeMessage string - if totalSize == 0 { - sizeMessage = "<1 KiB" - } else { - sizeMessage = fmt.Sprintf("%d KiB", totalSize) - } + if !summaryOnly { + totalSize := int64(unsafe.Sizeof(header.BuildMap{})) * int64(len(h.Mapping)) / 1024 + var sizeMessage string + if totalSize == 0 { + sizeMessage = "<1 KiB" + } else { + sizeMessage = fmt.Sprintf("%d KiB", totalSize) + } - fmt.Printf("\nMAPPING (%d maps, uses %s in storage)\n", len(h.Mapping), sizeMessage) - fmt.Printf("=======\n") + fmt.Printf("\nMAPPING (%d maps, uses %s in storage)\n", len(h.Mapping), sizeMessage) + fmt.Printf("=======\n") - for _, mapping := range h.Mapping { - fmt.Println(mapping.Format(h.Metadata.BlockSize)) + for _, mapping := range h.Mapping { + fmt.Println(cmdutil.FormatMappingWithCompression(mapping, h.Metadata.BlockSize)) + } } fmt.Printf("\nMAPPING SUMMARY\n") @@ -150,11 +215,59 @@ func printHeader(h *header.Header, source string) { additionalInfo = " (current)" case h.Metadata.BaseBuildId.String(): additionalInfo = " (parent)" - case nilUUID: + case cmdutil.NilUUID: additionalInfo = " (sparse)" } fmt.Printf("%s%s: %d blocks, %d MiB (%0.2f%%)\n", buildID, additionalInfo, uint64(size)/h.Metadata.BlockSize, uint64(size)/1024/1024, float64(size)/float64(h.Metadata.Size)*100) } + + // Print compression summary + cmdutil.PrintCompressionSummary(h) +} + +// printFileList lists all files that actually exist for this build in storage. +func printFileList(ctx context.Context, storagePath, buildID string) { + fmt.Printf("\nFILES for build %s\n", buildID) + fmt.Printf("====================\n") + + files, err := cmdutil.ListFiles(ctx, storagePath, buildID) + if err != nil { + fmt.Printf("ERROR listing files: %s\n", err) + + return + } + + if len(files) == 0 { + fmt.Printf("(no files found)\n") + + return + } + + fmt.Printf("%-45s %12s\n", "FILE", "SIZE") + fmt.Printf("%-45s %12s\n", strings.Repeat("-", 45), strings.Repeat("-", 12)) + + for _, info := range files { + extra := "" + if uSize, ok := info.Metadata["uncompressed-size"]; ok { + extra = fmt.Sprintf(" (uncompressed-size=%s)", uSize) + } + fmt.Printf("%-45s %12s%s\n", info.Name, formatSize(info.Size), extra) + } + + fmt.Printf("\n%d files total\n", len(files)) +} + +func formatSize(size int64) string { + switch { + case size >= 1024*1024*1024: + return fmt.Sprintf("%.1f GiB", float64(size)/1024/1024/1024) + case size >= 1024*1024: + return fmt.Sprintf("%.1f MiB", float64(size)/1024/1024) + case size >= 1024: + return fmt.Sprintf("%.1f KiB", float64(size)/1024) + default: + return fmt.Sprintf("%d B", size) + } } func inspectData(ctx context.Context, storagePath, buildID, dataFile string, h *header.Header, start, end int64) { @@ -186,7 +299,7 @@ func inspectData(ctx context.Context, storagePath, buildID, dataFile string, h * fmt.Printf("\nDATA\n") fmt.Printf("====\n") fmt.Printf("Source %s\n", source) - fmt.Printf("Size %d B (%d MiB)\n", size, size/1024/1024) + fmt.Printf("Size %#x (%d MiB)\n", size, size/1024/1024) b := make([]byte, blockSize) emptyCount := 0 @@ -206,10 +319,10 @@ func inspectData(ctx context.Context, storagePath, buildID, dataFile string, h * if nonZeroCount > 0 { nonEmptyCount++ - fmt.Printf("%-10d [%11d,%11d) %d non-zero bytes\n", i/blockSize, i, i+blockSize, nonZeroCount) + fmt.Printf("%-10d [%#x,%#x) %#x non-zero bytes\n", i/blockSize, i, i+blockSize, nonZeroCount) } else { emptyCount++ - fmt.Printf("%-10d [%11d,%11d) EMPTY\n", i/blockSize, i, i+blockSize) + fmt.Printf("%-10d [%#x,%#x) EMPTY\n", i/blockSize, i, i+blockSize) } } @@ -218,12 +331,313 @@ func inspectData(ctx context.Context, storagePath, buildID, dataFile string, h * fmt.Printf("Empty blocks: %d\n", emptyCount) fmt.Printf("Non-empty blocks: %d\n", nonEmptyCount) fmt.Printf("Total blocks inspected: %d\n", emptyCount+nonEmptyCount) - fmt.Printf("Total size inspected: %d B (%d MiB)\n", int64(emptyCount+nonEmptyCount)*blockSize, int64(emptyCount+nonEmptyCount)*blockSize/1024/1024) - fmt.Printf("Empty size: %d B (%d MiB)\n", int64(emptyCount)*blockSize, int64(emptyCount)*blockSize/1024/1024) + fmt.Printf("Total size inspected: %#x (%d MiB)\n", int64(emptyCount+nonEmptyCount)*blockSize, int64(emptyCount+nonEmptyCount)*blockSize/1024/1024) + fmt.Printf("Empty size: %#x (%d MiB)\n", int64(emptyCount)*blockSize, int64(emptyCount)*blockSize/1024/1024) reader.Close() } +// validateArtifact validates data integrity for an artifact (memfile or rootfs). +func validateArtifact(ctx context.Context, storagePath, buildID, artifactName string) error { + fmt.Printf("\n=== Validating %s for build %s ===\n", artifactName, buildID) + + // 1. Read and deserialize header + headerFile := artifactName + ".header" + headerData, _, err := cmdutil.ReadFile(ctx, storagePath, buildID, headerFile) + if err != nil { + return fmt.Errorf("failed to read header: %w", err) + } + + h, err := header.DeserializeBytes(headerData) + if err != nil { + return fmt.Errorf("failed to deserialize header: %w", err) + } + fmt.Printf(" Header: version=%d size=%#x blockSize=%#x mappings=%d\n", + h.Metadata.Version, h.Metadata.Size, h.Metadata.BlockSize, len(h.Mapping)) + + // 2. Validate mappings cover entire file + if err := header.ValidateHeader(h); err != nil { + return fmt.Errorf("header validation failed: %w", err) + } + fmt.Printf(" Mappings: coverage validated\n") + + // 3. Open data file and check size + reader, dataSize, _, err := cmdutil.OpenDataFile(ctx, storagePath, buildID, artifactName) + if err != nil { + return fmt.Errorf("failed to open data file: %w", err) + } + defer reader.Close() + + fmt.Printf(" Data file: size=%#x\n", dataSize) + + // 4. Validate mappings for the current build only + currentBuildID := h.Metadata.BuildId.String() + validatedCount := 0 + for i, mapping := range h.Mapping { + if mapping.BuildId.String() != currentBuildID { + continue + } + if err := validateMapping(ctx, storagePath, artifactName, h, mapping, i); err != nil { + return fmt.Errorf("mapping[%d] validation failed: %w", i, err) + } + validatedCount++ + } + fmt.Printf(" %d/%d current-build mappings validated\n", validatedCount, len(h.Mapping)) + + // 5. Compute and display MD5 of actual data on storage + hash := md5.New() + chunkSize := int64(1024 * 1024) + buf := make([]byte, chunkSize) + + for offset := int64(0); offset < dataSize; offset += chunkSize { + readSize := chunkSize + if offset+chunkSize > dataSize { + readSize = dataSize - offset + } + n, err := reader.ReadAt(buf[:readSize], offset) + if err != nil && n == 0 { + return fmt.Errorf("failed to read at offset %d: %w", offset, err) + } + hash.Write(buf[:n]) + } + + dataMD5 := hex.EncodeToString(hash.Sum(nil)) + fmt.Printf(" Data MD5 (storage): %s\n", dataMD5) + + // 6. Validate compressed header and frames if it exists + compressedH, _, compErr := cmdutil.ReadCompressedHeader(ctx, storagePath, buildID, artifactName) + + switch { + case compErr != nil: + fmt.Printf(" Compressed header: read error: %s\n", compErr) + case compressedH != nil: + if err := header.ValidateHeader(compressedH); err != nil { + return fmt.Errorf("compressed header validation failed: %w", err) + } + fmt.Printf(" Compressed header: validated (mappings=%d)\n", len(compressedH.Mapping)) + + if err := validateCompressedFrames(ctx, storagePath, artifactName, compressedH); err != nil { + return fmt.Errorf("compressed frame validation failed: %w", err) + } + default: + fmt.Printf(" Compressed header: not present\n") + } + + return nil +} + +// validateMapping validates a single mapping's data integrity. +func validateMapping(ctx context.Context, storagePath, artifactName string, h *header.Header, mapping *header.BuildMap, _ int) error { + if mapping.BuildId.String() == cmdutil.NilUUID { + return nil + } + + if !storage.IsCompressed(mapping.FrameTable) { + reader, _, _, err := cmdutil.OpenDataFile(ctx, storagePath, mapping.BuildId.String(), artifactName) + if err != nil { + return fmt.Errorf("failed to open data for build %s: %w", mapping.BuildId, err) + } + defer reader.Close() + + buf := make([]byte, h.Metadata.BlockSize) + _, err = reader.ReadAt(buf, int64(mapping.BuildStorageOffset)) + if err != nil { + return fmt.Errorf("failed to read data at offset %d: %w", mapping.BuildStorageOffset, err) + } + + return nil + } + + ft := mapping.FrameTable + + var totalU int64 + for _, frame := range ft.Frames { + totalU += int64(frame.U) + } + + if totalU < int64(mapping.Length) { + return fmt.Errorf("frame table covers %#x bytes but mapping length is %#x", totalU, mapping.Length) + } + + reader, fileSize, _, err := cmdutil.OpenDataFile(ctx, storagePath, mapping.BuildId.String(), artifactName) + if err != nil { + return fmt.Errorf("failed to open compressed data for build %s: %w", mapping.BuildId, err) + } + defer reader.Close() + + var totalC int64 + for _, frame := range ft.Frames { + totalC += int64(frame.C) + } + expectedSize := ft.StartAt.C + totalC + + if fileSize < expectedSize { + return fmt.Errorf("compressed file size %#x is less than expected %#x (startC=%#x + framesC=%#x)", + fileSize, expectedSize, ft.StartAt.C, totalC) + } + + firstFrameBuf := make([]byte, ft.Frames[0].C) + _, err = reader.ReadAt(firstFrameBuf, ft.StartAt.C) + if err != nil { + return fmt.Errorf("failed to read first compressed frame at C=%#x: %w", ft.StartAt.C, err) + } + + if len(ft.Frames) > 1 { + lastIdx := len(ft.Frames) - 1 + lastOffset := calculateCOffset(ft, lastIdx) + lastFrameBuf := make([]byte, ft.Frames[lastIdx].C) + _, err = reader.ReadAt(lastFrameBuf, lastOffset) + if err != nil { + return fmt.Errorf("failed to read last compressed frame at C=%#x: %w", lastOffset, err) + } + } + + return nil +} + +// validateCompressedFrames decompresses every frame described in the compressed +// header and compares the result with the uncompressed data file byte-for-byte. +func validateCompressedFrames(ctx context.Context, storagePath, artifactName string, compressedH *header.Header) error { + // Collect unique frames to validate, keyed by (buildID, C-offset). + type frameInfo struct { + offset storage.FrameOffset + size storage.FrameSize + ct storage.CompressionType + } + type frameKey struct { + buildID string + cOffset int64 + } + + buildFrames := make(map[string][]frameInfo) + seen := make(map[frameKey]bool) + + for _, mapping := range compressedH.Mapping { + ft := mapping.FrameTable + if !storage.IsCompressed(ft) { + continue + } + + bid := mapping.BuildId.String() + if bid == cmdutil.NilUUID { + continue + } + + currentOffset := ft.StartAt + for _, frame := range ft.Frames { + key := frameKey{bid, currentOffset.C} + if !seen[key] { + seen[key] = true + buildFrames[bid] = append(buildFrames[bid], frameInfo{ + offset: currentOffset, + size: frame, + ct: ft.CompressionType, + }) + } + currentOffset.Add(frame) + } + } + + if len(buildFrames) == 0 { + fmt.Printf(" No compressed frames to validate\n") + + return nil + } + + totalFrames := 0 + for _, frames := range buildFrames { + totalFrames += len(frames) + } + fmt.Printf(" Validating %d unique compressed frames across %d builds\n", totalFrames, len(buildFrames)) + + for bid, frames := range buildFrames { + // Open compressed file (e.g., v4.memfile.lz4) + compressedFile := storage.V4DataName(artifactName, frames[0].ct) + compReader, compSize, _, err := cmdutil.OpenDataFile(ctx, storagePath, bid, compressedFile) + if err != nil { + return fmt.Errorf("build %s: failed to open %s: %w", bid, compressedFile, err) + } + + // Open uncompressed file (e.g., memfile) + uncReader, uncSize, _, err := cmdutil.OpenDataFile(ctx, storagePath, bid, artifactName) + if err != nil { + compReader.Close() + + return fmt.Errorf("build %s: failed to open %s: %w", bid, artifactName, err) + } + + fmt.Printf(" Build %s: %d frames, compressed=%#x uncompressed=%#x\n", bid, len(frames), compSize, uncSize) + + for i, frame := range frames { + // Read compressed bytes from .lz4 at C offset + compBuf := make([]byte, frame.size.C) + _, err := compReader.ReadAt(compBuf, frame.offset.C) + if err != nil { + compReader.Close() + uncReader.Close() + + return fmt.Errorf("build %s frame[%d]: read compressed at C=%#x size=%#x: %w", + bid, i, frame.offset.C, frame.size.C, err) + } + + // Decompress + decompressed, err := storage.DecompressFrame(frame.ct, compBuf, frame.size.U) + if err != nil { + previewLen := min(32, len(compBuf)) + compReader.Close() + uncReader.Close() + + return fmt.Errorf("build %s frame[%d]: decompress at C=%#x (first %d bytes: %x): %w", + bid, i, frame.offset.C, previewLen, compBuf[:previewLen], err) + } + + // Read corresponding uncompressed bytes + uncBuf := make([]byte, frame.size.U) + _, err = uncReader.ReadAt(uncBuf, frame.offset.U) + if err != nil { + compReader.Close() + uncReader.Close() + + return fmt.Errorf("build %s frame[%d]: read uncompressed at U=%#x size=%#x: %w", + bid, i, frame.offset.U, frame.size.U, err) + } + + // Compare + if !bytes.Equal(decompressed, uncBuf) { + for j := range decompressed { + if j < len(uncBuf) && decompressed[j] != uncBuf[j] { + compReader.Close() + uncReader.Close() + + return fmt.Errorf("build %s frame[%d]: mismatch at U=%#x+%d (byte %d: got %#x want %#x)", + bid, i, frame.offset.U, j, j, decompressed[j], uncBuf[j]) + } + } + } + + fmt.Printf(" frame[%d] U=%#x C=%#x OK (%#x→%#x)\n", + i, frame.offset.U, frame.offset.C, frame.size.C, frame.size.U) + } + + compReader.Close() + uncReader.Close() + } + + fmt.Printf(" Compressed frames: all %d validated\n", totalFrames) + + return nil +} + +// calculateCOffset calculates the compressed offset for frame at index i. +func calculateCOffset(ft *storage.FrameTable, frameIdx int) int64 { + offset := ft.StartAt.C + for i := range frameIdx { + offset += int64(ft.Frames[i].C) + } + + return offset +} + // templateInfo represents a template from the E2B API. type templateInfo struct { TemplateID string `json:"templateID"` @@ -233,20 +647,17 @@ type templateInfo struct { } // resolveTemplateID fetches the build ID for a template from the E2B API. -// Input can be a template ID, alias, or full name (e.g., "e2b/base"). func resolveTemplateID(input string) (string, error) { apiKey := os.Getenv("E2B_API_KEY") if apiKey == "" { return "", fmt.Errorf("E2B_API_KEY environment variable required for -template flag") } - // Determine API URL apiURL := "https://api.e2b.dev/templates" if domain := os.Getenv("E2B_DOMAIN"); domain != "" { apiURL = fmt.Sprintf("https://api.%s/templates", domain) } - // Make HTTP request ctx := context.Background() req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) if err != nil { @@ -266,37 +677,30 @@ func resolveTemplateID(input string) (string, error) { return "", fmt.Errorf("API returned %d: %s", resp.StatusCode, string(body)) } - // Parse response var templates []templateInfo if err := json.NewDecoder(resp.Body).Decode(&templates); err != nil { return "", fmt.Errorf("failed to parse API response: %w", err) } - // Find matching template var match *templateInfo var availableAliases []string for i := range templates { t := &templates[i] - - // Collect aliases for error message availableAliases = append(availableAliases, t.Aliases...) - // Match by template ID if t.TemplateID == input { match = t break } - // Match by alias if slices.Contains(t.Aliases, input) { match = t break } - // Match by full name (e.g., "e2b/base") if slices.Contains(t.Names, input) { match = t @@ -308,7 +712,7 @@ func resolveTemplateID(input string) (string, error) { return "", fmt.Errorf("template %q not found. Available aliases: %s", input, strings.Join(availableAliases, ", ")) } - if match.BuildID == "" || match.BuildID == nilUUID { + if match.BuildID == "" || match.BuildID == cmdutil.NilUUID { return "", fmt.Errorf("template %q has no successful build", input) } diff --git a/packages/orchestrator/cmd/internal/cmdutil/cmdutil.go b/packages/orchestrator/cmd/internal/cmdutil/cmdutil.go index 4530bbc832..5b4d069c4a 100644 --- a/packages/orchestrator/cmd/internal/cmdutil/cmdutil.go +++ b/packages/orchestrator/cmd/internal/cmdutil/cmdutil.go @@ -72,24 +72,55 @@ func GetActualFileSize(path string) (int64, error) { // ArtifactInfo contains information about a build artifact. type ArtifactInfo struct { - Name string - File string - HeaderFile string + Name string + File string // e.g., "memfile" + HeaderFile string // e.g., "memfile.header" + CompressedFiles []string // e.g., ["v4.memfile.lz4", "v4.memfile.zstd"] + CompressedHeaderFile string // e.g., "v4.memfile.header.lz4" +} + +// allCompressionTypes lists all supported compression types for file probing. +var allCompressionTypes = []storage.CompressionType{ + storage.CompressionLZ4, + storage.CompressionZstd, } // MainArtifacts returns the list of main artifacts (rootfs, memfile). func MainArtifacts() []ArtifactInfo { return []ArtifactInfo{ - {"Rootfs", storage.RootfsName, storage.RootfsName + storage.HeaderSuffix}, - {"Memfile", storage.MemfileName, storage.MemfileName + storage.HeaderSuffix}, + { + Name: "Rootfs", + File: storage.RootfsName, + HeaderFile: storage.RootfsName + storage.HeaderSuffix, + CompressedFiles: v4DataNames(storage.RootfsName), + CompressedHeaderFile: storage.V4HeaderName(storage.RootfsName), + }, + { + Name: "Memfile", + File: storage.MemfileName, + HeaderFile: storage.MemfileName + storage.HeaderSuffix, + CompressedFiles: v4DataNames(storage.MemfileName), + CompressedHeaderFile: storage.V4HeaderName(storage.MemfileName), + }, + } +} + +func v4DataNames(fileName string) []string { + names := make([]string, len(allCompressionTypes)) + for i, ct := range allCompressionTypes { + names[i] = storage.V4DataName(fileName, ct) } + + return names } // SmallArtifacts returns the list of small artifacts (headers, snapfile, metadata). func SmallArtifacts() []struct{ Name, File string } { return []struct{ Name, File string }{ {"Rootfs header", storage.RootfsName + storage.HeaderSuffix}, + {"Rootfs v4 header", storage.V4HeaderName(storage.RootfsName)}, {"Memfile header", storage.MemfileName + storage.HeaderSuffix}, + {"Memfile v4 header", storage.V4HeaderName(storage.MemfileName)}, {"Snapfile", storage.SnapfileName}, {"Metadata", storage.MetadataName}, } diff --git a/packages/orchestrator/cmd/internal/cmdutil/format.go b/packages/orchestrator/cmd/internal/cmdutil/format.go new file mode 100644 index 0000000000..f7cb92b15a --- /dev/null +++ b/packages/orchestrator/cmd/internal/cmdutil/format.go @@ -0,0 +1,195 @@ +package cmdutil + +import ( + "fmt" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +const NilUUID = "00000000-0000-0000-0000-000000000000" + +// ANSI color codes for compression ratio visualization. +const ( + colorReset = "\033[0m" + colorRed = "\033[91m" // bright red — incompressible + colorYellow = "\033[33m" // yellow — poor + colorGreen = "\033[32m" // green — good + colorCyan = "\033[36m" // cyan — very sparse + colorBlue = "\033[34m" // blue — nearly empty +) + +// RatioColor returns an ANSI color code for a compression ratio value. +func RatioColor(ratio float64) string { + switch { + case ratio < 1.5: + return colorRed + case ratio < 2.5: + return colorYellow + case ratio < 4: + return colorReset + case ratio < 8: + return colorGreen + case ratio < 50: + return colorCyan + default: + return colorBlue + } +} + +// FormatRatio returns a color-coded ratio string (4 chars wide). +func FormatRatio(ratio float64) string { + color := RatioColor(ratio) + if ratio >= 100 { + return fmt.Sprintf("%s%4.0f%s", color, ratio, colorReset) + } + + return fmt.Sprintf("%s%4.1f%s", color, ratio, colorReset) +} + +// FormatMappingWithCompression returns mapping info with compression details. +func FormatMappingWithCompression(mapping *header.BuildMap, blockSize uint64) string { + base := mapping.Format(blockSize) + + if mapping.FrameTable == nil { + return base + " [uncompressed]" + } + + ft := mapping.FrameTable + var totalU, totalC int64 + for _, frame := range ft.Frames { + totalU += int64(frame.U) + totalC += int64(frame.C) + } + + ratio := float64(totalU) / float64(totalC) + + return fmt.Sprintf("%s [%s: %d frames, U=%#x C=%#x ratio=%s]", + base, ft.CompressionType.String(), len(ft.Frames), totalU, totalC, FormatRatio(ratio)) +} + +// PrintCompressionSummary prints compression statistics for a header. +func PrintCompressionSummary(h *header.Header) { + var compressedMappings, uncompressedMappings int + var totalUncompressedBytes, totalCompressedBytes int64 + var totalFrames int + + type buildStats struct { + uncompressedBytes int64 + compressedBytes int64 + frames []storage.FrameSize + compressed bool + } + buildCompressionStats := make(map[string]*buildStats) + + for _, mapping := range h.Mapping { + buildID := mapping.BuildId.String() + if buildID == NilUUID { + continue + } + + if _, ok := buildCompressionStats[buildID]; !ok { + buildCompressionStats[buildID] = &buildStats{} + } + stats := buildCompressionStats[buildID] + + if mapping.FrameTable != nil && mapping.FrameTable.CompressionType != storage.CompressionNone { + compressedMappings++ + stats.compressed = true + + for _, frame := range mapping.FrameTable.Frames { + totalUncompressedBytes += int64(frame.U) + totalCompressedBytes += int64(frame.C) + stats.uncompressedBytes += int64(frame.U) + stats.compressedBytes += int64(frame.C) + stats.frames = append(stats.frames, frame) + } + totalFrames += len(mapping.FrameTable.Frames) + } else { + uncompressedMappings++ + totalUncompressedBytes += int64(mapping.Length) + stats.uncompressedBytes += int64(mapping.Length) + } + } + + fmt.Printf("\nCOMPRESSION SUMMARY\n") + fmt.Printf("===================\n") + + if compressedMappings == 0 && uncompressedMappings == 0 { + fmt.Printf("No data mappings (all sparse)\n") + + return + } + + fmt.Printf("Mappings: %d compressed, %d uncompressed\n", compressedMappings, uncompressedMappings) + + if compressedMappings > 0 { + ratio := float64(totalUncompressedBytes) / float64(totalCompressedBytes) + savings := 100.0 * (1.0 - float64(totalCompressedBytes)/float64(totalUncompressedBytes)) + fmt.Printf("Total frames: %d\n", totalFrames) + fmt.Printf("Uncompressed size: %#x (%.2f MiB)\n", totalUncompressedBytes, float64(totalUncompressedBytes)/1024/1024) + fmt.Printf("Compressed size: %#x (%.2f MiB)\n", totalCompressedBytes, float64(totalCompressedBytes)/1024/1024) + fmt.Printf("Compression ratio: %s (%.1f%% space savings)\n", FormatRatio(ratio), savings) + } else { + fmt.Printf("All mappings are uncompressed\n") + } + + hasCompressedBuilds := false + for _, stats := range buildCompressionStats { + if stats.compressed { + hasCompressedBuilds = true + + break + } + } + + if hasCompressedBuilds { + fmt.Printf("\nPer-build compression:\n") + for buildID, stats := range buildCompressionStats { + label := buildID[:8] + "..." + if buildID == h.Metadata.BuildId.String() { + label += " (current)" + } else if buildID == h.Metadata.BaseBuildId.String() { + label += " (parent)" + } + + if !stats.compressed { + fmt.Printf(" %s: uncompressed, %#x\n", label, stats.uncompressedBytes) + + continue + } + + ratio := float64(stats.uncompressedBytes) / float64(stats.compressedBytes) + fmt.Printf(" %s: %d frames, U=%#x C=%#x (%s)\n", + label, len(stats.frames), stats.uncompressedBytes, stats.compressedBytes, FormatRatio(ratio)) + + // Frame stats + if len(stats.frames) > 0 { + minC, maxC := stats.frames[0].C, stats.frames[0].C + for _, f := range stats.frames[1:] { + minC = min(minC, f.C) + maxC = max(maxC, f.C) + } + avgC := stats.compressedBytes / int64(len(stats.frames)) + fmt.Printf(" Frame sizes: avg %d KiB, min %d KiB, max %d KiB\n", + avgC/1024, minC/1024, maxC/1024) + } + + // Ratio matrix: 16 frames per row + if len(stats.frames) > 1 { + const cols = 16 + fmt.Printf("\n Ratio matrix (%d per row):\n", cols) + for row := 0; row < len(stats.frames); row += cols { + end := min(row+cols, len(stats.frames)) + fmt.Printf(" %4d: ", row) + for _, f := range stats.frames[row:end] { + r := float64(f.U) / float64(f.C) + fmt.Printf(" %s", FormatRatio(r)) + } + fmt.Println() + } + fmt.Println() + } + } + } +} diff --git a/packages/orchestrator/cmd/internal/cmdutil/storage.go b/packages/orchestrator/cmd/internal/cmdutil/storage.go index 69817e75e4..0307732bfc 100644 --- a/packages/orchestrator/cmd/internal/cmdutil/storage.go +++ b/packages/orchestrator/cmd/internal/cmdutil/storage.go @@ -2,6 +2,7 @@ package cmdutil import ( "context" + "errors" "fmt" "io" "os" @@ -9,6 +10,10 @@ import ( "strings" gcsstorage "cloud.google.com/go/storage" + "google.golang.org/api/iterator" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) // IsGCSPath checks if the path is a GCS path (gs:// or gs:). @@ -210,3 +215,190 @@ func openGCS(ctx context.Context, gcsPath string) (DataReader, int64, string, er return &gcsReader{client: client, bucket: bucket, object: object}, attrs.Size, gcsPath, nil } + +// ReadFileIfExists reads a file from local storage or GCS. +// Returns nil, "", nil when the file doesn't exist (instead of an error). +func ReadFileIfExists(ctx context.Context, storagePath, buildID, filename string) ([]byte, string, error) { + data, source, err := ReadFile(ctx, storagePath, buildID, filename) + if err != nil { + if isNotFoundError(err) { + return nil, "", nil + } + + return nil, "", err + } + + return data, source, nil +} + +// ReadCompressedHeader reads a v4 header file (e.g. "v4.memfile.header.lz4"), +// LZ4-block-decompresses it, and deserializes. +// Returns nil, "", nil when the v4 header doesn't exist. +func ReadCompressedHeader(ctx context.Context, storagePath, buildID, artifactName string) (*header.Header, string, error) { + filename := storage.V4HeaderName(artifactName) + data, source, err := ReadFileIfExists(ctx, storagePath, buildID, filename) + if err != nil { + return nil, "", fmt.Errorf("failed to read compressed header: %w", err) + } + if data == nil { + return nil, "", nil + } + + decompressed, err := storage.DecompressLZ4(data, storage.MaxCompressedHeaderSize) + if err != nil { + return nil, "", fmt.Errorf("failed to decompress LZ4 header from %s: %w", source, err) + } + + h, err := header.DeserializeBytes(decompressed) + if err != nil { + return nil, "", fmt.Errorf("failed to deserialize compressed header from %s: %w", source, err) + } + + return h, source, nil +} + +// FileInfo contains existence and size information about a file. +type FileInfo struct { + Name string + Path string + Exists bool + Size int64 + Metadata map[string]string // GCS custom metadata (nil for local files) +} + +// ProbeFile checks if a file exists and returns its info. +func ProbeFile(ctx context.Context, storagePath, buildID, filename string) FileInfo { + info := FileInfo{Name: filename} + + if IsGCSPath(storagePath) { + gcsPath := NormalizeGCSPath(storagePath) + "/" + buildID + "/" + filename + info.Path = gcsPath + + path := strings.TrimPrefix(gcsPath, "gs://") + parts := strings.SplitN(path, "/", 2) + if len(parts) != 2 { + return info + } + + client, err := gcsstorage.NewClient(ctx) + if err != nil { + return info + } + defer client.Close() + + attrs, err := client.Bucket(parts[0]).Object(parts[1]).Attrs(ctx) + if err != nil { + return info + } + + info.Exists = true + info.Size = attrs.Size + info.Metadata = attrs.Metadata + } else { + localPath := filepath.Join(storagePath, "templates", buildID, filename) + info.Path = localPath + + fi, err := os.Stat(localPath) + if err != nil { + return info + } + + info.Exists = true + info.Size = fi.Size() + } + + return info +} + +// isNotFoundError checks if an error indicates a file/object doesn't exist. +func isNotFoundError(err error) bool { + if os.IsNotExist(err) { + return true + } + + if errors.Is(err, gcsstorage.ErrObjectNotExist) { + return true + } + + return false +} + +// ListFiles lists all files for a build in storage. +// Returns FileInfo for each file found. +func ListFiles(ctx context.Context, storagePath, buildID string) ([]FileInfo, error) { + if IsGCSPath(storagePath) { + return listGCSFiles(ctx, storagePath, buildID) + } + + return listLocalFiles(storagePath, buildID) +} + +func listGCSFiles(ctx context.Context, storagePath, buildID string) ([]FileInfo, error) { + normalized := NormalizeGCSPath(storagePath) + bucket := ExtractBucketName(storagePath) + prefix := buildID + "/" + + client, err := gcsstorage.NewClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create GCS client: %w", err) + } + defer client.Close() + + var files []FileInfo + it := client.Bucket(bucket).Objects(ctx, &gcsstorage.Query{Prefix: prefix}) + + for { + attrs, err := it.Next() + if errors.Is(err, iterator.Done) { + break + } + if err != nil { + return nil, fmt.Errorf("failed to list objects: %w", err) + } + + name := strings.TrimPrefix(attrs.Name, prefix) + files = append(files, FileInfo{ + Name: name, + Path: normalized + "/" + attrs.Name, + Exists: true, + Size: attrs.Size, + Metadata: attrs.Metadata, + }) + } + + return files, nil +} + +func listLocalFiles(storagePath, buildID string) ([]FileInfo, error) { + dir := filepath.Join(storagePath, "templates", buildID) + + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + + return nil, fmt.Errorf("failed to read directory: %w", err) + } + + var files []FileInfo + for _, entry := range entries { + if entry.IsDir() { + continue + } + + fi, err := entry.Info() + if err != nil { + continue + } + + files = append(files, FileInfo{ + Name: entry.Name(), + Path: filepath.Join(dir, entry.Name()), + Exists: true, + Size: fi.Size(), + }) + } + + return files, nil +} diff --git a/packages/orchestrator/cmd/resume-build/main.go b/packages/orchestrator/cmd/resume-build/main.go index cfdf52c6af..c22ddd0e15 100644 --- a/packages/orchestrator/cmd/resume-build/main.go +++ b/packages/orchestrator/cmd/resume-build/main.go @@ -622,16 +622,20 @@ func (r *runner) pauseOnce(ctx context.Context, opts pauseOptions, verbose bool) // Only upload when not in benchmark mode (verbose = true means single run) if verbose { - templateFiles := storage.TemplateFiles{BuildID: opts.newBuildID} + tb, err := sandbox.NewTemplateBuild(snapshot, r.storage, storage.TemplateFiles{BuildID: opts.newBuildID}, nil, nil) + if err != nil { + return timings, fmt.Errorf("failed to create template build: %w", err) + } + if opts.isRemoteStorage { fmt.Println("📤 Uploading snapshot...") - if err := snapshot.Upload(ctx, r.storage, templateFiles); err != nil { + if err := tb.UploadAll(ctx); err != nil { return timings, fmt.Errorf("failed to upload snapshot: %w", err) } fmt.Println("✅ Snapshot uploaded successfully") } else { fmt.Println("💾 Saving snapshot to local storage...") - if err := snapshot.Upload(ctx, r.storage, templateFiles); err != nil { + if err := tb.UploadAll(ctx); err != nil { return timings, fmt.Errorf("failed to save snapshot: %w", err) } fmt.Println("✅ Snapshot saved successfully") diff --git a/packages/orchestrator/go.mod b/packages/orchestrator/go.mod index dddc5c5f02..21ed91d4e0 100644 --- a/packages/orchestrator/go.mod +++ b/packages/orchestrator/go.mod @@ -44,11 +44,13 @@ require ( github.com/hashicorp/consul/api v1.32.1 github.com/inetaf/tcpproxy v0.0.0-20250222171855-c4b9df066048 github.com/jellydator/ttlcache/v3 v3.4.0 + github.com/klauspost/compress v1.18.2 github.com/launchdarkly/go-sdk-common/v3 v3.3.0 github.com/launchdarkly/go-server-sdk/v7 v7.13.0 github.com/ngrok/firewall_toolkit v0.0.18 github.com/oapi-codegen/gin-middleware v1.0.2 github.com/oapi-codegen/runtime v1.1.1 + github.com/pierrec/lz4/v4 v4.1.22 github.com/pkg/errors v0.9.1 github.com/shirou/gopsutil/v4 v4.25.9 github.com/soheilhy/cmux v0.1.5 @@ -202,7 +204,6 @@ require ( github.com/hashicorp/serf v0.10.2 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.11 // indirect github.com/klauspost/pgzip v1.2.6 // indirect github.com/launchdarkly/ccache v1.1.0 // indirect @@ -248,7 +249,6 @@ require ( github.com/paulmach/orb v0.11.1 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect - github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 8080841ccc..234f1f5634 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -8,7 +8,6 @@ import ( "math" "math/rand" "os" - "slices" "sync" "sync/atomic" "syscall" @@ -49,7 +48,7 @@ type Cache struct { blockSize int64 mmap *mmap.MMap mu sync.RWMutex - dirty sync.Map + dirty []atomic.Bool // indexed by off/blockSize — block is present and dirty dirtyFile bool closed atomic.Bool } @@ -87,12 +86,15 @@ func NewCache(size, blockSize int64, filePath string, dirtyFile bool) (*Cache, e return nil, fmt.Errorf("error mapping file: %w", err) } + numBlocks := (size + blockSize - 1) / blockSize + return &Cache{ mmap: &mm, filePath: filePath, size: size, blockSize: blockSize, dirtyFile: dirtyFile, + dirty: make([]atomic.Bool, numBlocks), }, nil } @@ -246,9 +248,11 @@ func (c *Cache) Slice(off, length int64) ([]byte, error) { } func (c *Cache) isCached(off, length int64) bool { - for _, blockOff := range header.BlocksOffsets(length, c.blockSize) { - _, dirty := c.dirty.Load(off + blockOff) - if !dirty { + startIdx := off / c.blockSize + endIdx := (off + length + c.blockSize - 1) / c.blockSize + + for idx := startIdx; idx < endIdx; idx++ { + if !c.dirty[idx].Load() { return false } } @@ -257,8 +261,11 @@ func (c *Cache) isCached(off, length int64) bool { } func (c *Cache) setIsCached(off, length int64) { - for _, blockOff := range header.BlocksOffsets(length, c.blockSize) { - c.dirty.Store(off+blockOff, struct{}{}) + startIdx := off / c.blockSize + endIdx := (off + length + c.blockSize - 1) / c.blockSize + + for idx := startIdx; idx < endIdx; idx++ { + c.dirty[idx].Store(true) } } @@ -281,16 +288,14 @@ func (c *Cache) WriteAtWithoutLock(b []byte, off int64) (int, error) { return n, nil } -// dirtySortedKeys returns a sorted list of dirty keys. -// Key represents a block offset. +// dirtySortedKeys returns a sorted list of dirty block offsets. func (c *Cache) dirtySortedKeys() []int64 { var keys []int64 - c.dirty.Range(func(key, _ any) bool { - keys = append(keys, key.(int64)) - - return true - }) - slices.Sort(keys) + for i := range c.dirty { + if c.dirty[i].Load() { + keys = append(keys, int64(i)*c.blockSize) + } + } return keys } @@ -481,9 +486,7 @@ func (c *Cache) copyProcessMemory( return fmt.Errorf("failed to read memory: expected %d bytes, got %d", segmentSize, n) } - for _, blockOff := range header.BlocksOffsets(segmentSize, c.blockSize) { - c.dirty.Store(offset+blockOff, struct{}{}) - } + c.setIsCached(offset, segmentSize) offset += segmentSize diff --git a/packages/orchestrator/internal/sandbox/block/chunk.go b/packages/orchestrator/internal/sandbox/block/chunk.go index f90c7d1feb..fb2027b951 100644 --- a/packages/orchestrator/internal/sandbox/block/chunk.go +++ b/packages/orchestrator/internal/sandbox/block/chunk.go @@ -4,125 +4,49 @@ import ( "context" "errors" "fmt" - "io" "go.opentelemetry.io/otel/attribute" - "go.uber.org/zap" "golang.org/x/sync/errgroup" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" - featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" - "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) -// Chunker is the interface satisfied by both FullFetchChunker and StreamingChunker. -type Chunker interface { - Slice(ctx context.Context, off, length int64) ([]byte, error) - ReadAt(ctx context.Context, b []byte, off int64) (int, error) - WriteTo(ctx context.Context, w io.Writer) (int64, error) - Close() error - FileSize() (int64, error) -} - -// NewChunker creates a Chunker based on the chunker-config feature flag. -// It reads the flag internally so callers don't need to parse flag values. -func NewChunker( - ctx context.Context, - featureFlags *featureflags.Client, - size, blockSize int64, - upstream storage.Seekable, - cachePath string, - metrics metrics.Metrics, -) (Chunker, error) { - useStreaming, minReadBatchSizeKB := getChunkerConfig(ctx, featureFlags) - - if useStreaming { - return NewStreamingChunker(size, blockSize, upstream, cachePath, metrics, int64(minReadBatchSizeKB)*1024, featureFlags) - } - - return NewFullFetchChunker(size, blockSize, upstream, cachePath, metrics) -} - -// getChunkerConfig fetches the chunker-config feature flag and returns the parsed values. -func getChunkerConfig(ctx context.Context, ff *featureflags.Client) (useStreaming bool, minReadBatchSizeKB int) { - value := ff.JSONFlag(ctx, featureflags.ChunkerConfigFlag) - - if v := value.GetByKey("useStreaming"); v.IsDefined() { - useStreaming = v.BoolValue() - } - - if v := value.GetByKey("minReadBatchSizeKB"); v.IsDefined() { - minReadBatchSizeKB = v.IntValue() - } - - return useStreaming, minReadBatchSizeKB -} - -type FullFetchChunker struct { - base storage.SeekableReader - cache *Cache - metrics metrics.Metrics - - size int64 - - // TODO: Optimize this so we don't need to keep the fetchers in memory. +// fullFetchChunker is a benchmark-only port of main's FullFetchChunker. +// It fetches aligned MemoryChunkSize (4 MB) chunks via GetFrame and uses +// WaitMap for dedup (one in-flight fetch per chunk offset). +type fullFetchChunker struct { + upstream storage.FramedFile + cache *Cache + metrics metrics.Metrics + size int64 fetchers *utils.WaitMap } -func NewFullFetchChunker( +func newFullFetchChunker( size, blockSize int64, - base storage.SeekableReader, + upstream storage.FramedFile, cachePath string, - metrics metrics.Metrics, -) (*FullFetchChunker, error) { + m metrics.Metrics, +) (*fullFetchChunker, error) { cache, err := NewCache(size, blockSize, cachePath, false) if err != nil { return nil, fmt.Errorf("failed to create file cache: %w", err) } - chunker := &FullFetchChunker{ + return &fullFetchChunker{ size: size, - base: base, + upstream: upstream, cache: cache, fetchers: utils.NewWaitMap(), - metrics: metrics, - } - - return chunker, nil + metrics: m, + }, nil } -func (c *FullFetchChunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { - slice, err := c.Slice(ctx, off, int64(len(b))) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", off, off+int64(len(b)), err) - } - - return copy(b, slice), nil -} - -func (c *FullFetchChunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - for i := int64(0); i < c.size; i += storage.MemoryChunkSize { - chunk := make([]byte, storage.MemoryChunkSize) - - n, err := c.ReadAt(ctx, chunk, i) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", i, i+storage.MemoryChunkSize, err) - } - - _, err = w.Write(chunk[:n]) - if err != nil { - return 0, fmt.Errorf("failed to write chunk %d to writer: %w", i, err) - } - } - - return c.size, nil -} - -func (c *FullFetchChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { - timer := c.metrics.SlicesTimerFactory.Begin() +func (c *fullFetchChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { + timer := c.metrics.BlocksTimerFactory.Begin() b, err := c.cache.Slice(off, length) if err == nil { @@ -132,7 +56,8 @@ func (c *FullFetchChunker) Slice(ctx context.Context, off, length int64) ([]byte return b, nil } - if !errors.As(err, &BytesNotAvailableError{}) { + var bytesNotAvailableError BytesNotAvailableError + if !errors.As(err, &bytesNotAvailableError) { timer.Failure(ctx, length, attribute.String(pullType, pullTypeLocal), attribute.String(failureReason, failureTypeLocalRead)) @@ -164,97 +89,64 @@ func (c *FullFetchChunker) Slice(ctx context.Context, off, length int64) ([]byte return b, nil } -// fetchToCache ensures that the data at the given offset and length is available in the cache. -func (c *FullFetchChunker) fetchToCache(ctx context.Context, off, length int64) error { +// fetchToCache ensures the MemoryChunkSize-aligned region(s) covering +// [off, off+length) are present in the cache. Uses WaitMap for dedup. +func (c *fullFetchChunker) fetchToCache(ctx context.Context, off, length int64) error { var eg errgroup.Group chunks := header.BlocksOffsets(length, storage.MemoryChunkSize) - startingChunk := header.BlockIdx(off, storage.MemoryChunkSize) startingChunkOffset := header.BlockOffset(startingChunk, storage.MemoryChunkSize) for _, chunkOff := range chunks { - // Ensure the closure captures the correct block offset. fetchOff := startingChunkOffset + chunkOff - eg.Go(func() (err error) { - defer func() { - if r := recover(); r != nil { - logger.L().Error(ctx, "recovered from panic in the fetch handler", zap.Any("error", r)) - err = fmt.Errorf("recovered from panic in the fetch handler: %v", r) - } - }() - - err = c.fetchers.Wait(fetchOff, func() error { + eg.Go(func() error { + return c.fetchers.Wait(fetchOff, func() error { select { case <-ctx.Done(): return fmt.Errorf("error fetching range %d-%d: %w", fetchOff, fetchOff+storage.MemoryChunkSize, ctx.Err()) default: } - // The size of the buffer is adjusted if the last chunk is not a multiple of the block size. - b, releaseCacheCloseLock, err := c.cache.addressBytes(fetchOff, storage.MemoryChunkSize) + b, releaseLock, err := c.cache.addressBytes(fetchOff, storage.MemoryChunkSize) if err != nil { return err } - - defer releaseCacheCloseLock() + defer releaseLock() fetchSW := c.metrics.RemoteReadsTimerFactory.Begin() - readBytes, err := c.base.ReadAt(ctx, b, fetchOff) - if err != nil { - fetchSW.Failure(ctx, int64(readBytes), - attribute.String(failureReason, failureTypeRemoteRead), - ) - - return fmt.Errorf("failed to read chunk from base %d: %w", fetchOff, err) + // Pass onRead + readSize identical to the branch Chunker so + // slowFrameGetter simulates the same bandwidth delay. + readSize := int64(defaultMinReadBatchSize) + onRead := func(totalWritten int64) { + c.cache.setIsCached(fetchOff, totalWritten) } - if readBytes != len(b) { - fetchSW.Failure(ctx, int64(readBytes), - attribute.String(failureReason, failureTypeRemoteRead), - ) + _, err = c.upstream.GetFrame(ctx, fetchOff, nil, false, b, readSize, onRead) + if err != nil { + fetchSW.Failure(ctx, int64(len(b)), + attribute.String(failureReason, failureTypeRemoteRead)) - return fmt.Errorf("failed to read chunk from base %d: expected %d bytes, got %d bytes", fetchOff, len(b), readBytes) + return fmt.Errorf("failed to read chunk from upstream at %d: %w", fetchOff, err) } - c.cache.setIsCached(fetchOff, int64(readBytes)) - - fetchSW.Success(ctx, int64(readBytes)) + c.cache.setIsCached(fetchOff, int64(len(b))) + fetchSW.Success(ctx, int64(len(b))) return nil }) - - return err }) } - err := eg.Wait() - if err != nil { + if err := eg.Wait(); err != nil { return fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, err) } return nil } -func (c *FullFetchChunker) Close() error { +func (c *fullFetchChunker) Close() error { return c.cache.Close() } - -func (c *FullFetchChunker) FileSize() (int64, error) { - return c.cache.FileSize() -} - -const ( - pullType = "pull-type" - pullTypeLocal = "local" - pullTypeRemote = "remote" - - failureReason = "failure-reason" - - failureTypeLocalRead = "local-read" - failureTypeLocalReadAgain = "local-read-again" - failureTypeRemoteRead = "remote-read" - failureTypeCacheFetch = "cache-fetch" -) diff --git a/packages/orchestrator/internal/sandbox/block/chunk_bench_test.go b/packages/orchestrator/internal/sandbox/block/chunk_bench_test.go new file mode 100644 index 0000000000..28a147116e --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/chunk_bench_test.go @@ -0,0 +1,456 @@ +package block + +import ( + "context" + "fmt" + "math/rand/v2" + "testing" + "time" + + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +// --------------------------------------------------------------------------- +// Benchmark constants & dimensions +// --------------------------------------------------------------------------- + +const benchDataSize = 100 * 1024 * 1024 // 100 MB + +var benchFrameSizes = []int{ + 1 * 1024 * 1024, // 1 MB + 2 * 1024 * 1024, // 2 MB + 4 * 1024 * 1024, // 4 MB (= MemoryChunkSize) +} + +var benchBlockSizes = []int64{ + 4 * 1024, // 4 KB — typical VM page fault + 2 * 1024 * 1024, // 2 MB — large sequential read +} + +// --------------------------------------------------------------------------- +// Backend profiles (simulated latency/bandwidth) +// --------------------------------------------------------------------------- + +type backendProfile struct { + name string + ttfb time.Duration + bandwidth int64 // bytes/sec +} + +var profiles = []backendProfile{ + {name: "GCS", ttfb: 50 * time.Millisecond, bandwidth: 100 * 1024 * 1024}, + {name: "NFS", ttfb: 1 * time.Millisecond, bandwidth: 500 * 1024 * 1024}, +} + +// --------------------------------------------------------------------------- +// Codec configurations +// --------------------------------------------------------------------------- + +type codecConfig struct { + name string + compressionType storage.CompressionType + level int +} + +var benchCodecs = []codecConfig{ + {name: "LZ4", compressionType: storage.CompressionLZ4, level: 0}, + {name: "Zstd1", compressionType: storage.CompressionZstd, level: 1}, + {name: "Zstd3", compressionType: storage.CompressionZstd, level: 3}, +} + +// --------------------------------------------------------------------------- +// Generic read function + setup types +// --------------------------------------------------------------------------- + +type benchReadFunc func(ctx context.Context, off, length int64) ([]byte, error) + +type coldSetup struct { + read benchReadFunc + close func() + fetchCount func() int64 + storeBytes int64 // compressed bytes transferred per iteration (= benchDataSize for uncompressed) +} + +// --------------------------------------------------------------------------- +// Shared helpers +// --------------------------------------------------------------------------- + +const benchWorkers = 4 + +func newBenchFlags(tb testing.TB) *MockFlagsClient { + tb.Helper() + + m := NewMockFlagsClient(tb) + m.EXPECT().JSONFlag(mock.Anything, mock.Anything).Return( + ldvalue.FromJSONMarshal(map[string]any{"minReadBatchSizeKB": 256}), + ).Maybe() + + return m +} + +func generateSemiRandomData(size int) []byte { + data := make([]byte, size) + rng := rand.New(rand.NewPCG(1, 2)) //nolint:gosec // deterministic for benchmarks + + // Random byte value repeated 1–16 times. Resembles real VM memory: + // mostly random with occasional short runs (zero-filled structs, padding). + // Kept short enough that compression stays under ~4x so frame count + // scales with TargetFrameSize without hitting DefaultMaxFrameUncompressedSize. + i := 0 + for i < size { + runLen := rng.IntN(16) + 1 + if i+runLen > size { + runLen = size - i + } + b := byte(rng.IntN(256)) + for j := range runLen { + data[i+j] = b + } + i += runLen + } + + return data +} + +func newBenchChunker(tb testing.TB, assets AssetInfo, blockSize int64) *Chunker { + tb.Helper() + + c, err := NewChunker(assets, blockSize, tb.TempDir()+"/cache", newTestMetrics(tb), newBenchFlags(tb)) + require.NoError(tb, err) + + return c +} + +func newFullFetchBench(tb testing.TB, upstream storage.FramedFile, size, blockSize int64) *fullFetchChunker { + tb.Helper() + + c, err := newFullFetchChunker(size, blockSize, upstream, tb.TempDir()+"/cache", newTestMetrics(tb)) + require.NoError(tb, err) + + return c +} + +func shuffledOffsets(dataSize, blockSize int64) []int64 { + n := (dataSize + blockSize - 1) / blockSize + offsets := make([]int64, n) + for i := range offsets { + offsets[i] = int64(i) * blockSize + } + rng := rand.New(rand.NewPCG(42, 99)) //nolint:gosec // deterministic for benchmarks + rng.Shuffle(len(offsets), func(i, j int) { + offsets[i], offsets[j] = offsets[j], offsets[i] + }) + + return offsets +} + +func fmtSize(n int64) string { + switch { + case n >= 1024*1024: + return fmt.Sprintf("%dMB", n/(1024*1024)) + case n >= 1024: + return fmt.Sprintf("%dKB", n/1024) + default: + return fmt.Sprintf("%dB", n) + } +} + +func frameTableCompressedSize(ft *storage.FrameTable) int64 { + var total int64 + for _, f := range ft.Frames { + total += int64(f.C) + } + + return total +} + +func setCompressedAsset(a *AssetInfo, ct storage.CompressionType, file storage.FramedFile) { + switch ct { + case storage.CompressionLZ4: + a.HasLZ4 = true + a.LZ4 = file + case storage.CompressionZstd: + a.HasZstd = true + a.Zstd = file + } +} + +// --------------------------------------------------------------------------- +// Leaf runners +// --------------------------------------------------------------------------- + +// runColdLeaf runs a single cold-concurrent benchmark leaf (one profile, one +// blockSize, one mode). Each b.N iteration creates a fresh cold cache. +// +// Reported metrics (in addition to ns/op): +// - U-MB/op — uncompressed megabytes delivered per iteration (fixed) +// - U-MB/s — uncompressed throughput to the client +// - C-MB/op — compressed megabytes fetched from store per iteration +// - fetches/op — upstream fetch count (deduped) +func runColdLeaf(b *testing.B, data []byte, blockSize int64, profile backendProfile, newIter func(tb testing.TB, slow *slowFrameGetter, blockSize int64) coldSetup) { + b.Helper() + + dataSize := int64(len(data)) + offsets := shuffledOffsets(dataSize, blockSize) + b.ResetTimer() + + var totalElapsed time.Duration + var storeBytes int64 + + for range b.N { + b.StopTimer() + slow := &slowFrameGetter{data: data, ttfb: profile.ttfb, bandwidth: profile.bandwidth} + s := newIter(b, slow, blockSize) + storeBytes = s.storeBytes + b.StartTimer() + + start := time.Now() + + g, ctx := errgroup.WithContext(context.Background()) + for w := range benchWorkers { + g.Go(func() error { + for i := w; i < len(offsets); i += benchWorkers { + off := offsets[i] + length := min(blockSize, dataSize-off) + if _, err := s.read(ctx, off, length); err != nil { + return err + } + } + + return nil + }) + } + if err := g.Wait(); err != nil { + b.Fatal(err) + } + + totalElapsed += time.Since(start) + + b.StopTimer() + b.ReportMetric(float64(s.fetchCount()), "fetches/op") + s.close() + b.StartTimer() + } + + uMB := float64(dataSize) / (1024 * 1024) + cMB := float64(storeBytes) / (1024 * 1024) + + b.ReportMetric(uMB, "U-MB/op") + b.ReportMetric(cMB, "C-MB/op") + + if totalElapsed > 0 { + b.ReportMetric(uMB/(totalElapsed.Seconds()/float64(b.N)), "U-MB/s") + } +} + +// runCacheHitLeaf runs a single cache-hit benchmark leaf (one blockSize, one +// mode). Creates one chunker, warms the cache, then measures b.N reads. +func runCacheHitLeaf(b *testing.B, dataSize, blockSize int64, read benchReadFunc) { + b.Helper() + + ctx := context.Background() + for off := int64(0); off < dataSize; off += blockSize { + _, err := read(ctx, off, min(blockSize, dataSize-off)) + require.NoError(b, err) + } + + nOffsets := dataSize / blockSize + b.ResetTimer() + + for i := range b.N { + off := (int64(i) % nOffsets) * blockSize + if _, err := read(ctx, off, blockSize); err != nil { + b.Fatal(err) + } + } +} + +// --------------------------------------------------------------------------- +// BenchmarkCacheHit +// +// block=4KB/ +// +// Legacy +// Uncompressed +// +// block=2MB/ +// +// Legacy +// Uncompressed +// +// --------------------------------------------------------------------------- +func BenchmarkCacheHit(b *testing.B) { + data := generateSemiRandomData(benchDataSize) + dataSize := int64(len(data)) + + for _, blockSize := range benchBlockSizes { + b.Run(fmt.Sprintf("block=%s", fmtSize(blockSize)), func(b *testing.B) { + b.Run("Legacy", func(b *testing.B) { + getter := &slowFrameGetter{data: data} + c := newFullFetchBench(b, getter, dataSize, blockSize) + defer c.Close() + + runCacheHitLeaf(b, dataSize, blockSize, func(ctx context.Context, off, length int64) ([]byte, error) { + return c.Slice(ctx, off, length) + }) + }) + + b.Run("Uncompressed", func(b *testing.B) { + getter := &slowFrameGetter{data: data} + assets := AssetInfo{ + BasePath: "bench", + Size: dataSize, + HasUncompressed: true, + Uncompressed: getter, + } + c := newBenchChunker(b, assets, blockSize) + defer c.Close() + + runCacheHitLeaf(b, dataSize, blockSize, func(ctx context.Context, off, length int64) ([]byte, error) { + return c.GetBlock(ctx, off, length, nil) + }) + }) + }) + } +} + +// --------------------------------------------------------------------------- +// BenchmarkColdConcurrent +// +// GCS/ +// +// no-frame/ +// block=4KB/ +// Legacy +// Uncompressed +// frame=1MB/ +// block=4KB/ +// LZ4 +// Zstd1 +// Zstd3 +// +// NFS/ +// +// ... +// +// --------------------------------------------------------------------------- +func BenchmarkColdConcurrent(b *testing.B) { + data := generateSemiRandomData(benchDataSize) + dataSize := int64(len(data)) + + // Precompute frame tables so CompressBytes runs once per combo, not per profile. + type ftEntry struct { + ft *storage.FrameTable + } + type ftKey struct { + frameSize int + codecIdx int + } + + frameTables := make(map[ftKey]ftEntry) + + for _, frameSize := range benchFrameSizes { + for ci, codec := range benchCodecs { + _, ft, err := storage.CompressBytes(context.Background(), data, &storage.FramedUploadOptions{ + CompressionType: codec.compressionType, + Level: codec.level, + CompressionConcurrency: 1, + TargetFrameSize: frameSize, + MaxUncompressedFrameSize: storage.DefaultMaxFrameUncompressedSize, + TargetPartSize: 50 * 1024 * 1024, + }) + require.NoError(b, err) + + frameTables[ftKey{frameSize, ci}] = ftEntry{ft} + } + } + + legacyFactory := func(tb testing.TB, slow *slowFrameGetter, blockSize int64) coldSetup { + tb.Helper() + + c := newFullFetchBench(tb, slow, dataSize, blockSize) + + return coldSetup{ + read: func(ctx context.Context, off, length int64) ([]byte, error) { return c.Slice(ctx, off, length) }, + close: func() { c.Close() }, + fetchCount: func() int64 { return slow.fetchCount.Load() }, + storeBytes: benchDataSize, + } + } + + uncompressedFactory := func(tb testing.TB, slow *slowFrameGetter, blockSize int64) coldSetup { + tb.Helper() + + assets := AssetInfo{ + BasePath: "bench", + Size: dataSize, + HasUncompressed: true, + Uncompressed: slow, + } + c := newBenchChunker(tb, assets, blockSize) + + return coldSetup{ + read: func(ctx context.Context, off, length int64) ([]byte, error) { return c.GetBlock(ctx, off, length, nil) }, + close: func() { c.Close() }, + fetchCount: func() int64 { return slow.fetchCount.Load() }, + storeBytes: benchDataSize, + } + } + + for _, profile := range profiles { + b.Run(profile.name, func(b *testing.B) { + // Uncompressed: no-frame → block → {Legacy, Uncompressed} + b.Run("no-frame", func(b *testing.B) { + for _, blockSize := range benchBlockSizes { + b.Run(fmt.Sprintf("block=%s", fmtSize(blockSize)), func(b *testing.B) { + b.Run("Legacy", func(b *testing.B) { + runColdLeaf(b, data, blockSize, profile, legacyFactory) + }) + b.Run("Uncompressed", func(b *testing.B) { + runColdLeaf(b, data, blockSize, profile, uncompressedFactory) + }) + }) + } + }) + + // Compressed: frame → block → codec + for _, frameSize := range benchFrameSizes { + b.Run(fmt.Sprintf("frame=%s", fmtSize(int64(frameSize))), func(b *testing.B) { + for _, blockSize := range benchBlockSizes { + b.Run(fmt.Sprintf("block=%s", fmtSize(blockSize)), func(b *testing.B) { + for ci, codec := range benchCodecs { + ft := frameTables[ftKey{frameSize, ci}].ft + cBytes := frameTableCompressedSize(ft) + + b.Run(codec.name, func(b *testing.B) { + runColdLeaf(b, data, blockSize, profile, func(tb testing.TB, slow *slowFrameGetter, blockSize int64) coldSetup { + tb.Helper() + + assets := AssetInfo{ + BasePath: "bench", + Size: dataSize, + } + setCompressedAsset(&assets, codec.compressionType, slow) + c := newBenchChunker(tb, assets, blockSize) + + return coldSetup{ + read: func(ctx context.Context, off, length int64) ([]byte, error) { return c.GetBlock(ctx, off, length, ft) }, + close: func() { c.Close() }, + fetchCount: func() int64 { return slow.fetchCount.Load() }, + storeBytes: cBytes, + } + }) + }) + } + }) + } + }) + } + }) + } +} diff --git a/packages/orchestrator/internal/sandbox/block/chunk_framed.go b/packages/orchestrator/internal/sandbox/block/chunk_framed.go new file mode 100644 index 0000000000..5a7557b739 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/chunk_framed.go @@ -0,0 +1,388 @@ +package block + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" +) + +const ( + compressedAttr = "compressed" + + // decompressFetchTimeout is the maximum time a single frame/chunk fetch may take. + decompressFetchTimeout = 60 * time.Second + + // defaultMinReadBatchSize is the floor for the read batch size when blockSize + // is very small (e.g. 4KB rootfs). The actual batch is max(blockSize, minReadBatchSize). + defaultMinReadBatchSize = 256 * 1024 // 256 KB +) + +// AssetInfo describes which storage variants exist for a build artifact. +type AssetInfo struct { + BasePath string // uncompressed path (e.g., "build-123/memfile") + Size int64 // uncompressed size (from either source) + HasUncompressed bool // true if the uncompressed object exists in storage + HasLZ4 bool // true if a .lz4 compressed variant exists + HasZstd bool // true if a .zstd compressed variant exists + + // Opened FramedFile handles — may be nil if the corresponding asset doesn't exist. + Uncompressed storage.FramedFile + LZ4 storage.FramedFile + Zstd storage.FramedFile +} + +// HasCompressed reports whether a compressed asset matching ft's type exists. +func (a *AssetInfo) HasCompressed(ft *storage.FrameTable) bool { + if ft == nil { + return false + } + + switch ft.CompressionType { + case storage.CompressionLZ4: + return a.HasLZ4 + case storage.CompressionZstd: + return a.HasZstd + default: + return false + } +} + +// CompressedFile returns the FramedFile for the compression type in ft, or nil. +func (a *AssetInfo) CompressedFile(ft *storage.FrameTable) storage.FramedFile { + if ft == nil { + return nil + } + + switch ft.CompressionType { + case storage.CompressionLZ4: + return a.LZ4 + case storage.CompressionZstd: + return a.Zstd + default: + return nil + } +} + +// flagsClient is the subset of featureflags.Client used by Chunker. +// Extracted as an interface so benchmarks and tests can supply lightweight fakes. +type flagsClient interface { + JSONFlag(ctx context.Context, flag featureflags.JSONFlag, ldctx ...ldcontext.Context) ldvalue.Value +} + +type precomputedAttrs struct { + successFromCache metric.MeasurementOption + successFromRemote metric.MeasurementOption + + failCacheRead metric.MeasurementOption + failRemoteFetch metric.MeasurementOption + failLocalReadAgain metric.MeasurementOption + + begin attribute.KeyValue +} + +func precomputeAttributes(isCompressed bool) precomputedAttrs { + compressed := attribute.Bool(compressedAttr, isCompressed) + + return precomputedAttrs{ + successFromCache: telemetry.PrecomputeAttrs( + telemetry.Success, compressed, + attribute.String(pullType, pullTypeLocal)), + + successFromRemote: telemetry.PrecomputeAttrs( + telemetry.Success, compressed, + attribute.String(pullType, pullTypeRemote)), + + failCacheRead: telemetry.PrecomputeAttrs( + telemetry.Failure, compressed, + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalRead)), + + failRemoteFetch: telemetry.PrecomputeAttrs( + telemetry.Failure, compressed, + attribute.String(pullType, pullTypeRemote), + attribute.String(failureReason, failureTypeCacheFetch)), + + failLocalReadAgain: telemetry.PrecomputeAttrs( + telemetry.Failure, compressed, + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalReadAgain)), + + begin: compressed, + } +} + +var ( + precomputedCompressed = precomputeAttributes(true) + precomputedUncompressed = precomputeAttributes(false) +) + +func attrs(compressed bool) precomputedAttrs { + if compressed { + return precomputedCompressed + } + + return precomputedUncompressed +} + +type Chunker struct { + assets AssetInfo + + cache *Cache + metrics metrics.Metrics + flags flagsClient + + sessionsMu sync.Mutex + sessions []*fetchSession +} + +var _ Reader = (*Chunker)(nil) + +// NewChunker creates a Chunker backed by a new mmap cache at cachePath. +func NewChunker( + assets AssetInfo, + blockSize int64, + cachePath string, + m metrics.Metrics, + flags flagsClient, +) (*Chunker, error) { + cache, err := NewCache(assets.Size, blockSize, cachePath, false) + if err != nil { + return nil, fmt.Errorf("failed to create cache: %w", err) + } + + return &Chunker{ + assets: assets, + cache: cache, + metrics: m, + flags: flags, + }, nil +} + +func (c *Chunker) ReadBlock(ctx context.Context, b []byte, off int64, ft *storage.FrameTable) (int, error) { + block, err := c.GetBlock(ctx, off, int64(len(b)), ft) + if err != nil { + return 0, fmt.Errorf("failed to get block at %d-%d: %w", off, off+int64(len(b)), err) + } + + return copy(b, block), nil +} + +// GetBlock returns a reference to the mmap cache at the given uncompressed +// offset. On cache miss, fetches from storage into the cache first. +func (c *Chunker) GetBlock(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) { + compressed := c.assets.HasCompressed(ft) + attrs := attrs(compressed) + timer := c.metrics.BlocksTimerFactory.Begin(attrs.begin) + + // Fast path: already in mmap cache. No timer allocation — cache hits + // record only counters (zero-alloc precomputed attributes). + b, err := c.cache.Slice(off, length) + if err == nil { + timer.Record(ctx, length, attrs.successFromCache) + + return b, nil + } + + var bytesNotAvailableError BytesNotAvailableError + if !errors.As(err, &bytesNotAvailableError) { + timer.Record(ctx, length, attrs.failCacheRead) + + return nil, fmt.Errorf("failed read from cache at offset %d: %w", off, err) + } + + session, sessionErr := c.getOrCreateSession(ctx, off, ft, compressed) + if sessionErr != nil { + timer.Record(ctx, length, attrs.failRemoteFetch) + + return nil, sessionErr + } + + if err := session.registerAndWait(ctx, off, length); err != nil { + timer.Record(ctx, length, attrs.failRemoteFetch) + + return nil, fmt.Errorf("failed to fetch data at %#x: %w", off, err) + } + + b, cacheErr := c.cache.Slice(off, length) + if cacheErr != nil { + timer.Record(ctx, length, attrs.failLocalReadAgain) + + return nil, fmt.Errorf("failed to read from cache after fetch at %d-%d: %w", off, off+length, cacheErr) + } + + timer.Record(ctx, length, attrs.successFromRemote) + + return b, nil +} + +// getOrCreateSession returns an existing session covering [off, off+...) or +// creates a new one. Session boundaries are frame-aligned for compressed +// requests and MemoryChunkSize-aligned for uncompressed requests. +// +// Deduplication is handled by the sessionList: if an active session's range +// contains the requested offset, the caller joins it instead of creating a +// new fetch. +func (c *Chunker) getOrCreateSession(ctx context.Context, off int64, ft *storage.FrameTable, useCompressed bool) (*fetchSession, error) { + var ( + chunkOff int64 + chunkLen int64 + decompress bool + ) + + if useCompressed { + frameStarts, frameSize, err := ft.FrameFor(off) + if err != nil { + return nil, fmt.Errorf("failed to get frame for offset %#x: %w", off, err) + } + + chunkOff = frameStarts.U + chunkLen = int64(frameSize.U) + decompress = true + } else { + chunkOff = (off / storage.MemoryChunkSize) * storage.MemoryChunkSize + chunkLen = min(int64(storage.MemoryChunkSize), c.assets.Size-chunkOff) + decompress = false + } + + session, isNew := c.getOrCreateFetchSession(chunkOff, chunkLen) + + if isNew { + go c.runFetch(context.WithoutCancel(ctx), session, chunkOff, ft, decompress) + } + + return session, nil +} + +// runFetch fetches data from storage into the mmap cache. Runs in a background goroutine. +// Works for both compressed (decompress=true, ft!=nil) and uncompressed (decompress=false, ft=nil) paths. +func (c *Chunker) runFetch(ctx context.Context, s *fetchSession, offsetU int64, ft *storage.FrameTable, decompress bool) { + ctx, cancel := context.WithTimeout(ctx, decompressFetchTimeout) + defer cancel() + + // Remove session from active list after completion. + defer c.releaseFetchSession(s) + + defer func() { + if r := recover(); r != nil { + s.setError(fmt.Errorf("fetch panicked: %v", r), true) + } + }() + + // Get mmap region for the fetch target. + mmapSlice, releaseLock, err := c.cache.addressBytes(s.chunkOff, s.chunkLen) + if err != nil { + s.setError(err, false) + + return + } + defer releaseLock() + + fetchSW := c.metrics.RemoteReadsTimerFactory.Begin( + attribute.Bool(compressedAttr, decompress), + ) + + // Compute read batch size from feature flag. This controls how frequently + // onRead fires (progress granularity). Deliberately independent of blockSize + // to avoid a broadcast-wake storm when blockSize is small. + readSize := int64(defaultMinReadBatchSize) + if v := c.flags.JSONFlag(ctx, featureflags.ChunkerConfigFlag).AsValueMap().Get("minReadBatchSizeKB"); v.IsNumber() { + readSize = int64(v.IntValue()) * 1024 + } + + // Build onRead callback: publishes blocks to mmap cache and wakes waiters + // as each readSize-aligned chunk arrives. + var prevTotal int64 + onRead := func(totalWritten int64) { + newBytes := totalWritten - prevTotal + c.cache.setIsCached(s.chunkOff+prevTotal, newBytes) + s.advance(totalWritten) + prevTotal = totalWritten + } + + var handle storage.FramedFile + if decompress { + handle = c.assets.CompressedFile(ft) + } else { + handle = c.assets.Uncompressed + } + + _, err = handle.GetFrame(ctx, offsetU, ft, decompress, mmapSlice[:s.chunkLen], readSize, onRead) + if err != nil { + fetchSW.Failure(ctx, s.chunkLen, + attribute.String(failureReason, failureTypeRemoteRead)) + s.setError(fmt.Errorf("failed to fetch data at %#x: %w", offsetU, err), false) + + return + } + + fetchSW.Success(ctx, s.chunkLen) + s.setDone() +} + +func (c *Chunker) Close() error { + return c.cache.Close() +} + +func (c *Chunker) FileSize() (int64, error) { + return c.cache.FileSize() +} + +// getOrCreateFetchSession returns an existing session whose range contains +// [off, off+len) or creates a new one. At most ~4-8 sessions are active at +// a time so a linear scan is sufficient. +func (c *Chunker) getOrCreateFetchSession(off, length int64) (*fetchSession, bool) { + c.sessionsMu.Lock() + defer c.sessionsMu.Unlock() + + for _, s := range c.sessions { + if s.chunkOff <= off && s.chunkOff+s.chunkLen >= off+length { + return s, false + } + } + + s := newFetchSession(off, length, c.cache.BlockSize(), c.cache.isCached) + c.sessions = append(c.sessions, s) + + return s, true +} + +// releaseFetchSession removes s from the active list (swap-delete). +func (c *Chunker) releaseFetchSession(s *fetchSession) { + c.sessionsMu.Lock() + defer c.sessionsMu.Unlock() + + for i, a := range c.sessions { + if a == s { + c.sessions[i] = c.sessions[len(c.sessions)-1] + c.sessions[len(c.sessions)-1] = nil + c.sessions = c.sessions[:len(c.sessions)-1] + + return + } + } +} + +const ( + pullType = "pull-type" + pullTypeLocal = "local" + pullTypeRemote = "remote" + + failureReason = "failure-reason" + + failureTypeLocalRead = "local-read" + failureTypeLocalReadAgain = "local-read-again" + failureTypeRemoteRead = "remote-read" + failureTypeCacheFetch = "cache-fetch" +) diff --git a/packages/orchestrator/internal/sandbox/block/chunker_test.go b/packages/orchestrator/internal/sandbox/block/chunker_test.go new file mode 100644 index 0000000000..1199b0e653 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/chunker_test.go @@ -0,0 +1,970 @@ +package block + +import ( + "bytes" + "context" + "crypto/rand" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +// --------------------------------------------------------------------------- +// Test constants +// --------------------------------------------------------------------------- + +const ( + testFrameSize = 256 * 1024 // 256 KB per frame for fast tests + testFileSize = testFrameSize * 4 +) + +// --------------------------------------------------------------------------- +// Test fakes +// --------------------------------------------------------------------------- + +// slowFrameGetter implements storage.FramedFile for testing and benchmarks. +// Serves raw uncompressed data with optional latency (ttfb) and bandwidth +// simulation. Used as both the Uncompressed and compressed FramedFile handle +// (Chunker always passes decompress=true, so real decompression never happens). +type slowFrameGetter struct { + data []byte + ttfb time.Duration + bandwidth int64 // bytes/sec; 0 = instant + fetchCount atomic.Int64 +} + +var _ storage.FramedFile = (*slowFrameGetter)(nil) + +func (s *slowFrameGetter) Size(_ context.Context) (int64, error) { + return int64(len(s.data)), nil +} + +func (s *slowFrameGetter) StoreFile(context.Context, string, *storage.FramedUploadOptions) (*storage.FrameTable, error) { + panic("slowFrameGetter: StoreFile not used in tests") +} + +func (s *slowFrameGetter) GetFrame(_ context.Context, offsetU int64, _ *storage.FrameTable, _ bool, buf []byte, readSize int64, onRead func(int64)) (storage.Range, error) { + s.fetchCount.Add(1) + + if s.ttfb > 0 { + time.Sleep(s.ttfb) + } + + end := min(offsetU+int64(len(buf)), int64(len(s.data))) + n := copy(buf, s.data[offsetU:end]) + + // Progressive delivery with optional bandwidth simulation. + if onRead != nil { + batch := readSize + if batch <= 0 { + batch = int64(n) + } + + for written := batch; written <= int64(n); written += batch { + if s.bandwidth > 0 { + delay := time.Duration(float64(batch) / float64(s.bandwidth) * float64(time.Second)) + time.Sleep(delay) + } + onRead(written) + } + if int64(n)%batch != 0 { + tail := int64(n) % batch + if s.bandwidth > 0 { + delay := time.Duration(float64(tail) / float64(s.bandwidth) * float64(time.Second)) + time.Sleep(delay) + } + onRead(int64(n)) + } + } + + return storage.Range{Start: offsetU, Length: n}, nil +} + +// makeCompressedTestData builds a synthetic FrameTable with testFrameSize +// boundaries and a slowFrameGetter that serves the original data. The C sizes +// are set equal to U sizes since Chunker only uses U-space values. +func makeCompressedTestData(tb testing.TB, data []byte, ttfb time.Duration) (*storage.FrameTable, *slowFrameGetter) { + tb.Helper() + + ft := &storage.FrameTable{CompressionType: storage.CompressionLZ4} + for off := 0; off < len(data); off += testFrameSize { + u := int32(min(testFrameSize, len(data)-off)) + ft.Frames = append(ft.Frames, storage.FrameSize{U: u, C: u}) + } + + return ft, &slowFrameGetter{data: data, ttfb: ttfb} +} + +// testProgressiveStorage implements storage.FramedFile with progressive +// batch delivery and injectable faults. Used by the ported progressive tests. +type testProgressiveStorage struct { + data []byte + batchDelay time.Duration // delay between onRead callbacks + failAfter int64 // absolute U-offset to error at (-1 = disabled) + panicAfter int64 // absolute U-offset to panic at (-1 = disabled) + gate chan struct{} // if non-nil, GetFrame blocks until closed + fetchCount atomic.Int64 +} + +var _ storage.FramedFile = (*testProgressiveStorage)(nil) + +func (p *testProgressiveStorage) Size(_ context.Context) (int64, error) { + return int64(len(p.data)), nil +} + +func (p *testProgressiveStorage) StoreFile(_ context.Context, _ string, _ *storage.FramedUploadOptions) (*storage.FrameTable, error) { + return nil, fmt.Errorf("testProgressiveStorage: StoreFile not supported") +} + +func (p *testProgressiveStorage) GetFrame(_ context.Context, offsetU int64, ft *storage.FrameTable, _ bool, buf []byte, readSize int64, onRead func(int64)) (storage.Range, error) { + p.fetchCount.Add(1) + + if p.gate != nil { + <-p.gate + } + + // Determine the copy region. + var srcStart, srcEnd int64 + if ft != nil { + starts, size, err := ft.FrameFor(offsetU) + if err != nil { + return storage.Range{}, fmt.Errorf("testProgressiveStorage: %w", err) + } + srcStart = starts.U + srcEnd = min(starts.U+int64(size.U), int64(len(p.data))) + } else { + srcStart = offsetU + srcEnd = min(offsetU+int64(len(buf)), int64(len(p.data))) + } + + batchSize := int64(testBlockSize) + if readSize > 0 { + batchSize = readSize + } + + var written int64 + for pos := srcStart; pos < srcEnd; pos += batchSize { + end := min(pos+batchSize, srcEnd) + relStart := pos - srcStart + relEnd := end - srcStart + + // Check fault injection before each batch. + if p.panicAfter >= 0 && pos >= p.panicAfter { + panic("simulated upstream panic") + } + if p.failAfter >= 0 && pos >= p.failAfter { + // Notify what we have so far, then error. + if onRead != nil && written > 0 { + onRead(written) + } + + return storage.Range{Start: srcStart, Length: int(written)}, fmt.Errorf("simulated upstream error at offset %d", pos) + } + + copy(buf[relStart:relEnd], p.data[pos:end]) + written = relEnd + + if p.batchDelay > 0 { + time.Sleep(p.batchDelay) + } + + if onRead != nil { + onRead(written) + } + } + + return storage.Range{Start: srcStart, Length: int(written)}, nil +} + +// --------------------------------------------------------------------------- +// Test case helpers +// --------------------------------------------------------------------------- + +type chunkerTestCase struct { + name string + newChunker func(t *testing.T, data []byte, delay time.Duration) (*Chunker, *storage.FrameTable) +} + +func allChunkerTestCases() []chunkerTestCase { + return []chunkerTestCase{ + { + name: "Chunker_Compressed", + newChunker: func(t *testing.T, data []byte, delay time.Duration) (*Chunker, *storage.FrameTable) { + t.Helper() + ft, getter := makeCompressedTestData(t, data, delay) + c, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasLZ4: true, + Uncompressed: getter, + LZ4: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + + return c, ft + }, + }, + { + name: "Chunker_Uncompressed", + newChunker: func(t *testing.T, data []byte, delay time.Duration) (*Chunker, *storage.FrameTable) { + t.Helper() + getter := &slowFrameGetter{data: data, ttfb: delay} + c, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasUncompressed: true, + Uncompressed: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + + return c, nil + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Concurrency tests (from chunker_concurrency_test.go) +// --------------------------------------------------------------------------- + +func TestChunker_ConcurrentSameOffset(t *testing.T) { + t.Parallel() + + for _, tc := range allChunkerTestCases() { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + chunker, ft := tc.newChunker(t, data, 100*time.Microsecond) + defer chunker.Close() + + const numGoroutines = 20 + off := int64(0) + readLen := int64(testBlockSize) + + results := make([][]byte, numGoroutines) + var eg errgroup.Group + + for i := range numGoroutines { + eg.Go(func() error { + slice, err := chunker.GetBlock(t.Context(), off, readLen, ft) + if err != nil { + return fmt.Errorf("goroutine %d: %w", i, err) + } + results[i] = make([]byte, len(slice)) + copy(results[i], slice) + + return nil + }) + } + + require.NoError(t, eg.Wait()) + + for i := range numGoroutines { + assert.Equal(t, data[off:off+readLen], results[i], + "goroutine %d got wrong data", i) + } + }) + } +} + +func TestChunker_ConcurrentDifferentOffsets(t *testing.T) { + t.Parallel() + + for _, tc := range allChunkerTestCases() { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + chunker, ft := tc.newChunker(t, data, 50*time.Microsecond) + defer chunker.Close() + + const numGoroutines = 10 + readLen := int64(testBlockSize) + + // Pick offsets spread across the file. + offsets := make([]int64, numGoroutines) + for i := range numGoroutines { + offsets[i] = int64(i) * readLen + if offsets[i]+readLen > int64(len(data)) { + offsets[i] = 0 + } + } + + results := make([][]byte, numGoroutines) + var eg errgroup.Group + + for i := range numGoroutines { + eg.Go(func() error { + slice, err := chunker.GetBlock(t.Context(), offsets[i], readLen, ft) + if err != nil { + return fmt.Errorf("goroutine %d (off=%d): %w", i, offsets[i], err) + } + results[i] = make([]byte, len(slice)) + copy(results[i], slice) + + return nil + }) + } + + require.NoError(t, eg.Wait()) + + for i := range numGoroutines { + assert.Equal(t, data[offsets[i]:offsets[i]+readLen], results[i], + "goroutine %d got wrong data", i) + } + }) + } +} + +func TestChunker_ConcurrentMixed(t *testing.T) { + t.Parallel() + + for _, tc := range allChunkerTestCases() { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + chunker, ft := tc.newChunker(t, data, 50*time.Microsecond) + defer chunker.Close() + + // Mix of ReadBlock, GetBlock, and repeated same-offset reads. + const numGoroutines = 15 + readLen := int64(testBlockSize) + + var eg errgroup.Group + + for i := range numGoroutines { + off := int64((i % 4) * testBlockSize) // 4 distinct offsets + eg.Go(func() error { + if i%2 == 0 { + // GetBlock path + slice, err := chunker.GetBlock(t.Context(), off, readLen, ft) + if err != nil { + return fmt.Errorf("goroutine %d GetBlock: %w", i, err) + } + if !bytes.Equal(data[off:off+readLen], slice) { + return fmt.Errorf("goroutine %d GetBlock: data mismatch at off=%d", i, off) + } + } else { + // ReadBlock path + buf := make([]byte, readLen) + n, err := chunker.ReadBlock(t.Context(), buf, off, ft) + if err != nil { + return fmt.Errorf("goroutine %d ReadBlock: %w", i, err) + } + if !bytes.Equal(data[off:off+int64(n)], buf[:n]) { + return fmt.Errorf("goroutine %d ReadBlock: data mismatch at off=%d", i, off) + } + } + + return nil + }) + } + + require.NoError(t, eg.Wait()) + }) + } +} + +func TestChunker_ConcurrentStress(t *testing.T) { + t.Parallel() + + for _, tc := range allChunkerTestCases() { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + chunker, ft := tc.newChunker(t, data, 0) // no delay for stress + defer chunker.Close() + + const numGoroutines = 50 + const opsPerGoroutine = 5 + readLen := int64(testBlockSize) + + var eg errgroup.Group + + for i := range numGoroutines { + eg.Go(func() error { + for j := range opsPerGoroutine { + off := int64(((i*opsPerGoroutine)+j)%(len(data)/int(readLen))) * readLen + slice, err := chunker.GetBlock(t.Context(), off, readLen, ft) + if err != nil { + return fmt.Errorf("goroutine %d op %d: %w", i, j, err) + } + if !bytes.Equal(data[off:off+readLen], slice) { + return fmt.Errorf("goroutine %d op %d: data mismatch at off=%d", i, j, off) + } + } + + return nil + }) + } + + require.NoError(t, eg.Wait()) + }) + } +} + +func TestChunker_ConcurrentReadBlock_CrossFrame(t *testing.T) { + t.Parallel() + + // Test cross-frame ReadBlock for both compressed and uncompressed modes. + for _, tc := range allChunkerTestCases() { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + chunker, ft := tc.newChunker(t, data, 50*time.Microsecond) + defer chunker.Close() + + const numGoroutines = 10 + + // Read spanning multiple blocks/frames. + readLen := testBlockSize * 2 + if int64(readLen) > int64(len(data)) { + readLen = len(data) + } + + var eg errgroup.Group + + for i := range numGoroutines { + off := int64(0) // all read from start + eg.Go(func() error { + buf := make([]byte, readLen) + n, err := chunker.ReadBlock(t.Context(), buf, off, ft) + if err != nil { + return fmt.Errorf("goroutine %d: %w", i, err) + } + if !bytes.Equal(data[off:off+int64(n)], buf[:n]) { + return fmt.Errorf("goroutine %d: data mismatch", i) + } + + return nil + }) + } + + require.NoError(t, eg.Wait()) + }) + } +} + +// TestChunker_FetchDedup verifies that concurrent requests for the same data +// don't cause duplicate upstream fetches. +func TestChunker_FetchDedup(t *testing.T) { + t.Parallel() + + t.Run("DecompressMMapChunker_Compressed", func(t *testing.T) { + t.Parallel() + + data := make([]byte, testFileSize) + _, err := rand.Read(data) + require.NoError(t, err) + + ft, getter := makeCompressedTestData(t, data, 10*time.Millisecond) + + chunker, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasLZ4: true, + Uncompressed: getter, + LZ4: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + defer chunker.Close() + + const numGoroutines = 10 + + var eg errgroup.Group + for range numGoroutines { + eg.Go(func() error { + // All request offset 0 (same frame). + _, err := chunker.GetBlock(t.Context(), 0, testBlockSize, ft) + + return err + }) + } + require.NoError(t, eg.Wait()) + + // With frameFlight dedup, only 1 fetch should have happened. + assert.Equal(t, int64(1), getter.fetchCount.Load(), + "expected 1 fetch (dedup), got %d", getter.fetchCount.Load()) + }) +} + +// TestChunker_DualMode_SharedCache verifies that a single chunker +// instance correctly serves both compressed and uncompressed callers, sharing +// the mmap cache across modes. If region X is fetched via compressed path, +// a subsequent uncompressed request for region X is served from cache (no fetch). +func TestChunker_DualMode_SharedCache(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + ft, getter := makeCompressedTestData(t, data, 0) + + // Create ONE chunker with both compressed and uncompressed assets available. + chunker, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasLZ4: true, + HasUncompressed: true, + Uncompressed: getter, + LZ4: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + defer chunker.Close() + + readLen := int64(testBlockSize) + + // --- Phase 1: Compressed caller fetches frame 0 --- + slice1, err := chunker.GetBlock(t.Context(), 0, readLen, ft) + require.NoError(t, err) + assert.Equal(t, data[0:readLen], slice1, "compressed read: data mismatch at offset 0") + + fetchesAfterPhase1 := getter.fetchCount.Load() + assert.Equal(t, int64(1), fetchesAfterPhase1, "expected 1 fetch for frame 0") + + // --- Phase 2: Uncompressed caller reads offset 0 — should be served from cache --- + slice2, err := chunker.GetBlock(t.Context(), 0, readLen, nil) + require.NoError(t, err) + assert.Equal(t, data[0:readLen], slice2, "uncompressed read from cache: data mismatch at offset 0") + + // No new fetches should have occurred. + assert.Equal(t, fetchesAfterPhase1, getter.fetchCount.Load(), + "uncompressed read of cached region should not trigger any fetch") + + // --- Phase 3: Uncompressed caller reads a new region (frame 1) --- + frame1Off := int64(testFrameSize) // start of frame 1 + slice3, err := chunker.GetBlock(t.Context(), frame1Off, readLen, nil) + require.NoError(t, err) + assert.Equal(t, data[frame1Off:frame1Off+readLen], slice3, + "uncompressed read: data mismatch at frame 1") + + // This should have triggered a new fetch via GetFrame (uncompressed path). + assert.Greater(t, getter.fetchCount.Load(), fetchesAfterPhase1, + "new region should trigger a fetch") + fetchesAfterPhase3 := getter.fetchCount.Load() + + // --- Phase 4: Compressed caller reads frame 1 — should be served from cache --- + slice4, err := chunker.GetBlock(t.Context(), frame1Off, readLen, ft) + require.NoError(t, err) + assert.Equal(t, data[frame1Off:frame1Off+readLen], slice4, + "compressed read from cache: data mismatch at frame 1") + + // No new fetches for frame 1. + assert.Equal(t, fetchesAfterPhase3, getter.fetchCount.Load(), + "compressed read of cached region should not trigger new fetch") +} + +// --------------------------------------------------------------------------- +// Progressive delivery tests (ported from main's streaming_chunk_test.go) +// --------------------------------------------------------------------------- + +// TestChunker_BasicGetBlock is a simple smoke test: read one block at offset 0. +func TestChunker_BasicGetBlock(t *testing.T) { + t.Parallel() + + for _, tc := range allChunkerTestCases() { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + chunker, ft := tc.newChunker(t, data, 0) + defer chunker.Close() + + slice, err := chunker.GetBlock(t.Context(), 0, testBlockSize, ft) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) + }) + } +} + +// TestChunker_FullChunkCachedAfterPartialRequest verifies that requesting the +// first block triggers a full background fetch of the entire chunk/frame, so +// the last block becomes available without additional upstream fetches. +func TestChunker_FullChunkCachedAfterPartialRequest(t *testing.T) { + t.Parallel() + + t.Run("Compressed", func(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + ft, getter := makeCompressedTestData(t, data, 0) + + chunker, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasLZ4: true, + Uncompressed: getter, + LZ4: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + defer chunker.Close() + + // Request only the FIRST block (triggers fetch of entire frame). + _, err = chunker.GetBlock(t.Context(), 0, testBlockSize, ft) + require.NoError(t, err) + + // The entire frame should now be cached. The last block of frame 0 + // should be available without triggering an additional fetch. + lastBlockInFrame := int64(testFrameSize) - testBlockSize + require.Eventually(t, func() bool { + slice, err := chunker.GetBlock(t.Context(), lastBlockInFrame, testBlockSize, ft) + if err != nil { + return false + } + + return bytes.Equal(data[lastBlockInFrame:lastBlockInFrame+testBlockSize], slice) + }, 5*time.Second, 10*time.Millisecond) + + assert.Equal(t, int64(1), getter.fetchCount.Load(), + "expected 1 fetch (full frame cached in background), got %d", getter.fetchCount.Load()) + }) + + t.Run("Uncompressed", func(t *testing.T) { + t.Parallel() + + data := makeTestData(t, storage.MemoryChunkSize) + getter := &slowFrameGetter{data: data} + + chunker, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasUncompressed: true, + Uncompressed: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + defer chunker.Close() + + _, err = chunker.GetBlock(t.Context(), 0, testBlockSize, nil) + require.NoError(t, err) + + lastOff := int64(storage.MemoryChunkSize) - testBlockSize + require.Eventually(t, func() bool { + slice, err := chunker.GetBlock(t.Context(), lastOff, testBlockSize, nil) + if err != nil { + return false + } + + return bytes.Equal(data[lastOff:lastOff+testBlockSize], slice) + }, 5*time.Second, 10*time.Millisecond) + + assert.Equal(t, int64(1), getter.fetchCount.Load(), + "expected 1 fetch (full chunk cached in background), got %d", getter.fetchCount.Load()) + }) +} + +// TestChunker_EarlyReturn verifies progressive delivery: earlier offsets +// complete before later offsets within the same chunk. +func TestChunker_EarlyReturn(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + gate := make(chan struct{}) + + getter := &testProgressiveStorage{ + data: data, + batchDelay: 50 * time.Microsecond, + failAfter: -1, + panicAfter: -1, + gate: gate, + } + + chunker, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasUncompressed: true, + Uncompressed: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + defer chunker.Close() + + // Request blocks at different offsets, recording completion order. + var mu sync.Mutex + var order []int64 + + offsets := []int64{ + 0, + int64(testFileSize/2) - testBlockSize, + int64(testFileSize) - testBlockSize, + } + + var eg errgroup.Group + for _, off := range offsets { + eg.Go(func() error { + _, err := chunker.GetBlock(t.Context(), off, testBlockSize, nil) + if err != nil { + return err + } + + mu.Lock() + order = append(order, off) + mu.Unlock() + + return nil + }) + } + + // Let the goroutines register, then release the gate. + time.Sleep(5 * time.Millisecond) + close(gate) + + require.NoError(t, eg.Wait()) + + // The first offset should complete first (progressive delivery). + require.Len(t, order, 3) + assert.Equal(t, int64(0), order[0], + "expected offset 0 to complete first, got order: %v", order) +} + +// TestChunker_ErrorKeepsPartialData verifies that an upstream error at the +// midpoint of a chunk still allows data before the error to be served. +func TestChunker_ErrorKeepsPartialData(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + + getter := &testProgressiveStorage{ + data: data, + failAfter: int64(testFileSize / 2), + panicAfter: -1, + } + + chunker, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasUncompressed: true, + Uncompressed: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + defer chunker.Close() + + // Request the last block — should fail because upstream dies at midpoint. + lastOff := int64(testFileSize) - testBlockSize + _, err = chunker.GetBlock(t.Context(), lastOff, testBlockSize, nil) + require.Error(t, err) + + // First block (within the first half) should still be cached and servable. + slice, err := chunker.GetBlock(t.Context(), 0, testBlockSize, nil) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) +} + +// TestChunker_ContextCancellation verifies that a cancelled caller context +// doesn't kill the background fetch — another caller can still get data. +func TestChunker_ContextCancellation(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + + getter := &testProgressiveStorage{ + data: data, + batchDelay: 100 * time.Microsecond, + failAfter: -1, + panicAfter: -1, + } + + chunker, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasUncompressed: true, + Uncompressed: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + defer chunker.Close() + + // Request with a short-lived context — should fail. + ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond) + defer cancel() + + lastOff := int64(testFileSize) - testBlockSize + _, err = chunker.GetBlock(ctx, lastOff, testBlockSize, nil) + require.Error(t, err) + + // Wait for the background fetch to complete. + time.Sleep(200 * time.Millisecond) + + // Another caller with a valid context should still get the data. + slice, err := chunker.GetBlock(t.Context(), 0, testBlockSize, nil) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) +} + +// TestChunker_LastBlockPartial verifies correct handling of a file whose size +// is not aligned to blockSize — the final block is shorter than blockSize. +func TestChunker_LastBlockPartial(t *testing.T) { + t.Parallel() + + // File size not aligned to blockSize. + size := testFileSize - 100 + data := makeTestData(t, size) + + for _, tc := range []chunkerTestCase{ + { + name: "Uncompressed", + newChunker: func(t *testing.T, data []byte, _ time.Duration) (*Chunker, *storage.FrameTable) { + t.Helper() + getter := &slowFrameGetter{data: data} + c, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasUncompressed: true, + Uncompressed: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + + return c, nil + }, + }, + { + name: "Compressed", + newChunker: func(t *testing.T, data []byte, _ time.Duration) (*Chunker, *storage.FrameTable) { + t.Helper() + ft, getter := makeCompressedTestData(t, data, 0) + c, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasLZ4: true, + Uncompressed: getter, + LZ4: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + + return c, ft + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + localData := make([]byte, len(data)) + copy(localData, data) + + chunker, ft := tc.newChunker(t, localData, 0) + defer chunker.Close() + + // Read the last partial block. + lastBlockOff := (int64(size) / testBlockSize) * testBlockSize + remaining := int64(size) - lastBlockOff + + slice, err := chunker.GetBlock(t.Context(), lastBlockOff, remaining, ft) + require.NoError(t, err) + require.Equal(t, localData[lastBlockOff:], slice) + }) + } +} + +// TestChunker_PanicRecovery verifies that an upstream panic is recovered and +// converted to an error. Data before the panic point remains servable. +func TestChunker_PanicRecovery(t *testing.T) { + t.Parallel() + + data := makeTestData(t, testFileSize) + panicAt := int64(testFileSize / 2) + + getter := &testProgressiveStorage{ + data: data, + panicAfter: panicAt, + failAfter: -1, + } + + chunker, err := NewChunker( + AssetInfo{ + BasePath: "test-object", + Size: int64(len(data)), + HasUncompressed: true, + Uncompressed: getter, + }, + testBlockSize, + t.TempDir()+"/cache", + newTestMetrics(t), + newTestFlags(t), + ) + require.NoError(t, err) + defer chunker.Close() + + // Request data past the panic point — should get an error, not hang or crash. + lastOff := int64(testFileSize) - testBlockSize + _, err = chunker.GetBlock(t.Context(), lastOff, testBlockSize, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "panicked") + + // Data before the panic point should still be cached. + slice, err := chunker.GetBlock(t.Context(), 0, testBlockSize, nil) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) +} diff --git a/packages/orchestrator/internal/sandbox/block/chunker_test_helpers_test.go b/packages/orchestrator/internal/sandbox/block/chunker_test_helpers_test.go new file mode 100644 index 0000000000..d33253347b --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/chunker_test_helpers_test.go @@ -0,0 +1,46 @@ +package block + +import ( + "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +const ( + testBlockSize = header.PageSize // 4KB +) + +func newTestMetrics(tb testing.TB) metrics.Metrics { + tb.Helper() + + m, err := metrics.NewMetrics(noop.NewMeterProvider()) + require.NoError(tb, err) + + return m +} + +func newTestFlags(t *testing.T) *featureflags.Client { + t.Helper() + + flags, err := featureflags.NewClient() + require.NoError(t, err) + t.Cleanup(func() { _ = flags.Close(t.Context()) }) + + return flags +} + +func makeTestData(t *testing.T, size int) []byte { + t.Helper() + + data := make([]byte, size) + _, err := rand.Read(data) + require.NoError(t, err) + + return data +} diff --git a/packages/orchestrator/internal/sandbox/block/device.go b/packages/orchestrator/internal/sandbox/block/device.go index 39a1cae845..d4db613f93 100644 --- a/packages/orchestrator/internal/sandbox/block/device.go +++ b/packages/orchestrator/internal/sandbox/block/device.go @@ -8,19 +8,28 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) +// BytesNotAvailableError indicates the requested range is not yet cached. type BytesNotAvailableError struct{} func (BytesNotAvailableError) Error() string { return "The requested bytes are not available on the device" } +// Reader reads data with optional FrameTable for compressed fetch. +type Reader interface { + ReadBlock(ctx context.Context, p []byte, off int64, ft *storage.FrameTable) (int, error) + GetBlock(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) +} + +// Slicer provides plain block reads (no FrameTable). Used by UFFD/NBD. type Slicer interface { Slice(ctx context.Context, off, length int64) ([]byte, error) BlockSize() int64 } type ReadonlyDevice interface { - storage.SeekableReader + ReadAt(ctx context.Context, p []byte, off int64) (int, error) + Size(ctx context.Context) (int64, error) io.Closer Slicer BlockSize() int64 diff --git a/packages/orchestrator/internal/sandbox/block/fetch_session.go b/packages/orchestrator/internal/sandbox/block/fetch_session.go new file mode 100644 index 0000000000..1929f85976 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/fetch_session.go @@ -0,0 +1,145 @@ +package block + +import ( + "context" + "fmt" + "sync" + "sync/atomic" +) + +type fetchSession struct { + chunkOff int64 // absolute start offset in U-space + chunkLen int64 // total length of this chunk/frame + blockSize int64 // progress tracking granularity + + mu sync.Mutex + fetchErr error + signal chan struct{} // closed on each advance; nil when terminated + + // bytesReady is the byte count (from chunkOff) up to which all blocks + // are fully written and marked cached. Atomic so registerAndWait can + // do a lock-free fast-path check: bytesReady only increases. + bytesReady atomic.Int64 + + // isCachedFn checks persistent cache for data from previous sessions. + isCachedFn func(off, length int64) bool +} + +// terminated reports whether the session reached a terminal state. +// Must be called with mu held. +func (s *fetchSession) terminated() bool { + return s.fetchErr != nil || s.bytesReady.Load() == s.chunkLen +} + +func newFetchSession(chunkOff, chunkLen, blockSize int64, isCachedFn func(off, length int64) bool) *fetchSession { + return &fetchSession{ + chunkOff: chunkOff, + chunkLen: chunkLen, + blockSize: blockSize, + isCachedFn: isCachedFn, + signal: make(chan struct{}), + } +} + +// registerAndWait blocks until [off, off+length) is cached, the session +// terminates, or ctx is cancelled. +func (s *fetchSession) registerAndWait(ctx context.Context, off, length int64) error { + relEnd := off + length - s.chunkOff + + var endByte int64 + if s.blockSize > 0 { + lastBlockIdx := (relEnd - 1) / s.blockSize + endByte = min((lastBlockIdx+1)*s.blockSize, s.chunkLen) + } else { + endByte = s.chunkLen + } + + for { + // Lock-free fast path: bytesReady only increases, so >= endByte + // guarantees data is available. + if s.bytesReady.Load() >= endByte { + return nil + } + + s.mu.Lock() + + // Re-check under lock. + if s.bytesReady.Load() >= endByte { + s.mu.Unlock() + + return nil + } + + // Terminal but range not covered — only happens on error + // (setDone sets bytesReady=chunkLen). Check cache for prior session data. + if s.terminated() { + fetchErr := s.fetchErr + s.mu.Unlock() + + if s.isCachedFn != nil && s.isCachedFn(off, length) { + return nil + } + + if fetchErr != nil { + return fmt.Errorf("fetch failed: %w", fetchErr) + } + + return nil + } + + ch := s.signal + s.mu.Unlock() + + select { + case <-ch: + continue + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// advance updates progress and wakes all waiters by closing the current +// broadcast channel and replacing it with a fresh one. +func (s *fetchSession) advance(bytesReady int64) { + s.mu.Lock() + s.bytesReady.Store(bytesReady) + old := s.signal + s.signal = make(chan struct{}) + s.mu.Unlock() + + close(old) +} + +// setDone marks the session as successfully completed and wakes all waiters. +func (s *fetchSession) setDone() { + s.mu.Lock() + s.bytesReady.Store(s.chunkLen) + old := s.signal + s.signal = nil + s.mu.Unlock() + + close(old) +} + +// setError records the error and wakes all waiters. +// When onlyIfRunning is true, it is a no-op if the session already +// terminated (used for panic recovery to avoid overriding a successful +// completion or double-closing the broadcast channel). +func (s *fetchSession) setError(err error, onlyIfRunning bool) { + s.mu.Lock() + if onlyIfRunning && s.terminated() { + s.mu.Unlock() + + return + } + + s.fetchErr = err + old := s.signal + s.signal = nil + s.mu.Unlock() + + if old != nil { + close(old) + } +} diff --git a/packages/orchestrator/internal/sandbox/block/metrics/main.go b/packages/orchestrator/internal/sandbox/block/metrics/main.go index ca45a4e64d..d151331132 100644 --- a/packages/orchestrator/internal/sandbox/block/metrics/main.go +++ b/packages/orchestrator/internal/sandbox/block/metrics/main.go @@ -15,13 +15,16 @@ const ( ) type Metrics struct { - // SlicesMetric is used to measure page faulting performance. + // BlocksTimerFactory measures page-fault / GetBlock latency. + BlocksTimerFactory telemetry.TimerFactory + + // SlicesTimerFactory is the legacy name for BlocksTimerFactory (fullFetchChunker path). SlicesTimerFactory telemetry.TimerFactory - // WriteChunksMetric is used to measure the time taken to download chunks from remote storage + // RemoteReadsTimerFactory measures the time taken to download chunks from remote storage. RemoteReadsTimerFactory telemetry.TimerFactory - // WriteChunksMetric is used to measure performance of writing chunks to disk. + // WriteChunksTimerFactory measures performance of writing chunks to disk. WriteChunksTimerFactory telemetry.TimerFactory } @@ -31,7 +34,7 @@ func NewMetrics(meterProvider metric.MeterProvider) (Metrics, error) { blocksMeter := meterProvider.Meter("internal.sandbox.block.metrics") var err error - if m.SlicesTimerFactory, err = telemetry.NewTimerFactory( + if m.BlocksTimerFactory, err = telemetry.NewTimerFactory( blocksMeter, orchestratorBlockSlices, "Time taken to retrieve memory slices", "Total bytes requested", @@ -40,6 +43,8 @@ func NewMetrics(meterProvider metric.MeterProvider) (Metrics, error) { return m, fmt.Errorf("error creating slices timer factory: %w", err) } + m.SlicesTimerFactory = m.BlocksTimerFactory + if m.RemoteReadsTimerFactory, err = telemetry.NewTimerFactory( blocksMeter, orchestratorBlockChunksFetch, "Time taken to fetch memory chunks from remote store", diff --git a/packages/orchestrator/internal/sandbox/block/mock_flagsclient_test.go b/packages/orchestrator/internal/sandbox/block/mock_flagsclient_test.go new file mode 100644 index 0000000000..274f146e41 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/mock_flagsclient_test.go @@ -0,0 +1,113 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package block + +import ( + "context" + + "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + mock "github.com/stretchr/testify/mock" +) + +// NewMockFlagsClient creates a new instance of MockFlagsClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockFlagsClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockFlagsClient { + mock := &MockFlagsClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockFlagsClient is an autogenerated mock type for the flagsClient type +type MockFlagsClient struct { + mock.Mock +} + +type MockFlagsClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockFlagsClient) EXPECT() *MockFlagsClient_Expecter { + return &MockFlagsClient_Expecter{mock: &_m.Mock} +} + +// JSONFlag provides a mock function for the type MockFlagsClient +func (_mock *MockFlagsClient) JSONFlag(ctx context.Context, flag feature_flags.JSONFlag, ldctx ...ldcontext.Context) ldvalue.Value { + var tmpRet mock.Arguments + if len(ldctx) > 0 { + tmpRet = _mock.Called(ctx, flag, ldctx) + } else { + tmpRet = _mock.Called(ctx, flag) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for JSONFlag") + } + + var r0 ldvalue.Value + if returnFunc, ok := ret.Get(0).(func(context.Context, feature_flags.JSONFlag, ...ldcontext.Context) ldvalue.Value); ok { + r0 = returnFunc(ctx, flag, ldctx...) + } else { + r0 = ret.Get(0).(ldvalue.Value) + } + return r0 +} + +// MockFlagsClient_JSONFlag_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'JSONFlag' +type MockFlagsClient_JSONFlag_Call struct { + *mock.Call +} + +// JSONFlag is a helper method to define mock.On call +// - ctx context.Context +// - flag feature_flags.JSONFlag +// - ldctx ...ldcontext.Context +func (_e *MockFlagsClient_Expecter) JSONFlag(ctx interface{}, flag interface{}, ldctx ...interface{}) *MockFlagsClient_JSONFlag_Call { + return &MockFlagsClient_JSONFlag_Call{Call: _e.mock.On("JSONFlag", + append([]interface{}{ctx, flag}, ldctx...)...)} +} + +func (_c *MockFlagsClient_JSONFlag_Call) Run(run func(ctx context.Context, flag feature_flags.JSONFlag, ldctx ...ldcontext.Context)) *MockFlagsClient_JSONFlag_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 feature_flags.JSONFlag + if args[1] != nil { + arg1 = args[1].(feature_flags.JSONFlag) + } + var arg2 []ldcontext.Context + var variadicArgs []ldcontext.Context + if len(args) > 2 { + variadicArgs = args[2].([]ldcontext.Context) + } + arg2 = variadicArgs + run( + arg0, + arg1, + arg2..., + ) + }) + return _c +} + +func (_c *MockFlagsClient_JSONFlag_Call) Return(value ldvalue.Value) *MockFlagsClient_JSONFlag_Call { + _c.Call.Return(value) + return _c +} + +func (_c *MockFlagsClient_JSONFlag_Call) RunAndReturn(run func(ctx context.Context, flag feature_flags.JSONFlag, ldctx ...ldcontext.Context) ldvalue.Value) *MockFlagsClient_JSONFlag_Call { + _c.Call.Return(run) + return _c +} diff --git a/packages/orchestrator/internal/sandbox/block/streaming_chunk.go b/packages/orchestrator/internal/sandbox/block/streaming_chunk.go deleted file mode 100644 index c62395237c..0000000000 --- a/packages/orchestrator/internal/sandbox/block/streaming_chunk.go +++ /dev/null @@ -1,447 +0,0 @@ -package block - -import ( - "cmp" - "context" - "errors" - "fmt" - "io" - "slices" - "sync" - "sync/atomic" - "time" - - "go.opentelemetry.io/otel/attribute" - "golang.org/x/sync/errgroup" - - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" - featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" -) - -const ( - // defaultFetchTimeout is the maximum time a single 4MB chunk fetch may take. - // Acts as a safety net: if the upstream hangs, the goroutine won't live forever. - defaultFetchTimeout = 60 * time.Second - - // defaultMinReadBatchSize is the floor for the read batch size when blockSize - // is very small (e.g. 4KB rootfs). The actual batch is max(blockSize, minReadBatchSize). - defaultMinReadBatchSize = 16 * 1024 // 16 KB -) - -type rangeWaiter struct { - // endByte is the byte offset (relative to chunkOff) at which this waiter's - // entire requested range is cached. Equal to the end of the last block - // overlapping the requested range. Always a multiple of blockSize. - endByte int64 - ch chan error // buffered cap 1 -} - -type fetchSession struct { - mu sync.Mutex - chunkOff int64 - chunkLen int64 - cache *Cache - waiters []*rangeWaiter // sorted by endByte ascending - fetchErr error - - // bytesReady is the byte count (from chunkOff) up to which all blocks are - // fully written to mmap and marked cached. Always a multiple of blockSize - // during progressive reads. Used to cheaply determine which sorted waiters - // are satisfied without calling isCached. - // - // Atomic so registerAndWait can do a lock-free fast-path check: - // bytesReady only increases, so a Load() >= endByte guarantees data - // availability without taking the mutex. - bytesReady atomic.Int64 -} - -// terminated reports whether the fetch session has reached a terminal state -// (done or errored). Must be called with s.mu held. -func (s *fetchSession) terminated() bool { - return s.fetchErr != nil || s.bytesReady.Load() == s.chunkLen -} - -// registerAndWait adds a waiter for the given range and blocks until the range -// is cached or the context is cancelled. Returns nil if the range was already -// cached before registering. -func (s *fetchSession) registerAndWait(ctx context.Context, off, length int64) error { - blockSize := s.cache.BlockSize() - lastBlockIdx := (off + length - 1 - s.chunkOff) / blockSize - endByte := (lastBlockIdx + 1) * blockSize - - // Lock-free fast path: bytesReady only increases, so >= endByte - // guarantees data is available without taking the lock. - if s.bytesReady.Load() >= endByte { - return nil - } - - s.mu.Lock() - - // Re-check under lock. - if endByte <= s.bytesReady.Load() { - s.mu.Unlock() - - return nil - } - - // Terminal but range not covered — only happens on error - // (Done sets bytesReady=chunkLen). Check cache for prior session data. - if s.terminated() { - fetchErr := s.fetchErr - s.mu.Unlock() - if s.cache.isCached(off, length) { - return nil - } - - if fetchErr != nil { - return fmt.Errorf("fetch failed: %w", fetchErr) - } - - return fmt.Errorf("fetch completed but range %d-%d not cached", off, off+length) - } - - // Fetch in progress — register waiter. - w := &rangeWaiter{endByte: endByte, ch: make(chan error, 1)} - idx, _ := slices.BinarySearchFunc(s.waiters, endByte, func(w *rangeWaiter, target int64) int { - return cmp.Compare(w.endByte, target) - }) - s.waiters = slices.Insert(s.waiters, idx, w) - s.mu.Unlock() - - select { - case err := <-w.ch: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - -// notifyWaiters notifies waiters whose ranges are satisfied. -// -// Because waiters are sorted by endByte and the fetch fills the chunk -// sequentially, we only need to walk from the front until we hit a waiter -// whose endByte exceeds bytesReady — all subsequent waiters are unsatisfied. -// -// In terminal states (done/errored) all remaining waiters are notified. -// Must be called with s.mu held. -func (s *fetchSession) notifyWaiters(sendErr error) { - ready := s.bytesReady.Load() - - // Terminal: notify every remaining waiter. - if s.terminated() { - for _, w := range s.waiters { - if sendErr != nil && w.endByte > ready { - w.ch <- sendErr - } - close(w.ch) - } - s.waiters = nil - - return - } - - // Progress: pop satisfied waiters from the sorted front. - i := 0 - for i < len(s.waiters) && s.waiters[i].endByte <= ready { - close(s.waiters[i].ch) - i++ - } - s.waiters = s.waiters[i:] -} - -type StreamingChunker struct { - upstream storage.StreamingReader - cache *Cache - metrics metrics.Metrics - fetchTimeout time.Duration - featureFlags *featureflags.Client - minReadBatchSize int64 - - size int64 - - fetchMu sync.Mutex - fetchMap map[int64]*fetchSession -} - -func NewStreamingChunker( - size, blockSize int64, - upstream storage.StreamingReader, - cachePath string, - metrics metrics.Metrics, - minReadBatchSize int64, - ff *featureflags.Client, -) (*StreamingChunker, error) { - cache, err := NewCache(size, blockSize, cachePath, false) - if err != nil { - return nil, fmt.Errorf("failed to create file cache: %w", err) - } - - if minReadBatchSize <= 0 { - minReadBatchSize = defaultMinReadBatchSize - } - - return &StreamingChunker{ - size: size, - upstream: upstream, - cache: cache, - metrics: metrics, - featureFlags: ff, - fetchTimeout: defaultFetchTimeout, - minReadBatchSize: minReadBatchSize, - fetchMap: make(map[int64]*fetchSession), - }, nil -} - -func (c *StreamingChunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { - slice, err := c.Slice(ctx, off, int64(len(b))) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", off, off+int64(len(b)), err) - } - - return copy(b, slice), nil -} - -func (c *StreamingChunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - chunk := make([]byte, storage.MemoryChunkSize) - - for i := int64(0); i < c.size; i += storage.MemoryChunkSize { - n, err := c.ReadAt(ctx, chunk, i) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", i, i+storage.MemoryChunkSize, err) - } - - _, err = w.Write(chunk[:n]) - if err != nil { - return 0, fmt.Errorf("failed to write chunk %d to writer: %w", i, err) - } - } - - return c.size, nil -} - -func (c *StreamingChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { - timer := c.metrics.SlicesTimerFactory.Begin() - - // Fast path: already cached - b, err := c.cache.Slice(off, length) - if err == nil { - timer.Success(ctx, length, - attribute.String(pullType, pullTypeLocal)) - - return b, nil - } - - if !errors.As(err, &BytesNotAvailableError{}) { - timer.Failure(ctx, length, - attribute.String(pullType, pullTypeLocal), - attribute.String(failureReason, failureTypeLocalRead)) - - return nil, fmt.Errorf("failed read from cache at offset %d: %w", off, err) - } - - // Compute which 4MB chunks overlap with the requested range - firstChunkOff := header.BlockOffset(header.BlockIdx(off, storage.MemoryChunkSize), storage.MemoryChunkSize) - lastChunkOff := header.BlockOffset(header.BlockIdx(off+length-1, storage.MemoryChunkSize), storage.MemoryChunkSize) - - var eg errgroup.Group - - for fetchOff := firstChunkOff; fetchOff <= lastChunkOff; fetchOff += storage.MemoryChunkSize { - eg.Go(func() error { - // Clip request to this chunk's boundaries - chunkEnd := fetchOff + storage.MemoryChunkSize - clippedOff := max(off, fetchOff) - clippedEnd := min(off+length, chunkEnd, c.size) - clippedLen := clippedEnd - clippedOff - - if clippedLen <= 0 { - return nil - } - - session := c.getOrCreateSession(ctx, fetchOff) - - return session.registerAndWait(ctx, clippedOff, clippedLen) - }) - } - - if err := eg.Wait(); err != nil { - timer.Failure(ctx, length, - attribute.String(pullType, pullTypeRemote), - attribute.String(failureReason, failureTypeCacheFetch)) - - return nil, fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, err) - } - - b, cacheErr := c.cache.Slice(off, length) - if cacheErr != nil { - timer.Failure(ctx, length, - attribute.String(pullType, pullTypeLocal), - attribute.String(failureReason, failureTypeLocalReadAgain)) - - return nil, fmt.Errorf("failed to read from cache after ensuring data at %d-%d: %w", off, off+length, cacheErr) - } - - timer.Success(ctx, length, - attribute.String(pullType, pullTypeRemote)) - - return b, nil -} - -func (c *StreamingChunker) getOrCreateSession(ctx context.Context, fetchOff int64) *fetchSession { - s := &fetchSession{ - chunkOff: fetchOff, - chunkLen: min(int64(storage.MemoryChunkSize), c.size-fetchOff), - cache: c.cache, - } - - c.fetchMu.Lock() - if existing, ok := c.fetchMap[fetchOff]; ok { - c.fetchMu.Unlock() - - return existing - } - c.fetchMap[fetchOff] = s - c.fetchMu.Unlock() - - // Detach from the caller's cancel signal so the shared fetch goroutine - // continues even if the first caller's context is cancelled. Trace/value - // context is preserved for metrics. - go c.runFetch(context.WithoutCancel(ctx), s) - - return s -} - -func (s *fetchSession) setDone() { - s.mu.Lock() - defer s.mu.Unlock() - - s.bytesReady.Store(s.chunkLen) - s.notifyWaiters(nil) -} - -func (s *fetchSession) setError(err error, onlyIfRunning bool) { - s.mu.Lock() - defer s.mu.Unlock() - - if onlyIfRunning && s.terminated() { - return - } - - s.fetchErr = err - s.notifyWaiters(err) -} - -func (c *StreamingChunker) runFetch(ctx context.Context, s *fetchSession) { - ctx, cancel := context.WithTimeout(ctx, c.fetchTimeout) - defer cancel() - - defer func() { - c.fetchMu.Lock() - delete(c.fetchMap, s.chunkOff) - c.fetchMu.Unlock() - }() - - // Panic recovery: ensure waiters are always notified even if the fetch - // goroutine panics (e.g. nil pointer in upstream reader, mmap fault). - // Without this, waiters would block forever on their channels. - defer func() { - if r := recover(); r != nil { - err := fmt.Errorf("fetch panicked: %v", r) - s.setError(err, true) - } - }() - - mmapSlice, releaseLock, err := c.cache.addressBytes(s.chunkOff, s.chunkLen) - if err != nil { - s.setError(err, false) - - return - } - defer releaseLock() - - fetchTimer := c.metrics.RemoteReadsTimerFactory.Begin() - - err = c.progressiveRead(ctx, s, mmapSlice) - if err != nil { - fetchTimer.Failure(ctx, s.chunkLen, - attribute.String(failureReason, failureTypeRemoteRead)) - - s.setError(err, false) - - return - } - - fetchTimer.Success(ctx, s.chunkLen) - s.setDone() -} - -func (c *StreamingChunker) progressiveRead(ctx context.Context, s *fetchSession, mmapSlice []byte) error { - reader, err := c.upstream.OpenRangeReader(ctx, s.chunkOff, s.chunkLen) - if err != nil { - return fmt.Errorf("failed to open range reader at %d: %w", s.chunkOff, err) - } - defer reader.Close() - - blockSize := c.cache.BlockSize() - readBatch := max(blockSize, c.getMinReadBatchSize(ctx)) - var totalRead int64 - var prevCompleted int64 - - for totalRead < s.chunkLen { - // Read in batches of max(blockSize, 16KB) to align notification - // granularity with the read size and minimize lock/notify overhead. - readEnd := min(totalRead+readBatch, s.chunkLen) - n, readErr := reader.Read(mmapSlice[totalRead:readEnd]) - totalRead += int64(n) - - completedBlocks := totalRead / blockSize - if completedBlocks > prevCompleted { - newBytes := (completedBlocks - prevCompleted) * blockSize - c.cache.setIsCached(s.chunkOff+prevCompleted*blockSize, newBytes) - prevCompleted = completedBlocks - - s.mu.Lock() - s.bytesReady.Store(completedBlocks * blockSize) - s.notifyWaiters(nil) - s.mu.Unlock() - } - - if errors.Is(readErr, io.EOF) { - // Mark final partial block if any - if totalRead > prevCompleted*blockSize { - c.cache.setIsCached(s.chunkOff+prevCompleted*blockSize, totalRead-prevCompleted*blockSize) - } - // Remaining waiters are notified in runFetch via the Done state. - break - } - - if readErr != nil { - return fmt.Errorf("failed reading at offset %d after %d bytes: %w", s.chunkOff, totalRead, readErr) - } - } - - return nil -} - -// getMinReadBatchSize returns the effective min read batch size. When a feature -// flags client is available, the value is read just-in-time from the flag so -// it can be tuned without restarting the service. -func (c *StreamingChunker) getMinReadBatchSize(ctx context.Context) int64 { - if c.featureFlags != nil { - _, minKB := getChunkerConfig(ctx, c.featureFlags) - if minKB > 0 { - return int64(minKB) * 1024 - } - } - - return c.minReadBatchSize -} - -func (c *StreamingChunker) Close() error { - return c.cache.Close() -} - -func (c *StreamingChunker) FileSize() (int64, error) { - return c.cache.FileSize() -} diff --git a/packages/orchestrator/internal/sandbox/block/streaming_chunk_test.go b/packages/orchestrator/internal/sandbox/block/streaming_chunk_test.go deleted file mode 100644 index c509e0af38..0000000000 --- a/packages/orchestrator/internal/sandbox/block/streaming_chunk_test.go +++ /dev/null @@ -1,953 +0,0 @@ -package block - -import ( - "bytes" - "context" - "crypto/rand" - "fmt" - "io" - mathrand "math/rand/v2" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel/metric/noop" - "golang.org/x/sync/errgroup" - - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" -) - -const ( - testBlockSize = header.PageSize // 4KB -) - -// slowUpstream simulates GCS: implements both SeekableReader and StreamingReader. -// OpenRangeReader returns a reader that yields blockSize bytes per Read() call -// with a configurable delay between calls. -type slowUpstream struct { - data []byte - blockSize int64 - delay time.Duration -} - -var ( - _ storage.SeekableReader = (*slowUpstream)(nil) - _ storage.StreamingReader = (*slowUpstream)(nil) -) - -func (s *slowUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { - end := min(off+int64(len(buffer)), int64(len(s.data))) - n := copy(buffer, s.data[off:end]) - - return n, nil -} - -func (s *slowUpstream) Size(_ context.Context) (int64, error) { - return int64(len(s.data)), nil -} - -func (s *slowUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(s.data))) - - return &slowReader{ - data: s.data[off:end], - blockSize: int(s.blockSize), - delay: s.delay, - }, nil -} - -type slowReader struct { - data []byte - pos int - blockSize int - delay time.Duration -} - -func (r *slowReader) Read(p []byte) (int, error) { - if r.pos >= len(r.data) { - return 0, io.EOF - } - - if r.delay > 0 { - time.Sleep(r.delay) - } - - end := min(r.pos+r.blockSize, len(r.data)) - - n := copy(p, r.data[r.pos:end]) - r.pos += n - - if r.pos >= len(r.data) { - return n, io.EOF - } - - return n, nil -} - -func (r *slowReader) Close() error { - return nil -} - -// fastUpstream simulates NFS: same interfaces but no delay. -type fastUpstream = slowUpstream - -// streamingFunc adapts a function into a StreamingReader. -type streamingFunc func(ctx context.Context, off, length int64) (io.ReadCloser, error) - -func (f streamingFunc) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - return f(ctx, off, length) -} - -// errorAfterNUpstream fails after reading n bytes. -type errorAfterNUpstream struct { - data []byte - failAfter int64 - blockSize int64 -} - -var _ storage.StreamingReader = (*errorAfterNUpstream)(nil) - -func (u *errorAfterNUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) - - return &errorAfterNReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - failAfter: int(u.failAfter - off), - }, nil -} - -type errorAfterNReader struct { - data []byte - pos int - blockSize int - failAfter int -} - -func (r *errorAfterNReader) Read(p []byte) (int, error) { - if r.pos >= len(r.data) { - return 0, io.EOF - } - - if r.pos >= r.failAfter { - return 0, fmt.Errorf("simulated upstream error") - } - - end := min(r.pos+r.blockSize, len(r.data)) - - n := copy(p, r.data[r.pos:end]) - r.pos += n - - if r.pos >= len(r.data) { - return n, io.EOF - } - - return n, nil -} - -func (r *errorAfterNReader) Close() error { - return nil -} - -func newTestMetrics(t *testing.T) metrics.Metrics { - t.Helper() - - m, err := metrics.NewMetrics(noop.NewMeterProvider()) - require.NoError(t, err) - - return m -} - -func makeTestData(t *testing.T, size int) []byte { - t.Helper() - - data := make([]byte, size) - _, err := rand.Read(data) - require.NoError(t, err) - - return data -} - -func TestStreamingChunker_BasicSlice(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Read first page - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) -} - -func TestStreamingChunker_CacheHit(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - readCount := atomic.Int64{} - - upstream := &countingUpstream{ - inner: &fastUpstream{data: data, blockSize: testBlockSize}, - readCount: &readCount, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // First read: triggers fetch - _, err = chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - - // Wait for the full chunk to be fetched - time.Sleep(50 * time.Millisecond) - - firstCount := readCount.Load() - require.Positive(t, firstCount) - - // Second read: should hit cache - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) - - // No additional reads should have happened - assert.Equal(t, firstCount, readCount.Load()) -} - -type countingUpstream struct { - inner *fastUpstream - readCount *atomic.Int64 -} - -var ( - _ storage.SeekableReader = (*countingUpstream)(nil) - _ storage.StreamingReader = (*countingUpstream)(nil) -) - -func (c *countingUpstream) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - c.readCount.Add(1) - - return c.inner.ReadAt(ctx, buffer, off) -} - -func (c *countingUpstream) Size(ctx context.Context) (int64, error) { - return c.inner.Size(ctx) -} - -func (c *countingUpstream) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - c.readCount.Add(1) - - return c.inner.OpenRangeReader(ctx, off, length) -} - -func TestStreamingChunker_FullChunkCachedAfterPartialRequest(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - openCount := atomic.Int64{} - - upstream := &countingUpstream{ - inner: &fastUpstream{data: data, blockSize: testBlockSize}, - readCount: &openCount, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Request only the FIRST block of the 4MB chunk. - _, err = chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - - // The background goroutine should continue fetching the remaining data. - // Wait for it to complete. - require.Eventually(t, func() bool { - // Try reading the LAST block — if the full chunk is cached this - // will succeed without opening another range reader. - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - slice, err := chunker.Slice(t.Context(), lastOff, testBlockSize) - if err != nil { - return false - } - - return bytes.Equal(data[lastOff:], slice) - }, 5*time.Second, 10*time.Millisecond) - - // Exactly one OpenRangeReader call should have been made for the entire - // chunk, not one per requested block. - assert.Equal(t, int64(1), openCount.Load(), - "expected 1 OpenRangeReader call (full chunk fetched in background), got %d", openCount.Load()) -} - -func TestStreamingChunker_ConcurrentSameChunk(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - // Use a slow upstream so requests will overlap - upstream := &slowUpstream{ - data: data, - blockSize: testBlockSize, - delay: 50 * time.Microsecond, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - numGoroutines := 10 - offsets := make([]int64, numGoroutines) - for i := range numGoroutines { - offsets[i] = int64(i) * testBlockSize - } - - results := make([][]byte, numGoroutines) - - var eg errgroup.Group - - for i := range numGoroutines { - eg.Go(func() error { - slice, err := chunker.Slice(t.Context(), offsets[i], testBlockSize) - if err != nil { - return fmt.Errorf("goroutine %d failed: %w", i, err) - } - results[i] = make([]byte, len(slice)) - copy(results[i], slice) - - return nil - }) - } - - require.NoError(t, eg.Wait()) - - for i := range numGoroutines { - require.Equal(t, data[offsets[i]:offsets[i]+testBlockSize], results[i], - "goroutine %d got wrong data", i) - } -} - -func TestStreamingChunker_EarlyReturn(t *testing.T) { - t.Parallel() - - type testCase struct { - name string - blockSize int64 - delay time.Duration - // blockIndices are block indices within the chunk, listed in the - // expected completion order (earlier blocks are notified first). - blockIndices []int - } - - cases := []testCase{ - { - name: "hugepage", - blockSize: header.HugepageSize, // 2MB → 2 blocks per 4MB chunk - delay: 50 * time.Millisecond, - blockIndices: []int{0, 1}, - }, - { - name: "4K", - blockSize: header.PageSize, // 4KB → 1024 blocks per 4MB chunk - delay: 100 * time.Microsecond, - blockIndices: []int{1, 512, 1022}, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - - gate := make(chan struct{}) - upstream := streamingFunc(func(_ context.Context, off, length int64) (io.ReadCloser, error) { - <-gate - end := min(off+length, int64(len(data))) - - return &slowReader{ - data: data[off:end], - blockSize: int(tc.blockSize), - delay: tc.delay, - }, nil - }) - - chunker, err := NewStreamingChunker( - int64(len(data)), tc.blockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - n := len(tc.blockIndices) - completionOrder := make(chan int, n) - - var eg errgroup.Group - for i, blockIdx := range tc.blockIndices { - off := int64(blockIdx) * tc.blockSize - eg.Go(func() error { - _, err := chunker.Slice(t.Context(), off, tc.blockSize) - if err != nil { - return fmt.Errorf("request %d (block %d) failed: %w", i, blockIdx, err) - } - completionOrder <- i - - return nil - }) - } - - // Let all goroutines register as waiters before the fetch begins. - time.Sleep(10 * time.Millisecond) - close(gate) - - require.NoError(t, eg.Wait()) - close(completionOrder) - - got := make([]int, 0, n) - for idx := range completionOrder { - got = append(got, idx) - } - - expected := make([]int, n) - for i := range expected { - expected[i] = i - } - - assert.Equal(t, expected, got, - "requests should complete in offset order (earlier blocks first)") - }) - } -} - -func TestStreamingChunker_ErrorKeepsPartialData(t *testing.T) { - t.Parallel() - - chunkSize := storage.MemoryChunkSize - data := makeTestData(t, chunkSize) - failAfter := int64(chunkSize / 2) // Fail at 2MB - - upstream := &errorAfterNUpstream{ - data: data, - failAfter: failAfter, - blockSize: testBlockSize, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Request the last page — this should fail because upstream dies at 2MB - lastOff := int64(chunkSize) - testBlockSize - _, err = chunker.Slice(t.Context(), lastOff, testBlockSize) - require.Error(t, err) - - // But first page (within first 2MB) should still be cached and servable - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) -} - -func TestStreamingChunker_ContextCancellation(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - upstream := &slowUpstream{ - data: data, - blockSize: testBlockSize, - delay: 1 * time.Millisecond, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Request with a context that we'll cancel quickly - ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond) - defer cancel() - - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - _, err = chunker.Slice(ctx, lastOff, testBlockSize) - // This should fail with context cancellation - require.Error(t, err) - - // But another caller with a valid context should still get the data - // because the fetch goroutine uses background context - time.Sleep(200 * time.Millisecond) // Wait for fetch to complete - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) -} - -func TestStreamingChunker_LastBlockPartial(t *testing.T) { - t.Parallel() - - // File size not aligned to blockSize - size := storage.MemoryChunkSize - 100 - data := makeTestData(t, size) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Read the last partial block - lastBlockOff := (int64(size) / testBlockSize) * testBlockSize - remaining := int64(size) - lastBlockOff - - slice, err := chunker.Slice(t.Context(), lastBlockOff, remaining) - require.NoError(t, err) - require.Equal(t, data[lastBlockOff:], slice) -} - -func TestStreamingChunker_MultiChunkSlice(t *testing.T) { - t.Parallel() - - // Two 4MB chunks - size := storage.MemoryChunkSize * 2 - data := makeTestData(t, size) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Request spanning two chunks: last page of chunk 0 + first page of chunk 1 - off := int64(storage.MemoryChunkSize) - testBlockSize - length := testBlockSize * 2 - - slice, err := chunker.Slice(t.Context(), off, int64(length)) - require.NoError(t, err) - require.Equal(t, data[off:off+int64(length)], slice) -} - -// panicUpstream panics during Read after delivering a configurable number of bytes. -type panicUpstream struct { - data []byte - blockSize int64 - panicAfter int64 // byte offset at which to panic (0 = panic immediately) -} - -var _ storage.StreamingReader = (*panicUpstream)(nil) - -func (u *panicUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) - - return &panicReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - panicAfter: int(u.panicAfter - off), - }, nil -} - -type panicReader struct { - data []byte - pos int - blockSize int - panicAfter int -} - -func (r *panicReader) Read(p []byte) (int, error) { - if r.pos >= r.panicAfter { - panic("simulated upstream panic") - } - - if r.pos >= len(r.data) { - return 0, io.EOF - } - - end := min(r.pos+r.blockSize, len(r.data)) - n := copy(p, r.data[r.pos:end]) - r.pos += n - - return n, nil -} - -func (r *panicReader) Close() error { - return nil -} - -func TestStreamingChunker_PanicRecovery(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - panicAt := int64(storage.MemoryChunkSize / 2) // Panic at 2MB - - upstream := &panicUpstream{ - data: data, - blockSize: testBlockSize, - panicAfter: panicAt, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Request data past the panic point — should get an error, not hang or crash - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - _, err = chunker.Slice(t.Context(), lastOff, testBlockSize) - require.Error(t, err) - assert.Contains(t, err.Error(), "panicked") - - // Data before the panic point should still be cached - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) -} - -func TestStreamingChunker_ConcurrentSameChunk_SharedSession(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - - gate := make(chan struct{}) - openCount := atomic.Int64{} - - // OpenRangeReader blocks on the gate, keeping the session in fetchMap - // until both callers have entered. This removes the scheduling-dependent - // race in the old slow-upstream version of this test. - upstream := streamingFunc(func(_ context.Context, off, length int64) (io.ReadCloser, error) { - openCount.Add(1) - <-gate - - end := min(off+length, int64(len(data))) - - return io.NopCloser(bytes.NewReader(data[off:end])), nil - }) - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Two different ranges inside the same 4MB chunk. - offA := int64(0) - offB := int64(storage.MemoryChunkSize) - testBlockSize // last block - - var eg errgroup.Group - var sliceA, sliceB []byte - - eg.Go(func() error { - s, err := chunker.Slice(t.Context(), offA, testBlockSize) - if err != nil { - return err - } - sliceA = make([]byte, len(s)) - copy(sliceA, s) - - return nil - }) - eg.Go(func() error { - s, err := chunker.Slice(t.Context(), offB, testBlockSize) - if err != nil { - return err - } - sliceB = make([]byte, len(s)) - copy(sliceB, s) - - return nil - }) - - // Let both goroutines enter getOrCreateSession, then release the fetch. - time.Sleep(10 * time.Millisecond) - close(gate) - - require.NoError(t, eg.Wait()) - - assert.Equal(t, data[offA:offA+testBlockSize], sliceA) - assert.Equal(t, data[offB:offB+testBlockSize], sliceB) - assert.Equal(t, int64(1), openCount.Load(), - "expected exactly 1 OpenRangeReader call (shared session), got %d", openCount.Load()) -} - -// --- Benchmarks --- -// -// Uses a bandwidth-limited upstream with real time.Sleep to simulate GCS and -// NFS backends. Measures actual wall-clock latency per caller. -// -// Backend parameters (tuned to match observed production latencies): -// GCS: 20ms TTFB + 100 MB/s → 4MB chunk ≈ 62ms (observed ~60ms) -// NFS: 1ms TTFB + 500 MB/s → 4MB chunk ≈ 9ms (observed ~9-10ms) -// -// All sub-benchmarks share a pre-generated offset sequence so results are -// directly comparable across chunker types and backends. -// -// Recommended invocation (~1 minute): -// go test -bench BenchmarkRandomAccess -benchtime 150x -count=3 -run '^$' ./... - -func newBenchmarkMetrics(b *testing.B) metrics.Metrics { - b.Helper() - - m, err := metrics.NewMetrics(noop.NewMeterProvider()) - require.NoError(b, err) - - return m -} - -// realisticUpstream simulates a storage backend with configurable time-to-first-byte -// and bandwidth. ReadAt blocks for the full transfer duration (bulk fetch model). -// OpenRangeReader returns a bandwidth-limited progressive reader. -type realisticUpstream struct { - data []byte - blockSize int64 - ttfb time.Duration - bytesPerSec float64 -} - -var ( - _ storage.SeekableReader = (*realisticUpstream)(nil) - _ storage.StreamingReader = (*realisticUpstream)(nil) -) - -func (u *realisticUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { - transferTime := time.Duration(float64(len(buffer)) / u.bytesPerSec * float64(time.Second)) - time.Sleep(u.ttfb + transferTime) - - end := min(off+int64(len(buffer)), int64(len(u.data))) - n := copy(buffer, u.data[off:end]) - - return n, nil -} - -func (u *realisticUpstream) Size(_ context.Context) (int64, error) { - return int64(len(u.data)), nil -} - -func (u *realisticUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) - - return &bandwidthReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - ttfb: u.ttfb, - bytesPerSec: u.bytesPerSec, - }, nil -} - -// bandwidthReader delivers data at a steady rate after an initial TTFB delay. -// Uses cumulative timing (time since first byte) so OS scheduling jitter does -// not compound across blocks. -type bandwidthReader struct { - data []byte - pos int - blockSize int - ttfb time.Duration - bytesPerSec float64 - startTime time.Time - started bool -} - -func (r *bandwidthReader) Read(p []byte) (int, error) { - if !r.started { - r.started = true - time.Sleep(r.ttfb) - r.startTime = time.Now() - } - - if r.pos >= len(r.data) { - return 0, io.EOF - } - - end := min(r.pos+r.blockSize, len(r.data)) - n := copy(p, r.data[r.pos:end]) - r.pos += n - - // Enforce bandwidth: sleep until this many bytes should have arrived. - expectedArrival := r.startTime.Add(time.Duration(float64(r.pos) / r.bytesPerSec * float64(time.Second))) - if wait := time.Until(expectedArrival); wait > 0 { - time.Sleep(wait) - } - - if r.pos >= len(r.data) { - return n, io.EOF - } - - return n, nil -} - -func (r *bandwidthReader) Close() error { - return nil -} - -type benchChunker interface { - Slice(ctx context.Context, off, length int64) ([]byte, error) - Close() error -} - -func BenchmarkRandomAccess(b *testing.B) { - size := int64(storage.MemoryChunkSize) - data := make([]byte, size) - - backends := []struct { - name string - upstream *realisticUpstream - }{ - { - name: "GCS", - upstream: &realisticUpstream{ - data: data, - blockSize: testBlockSize, - ttfb: 20 * time.Millisecond, - bytesPerSec: 100e6, // 100 MB/s — full 4MB chunk ≈ 62ms (observed ~60ms) - }, - }, - { - name: "NFS", - upstream: &realisticUpstream{ - data: data, - blockSize: testBlockSize, - ttfb: 1 * time.Millisecond, - bytesPerSec: 500e6, // 500 MB/s — full 4MB chunk ≈ 9ms (observed ~9-10ms) - }, - }, - } - - chunkerTypes := []struct { - name string - newChunker func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker - }{ - { - name: "StreamingChunker", - newChunker: func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker { - b.Helper() - c, err := NewStreamingChunker(size, testBlockSize, upstream, b.TempDir()+"/cache", m, 0, nil) - require.NoError(b, err) - - return c - }, - }, - { - name: "FullFetchChunker", - newChunker: func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker { - b.Helper() - c, err := NewFullFetchChunker(size, testBlockSize, upstream, b.TempDir()+"/cache", m) - require.NoError(b, err) - - return c - }, - }, - } - - // Realistic concurrency: UFFD faults are limited by vCPU count (typically - // 1-2 for Firecracker VMs) and NBD requests are largely sequential. - const numCallers = 3 - - // Pre-generate a fixed sequence of random offsets so all sub-benchmarks - // use identical access patterns, making results directly comparable. - const maxIters = 500 - numBlocks := size / testBlockSize - rng := mathrand.New(mathrand.NewPCG(42, 0)) - - allOffsets := make([][]int64, maxIters) - for i := range allOffsets { - offsets := make([]int64, numCallers) - for j := range offsets { - offsets[j] = rng.Int64N(numBlocks) * testBlockSize - } - allOffsets[i] = offsets - } - - for _, backend := range backends { - for _, ct := range chunkerTypes { - b.Run(backend.name+"/"+ct.name, func(b *testing.B) { - m := newBenchmarkMetrics(b) - - b.ReportMetric(0, "ns/op") - - var sumAvg, sumMax float64 - - for i := range b.N { - offsets := allOffsets[i%maxIters] - - chunker := ct.newChunker(b, m, backend.upstream) - - latencies := make([]time.Duration, numCallers) - - var eg errgroup.Group - for ci, off := range offsets { - eg.Go(func() error { - start := time.Now() - _, err := chunker.Slice(context.Background(), off, testBlockSize) - latencies[ci] = time.Since(start) - - return err - }) - } - require.NoError(b, eg.Wait()) - - var totalLatency time.Duration - var maxLatency time.Duration - for _, l := range latencies { - totalLatency += l - maxLatency = max(maxLatency, l) - } - - avgUs := float64(totalLatency.Microseconds()) / float64(numCallers) - sumAvg += avgUs - sumMax = max(sumMax, float64(maxLatency.Microseconds())) - - chunker.Close() - } - - b.ReportMetric(sumAvg/float64(b.N), "avg-us/caller") - b.ReportMetric(sumMax, "worst-us/caller") - }) - } - } -} diff --git a/packages/orchestrator/internal/sandbox/build/build.go b/packages/orchestrator/internal/sandbox/build/build.go index a87cbe52f6..4fbbbe7d1a 100644 --- a/packages/orchestrator/internal/sandbox/build/build.go +++ b/packages/orchestrator/internal/sandbox/build/build.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" @@ -19,6 +20,7 @@ type File struct { fileType DiffType persistence storage.StorageProvider metrics blockmetrics.Metrics + flags *featureflags.Client } func NewFile( @@ -27,6 +29,7 @@ func NewFile( fileType DiffType, persistence storage.StorageProvider, metrics blockmetrics.Metrics, + flags *featureflags.Client, ) *File { return &File{ header: header, @@ -34,19 +37,19 @@ func NewFile( fileType: fileType, persistence: persistence, metrics: metrics, + flags: flags, } } func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err error) { for n < len(p) { - mappedOffset, mappedLength, buildID, err := b.header.GetShiftedMapping(ctx, off+int64(n)) + mappedToBuild, err := b.header.GetShiftedMapping(ctx, off+int64(n)) if err != nil { return 0, fmt.Errorf("failed to get mapping: %w", err) } remainingReadLength := int64(len(p)) - int64(n) - - readLength := min(mappedLength, remainingReadLength) + readLength := min(int64(mappedToBuild.Length), remainingReadLength) if readLength <= 0 { logger.L().Error(ctx, fmt.Sprintf( @@ -54,13 +57,13 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro len(p)-n, off, readLength, - buildID, + mappedToBuild.BuildId, b.fileType, - mappedOffset, + mappedToBuild.Offset, n, int64(n)+readLength, n, - mappedLength, + mappedToBuild.Length, remainingReadLength, )) @@ -70,20 +73,23 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro // Skip reading when the uuid is nil. // We will use this to handle base builds that are already diffs. // The passed slice p must start as empty, otherwise we would need to copy the empty values there. - if *buildID == uuid.Nil { + if mappedToBuild.BuildId == uuid.Nil { n += int(readLength) continue } - mappedBuild, err := b.getBuild(ctx, buildID) + mappedBuild, err := b.getBuild(ctx, mappedToBuild.BuildId) if err != nil { return 0, fmt.Errorf("failed to get build: %w", err) } - buildN, err := mappedBuild.ReadAt(ctx, + ft := mappedToBuild.FrameTable + + buildN, err := mappedBuild.ReadBlock(ctx, p[n:int64(n)+readLength], - mappedOffset, + int64(mappedToBuild.Offset), + ft, ) if err != nil { return 0, fmt.Errorf("failed to read from source: %w", err) @@ -97,25 +103,25 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro // The slice access must be in the predefined blocksize of the build. func (b *File) Slice(ctx context.Context, off, _ int64) ([]byte, error) { - mappedOffset, _, buildID, err := b.header.GetShiftedMapping(ctx, off) + mappedBuild, err := b.header.GetShiftedMapping(ctx, off) if err != nil { return nil, fmt.Errorf("failed to get mapping: %w", err) } // Pass empty huge page when the build id is nil. - if *buildID == uuid.Nil { + if mappedBuild.BuildId == uuid.Nil { return header.EmptyHugePage, nil } - build, err := b.getBuild(ctx, buildID) + build, err := b.getBuild(ctx, mappedBuild.BuildId) if err != nil { return nil, fmt.Errorf("failed to get build: %w", err) } - return build.Slice(ctx, mappedOffset, int64(b.header.Metadata.BlockSize)) + return build.GetBlock(ctx, int64(mappedBuild.Offset), int64(b.header.Metadata.BlockSize), mappedBuild.FrameTable) } -func (b *File) getBuild(ctx context.Context, buildID *uuid.UUID) (Diff, error) { +func (b *File) getBuild(ctx context.Context, buildID uuid.UUID) (Diff, error) { storageDiff, err := newStorageDiff( b.store.cachePath, buildID.String(), @@ -123,7 +129,7 @@ func (b *File) getBuild(ctx context.Context, buildID *uuid.UUID) (Diff, error) { int64(b.header.Metadata.BlockSize), b.metrics, b.persistence, - b.store.flags, + b.flags, ) if err != nil { return nil, fmt.Errorf("failed to create storage diff: %w", err) diff --git a/packages/orchestrator/internal/sandbox/build/cache_test.go b/packages/orchestrator/internal/sandbox/build/cache_test.go index 460135fe53..df510bc52b 100644 --- a/packages/orchestrator/internal/sandbox/build/cache_test.go +++ b/packages/orchestrator/internal/sandbox/build/cache_test.go @@ -13,8 +13,10 @@ package build // causing a race when closing the cancel channel. import ( + "context" "fmt" "sync" + "sync/atomic" "testing" "time" @@ -25,6 +27,8 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/cfg" featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) const ( @@ -496,6 +500,105 @@ func TestDiffStoreResetDeleteRace(t *testing.T) { time.Sleep(delay * 2) } +// concurrentTestDiff mimics StorageDiff's SetOnce pattern for testing +// concurrent Init + access through DiffStore. +type concurrentTestDiff struct { + data *utils.SetOnce[[]byte] + key DiffStoreKey + initCount *atomic.Int32 + testData []byte +} + +var _ Diff = (*concurrentTestDiff)(nil) + +func (d *concurrentTestDiff) Init(_ context.Context) error { + d.initCount.Add(1) + time.Sleep(50 * time.Millisecond) // simulate slow probe + chunker creation + + return d.data.SetValue(d.testData) +} + +func (d *concurrentTestDiff) ReadBlock(_ context.Context, p []byte, off int64, _ *storage.FrameTable) (int, error) { + data, err := d.data.Wait() + if err != nil { + return 0, err + } + + return copy(p, data[off:]), nil +} + +func (d *concurrentTestDiff) GetBlock(_ context.Context, off, length int64, _ *storage.FrameTable) ([]byte, error) { + data, err := d.data.Wait() + if err != nil { + return nil, err + } + + return data[off : off+length], nil +} + +func (d *concurrentTestDiff) CacheKey() DiffStoreKey { return d.key } +func (d *concurrentTestDiff) CachePath() (string, error) { return "", nil } +func (d *concurrentTestDiff) FileSize() (int64, error) { return int64(len(d.testData)), nil } +func (d *concurrentTestDiff) BlockSize() int64 { return 4096 } +func (d *concurrentTestDiff) Close() error { return nil } + +// TestDiffStoreConcurrentInitAndAccess simulates multiple UFFD handlers +// concurrently calling getBuild → DiffStore.Get for the same build. +// Only the first caller triggers Init; others block on SetOnce.Wait() +// until init completes, then all read correct data. +func TestDiffStoreConcurrentInitAndAccess(t *testing.T) { + t.Parallel() + + cachePath := t.TempDir() + c, err := cfg.Parse() + require.NoError(t, err) + flags := flagsWithMaxBuildCachePercentage(t, 100) + + store, err := NewDiffStore(c, flags, cachePath, 60*time.Second, 60*time.Second) + require.NoError(t, err) + store.Start(t.Context()) + t.Cleanup(store.Close) + + const numGoroutines = 50 + const dataSize = 4096 + + testData := make([]byte, dataSize) + for i := range testData { + testData[i] = byte(i % 256) + } + + var initCount atomic.Int32 + var wg sync.WaitGroup + + for range numGoroutines { + wg.Go(func() { + // Each goroutine creates its own diff instance (mimicking getBuild), + // but all share the same cache key. GetOrSet stores only the first. + diff := &concurrentTestDiff{ + data: utils.NewSetOnce[[]byte](), + key: "concurrent-test/memfile", + initCount: &initCount, + testData: testData, + } + + result, err := store.Get(t.Context(), diff) + require.NoError(t, err) + + // Read — blocks until the winning goroutine's Init completes. + buf := make([]byte, 256) + n, err := result.ReadBlock(t.Context(), buf, 0, nil) + require.NoError(t, err) + assert.Equal(t, 256, n) + assert.Equal(t, testData[:256], buf) + }) + } + + wg.Wait() + + // Init must have been called exactly once. + assert.Equal(t, int32(1), initCount.Load()) +} + func flagsWithMaxBuildCachePercentage(tb testing.TB, maxBuildCachePercentage int) *featureflags.Client { tb.Helper() diff --git a/packages/orchestrator/internal/sandbox/build/diff.go b/packages/orchestrator/internal/sandbox/build/diff.go index 73891339b0..a60c59da58 100644 --- a/packages/orchestrator/internal/sandbox/build/diff.go +++ b/packages/orchestrator/internal/sandbox/build/diff.go @@ -26,11 +26,11 @@ const ( type Diff interface { io.Closer - storage.SeekableReader - block.Slicer + block.Reader CacheKey() DiffStoreKey CachePath() (string, error) FileSize() (int64, error) + BlockSize() int64 Init(ctx context.Context) error } @@ -42,7 +42,7 @@ func (n *NoDiff) CachePath() (string, error) { return "", NoDiffError{} } -func (n *NoDiff) Slice(_ context.Context, _, _ int64) ([]byte, error) { +func (n *NoDiff) GetBlock(_ context.Context, _, _ int64, _ *storage.FrameTable) ([]byte, error) { return nil, NoDiffError{} } @@ -50,7 +50,7 @@ func (n *NoDiff) Close() error { return nil } -func (n *NoDiff) ReadAt(_ context.Context, _ []byte, _ int64) (int, error) { +func (n *NoDiff) ReadBlock(_ context.Context, _ []byte, _ int64, _ *storage.FrameTable) (int, error) { return 0, NoDiffError{} } @@ -58,10 +58,6 @@ func (n *NoDiff) FileSize() (int64, error) { return 0, NoDiffError{} } -func (n *NoDiff) Size(_ context.Context) (int64, error) { - return 0, NoDiffError{} -} - func (n *NoDiff) CacheKey() DiffStoreKey { return "" } diff --git a/packages/orchestrator/internal/sandbox/build/local_diff.go b/packages/orchestrator/internal/sandbox/build/local_diff.go index 9936650986..f3718adc74 100644 --- a/packages/orchestrator/internal/sandbox/build/local_diff.go +++ b/packages/orchestrator/internal/sandbox/build/local_diff.go @@ -6,6 +6,7 @@ import ( "os" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) type LocalDiffFile struct { @@ -114,18 +115,14 @@ func (b *localDiff) Close() error { return b.cache.Close() } -func (b *localDiff) ReadAt(_ context.Context, p []byte, off int64) (int, error) { +func (b *localDiff) ReadBlock(_ context.Context, p []byte, off int64, _ *storage.FrameTable) (int, error) { return b.cache.ReadAt(p, off) } -func (b *localDiff) Slice(_ context.Context, off, length int64) ([]byte, error) { +func (b *localDiff) GetBlock(_ context.Context, off, length int64, _ *storage.FrameTable) ([]byte, error) { return b.cache.Slice(off, length) } -func (b *localDiff) Size(_ context.Context) (int64, error) { - return b.cache.Size() -} - func (b *localDiff) FileSize() (int64, error) { return b.cache.FileSize() } diff --git a/packages/orchestrator/internal/sandbox/build/storage_diff.go b/packages/orchestrator/internal/sandbox/build/storage_diff.go index 1b5e6756a4..1dbd8cfbb0 100644 --- a/packages/orchestrator/internal/sandbox/build/storage_diff.go +++ b/packages/orchestrator/internal/sandbox/build/storage_diff.go @@ -3,7 +3,8 @@ package build import ( "context" "fmt" - "io" + + "golang.org/x/sync/errgroup" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" @@ -17,11 +18,10 @@ func storagePath(buildId string, diffType DiffType) string { } type StorageDiff struct { - chunker *utils.SetOnce[block.Chunker] - cachePath string - cacheKey DiffStoreKey - storagePath string - storageObjectType storage.SeekableObjectType + chunker *utils.SetOnce[*block.Chunker] + cachePath string + cacheKey DiffStoreKey + storagePath string blockSize int64 metrics blockmetrics.Metrics @@ -49,35 +49,26 @@ func newStorageDiff( featureFlags *featureflags.Client, ) (*StorageDiff, error) { storagePath := storagePath(buildId, diffType) - storageObjectType, ok := storageObjectType(diffType) - if !ok { + if !isKnownDiffType(diffType) { return nil, UnknownDiffTypeError{diffType} } cachePath := GenerateDiffCachePath(basePath, buildId, diffType) return &StorageDiff{ - storagePath: storagePath, - storageObjectType: storageObjectType, - cachePath: cachePath, - chunker: utils.NewSetOnce[block.Chunker](), - blockSize: blockSize, - metrics: metrics, - persistence: persistence, - featureFlags: featureFlags, - cacheKey: GetDiffStoreKey(buildId, diffType), + storagePath: storagePath, + cachePath: cachePath, + chunker: utils.NewSetOnce[*block.Chunker](), + blockSize: blockSize, + metrics: metrics, + persistence: persistence, + featureFlags: featureFlags, + cacheKey: GetDiffStoreKey(buildId, diffType), }, nil } -func storageObjectType(diffType DiffType) (storage.SeekableObjectType, bool) { - switch diffType { - case Memfile: - return storage.MemfileObjectType, true - case Rootfs: - return storage.RootFSObjectType, true - default: - return storage.UnknownSeekableObjectType, false - } +func isKnownDiffType(diffType DiffType) bool { + return diffType == Memfile || diffType == Rootfs } func (b *StorageDiff) CacheKey() DiffStoreKey { @@ -85,28 +76,112 @@ func (b *StorageDiff) CacheKey() DiffStoreKey { } func (b *StorageDiff) Init(ctx context.Context) error { - obj, err := b.persistence.OpenSeekable(ctx, b.storagePath, b.storageObjectType) + chunker, err := b.createChunker(ctx) if err != nil { - return err - } - - size, err := obj.Size(ctx) - if err != nil { - errMsg := fmt.Errorf("failed to get object size: %w", err) + errMsg := fmt.Errorf("failed to create chunker: %w", err) b.chunker.SetError(errMsg) return errMsg } - c, err := block.NewChunker(ctx, b.featureFlags, size, b.blockSize, obj, b.cachePath, b.metrics) - if err != nil { - errMsg := fmt.Errorf("failed to create chunker: %w", err) - b.chunker.SetError(errMsg) + return b.chunker.SetValue(chunker) +} - return errMsg +// createChunker probes for available assets and creates a Chunker. +func (b *StorageDiff) createChunker(ctx context.Context) (*block.Chunker, error) { + assets := b.probeAssets(ctx) + if assets.Size == 0 { + return nil, fmt.Errorf("no asset found for %s (no uncompressed or compressed with metadata)", b.storagePath) } - return b.chunker.SetValue(c) + return block.NewChunker(assets, b.blockSize, b.cachePath, b.metrics, b.featureFlags) +} + +// probeAssets probes for uncompressed and compressed asset variants in parallel. +// For compressed objects, Size() returns the uncompressed size from metadata, +// allowing us to derive the mmap allocation size even when the uncompressed +// object doesn't exist. +func (b *StorageDiff) probeAssets(ctx context.Context) block.AssetInfo { + assets := block.AssetInfo{BasePath: b.storagePath} + + var ( + lz4UncompressedSize int64 + zstdUncompressedSize int64 + ) + + // Probe all 3 paths in parallel: uncompressed, v4.*.lz4, v4.*.zstd. + // Errors are swallowed (missing assets are expected). + eg, ctx := errgroup.WithContext(ctx) + + eg.Go(func() error { + obj, err := b.persistence.OpenFramedFile(ctx, b.storagePath) + if err != nil { + return nil //nolint:nilerr // missing asset is expected + } + + uncompressedSize, err := obj.Size(ctx) + if err != nil { + return nil //nolint:nilerr // missing asset is expected + } + + assets.Size = uncompressedSize + assets.HasUncompressed = true + assets.Uncompressed = obj + + return nil + }) + + eg.Go(func() error { + lz4Path := storage.V4DataPath(b.storagePath, storage.CompressionLZ4) + obj, err := b.persistence.OpenFramedFile(ctx, lz4Path) + if err != nil { + return nil //nolint:nilerr // missing asset is expected + } + + uncompressedSize, err := obj.Size(ctx) + if err != nil { + return nil //nolint:nilerr // missing asset is expected + } + + assets.HasLZ4 = true + assets.LZ4 = obj + lz4UncompressedSize = uncompressedSize + + return nil + }) + + eg.Go(func() error { + zstdPath := storage.V4DataPath(b.storagePath, storage.CompressionZstd) + obj, err := b.persistence.OpenFramedFile(ctx, zstdPath) + if err != nil { + return nil //nolint:nilerr // missing asset is expected + } + + uncompressedSize, err := obj.Size(ctx) + if err != nil { + return nil //nolint:nilerr // missing asset is expected + } + + assets.HasZstd = true + assets.Zstd = obj + zstdUncompressedSize = uncompressedSize + + return nil + }) + + _ = eg.Wait() + + // If no uncompressed object exists, derive the mmap allocation size + // from the compressed object's uncompressed-size metadata. + if assets.Size == 0 { + if lz4UncompressedSize > 0 { + assets.Size = lz4UncompressedSize + } else if zstdUncompressedSize > 0 { + assets.Size = zstdUncompressedSize + } + } + + return assets } func (b *StorageDiff) Close() error { @@ -118,31 +193,22 @@ func (b *StorageDiff) Close() error { return c.Close() } -func (b *StorageDiff) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { - c, err := b.chunker.Wait() +func (b *StorageDiff) ReadBlock(ctx context.Context, p []byte, off int64, ft *storage.FrameTable) (int, error) { + chunker, err := b.chunker.Wait() if err != nil { return 0, err } - return c.ReadAt(ctx, p, off) + return chunker.ReadBlock(ctx, p, off, ft) } -func (b *StorageDiff) Slice(ctx context.Context, off, length int64) ([]byte, error) { - c, err := b.chunker.Wait() +func (b *StorageDiff) GetBlock(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) { + chunker, err := b.chunker.Wait() if err != nil { return nil, err } - return c.Slice(ctx, off, length) -} - -func (b *StorageDiff) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - c, err := b.chunker.Wait() - if err != nil { - return 0, err - } - - return c.WriteTo(ctx, w) + return chunker.GetBlock(ctx, off, length, ft) } // The local file might not be synced. @@ -159,10 +225,6 @@ func (b *StorageDiff) FileSize() (int64, error) { return c.FileSize() } -func (b *StorageDiff) Size(_ context.Context) (int64, error) { - return b.FileSize() -} - func (b *StorageDiff) BlockSize() int64 { return b.blockSize } diff --git a/packages/orchestrator/internal/sandbox/nbd/dispatch.go b/packages/orchestrator/internal/sandbox/nbd/dispatch.go index 3a40e79c71..ad051e3f64 100644 --- a/packages/orchestrator/internal/sandbox/nbd/dispatch.go +++ b/packages/orchestrator/internal/sandbox/nbd/dispatch.go @@ -11,13 +11,13 @@ import ( "go.uber.org/zap" "github.com/e2b-dev/infra/packages/shared/pkg/logger" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) var ErrShuttingDown = errors.New("shutting down. Cannot serve any new requests") type Provider interface { - storage.SeekableReader + ReadAt(ctx context.Context, p []byte, off int64) (int, error) + Size(ctx context.Context) (int64, error) io.WriterAt } diff --git a/packages/orchestrator/internal/sandbox/nbd/testutils/template_rootfs.go b/packages/orchestrator/internal/sandbox/nbd/testutils/template_rootfs.go index c2145bc0e9..c767626a8f 100644 --- a/packages/orchestrator/internal/sandbox/nbd/testutils/template_rootfs.go +++ b/packages/orchestrator/internal/sandbox/nbd/testutils/template_rootfs.go @@ -30,7 +30,7 @@ func TemplateRootfs(ctx context.Context, buildID string) (*BuildDevice, *Cleaner return nil, &cleaner, fmt.Errorf("failed to get storage provider: %w", err) } - obj, err := s.OpenBlob(ctx, files.StorageRootfsHeaderPath(), storage.RootFSHeaderObjectType) + obj, err := s.OpenBlob(ctx, files.StorageRootfsHeaderPath()) if err != nil { return nil, &cleaner, fmt.Errorf("failed to open object: %w", err) } @@ -42,7 +42,7 @@ func TemplateRootfs(ctx context.Context, buildID string) (*BuildDevice, *Cleaner return nil, &cleaner, fmt.Errorf("failed to parse build id: %w", err) } - r, err := s.OpenSeekable(ctx, files.StorageRootfsPath(), storage.RootFSObjectType) + r, err := s.OpenFramedFile(ctx, files.StorageRootfsPath()) if err != nil { return nil, &cleaner, fmt.Errorf("failed to open object: %w", err) } @@ -112,7 +112,7 @@ func TemplateRootfs(ctx context.Context, buildID string) (*BuildDevice, *Cleaner } buildDevice := NewBuildDevice( - build.NewFile(h, store, build.Rootfs, s, m), + build.NewFile(h, store, build.Rootfs, s, m, flags), h, int64(h.Metadata.BlockSize), ) diff --git a/packages/orchestrator/internal/sandbox/pending_frame_tables.go b/packages/orchestrator/internal/sandbox/pending_frame_tables.go new file mode 100644 index 0000000000..ab9155a2b0 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/pending_frame_tables.go @@ -0,0 +1,59 @@ +package sandbox + +import ( + "fmt" + "sync" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// PendingFrameTables collects FrameTables from compressed data uploads across +// all layers. After all data files are uploaded, the collected tables are applied +// to headers before the compressed headers are serialized and uploaded. +type PendingFrameTables struct { + tables sync.Map // key: "buildId/fileType", value: *storage.FrameTable +} + +func pendingFrameTableKey(buildID, fileType string) string { + return buildID + "/" + fileType +} + +func (p *PendingFrameTables) add(key string, ft *storage.FrameTable) { + if ft == nil { + return + } + + p.tables.Store(key, ft) +} + +func (p *PendingFrameTables) get(key string) *storage.FrameTable { + v, ok := p.tables.Load(key) + if !ok { + return nil + } + + return v.(*storage.FrameTable) +} + +func (p *PendingFrameTables) applyToHeader(h *header.Header, fileType string) error { + if h == nil { + return nil + } + + for _, mapping := range h.Mapping { + key := pendingFrameTableKey(mapping.BuildId.String(), fileType) + ft := p.get(key) + + if ft == nil { + continue + } + + if err := mapping.AddFrames(ft); err != nil { + return fmt.Errorf("apply frames to mapping at offset %#x for build %s: %w", + mapping.Offset, mapping.BuildId.String(), err) + } + } + + return nil +} diff --git a/packages/orchestrator/internal/sandbox/snapshot.go b/packages/orchestrator/internal/sandbox/snapshot.go index b4d7330fc7..478a8aa4ee 100644 --- a/packages/orchestrator/internal/sandbox/snapshot.go +++ b/packages/orchestrator/internal/sandbox/snapshot.go @@ -6,7 +6,6 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/build" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/template" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) @@ -21,59 +20,6 @@ type Snapshot struct { cleanup *Cleanup } -func (s *Snapshot) Upload( - ctx context.Context, - persistence storage.StorageProvider, - templateFiles storage.TemplateFiles, -) error { - var memfilePath *string - switch r := s.MemfileDiff.(type) { - case *build.NoDiff: - default: - memfileLocalPath, err := r.CachePath() - if err != nil { - return fmt.Errorf("error getting memfile diff path: %w", err) - } - - memfilePath = &memfileLocalPath - } - - var rootfsPath *string - switch r := s.RootfsDiff.(type) { - case *build.NoDiff: - default: - rootfsLocalPath, err := r.CachePath() - if err != nil { - return fmt.Errorf("error getting rootfs diff path: %w", err) - } - - rootfsPath = &rootfsLocalPath - } - - templateBuild := NewTemplateBuild( - s.MemfileDiffHeader, - s.RootfsDiffHeader, - persistence, - templateFiles, - ) - - uploadErrCh := templateBuild.Upload( - ctx, - s.Metafile.Path(), - s.Snapfile.Path(), - memfilePath, - rootfsPath, - ) - - // Wait for the upload to finish - uploadErr := <-uploadErrCh - if uploadErr != nil { - return fmt.Errorf("error uploading template build: %w", uploadErr) - } - - return nil -} - func (s *Snapshot) Close(ctx context.Context) error { err := s.cleanup.Run(ctx) if err != nil { diff --git a/packages/orchestrator/internal/sandbox/template/cache.go b/packages/orchestrator/internal/sandbox/template/cache.go index 24c9b9322c..bdaf06d056 100644 --- a/packages/orchestrator/internal/sandbox/template/cache.go +++ b/packages/orchestrator/internal/sandbox/template/cache.go @@ -140,7 +140,6 @@ func (c *Cache) GetTemplate( attribute.Bool("is_building", isBuilding), )) defer span.End() - persistence := c.persistence // Because of the template caching, if we enable the NFS cache feature flag, // it will start working only for new orchestrators or new builds. @@ -157,6 +156,7 @@ func (c *Cache) GetTemplate( buildID, nil, nil, + c.flags, persistence, c.blockMetrics, nil, @@ -196,6 +196,7 @@ func (c *Cache) AddSnapshot( buildId, memfileHeader, rootfsHeader, + c.flags, c.persistence, c.blockMetrics, localSnapfile, diff --git a/packages/orchestrator/internal/sandbox/template/storage.go b/packages/orchestrator/internal/sandbox/template/storage.go index 6fd722e87f..32dcec73c1 100644 --- a/packages/orchestrator/internal/sandbox/template/storage.go +++ b/packages/orchestrator/internal/sandbox/template/storage.go @@ -6,9 +6,11 @@ import ( "fmt" "github.com/google/uuid" + "golang.org/x/sync/errgroup" blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/build" + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) @@ -23,26 +25,77 @@ type Storage struct { source *build.File } -func storageHeaderObjectType(diffType build.DiffType) (storage.ObjectType, bool) { - switch diffType { - case build.Memfile: - return storage.MemfileHeaderObjectType, true - case build.Rootfs: - return storage.RootFSHeaderObjectType, true - default: - return storage.UnknownObjectType, false +func isKnownDiffType(diffType build.DiffType) bool { + return diffType == build.Memfile || diffType == build.Rootfs +} + +// loadHeaderV3 loads a v3 header from the standard (uncompressed) path. +// Returns (nil, nil) if not found. +func loadHeaderV3(ctx context.Context, persistence storage.StorageProvider, path string) (*header.Header, error) { + blob, err := persistence.OpenBlob(ctx, path) + if err != nil { + if errors.Is(err, storage.ErrObjectNotExist) { + return nil, nil + } + + return nil, err } + + return header.Deserialize(ctx, blob) } -func objectType(diffType build.DiffType) (storage.SeekableObjectType, bool) { - switch diffType { - case build.Memfile: - return storage.MemfileObjectType, true - case build.Rootfs: - return storage.RootFSObjectType, true - default: - return storage.UnknownSeekableObjectType, false +// loadV4Header loads a v4 header (LZ4 compressed), decompresses, and deserializes it. +// Returns (nil, nil) if not found. +func loadV4Header(ctx context.Context, persistence storage.StorageProvider, path string) (*header.Header, error) { + data, err := storage.LoadBlob(ctx, persistence, path) + if err != nil { + if errors.Is(err, storage.ErrObjectNotExist) { + return nil, nil + } + + return nil, err } + + return header.DeserializeV4(data) +} + +// loadHeaderPreferV4 fetches both v3 and v4 headers in parallel, +// preferring the v4 (compressed) header if available. +func loadHeaderPreferV4(ctx context.Context, persistence storage.StorageProvider, buildId string, fileType build.DiffType) (*header.Header, error) { + files := storage.TemplateFiles{BuildID: buildId} + v3Path := files.HeaderPath(string(fileType)) + v4Path := files.CompressedHeaderPath(string(fileType)) + + var v3Header, v4Header *header.Header + var v3Err, v4Err error + + eg, egCtx := errgroup.WithContext(ctx) + eg.Go(func() error { + v3Header, v3Err = loadHeaderV3(egCtx, persistence, v3Path) + + return nil + }) + eg.Go(func() error { + v4Header, v4Err = loadV4Header(egCtx, persistence, v4Path) + + return nil + }) + _ = eg.Wait() + + if v4Err == nil && v4Header != nil { + return v4Header, nil + } + if v3Err == nil && v3Header != nil { + return v3Header, nil + } + if v4Err != nil { + return nil, v4Err + } + if v3Err != nil { + return nil, v3Err + } + + return nil, nil } func NewStorage( @@ -51,41 +104,38 @@ func NewStorage( buildId string, fileType build.DiffType, h *header.Header, + flags *featureflags.Client, persistence storage.StorageProvider, metrics blockmetrics.Metrics, ) (*Storage, error) { + // Read chunker config from feature flag. + chunkerCfg := flags.JSONFlag(ctx, featureflags.ChunkerConfigFlag).AsValueMap() + useCompressedAssets := chunkerCfg.Get("useCompressedAssets").BoolValue() + if h == nil { - headerObjectPath := buildId + "/" + string(fileType) + storage.HeaderSuffix - headerObjectType, ok := storageHeaderObjectType(fileType) - if !ok { + if !isKnownDiffType(fileType) { return nil, build.UnknownDiffTypeError{DiffType: fileType} } - headerObject, err := persistence.OpenBlob(ctx, headerObjectPath, headerObjectType) + var err error + if useCompressedAssets { + h, err = loadHeaderPreferV4(ctx, persistence, buildId, fileType) + } else { + files := storage.TemplateFiles{BuildID: buildId} + h, err = loadHeaderV3(ctx, persistence, files.HeaderPath(string(fileType))) + } if err != nil { return nil, err } - - diffHeader, err := header.Deserialize(ctx, headerObject) - - // If we can't find the diff header in storage, we switch to templates without a headers - if err != nil && !errors.Is(err, storage.ErrObjectNotExist) { - return nil, fmt.Errorf("failed to deserialize header: %w", err) - } - - if err == nil { - h = diffHeader - } } // If we can't find the diff header in storage, we try to find the "old" style template without a header as a fallback. if h == nil { objectPath := buildId + "/" + string(fileType) - objectType, ok := objectType(fileType) - if !ok { + if !isKnownDiffType(fileType) { return nil, build.UnknownDiffTypeError{DiffType: fileType} } - object, err := persistence.OpenSeekable(ctx, objectPath, objectType) + object, err := persistence.OpenFramedFile(ctx, objectPath) if err != nil { return nil, err } @@ -126,7 +176,7 @@ func NewStorage( } } - b := build.NewFile(h, store, fileType, persistence, metrics) + b := build.NewFile(h, store, fileType, persistence, metrics, flags) return &Storage{ source: b, diff --git a/packages/orchestrator/internal/sandbox/template/storage_file.go b/packages/orchestrator/internal/sandbox/template/storage_file.go index 52eed020f1..fd3256c8b3 100644 --- a/packages/orchestrator/internal/sandbox/template/storage_file.go +++ b/packages/orchestrator/internal/sandbox/template/storage_file.go @@ -18,7 +18,6 @@ func newStorageFile( persistence storage.StorageProvider, objectPath string, path string, - objectType storage.ObjectType, ) (*storageFile, error) { f, err := os.Create(path) if err != nil { @@ -27,7 +26,7 @@ func newStorageFile( defer f.Close() - object, err := persistence.OpenBlob(ctx, objectPath, objectType) + object, err := persistence.OpenBlob(ctx, objectPath) if err != nil { return nil, err } diff --git a/packages/orchestrator/internal/sandbox/template/storage_template.go b/packages/orchestrator/internal/sandbox/template/storage_template.go index b967fc6e28..01f7733518 100644 --- a/packages/orchestrator/internal/sandbox/template/storage_template.go +++ b/packages/orchestrator/internal/sandbox/template/storage_template.go @@ -15,6 +15,7 @@ import ( blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/build" "github.com/e2b-dev/infra/packages/orchestrator/internal/template/metadata" + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" @@ -35,6 +36,7 @@ type storageTemplate struct { localSnapfile File localMetafile File + flags *featureflags.Client metrics blockmetrics.Metrics persistence storage.StorageProvider } @@ -44,6 +46,7 @@ func newTemplateFromStorage( buildId string, memfileHeader *header.Header, rootfsHeader *header.Header, + flags *featureflags.Client, persistence storage.StorageProvider, metrics blockmetrics.Metrics, localSnapfile File, @@ -62,6 +65,7 @@ func newTemplateFromStorage( localMetafile: localMetafile, memfileHeader: memfileHeader, rootfsHeader: rootfsHeader, + flags: flags, metrics: metrics, persistence: persistence, memfile: utils.NewSetOnce[block.ReadonlyDevice](), @@ -76,7 +80,6 @@ func (t *storageTemplate) Fetch(ctx context.Context, buildStore *build.DiffStore telemetry.WithBuildID(t.files.BuildID), )) defer span.End() - var wg errgroup.Group wg.Go(func() error { @@ -93,7 +96,6 @@ func (t *storageTemplate) Fetch(ctx context.Context, buildStore *build.DiffStore t.persistence, t.files.StorageSnapfilePath(), t.files.CacheSnapfilePath(), - storage.SnapfileObjectType, ) if snapfileErr != nil { errMsg := fmt.Errorf("failed to fetch snapfile: %w", snapfileErr) @@ -126,7 +128,6 @@ func (t *storageTemplate) Fetch(ctx context.Context, buildStore *build.DiffStore t.persistence, t.files.StorageMetadataPath(), t.files.CacheMetadataPath(), - storage.MetadataObjectType, ) if err != nil && !errors.Is(err, storage.ErrObjectNotExist) { sourceErr := fmt.Errorf("failed to fetch metafile: %w", err) @@ -178,10 +179,10 @@ func (t *storageTemplate) Fetch(ctx context.Context, buildStore *build.DiffStore t.files.BuildID, build.Memfile, t.memfileHeader, + t.flags, t.persistence, t.metrics, ) - if memfileErr != nil { errMsg := fmt.Errorf("failed to create memfile storage: %w", memfileErr) @@ -206,6 +207,7 @@ func (t *storageTemplate) Fetch(ctx context.Context, buildStore *build.DiffStore t.files.BuildID, build.Rootfs, t.rootfsHeader, + t.flags, t.persistence, t.metrics, ) diff --git a/packages/orchestrator/internal/sandbox/template_build.go b/packages/orchestrator/internal/sandbox/template_build.go index 323c26e068..2d6068fc35 100644 --- a/packages/orchestrator/internal/sandbox/template_build.go +++ b/packages/orchestrator/internal/sandbox/template_build.go @@ -8,6 +8,8 @@ import ( "golang.org/x/sync/errgroup" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/build" + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" "github.com/e2b-dev/infra/packages/shared/pkg/storage" headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) @@ -15,19 +17,63 @@ import ( type TemplateBuild struct { files storage.TemplateFiles persistence storage.StorageProvider + ff *featureflags.Client memfileHeader *headers.Header rootfsHeader *headers.Header + + memfilePath *string + rootfsPath *string + metadataPath string + snapfilePath string + + pending *PendingFrameTables } -func NewTemplateBuild(memfileHeader *headers.Header, rootfsHeader *headers.Header, persistence storage.StorageProvider, files storage.TemplateFiles) *TemplateBuild { +func NewTemplateBuild(snapshot *Snapshot, persistence storage.StorageProvider, files storage.TemplateFiles, ff *featureflags.Client, pending *PendingFrameTables) (*TemplateBuild, error) { + var memfilePath *string + switch r := snapshot.MemfileDiff.(type) { + case *build.NoDiff: + default: + p, err := r.CachePath() + if err != nil { + return nil, fmt.Errorf("error getting memfile diff path: %w", err) + } + + memfilePath = &p + } + + var rootfsPath *string + switch r := snapshot.RootfsDiff.(type) { + case *build.NoDiff: + default: + p, err := r.CachePath() + if err != nil { + return nil, fmt.Errorf("error getting rootfs diff path: %w", err) + } + + rootfsPath = &p + } + + if pending == nil { + pending = &PendingFrameTables{} + } + return &TemplateBuild{ persistence: persistence, files: files, + ff: ff, - memfileHeader: memfileHeader, - rootfsHeader: rootfsHeader, - } + memfileHeader: snapshot.MemfileDiffHeader, + rootfsHeader: snapshot.RootfsDiffHeader, + + memfilePath: memfilePath, + rootfsPath: rootfsPath, + metadataPath: snapshot.Metafile.Path(), + snapfilePath: snapshot.Snapfile.Path(), + + pending: pending, + }, nil } func (t *TemplateBuild) Remove(ctx context.Context) error { @@ -39,8 +85,8 @@ func (t *TemplateBuild) Remove(ctx context.Context) error { return nil } -func (t *TemplateBuild) uploadMemfileHeader(ctx context.Context, h *headers.Header) error { - object, err := t.persistence.OpenBlob(ctx, t.files.StorageMemfileHeaderPath(), storage.MemfileHeaderObjectType) +func (t *TemplateBuild) uploadMemfileHeaderV3(ctx context.Context, h *headers.Header) error { + object, err := t.persistence.OpenBlob(ctx, t.files.StorageMemfileHeaderPath()) if err != nil { return err } @@ -59,21 +105,20 @@ func (t *TemplateBuild) uploadMemfileHeader(ctx context.Context, h *headers.Head } func (t *TemplateBuild) uploadMemfile(ctx context.Context, memfilePath string) error { - object, err := t.persistence.OpenSeekable(ctx, t.files.StorageMemfilePath(), storage.MemfileObjectType) + object, err := t.persistence.OpenFramedFile(ctx, t.files.StorageMemfilePath()) if err != nil { return err } - err = object.StoreFile(ctx, memfilePath) - if err != nil { + if _, err := object.StoreFile(ctx, memfilePath, nil); err != nil { return fmt.Errorf("error when uploading memfile: %w", err) } return nil } -func (t *TemplateBuild) uploadRootfsHeader(ctx context.Context, h *headers.Header) error { - object, err := t.persistence.OpenBlob(ctx, t.files.StorageRootfsHeaderPath(), storage.RootFSHeaderObjectType) +func (t *TemplateBuild) uploadRootfsHeaderV3(ctx context.Context, h *headers.Header) error { + object, err := t.persistence.OpenBlob(ctx, t.files.StorageRootfsHeaderPath()) if err != nil { return err } @@ -92,13 +137,12 @@ func (t *TemplateBuild) uploadRootfsHeader(ctx context.Context, h *headers.Heade } func (t *TemplateBuild) uploadRootfs(ctx context.Context, rootfsPath string) error { - object, err := t.persistence.OpenSeekable(ctx, t.files.StorageRootfsPath(), storage.RootFSObjectType) + object, err := t.persistence.OpenFramedFile(ctx, t.files.StorageRootfsPath()) if err != nil { return err } - err = object.StoreFile(ctx, rootfsPath) - if err != nil { + if _, err := object.StoreFile(ctx, rootfsPath, nil); err != nil { return fmt.Errorf("error when uploading rootfs: %w", err) } @@ -107,7 +151,7 @@ func (t *TemplateBuild) uploadRootfs(ctx context.Context, rootfsPath string) err // Snap-file is small enough so we don't use composite upload. func (t *TemplateBuild) uploadSnapfile(ctx context.Context, path string) error { - object, err := t.persistence.OpenBlob(ctx, t.files.StorageSnapfilePath(), storage.SnapfileObjectType) + object, err := t.persistence.OpenBlob(ctx, t.files.StorageSnapfilePath()) if err != nil { return err } @@ -121,7 +165,7 @@ func (t *TemplateBuild) uploadSnapfile(ctx context.Context, path string) error { // Metadata is small enough so we don't use composite upload. func (t *TemplateBuild) uploadMetadata(ctx context.Context, path string) error { - object, err := t.persistence.OpenBlob(ctx, t.files.StorageMetadataPath(), storage.MetadataObjectType) + object, err := t.persistence.OpenBlob(ctx, t.files.StorageMetadataPath()) if err != nil { return err } @@ -153,78 +197,187 @@ func uploadFileAsBlob(ctx context.Context, b storage.Blob, path string) error { return nil } -func (t *TemplateBuild) Upload(ctx context.Context, metadataPath string, fcSnapfilePath string, memfilePath *string, rootfsPath *string) chan error { +// UploadExceptV4Headers uploads all template build files except compressed (V4) headers. +// This includes: V3 headers, uncompressed data, compressed data (when enabled via +// feature flag), snapfile, and metadata. Frame tables from compressed uploads are +// registered in the shared PendingFrameTables for later use by UploadV4Header. +// Returns true if compression was enabled (i.e. V4 headers need uploading). +func (t *TemplateBuild) UploadExceptV4Headers(ctx context.Context) (hasCompressed bool, err error) { + compressOpts := storage.GetUploadOptions(ctx, t.ff) eg, ctx := errgroup.WithContext(ctx) + buildID := t.files.BuildID + // Uncompressed headers (always) eg.Go(func() error { if t.rootfsHeader == nil { return nil } - err := t.uploadRootfsHeader(ctx, t.rootfsHeader) - if err != nil { - return err - } - - return nil + return t.uploadRootfsHeaderV3(ctx, t.rootfsHeader) }) eg.Go(func() error { - if rootfsPath == nil { + if t.memfileHeader == nil { return nil } - err := t.uploadRootfs(ctx, *rootfsPath) - if err != nil { - return err - } - - return nil + return t.uploadMemfileHeaderV3(ctx, t.memfileHeader) }) + // Uncompressed data (always, for rollback safety) eg.Go(func() error { - if t.memfileHeader == nil { + if t.rootfsPath == nil { return nil } - err := t.uploadMemfileHeader(ctx, t.memfileHeader) - if err != nil { - return err - } - - return nil + return t.uploadRootfs(ctx, *t.rootfsPath) }) eg.Go(func() error { - if memfilePath == nil { + if t.memfilePath == nil { return nil } - err := t.uploadMemfile(ctx, *memfilePath) - if err != nil { - return err + return t.uploadMemfile(ctx, *t.memfilePath) + }) + + // Compressed data (when enabled) + if compressOpts != nil { + if t.memfilePath != nil { + hasCompressed = true + + eg.Go(func() error { + ft, err := t.uploadCompressed(ctx, *t.memfilePath, storage.MemfileName, compressOpts) + if err != nil { + return fmt.Errorf("compressed memfile upload: %w", err) + } + + t.pending.add(pendingFrameTableKey(buildID, storage.MemfileName), ft) + + return nil + }) } - return nil - }) + if t.rootfsPath != nil { + hasCompressed = true - eg.Go(func() error { - if err := t.uploadSnapfile(ctx, fcSnapfilePath); err != nil { - return fmt.Errorf("error when uploading snapfile: %w", err) + eg.Go(func() error { + ft, err := t.uploadCompressed(ctx, *t.rootfsPath, storage.RootfsName, compressOpts) + if err != nil { + return fmt.Errorf("compressed rootfs upload: %w", err) + } + + t.pending.add(pendingFrameTableKey(buildID, storage.RootfsName), ft) + + return nil + }) } + } - return nil + // Snapfile + metadata + eg.Go(func() error { + return t.uploadSnapfile(ctx, t.snapfilePath) }) eg.Go(func() error { - return t.uploadMetadata(ctx, metadataPath) + return t.uploadMetadata(ctx, t.metadataPath) }) - done := make(chan error) + if err := eg.Wait(); err != nil { + return false, err + } + + return hasCompressed, nil +} + +// uploadCompressed compresses and uploads a file to the compressed data path. +func (t *TemplateBuild) uploadCompressed(ctx context.Context, localPath, fileName string, opts *storage.FramedUploadOptions) (*storage.FrameTable, error) { + objectPath := t.files.CompressedDataPath(fileName, opts.CompressionType) + + object, err := t.persistence.OpenFramedFile(ctx, objectPath) + if err != nil { + return nil, fmt.Errorf("error opening framed file for %s: %w", objectPath, err) + } + + ft, err := object.StoreFile(ctx, localPath, opts) + if err != nil { + return nil, fmt.Errorf("error compressing %s to %s: %w", fileName, objectPath, err) + } + + return ft, nil +} + +// serializeAndUploadHeader serializes a header as v4 compressed format, LZ4-compresses it, +// and uploads to the compressed header path. +func (t *TemplateBuild) serializeAndUploadHeader(ctx context.Context, h *headers.Header, fileType string) error { + meta := *h.Metadata + meta.Version = headers.MetadataVersionCompressed + + serialized, err := headers.Serialize(&meta, h.Mapping) + if err != nil { + return fmt.Errorf("serialize compressed %s header: %w", fileType, err) + } - go func() { - done <- eg.Wait() - }() + compressed, err := storage.CompressLZ4(serialized) + if err != nil { + return fmt.Errorf("compress %s header: %w", fileType, err) + } + + objectPath := t.files.CompressedHeaderPath(fileType) + blob, err := t.persistence.OpenBlob(ctx, objectPath) + if err != nil { + return fmt.Errorf("open blob for compressed %s header: %w", fileType, err) + } + + if err := blob.Put(ctx, compressed); err != nil { + return fmt.Errorf("upload compressed %s header: %w", fileType, err) + } - return done + return nil +} + +// UploadV4Header applies pending frame tables to headers and uploads them as V4 compressed format. +// Frame tables must have been registered by a prior UploadExceptV4Headers call. +func (t *TemplateBuild) UploadV4Header(ctx context.Context) error { + eg, ctx := errgroup.WithContext(ctx) + + if t.memfileHeader != nil { + eg.Go(func() error { + if err := t.pending.applyToHeader(t.memfileHeader, storage.MemfileName); err != nil { + return fmt.Errorf("apply frames to memfile header: %w", err) + } + + return t.serializeAndUploadHeader(ctx, t.memfileHeader, storage.MemfileName) + }) + } + + if t.rootfsHeader != nil { + eg.Go(func() error { + if err := t.pending.applyToHeader(t.rootfsHeader, storage.RootfsName); err != nil { + return fmt.Errorf("apply frames to rootfs header: %w", err) + } + + return t.serializeAndUploadHeader(ctx, t.rootfsHeader, storage.RootfsName) + }) + } + + return eg.Wait() +} + +// UploadAll uploads all template build files including V4 headers for a single-layer build. +// For multi-layer builds, use UploadExceptV4Headers + UploadV4Header with a shared +// PendingFrameTables instead. +func (t *TemplateBuild) UploadAll(ctx context.Context) error { + hasCompressed, err := t.UploadExceptV4Headers(ctx) + if err != nil { + return err + } + + if hasCompressed { + if err := t.UploadV4Header(ctx); err != nil { + return fmt.Errorf("error uploading compressed headers: %w", err) + } + } + + return nil } diff --git a/packages/orchestrator/internal/server/sandboxes.go b/packages/orchestrator/internal/server/sandboxes.go index e2949395e0..9b1bf78ce5 100644 --- a/packages/orchestrator/internal/server/sandboxes.go +++ b/packages/orchestrator/internal/server/sandboxes.go @@ -55,7 +55,6 @@ func (s *Server) Create(ctx context.Context, req *orchestrator.SandboxCreateRequ // set up tracing ctx, childSpan := tracer.Start(ctx, "sandbox-create") defer childSpan.End() - childSpan.SetAttributes( telemetry.WithTemplateID(req.GetSandbox().GetTemplateId()), attribute.String("kernel.version", req.GetSandbox().GetKernelVersion()), @@ -112,7 +111,6 @@ func (s *Server) Create(ctx context.Context, req *orchestrator.SandboxCreateRequ if err != nil { return nil, fmt.Errorf("failed to get template snapshot data: %w", err) } - // Clone the network config to avoid modifying the original request network := proto.CloneOf(req.GetSandbox().GetNetwork()) @@ -602,12 +600,16 @@ func (s *Server) snapshotAndCacheSandbox( telemetry.ReportEvent(ctx, "added snapshot to template cache") // Start upload in background, return a wait function + tb, err := sandbox.NewTemplateBuild(snapshot, s.persistence, storage.TemplateFiles{BuildID: meta.Template.BuildID}, s.featureFlags, nil) + if err != nil { + return metadata.Template{}, nil, fmt.Errorf("error creating template build: %w", err) + } + uploadCtx := context.WithoutCancel(ctx) errCh := make(chan error, 1) go func() { - err := snapshot.Upload(uploadCtx, s.persistence, storage.TemplateFiles{BuildID: meta.Template.BuildID}) - if err != nil { + if err := tb.UploadAll(uploadCtx); err != nil { sbxlogger.I(sbx).Error(uploadCtx, "error uploading snapshot", zap.Error(err)) errCh <- err diff --git a/packages/orchestrator/internal/template/build/builder.go b/packages/orchestrator/internal/template/build/builder.go index cd7fec3e4c..886921edcf 100644 --- a/packages/orchestrator/internal/template/build/builder.go +++ b/packages/orchestrator/internal/template/build/builder.go @@ -269,6 +269,7 @@ func runBuild( builder.buildStorage, index, uploadTracker, + builder.featureFlags, ) baseBuilder := base.New( @@ -404,7 +405,7 @@ func getRootfsSize( s storage.StorageProvider, metadata storage.TemplateFiles, ) (uint64, error) { - obj, err := s.OpenBlob(ctx, metadata.StorageRootfsHeaderPath(), storage.RootFSHeaderObjectType) + obj, err := s.OpenBlob(ctx, metadata.StorageRootfsHeaderPath()) if err != nil { return 0, fmt.Errorf("error opening rootfs header object: %w", err) } diff --git a/packages/orchestrator/internal/template/build/commands/copy.go b/packages/orchestrator/internal/template/build/commands/copy.go index f8fdca1111..70131b55db 100644 --- a/packages/orchestrator/internal/template/build/commands/copy.go +++ b/packages/orchestrator/internal/template/build/commands/copy.go @@ -80,7 +80,7 @@ func (c *Copy) Execute( } // 1) Download the layer tar file from the storage to the local filesystem - obj, err := c.FilesStorage.OpenBlob(ctx, paths.GetLayerFilesCachePath(c.CacheScope, step.GetFilesHash()), storage.BuildLayerFileObjectType) + obj, err := c.FilesStorage.OpenBlob(ctx, paths.GetLayerFilesCachePath(c.CacheScope, step.GetFilesHash())) if err != nil { return metadata.Context{}, fmt.Errorf("failed to open files object from storage: %w", err) } diff --git a/packages/orchestrator/internal/template/build/layer/layer_executor.go b/packages/orchestrator/internal/template/build/layer/layer_executor.go index 23466dddee..cb95c7f24d 100644 --- a/packages/orchestrator/internal/template/build/layer/layer_executor.go +++ b/packages/orchestrator/internal/template/build/layer/layer_executor.go @@ -16,6 +16,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/template/build/sandboxtools" "github.com/e2b-dev/infra/packages/orchestrator/internal/template/build/storage/cache" "github.com/e2b-dev/infra/packages/orchestrator/internal/template/metadata" + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) @@ -34,6 +35,7 @@ type LayerExecutor struct { buildStorage storage.StorageProvider index cache.Index uploadTracker *UploadTracker + featureFlags *featureflags.Client } func NewLayerExecutor( @@ -46,6 +48,7 @@ func NewLayerExecutor( buildStorage storage.StorageProvider, index cache.Index, uploadTracker *UploadTracker, + featureFlags *featureflags.Client, ) *LayerExecutor { return &LayerExecutor{ BuildContext: buildContext, @@ -59,6 +62,7 @@ func NewLayerExecutor( buildStorage: buildStorage, index: index, uploadTracker: uploadTracker, + featureFlags: featureFlags, } } @@ -280,46 +284,57 @@ func (lb *LayerExecutor) PauseAndUpload( // Upload snapshot async, it's added to the template cache immediately userLogger.Debug(ctx, fmt.Sprintf("Saving: %s", meta.Template.BuildID)) - // Register this upload and get functions to signal completion and wait for previous uploads + // Pipeline per layer: + // 1. Upload all files (uncompressed + compressed, except the V4 headers) — parallel across layers + // 2. Wait for previous layers to complete (data + headers) + // 3. Finalize compressed headers — all upstream FTs now available + // 4. Signal complete, save cache index completeUpload, waitForPreviousUploads := lb.uploadTracker.StartUpload() + buildID := meta.Template.BuildID + + tb, err := sandbox.NewTemplateBuild(snapshot, lb.templateStorage, storage.TemplateFiles{BuildID: buildID}, lb.featureFlags, lb.uploadTracker.Pending()) + if err != nil { + completeUpload() + + return fmt.Errorf("error creating template build: %w", err) + } lb.UploadErrGroup.Go(func() error { ctx := context.WithoutCancel(ctx) - ctx, span := tracer.Start(ctx, "upload snapshot") + ctx, span := tracer.Start(ctx, "upload layer") defer span.End() - // Always signal completion to unblock waiting goroutines, even on error. - // This prevents deadlocks when an earlier layer fails - later layers can - // still unblock and the errgroup can properly collect all errors. + // Signal completion when done (including on error) to unblock downstream layers. defer completeUpload() - err := snapshot.Upload( - ctx, - lb.templateStorage, - storage.TemplateFiles{BuildID: meta.Template.BuildID}, - ) + // Step 1: Upload everything except V4 headers (parallel across layers) + hasCompressed, err := tb.UploadExceptV4Headers(ctx) if err != nil { - return fmt.Errorf("error uploading snapshot: %w", err) + return fmt.Errorf("error uploading data files: %w", err) } - // Wait for all previous layer uploads to complete before saving the cache entry. - // This prevents race conditions where another build hits this cache entry - // before its dependencies (previous layers) are available in storage. - err = waitForPreviousUploads(ctx) - if err != nil { + // Step 2: Wait for all previous layers (data + headers) to complete + if err := waitForPreviousUploads(ctx); err != nil { return fmt.Errorf("error waiting for previous uploads: %w", err) } - err = lb.index.SaveLayerMeta(ctx, hash, cache.LayerMetadata{ + // Step 3: Finalize V4 compressed headers — all upstream FTs are now in pending + if hasCompressed { + if err := tb.UploadV4Header(ctx); err != nil { + return fmt.Errorf("error uploading compressed headers: %w", err) + } + } + + // Step 4: Save cache index + if err := lb.index.SaveLayerMeta(ctx, hash, cache.LayerMetadata{ Template: cache.Template{ - BuildID: meta.Template.BuildID, + BuildID: buildID, }, - }) - if err != nil { + }); err != nil { return fmt.Errorf("error saving UUID to hash mapping: %w", err) } - userLogger.Debug(ctx, fmt.Sprintf("Saved: %s", meta.Template.BuildID)) + userLogger.Debug(ctx, fmt.Sprintf("Saved: %s", buildID)) return nil }) diff --git a/packages/orchestrator/internal/template/build/layer/upload_tracker.go b/packages/orchestrator/internal/template/build/layer/upload_tracker.go index 213938f147..6105153818 100644 --- a/packages/orchestrator/internal/template/build/layer/upload_tracker.go +++ b/packages/orchestrator/internal/template/build/layer/upload_tracker.go @@ -3,22 +3,36 @@ package layer import ( "context" "sync" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox" ) -// UploadTracker tracks in-flight uploads and allows waiting for all previous uploads to complete. -// This prevents race conditions where a layer's cache entry is saved before its -// dependencies (previous layers) are fully uploaded. +// UploadTracker tracks in-flight layer uploads and provides ordering guarantees. +// +// Each layer's upload proceeds as: data files → wait for previous → compressed headers → save cache. +// waitForPreviousUploads ensures that by the time layer N finalizes its compressed headers, +// all upstream layers (0..N-1) have completed both their data uploads and header uploads, +// so all upstream frame tables are available in the shared PendingFrameTables. type UploadTracker struct { mu sync.Mutex waitChs []chan struct{} + + // pending collects frame tables from compressed uploads across all layers. + pending *sandbox.PendingFrameTables } func NewUploadTracker() *UploadTracker { return &UploadTracker{ waitChs: make([]chan struct{}, 0), + pending: &sandbox.PendingFrameTables{}, } } +// Pending returns the shared PendingFrameTables for collecting frame tables. +func (t *UploadTracker) Pending() *sandbox.PendingFrameTables { + return t.pending +} + // StartUpload registers that a new upload has started. // Returns a function that should be called when the upload completes. func (t *UploadTracker) StartUpload() (complete func(), waitForPrevious func(context.Context) error) { diff --git a/packages/orchestrator/internal/template/build/storage/cache/cache.go b/packages/orchestrator/internal/template/build/storage/cache/cache.go index b0ac924073..695a3a6ce1 100644 --- a/packages/orchestrator/internal/template/build/storage/cache/cache.go +++ b/packages/orchestrator/internal/template/build/storage/cache/cache.go @@ -62,14 +62,9 @@ func (h *HashIndex) LayerMetaFromHash(ctx context.Context, hash string) (LayerMe ctx, span := tracer.Start(ctx, "get layer_metadata") defer span.End() - obj, err := h.indexStorage.OpenBlob(ctx, paths.HashToPath(h.cacheScope, hash), storage.LayerMetadataObjectType) + data, err := storage.LoadBlob(ctx, h.indexStorage, paths.HashToPath(h.cacheScope, hash)) if err != nil { - return LayerMetadata{}, fmt.Errorf("error opening object for layer metadata: %w", err) - } - - data, err := storage.GetBlob(ctx, obj) - if err != nil { - return LayerMetadata{}, fmt.Errorf("error reading layer metadata from object: %w", err) + return LayerMetadata{}, fmt.Errorf("error reading layer metadata: %w", err) } var layerMetadata LayerMetadata @@ -89,7 +84,7 @@ func (h *HashIndex) SaveLayerMeta(ctx context.Context, hash string, template Lay ctx, span := tracer.Start(ctx, "save layer_metadata") defer span.End() - obj, err := h.indexStorage.OpenBlob(ctx, paths.HashToPath(h.cacheScope, hash), storage.LayerMetadataObjectType) + obj, err := h.indexStorage.OpenBlob(ctx, paths.HashToPath(h.cacheScope, hash)) if err != nil { return fmt.Errorf("error creating object for saving UUID: %w", err) } diff --git a/packages/orchestrator/internal/template/metadata/prefetch.go b/packages/orchestrator/internal/template/metadata/prefetch.go index 76229773ba..ef450fa4d7 100644 --- a/packages/orchestrator/internal/template/metadata/prefetch.go +++ b/packages/orchestrator/internal/template/metadata/prefetch.go @@ -51,7 +51,7 @@ func UploadMetadata(ctx context.Context, persistence storage.StorageProvider, t templateFiles := storage.TemplateFiles{BuildID: t.Template.BuildID} metadataPath := templateFiles.StorageMetadataPath() - object, err := persistence.OpenBlob(ctx, metadataPath, storage.MetadataObjectType) + object, err := persistence.OpenBlob(ctx, metadataPath) if err != nil { return fmt.Errorf("failed to open metadata object: %w", err) } diff --git a/packages/orchestrator/internal/template/metadata/template_metadata.go b/packages/orchestrator/internal/template/metadata/template_metadata.go index dcac79c075..e4f24a444f 100644 --- a/packages/orchestrator/internal/template/metadata/template_metadata.go +++ b/packages/orchestrator/internal/template/metadata/template_metadata.go @@ -204,14 +204,9 @@ func fromTemplate(ctx context.Context, s storage.StorageProvider, files storage. ctx, span := tracer.Start(ctx, "from template") defer span.End() - obj, err := s.OpenBlob(ctx, files.StorageMetadataPath(), storage.MetadataObjectType) + data, err := storage.LoadBlob(ctx, s, files.StorageMetadataPath()) if err != nil { - return Template{}, fmt.Errorf("error opening object for template metadata: %w", err) - } - - data, err := storage.GetBlob(ctx, obj) - if err != nil { - return Template{}, fmt.Errorf("error reading template metadata from object: %w", err) + return Template{}, fmt.Errorf("error reading template metadata: %w", err) } templateMetadata, err := deserialize(bytes.NewReader(data)) diff --git a/packages/orchestrator/internal/template/server/upload_layer_files_template.go b/packages/orchestrator/internal/template/server/upload_layer_files_template.go index 0830934740..fdd8f1a2e3 100644 --- a/packages/orchestrator/internal/template/server/upload_layer_files_template.go +++ b/packages/orchestrator/internal/template/server/upload_layer_files_template.go @@ -7,7 +7,6 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/template/build/storage/paths" templatemanager "github.com/e2b-dev/infra/packages/shared/pkg/grpc/template-manager" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) const signedUrlExpiration = time.Minute * 30 @@ -23,7 +22,7 @@ func (s *ServerStore) InitLayerFileUpload(ctx context.Context, in *templatemanag } path := paths.GetLayerFilesCachePath(cacheScope, in.GetHash()) - obj, err := s.buildStorage.OpenBlob(ctx, path, storage.BuildLayerFileObjectType) + obj, err := s.buildStorage.OpenBlob(ctx, path) if err != nil { return nil, fmt.Errorf("failed to open layer files cache: %w", err) } diff --git a/packages/orchestrator/main.go b/packages/orchestrator/main.go index 0490675094..6feb265927 100644 --- a/packages/orchestrator/main.go +++ b/packages/orchestrator/main.go @@ -283,6 +283,16 @@ func run(config cfg.Config) (success bool) { } closers = append(closers, closer{"feature flags", featureFlags.Close}) + // Log compression-related feature flags for diagnostics. + chunkerCfg := featureFlags.JSONFlag(ctx, featureflags.ChunkerConfigFlag) + compressCfg := featureFlags.JSONFlag(ctx, featureflags.CompressConfigFlag) + globalLogger.Info(ctx, "Feature flags", + zap.String("chunker-config", chunkerCfg.JSONString()), + zap.String("compress-config", compressCfg.JSONString()), + ) + + storage.InitDecoders(ctx, featureFlags) + if config.DomainName != "" { featureFlags.SetDeploymentName(config.DomainName) } diff --git a/packages/shared/go.mod b/packages/shared/go.mod index 601f719fd9..ada0090d35 100644 --- a/packages/shared/go.mod +++ b/packages/shared/go.mod @@ -30,11 +30,13 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.7 github.com/hashicorp/nomad/api v0.0.0-20251216171439-1dee0671280e github.com/jellydator/ttlcache/v3 v3.4.0 + github.com/klauspost/compress v1.18.2 github.com/launchdarkly/go-sdk-common/v3 v3.3.0 github.com/launchdarkly/go-server-sdk/v7 v7.13.0 github.com/ngrok/firewall_toolkit v0.0.18 github.com/oapi-codegen/runtime v1.1.1 github.com/orcaman/concurrent-map/v2 v2.0.1 + github.com/pierrec/lz4/v4 v4.1.22 github.com/redis/go-redis/extra/redisotel/v9 v9.17.3 github.com/redis/go-redis/v9 v9.17.3 github.com/stretchr/testify v1.11.1 @@ -228,7 +230,6 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/julienschmidt/httprouter v1.3.0 // indirect github.com/kamstrup/intmap v0.5.1 // indirect - github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.11 // indirect github.com/knadh/koanf/maps v0.1.2 // indirect github.com/knadh/koanf/providers/confmap v1.0.0 // indirect @@ -280,7 +281,6 @@ require ( github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/pires/go-proxyproto v0.7.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/packages/shared/pkg/feature-flags/flags.go b/packages/shared/pkg/feature-flags/flags.go index 6ed0d4add7..ca3819786b 100644 --- a/packages/shared/pkg/feature-flags/flags.go +++ b/packages/shared/pkg/feature-flags/flags.go @@ -245,13 +245,48 @@ func GetTrackedTemplatesSet(ctx context.Context, ff *Client) map[string]struct{} // ChunkerConfigFlag is a JSON flag controlling the chunker implementation and tuning. // -// NOTE: Changing useStreaming has no effect on chunkers already created for -// cached templates. A service restart (redeploy) is required for that change -// to take effect. minReadBatchSizeKB is checked just-in-time on each fetch, -// so it takes effect immediately. +// Fields: +// - useCompressedAssets (bool): Try loading v4 compressed headers and use +// the compressed read path. Restart required — no effect on already-cached templates. +// - minReadBatchSizeKB (int): Floor for uncompressed read batch size in KB. +// Applied at chunker creation time; restart required for existing chunkers. // -// JSON format: {"useStreaming": false, "minReadBatchSizeKB": 16} +// JSON format: {"useCompressedAssets": false, "minReadBatchSizeKB": 16} var ChunkerConfigFlag = newJSONFlag("chunker-config", ldvalue.FromJSONMarshal(map[string]any{ - "useStreaming": false, - "minReadBatchSizeKB": 16, + "useCompressedAssets": false, + "minReadBatchSizeKB": 16, +})) + +// CompressConfigFlag is a JSON flag controlling compression behaviour. +// +// Fields: +// - compressBuilds (bool): Enable compressed (dual-write) uploads during +// template builds. Default false. +// - compressionType (string): "lz4" or "zstd". Default "lz4". +// - level (int): Compression level. For LZ4 0=fast, higher=better ratio. Default 3. +// - frameTargetMB (int): Target compressed frame size in MiB. Default 2. +// - frameMaxUncompressedMB (int): Cap on uncompressed bytes per frame in MiB. +// Default 16 (= 4 × MemoryChunkSize). +// - uploadPartTargetMB (int): Target upload part size in MiB. Default 50. +// - encoderConcurrency (int): Goroutines per zstd encoder. Default 1. +// - decoderConcurrency (int): Goroutines per pooled zstd decoder. Default 1. +// +// JSON format: {"compressBuilds": false, "compressionType": "lz4", "level": 3, ...} +// OverrideJSONFlag updates a JSON flag value in the offline store. +// The change is visible immediately to all clients created from the offline store. +// Intended for benchmarks and tests. +func OverrideJSONFlag(flag JSONFlag, value ldvalue.Value) { + builder := launchDarklyOfflineStore.Flag(flag.Key()).ValueForAll(value) + launchDarklyOfflineStore.Update(builder) +} + +var CompressConfigFlag = newJSONFlag("compress-config", ldvalue.FromJSONMarshal(map[string]any{ + "compressBuilds": false, + "compressionType": "zstd", + "level": 2, + "frameTargetMB": 2, + "uploadPartTargetMB": 50, + "frameMaxUncompressedMB": 16, + "encoderConcurrency": 1, + "decoderConcurrency": 1, })) diff --git a/packages/shared/pkg/storage/compressed_upload.go b/packages/shared/pkg/storage/compressed_upload.go new file mode 100644 index 0000000000..f70be0c3ec --- /dev/null +++ b/packages/shared/pkg/storage/compressed_upload.go @@ -0,0 +1,515 @@ +package storage + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "slices" + "sync" + + "github.com/klauspost/compress/zstd" + lz4 "github.com/pierrec/lz4/v4" + "golang.org/x/sync/errgroup" + + featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags" +) + +const ( + defaultTargetFrameSizeC = 2 * megabyte // target compressed frame size + defaultLZ4CompressionLevel = 3 // lz4 compression level (0=fast, higher=better ratio) + defaultCompressionConcurrency = 0 // use default compression concurrency settings + defaultUploadPartSize = 50 * megabyte + + // DefaultMaxFrameUncompressedSize caps the uncompressed bytes in a single frame. + // When a frame's uncompressed size reaches this limit it is flushed regardless + // of the compressed size. 4× MemoryChunkSize = 16 MiB. + DefaultMaxFrameUncompressedSize = 4 * MemoryChunkSize + + // FrameAlignmentSize is the read granularity for compression input. + // Frames are composed of whole chunks of this size, guaranteeing that + // no request served by the chunker (UFFD, NBD, prefetch) ever crosses + // a frame boundary. + // + // This MUST be >= every block/page size the system uses: + // - MemoryChunkSize (4 MiB) — uncompressed fetch unit + // - header.HugepageSize (2 MiB) — UFFD huge-page size + // - header.RootfsBlockSize (4 KiB) — NBD / rootfs block size + // + // Do NOT increase this without also ensuring all compressed frame + // sizes remain exact multiples. Changing it is not free. + FrameAlignmentSize = 1 * MemoryChunkSize +) + +// PartUploader is the interface for uploading data in parts. +// Implementations exist for GCS multipart uploads and local file writes. +type PartUploader interface { + Start(ctx context.Context) error + UploadPart(ctx context.Context, partIndex int, data ...[]byte) error + Complete(ctx context.Context) error +} + +// FramedUploadOptions configures compression for framed uploads. +// Input is read in FrameAlignmentSize chunks; frames are always composed +// of whole chunks so no chunker request ever crosses a frame boundary. +type FramedUploadOptions struct { + CompressionType CompressionType + Level int + CompressionConcurrency int + TargetFrameSize int // frames may be bigger than this due to chunk alignment and async compression. + TargetPartSize int + + // MaxUncompressedFrameSize caps uncompressed bytes per frame. + // 0 = use DefaultMaxFrameUncompressedSize. + MaxUncompressedFrameSize int + + OnFrameReady func(offset FrameOffset, size FrameSize, data []byte) error +} + +// DefaultCompressionOptions is the default compression configuration (LZ4). +var DefaultCompressionOptions = &FramedUploadOptions{ + CompressionType: CompressionLZ4, + TargetFrameSize: defaultTargetFrameSizeC, + Level: defaultLZ4CompressionLevel, + CompressionConcurrency: defaultCompressionConcurrency, + TargetPartSize: defaultUploadPartSize, + MaxUncompressedFrameSize: DefaultMaxFrameUncompressedSize, +} + +// NoCompression indicates no compression should be applied. +var NoCompression = (*FramedUploadOptions)(nil) + +// GetUploadOptions reads the compress-config feature flag and returns +// FramedUploadOptions. Returns nil when compression is disabled. +func GetUploadOptions(ctx context.Context, ff *featureflags.Client) *FramedUploadOptions { + v := ff.JSONFlag(ctx, featureflags.CompressConfigFlag).AsValueMap() + + if !v.Get("compressBuilds").BoolValue() { + return nil + } + + intOr := func(key string, fallback int) int { + if n := v.Get(key).IntValue(); n != 0 { + return n + } + + return fallback + } + strOr := func(key, fallback string) string { + if s := v.Get(key).StringValue(); s != "" { + return s + } + + return fallback + } + + ct := parseCompressionType(strOr("compressionType", "lz4")) + if ct == CompressionNone { + return nil + } + + return &FramedUploadOptions{ + CompressionType: ct, + Level: intOr("level", 3), + TargetFrameSize: intOr("frameTargetMB", 2) * megabyte, + TargetPartSize: intOr("uploadPartTargetMB", 50) * megabyte, + MaxUncompressedFrameSize: intOr("frameMaxUncompressedMB", 16) * megabyte, + CompressionConcurrency: intOr("encoderConcurrency", 1), + } +} + +// InitDecoders reads the compress-config feature flag and sets the pooled +// zstd decoder concurrency. Call once at startup before any reads. +func InitDecoders(ctx context.Context, ff *featureflags.Client) { + v := ff.JSONFlag(ctx, featureflags.CompressConfigFlag).AsValueMap() + n := max(v.Get("decoderConcurrency").IntValue(), 1) + SetDecoderConcurrency(n) +} + +// ValidateCompressionOptions checks that compression options are valid. +func ValidateCompressionOptions(opts *FramedUploadOptions) error { + if opts == nil || opts.CompressionType == CompressionNone { + return nil + } + + return nil +} + +// CompressBytes compresses data using opts and returns the concatenated +// compressed bytes along with the FrameTable. This is a convenience wrapper +// around CompressStream that collects all parts in memory. +func CompressBytes(ctx context.Context, data []byte, opts *FramedUploadOptions) ([]byte, *FrameTable, error) { + up := &memPartUploader{} + + ft, err := CompressStream(ctx, bytes.NewReader(data), opts, up) + if err != nil { + return nil, nil, err + } + + return up.assemble(), ft, nil +} + +// memPartUploader collects compressed parts in memory. +type memPartUploader struct { + parts map[int][]byte +} + +func (m *memPartUploader) Start(context.Context) error { + m.parts = make(map[int][]byte) + + return nil +} + +func (m *memPartUploader) UploadPart(_ context.Context, partIndex int, data ...[]byte) error { + var buf bytes.Buffer + for _, d := range data { + buf.Write(d) + } + m.parts[partIndex] = buf.Bytes() + + return nil +} + +func (m *memPartUploader) Complete(context.Context) error { return nil } + +func (m *memPartUploader) assemble() []byte { + keys := make([]int, 0, len(m.parts)) + for k := range m.parts { + keys = append(keys, k) + } + slices.Sort(keys) + + var buf bytes.Buffer + for _, k := range keys { + buf.Write(m.parts[k]) + } + + return buf.Bytes() +} + +// CompressStream reads from in, compresses using opts, and writes parts through uploader. +// Returns the resulting FrameTable describing the compressed frames. +func CompressStream(ctx context.Context, in io.Reader, opts *FramedUploadOptions, uploader PartUploader) (*FrameTable, error) { + targetPartSize := int64(opts.TargetPartSize) + if targetPartSize == 0 { + targetPartSize = int64(defaultUploadPartSize) + } + enc := newFrameEncoder(opts, uploader, targetPartSize, 4) + + return enc.uploadFramed(ctx, in) +} + +type encoder struct { + opts *FramedUploadOptions + maxUploadConcurrency int + + // frame rotation is protected by mutex + mu sync.Mutex + frame *frame + frameTable *FrameTable + readyFrames [][]byte + offset FrameOffset // tracks cumulative offset for OnFrameReady callback + + // Upload-specific data + targetPartSize int64 + partIndex int + partLen int64 + uploader PartUploader +} + +type frame struct { + e *encoder + enc io.WriteCloser + compressedBuffer *bytes.Buffer + flushing bool + + // lenU is updated by the Copy goroutine when it writes uncompressed data + // into the _current_ frame; can be read without locking after the frame + // starts closing since the incoming data is going to a new frame. + lenU int + + // lenC is updated in the Write() method as compressed data is written into + // the compressedBuffer. It can be read without locking after the frame's + // encoder is flushed (closed). + lenC int +} + +var _ io.Writer = (*frame)(nil) // for compression output + +func newFrameEncoder(opts *FramedUploadOptions, u PartUploader, targetPartSize int64, maxUploadConcurrency int) *encoder { + return &encoder{ + opts: opts, + maxUploadConcurrency: maxUploadConcurrency, + targetPartSize: targetPartSize, + readyFrames: make([][]byte, 0, 8), + uploader: u, + frameTable: &FrameTable{ + CompressionType: opts.CompressionType, + }, + } +} + +func (e *encoder) uploadFramed(ctx context.Context, in io.Reader) (*FrameTable, error) { + // Set up the uploader + uploadEG, uploadCtx := errgroup.WithContext(ctx) + if e.maxUploadConcurrency > 0 { + uploadEG.SetLimit(e.maxUploadConcurrency) + } + + err := e.uploader.Start(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start framed upload: %w", err) + } + + // Start copying file to the compression encoder. Use a return channel + // instead of errgroup to be able to detect completion in the event loop. + // Buffer 8 chunks to allow read-ahead and better pipelining. + chunkCh := make(chan []byte, 8) + readErrorCh := make(chan error, 1) + go e.readFile(ctx, in, FrameAlignmentSize, chunkCh, readErrorCh) + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case err = <-readErrorCh: + return nil, err + + case chunk, haveData := <-chunkCh: + // See if we need to flush and to start a new frame + e.mu.Lock() + var flush *frame + if haveData { + if e.frame == nil || e.frame.flushing { + // Start a new frame and flush the current one + flush = e.frame + if e.frame, err = e.startFrame(); err != nil { + e.mu.Unlock() + + return nil, fmt.Errorf("failed to start frame: %w", err) + } + } + } else { + // No more data; flush current frame + flush = e.frame + } + frame := e.frame + e.mu.Unlock() + + if flush != nil { + if err = e.flushFrame(uploadEG, uploadCtx, flush, !haveData); err != nil { + return nil, fmt.Errorf("failed to flush frame: %w", err) + } + } + + // If we have data, write it to the current frame and continue + if haveData { + if err = e.writeChunk(frame, chunk); err != nil { + return nil, fmt.Errorf("failed to encode to frame: %w", err) + } + + continue + } + + // No more data to process; wait for the uploads to complete and done! + if err = uploadEG.Wait(); err != nil { + return nil, fmt.Errorf("failed to upload frames: %w", err) + } + + if e.uploader != nil { + if err = e.uploader.Complete(ctx); err != nil { + return nil, fmt.Errorf("failed to finish uploading frames: %w", err) + } + } + + return e.frameTable, nil + } + } +} + +func (e *encoder) flushFrame(eg *errgroup.Group, uploadCtx context.Context, f *frame, last bool) error { + if err := f.enc.Close(); err != nil { + return fmt.Errorf("failed to close encoder: %w", err) + } + + ft := FrameSize{ + U: int32(f.lenU), + C: int32(f.lenC), + } + + e.frameTable.Frames = append(e.frameTable.Frames, ft) + + data := f.compressedBuffer.Bytes() + + // Notify callback if provided (e.g., for cache write-through) + if e.opts.OnFrameReady != nil { + if err := e.opts.OnFrameReady(e.offset, ft, data); err != nil { + return fmt.Errorf("OnFrameReady callback failed: %w", err) + } + } + + // Advance offset for next frame + e.offset.Add(ft) + + e.partLen += int64(len(data)) + e.readyFrames = append(e.readyFrames, data) + + if e.partLen >= e.targetPartSize || last { + e.partIndex++ + + i := e.partIndex + frameData := append([][]byte{}, e.readyFrames...) + e.partLen = 0 + e.readyFrames = e.readyFrames[:0] + + eg.Go(func() error { + err := e.uploader.UploadPart(uploadCtx, i, frameData...) + if err != nil { + return fmt.Errorf("failed to upload part %d: %w", i, err) + } + + return nil + }) + } + + return nil +} + +func (e *encoder) readFile(ctx context.Context, in io.Reader, chunkSize int, chunkCh chan<- []byte, errorCh chan<- error) { + for i := 0; ; i++ { + chunk := make([]byte, chunkSize) + n, err := io.ReadFull(in, chunk) + + if err == nil { + if ctxErr := ctx.Err(); ctxErr != nil { + errorCh <- ctxErr + + return + } + chunkCh <- chunk[:n] + + continue + } + + // ErrUnexpectedEOF means a partial read (last chunk shorter than chunkSize). + if errors.Is(err, io.ErrUnexpectedEOF) { + if n > 0 { + chunkCh <- chunk[:n] + } + close(chunkCh) + + return + } + // EOF means no bytes were read at all. + if errors.Is(err, io.EOF) { + close(chunkCh) + + return + } + + errorCh <- fmt.Errorf("failed to read file chunk %d: %w", i, err) + + return + } +} + +func (e *encoder) startFrame() (*frame, error) { + var enc io.WriteCloser + var err error + frame := &frame{ + e: e, + compressedBuffer: bytes.NewBuffer(make([]byte, 0, e.opts.TargetFrameSize+e.opts.TargetFrameSize/2)), // pre-allocate buffer to avoid resizes during compression + } + switch e.opts.CompressionType { + case CompressionZstd: + enc, err = newZstdEncoder(frame, e.opts.CompressionConcurrency, e.opts.TargetFrameSize, zstd.EncoderLevel(e.opts.Level)) + case CompressionLZ4: + enc = newLZ4Encoder(frame, e.opts.Level) + default: + return nil, fmt.Errorf("unsupported compression type: %v", e.opts.CompressionType) + } + if err != nil { + return nil, fmt.Errorf("failed to create encoder: %w", err) + } + frame.enc = enc + + return frame, nil +} + +// writeChunk writes uncompressed data chunk into the frame. len(data) is expected to be <= FrameAlignmentSize. +func (e *encoder) writeChunk(frame *frame, data []byte) error { + for len(data) > 0 { + // Write out data that fits the current chunk + written, err := frame.enc.Write(data) + if err != nil { + return err + } + frame.lenU += written + data = data[written:] + } + + // Enforce uncompressed frame size cap. + maxU := e.opts.MaxUncompressedFrameSize + if maxU == 0 { + maxU = DefaultMaxFrameUncompressedSize + } + if frame.lenU >= maxU { + e.mu.Lock() + frame.flushing = true + e.mu.Unlock() + } + + return nil +} + +// Write implements io.Writer to be used as the output of the compression encoder. +func (frame *frame) Write(p []byte) (n int, err error) { + e := frame.e + n, err = frame.compressedBuffer.Write(p) + frame.lenC += n + + e.mu.Lock() + if frame.lenC < e.opts.TargetFrameSize || frame.flushing { + e.mu.Unlock() + + return n, err + } + frame.flushing = true + e.mu.Unlock() + + return n, err +} + +func newZstdEncoder(out io.Writer, concurrency int, windowSize int, compressionLevel zstd.EncoderLevel) (*zstd.Encoder, error) { + switch { + case concurrency > 0 && windowSize > 0: + return zstd.NewWriter(out, + zstd.WithEncoderConcurrency(concurrency), + zstd.WithWindowSize(windowSize), + zstd.WithEncoderLevel(compressionLevel)) + case concurrency > 0: + return zstd.NewWriter(out, + zstd.WithEncoderConcurrency(concurrency), + zstd.WithEncoderLevel(compressionLevel)) + case windowSize > 0: + return zstd.NewWriter(out, + zstd.WithWindowSize(windowSize), + zstd.WithEncoderLevel(compressionLevel)) + default: + return zstd.NewWriter(out, + zstd.WithEncoderLevel(compressionLevel)) + } +} + +func newLZ4Encoder(out io.Writer, level int) io.WriteCloser { + w := lz4.NewWriter(out) + opts := []lz4.Option{lz4.ConcurrencyOption(1)} + if level > 0 { + opts = append(opts, lz4.CompressionLevelOption(lz4.CompressionLevel(1<<(8+level)))) + } + _ = w.Apply(opts...) + + return w +} diff --git a/packages/shared/pkg/storage/decoders.go b/packages/shared/pkg/storage/decoders.go new file mode 100644 index 0000000000..4e12358290 --- /dev/null +++ b/packages/shared/pkg/storage/decoders.go @@ -0,0 +1,76 @@ +package storage + +import ( + "io" + "sync" + "sync/atomic" + + "github.com/klauspost/compress/zstd" + lz4 "github.com/pierrec/lz4/v4" +) + +var decoderConcurrency atomic.Int32 + +func init() { + decoderConcurrency.Store(1) +} + +// SetDecoderConcurrency sets the number of concurrent goroutines used by +// pooled zstd decoders. Call from orchestrator startup before any reads. +func SetDecoderConcurrency(n int) { + if n < 1 { + n = 1 + } + decoderConcurrency.Store(int32(n)) +} + +// --- zstd pool --- + +var zstdPool sync.Pool + +func getZstdDecoder(r io.Reader) (*zstd.Decoder, error) { + if v := zstdPool.Get(); v != nil { + dec := v.(*zstd.Decoder) + if err := dec.Reset(r); err != nil { + dec.Close() + + return nil, err + } + + return dec, nil + } + + dec, err := zstd.NewReader(r, + zstd.WithDecoderConcurrency(int(decoderConcurrency.Load())), + ) + if err != nil { + return nil, err + } + + return dec, nil +} + +func putZstdDecoder(dec *zstd.Decoder) { + dec.Reset(nil) + zstdPool.Put(dec) +} + +// --- lz4 pool --- + +var lz4Pool sync.Pool + +func getLZ4Reader(r io.Reader) *lz4.Reader { + if v := lz4Pool.Get(); v != nil { + rd := v.(*lz4.Reader) + rd.Reset(r) + + return rd + } + + return lz4.NewReader(r) +} + +func putLZ4Reader(rd *lz4.Reader) { + rd.Reset(nil) + lz4Pool.Put(rd) +} diff --git a/packages/shared/pkg/storage/frame_table.go b/packages/shared/pkg/storage/frame_table.go new file mode 100644 index 0000000000..43b85cd777 --- /dev/null +++ b/packages/shared/pkg/storage/frame_table.go @@ -0,0 +1,259 @@ +package storage + +import ( + "bytes" + "fmt" + "io" +) + +type CompressionType byte + +const ( + CompressionNone = CompressionType(iota) + CompressionZstd + CompressionLZ4 +) + +func (ct CompressionType) Suffix() string { + switch ct { + case CompressionZstd: + return ".zstd" + case CompressionLZ4: + return ".lz4" + default: + return "" + } +} + +func (ct CompressionType) String() string { + switch ct { + case CompressionZstd: + return "zstd" + case CompressionLZ4: + return "lz4" + default: + return "none" + } +} + +// parseCompressionType converts a string to CompressionType. +// Returns CompressionNone for unrecognised values. +func parseCompressionType(s string) CompressionType { + switch s { + case "lz4": + return CompressionLZ4 + case "zstd": + return CompressionZstd + default: + return CompressionNone + } +} + +type FrameOffset struct { + U int64 + C int64 +} + +func (o *FrameOffset) String() string { + return fmt.Sprintf("U:%#x/C:%#x", o.U, o.C) +} + +func (o *FrameOffset) Add(f FrameSize) { + o.U += int64(f.U) + o.C += int64(f.C) +} + +type FrameSize struct { + U int32 + C int32 +} + +func (s FrameSize) String() string { + return fmt.Sprintf("U:%#x/C:%#x", s.U, s.C) +} + +type Range struct { + Start int64 + Length int +} + +func (r Range) String() string { + return fmt.Sprintf("%#x/%#x", r.Start, r.Length) +} + +type FrameTable struct { + CompressionType CompressionType + StartAt FrameOffset + Frames []FrameSize +} + +// CompressionTypeSuffix returns ".lz4", ".zstd", or "" (nil-safe). +func (ft *FrameTable) CompressionTypeSuffix() string { + if ft == nil { + return "" + } + + return ft.CompressionType.Suffix() +} + +// IsCompressed reports whether ft is non-nil and has a compression type set. +func IsCompressed(ft *FrameTable) bool { + return ft != nil && ft.CompressionType != CompressionNone +} + +// Range calls fn for each frame overlapping [start, start+length). +func (ft *FrameTable) Range(start, length int64, fn func(offset FrameOffset, frame FrameSize) error) error { + currentOffset := ft.StartAt + for _, frame := range ft.Frames { + frameEnd := currentOffset.U + int64(frame.U) + requestEnd := start + length + if frameEnd <= start { + currentOffset.U += int64(frame.U) + currentOffset.C += int64(frame.C) + + continue + } + if currentOffset.U >= requestEnd { + break + } + + if err := fn(currentOffset, frame); err != nil { + return err + } + currentOffset.U += int64(frame.U) + currentOffset.C += int64(frame.C) + } + + return nil +} + +func (ft *FrameTable) Size() (uncompressed, compressed int64) { + for _, frame := range ft.Frames { + uncompressed += int64(frame.U) + compressed += int64(frame.C) + } + + return uncompressed, compressed +} + +// Subset returns frames covering r. Whole frames only (can't split compressed). +// Stops silently at the end of the frameset if r extends beyond. +func (ft *FrameTable) Subset(r Range) (*FrameTable, error) { + if ft == nil || r.Length == 0 { + return nil, nil + } + if r.Start < ft.StartAt.U { + return nil, fmt.Errorf("requested range starts before the beginning of the frame table") + } + newFrameTable := &FrameTable{ + CompressionType: ft.CompressionType, + } + + startSet := false + currentOffset := ft.StartAt + requestedEnd := r.Start + int64(r.Length) + for _, frame := range ft.Frames { + frameEnd := currentOffset.U + int64(frame.U) + if frameEnd <= r.Start { + currentOffset.Add(frame) + + continue + } + if currentOffset.U >= requestedEnd { + break + } + + if !startSet { + newFrameTable.StartAt = currentOffset + startSet = true + } + newFrameTable.Frames = append(newFrameTable.Frames, frame) + currentOffset.Add(frame) + } + + if !startSet { + return nil, fmt.Errorf("requested range is beyond the end of the frame table") + } + + return newFrameTable, nil +} + +// FrameFor finds the frame containing the given offset and returns its start position and full size. +func (ft *FrameTable) FrameFor(offset int64) (starts FrameOffset, size FrameSize, err error) { + if ft == nil { + return FrameOffset{}, FrameSize{}, fmt.Errorf("FrameFor called with nil frame table - data is not compressed") + } + + currentOffset := ft.StartAt + for _, frame := range ft.Frames { + frameEnd := currentOffset.U + int64(frame.U) + if offset >= currentOffset.U && offset < frameEnd { + return currentOffset, frame, nil + } + currentOffset.Add(frame) + } + + return FrameOffset{}, FrameSize{}, fmt.Errorf("offset %#x is beyond the end of the frame table", offset) +} + +// GetFetchRange translates a U-space range to C-space using the frame table. +func (ft *FrameTable) GetFetchRange(rangeU Range) (Range, error) { + fetchRange := rangeU + if ft != nil && ft.CompressionType != CompressionNone { + start, size, err := ft.FrameFor(rangeU.Start) + if err != nil { + return Range{}, fmt.Errorf("getting frame for offset %#x: %w", rangeU.Start, err) + } + endOffset := rangeU.Start + int64(rangeU.Length) + frameEnd := start.U + int64(size.U) + if endOffset > frameEnd { + return Range{}, fmt.Errorf("range %v spans beyond frame ending at %#x", rangeU, frameEnd) + } + fetchRange = Range{ + Start: start.C, + Length: int(size.C), + } + } + + return fetchRange, nil +} + +// DecompressReader decompresses from r into a new buffer of uncompressedSize. +func DecompressReader(ct CompressionType, r io.Reader, uncompressedSize int) ([]byte, error) { + buf := make([]byte, uncompressedSize) + + switch ct { + case CompressionZstd: + dec, err := getZstdDecoder(r) + if err != nil { + return nil, fmt.Errorf("failed to create zstd reader: %w", err) + } + defer putZstdDecoder(dec) + + n, err := io.ReadFull(dec, buf) + if err != nil { + return nil, fmt.Errorf("zstd decompress: %w", err) + } + + return buf[:n], nil + + case CompressionLZ4: + rd := getLZ4Reader(r) + defer putLZ4Reader(rd) + + n, err := io.ReadFull(rd, buf) + if err != nil { + return nil, fmt.Errorf("lz4 decompress: %w", err) + } + + return buf[:n], nil + + default: + return nil, fmt.Errorf("unsupported compression type: %d", ct) + } +} + +// DecompressFrame decompresses an in-memory compressed byte slice. +func DecompressFrame(ct CompressionType, compressed []byte, uncompressedSize int32) ([]byte, error) { + return DecompressReader(ct, bytes.NewReader(compressed), int(uncompressedSize)) +} diff --git a/packages/shared/pkg/storage/frame_table_test.go b/packages/shared/pkg/storage/frame_table_test.go new file mode 100644 index 0000000000..89c5128535 --- /dev/null +++ b/packages/shared/pkg/storage/frame_table_test.go @@ -0,0 +1,261 @@ +package storage + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// threeFrameFT returns a FrameTable with three 1MB uncompressed frames +// and varying compressed sizes, starting at the given offset. +func threeFrameFT(startU, startC int64) *FrameTable { + return &FrameTable{ + CompressionType: CompressionLZ4, + StartAt: FrameOffset{U: startU, C: startC}, + Frames: []FrameSize{ + {U: 1 << 20, C: 500_000}, // frame 0 + {U: 1 << 20, C: 600_000}, // frame 1 + {U: 1 << 20, C: 400_000}, // frame 2 + }, + } +} + +// collectRange calls ft.Range and returns the offsets visited. +func collectRange(ft *FrameTable, start, length int64) ([]FrameOffset, error) { + var offsets []FrameOffset + err := ft.Range(start, length, func(offset FrameOffset, _ FrameSize) error { + offsets = append(offsets, offset) + + return nil + }) + + return offsets, err +} + +func TestRange(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("selects all frames", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 0, 3<<20) + require.NoError(t, err) + assert.Len(t, offsets, 3) + }) + + t.Run("selects single middle frame", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 1<<20, 1<<20) + require.NoError(t, err) + require.Len(t, offsets, 1) + assert.Equal(t, int64(1<<20), offsets[0].U) + assert.Equal(t, int64(500_000), offsets[0].C) + }) + + t.Run("partial overlap selects touched frames", func(t *testing.T) { + t.Parallel() + // 1 byte spanning frames 0 and 1 boundary. + offsets, err := collectRange(ft, (1<<20)-1, 2) + require.NoError(t, err) + assert.Len(t, offsets, 2) + }) + + t.Run("beyond end returns nothing", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 3<<20, 1) + require.NoError(t, err) + assert.Empty(t, offsets) + }) + + t.Run("callback error propagates", func(t *testing.T) { + t.Parallel() + sentinel := fmt.Errorf("stop") + err := ft.Range(0, 3<<20, func(_ FrameOffset, _ FrameSize) error { + return sentinel + }) + assert.ErrorIs(t, err, sentinel) + }) + + t.Run("respects StartAt on subset", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 1 << 20, Length: 2 << 20}) + require.NoError(t, err) + + // Query for offset 2MB — the second frame of the subset. + offsets, err := collectRange(sub, 2<<20, 1<<20) + require.NoError(t, err) + require.Len(t, offsets, 1) + assert.Equal(t, int64(2<<20), offsets[0].U) + assert.Equal(t, int64(1_100_000), offsets[0].C) // 500k + 600k + + // Query for offset 0 — before the subset, should find nothing. + offsets, err = collectRange(sub, 0, 1<<20) + require.NoError(t, err) + assert.Empty(t, offsets, "Range should not find frames before StartAt") + }) +} + +func TestSubset(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("full range", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 3 << 20}) + require.NoError(t, err) + assert.Len(t, sub.Frames, 3) + assert.Equal(t, int64(0), sub.StartAt.U) + }) + + t.Run("last frame", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 2 << 20, Length: 1 << 20}) + require.NoError(t, err) + require.Len(t, sub.Frames, 1) + assert.Equal(t, int64(2<<20), sub.StartAt.U) + assert.Equal(t, int64(1_100_000), sub.StartAt.C) + assert.Equal(t, int32(400_000), sub.Frames[0].C) + }) + + t.Run("preserves compression type", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 1 << 20}) + require.NoError(t, err) + assert.Equal(t, CompressionLZ4, sub.CompressionType) + }) + + t.Run("nil table returns nil", func(t *testing.T) { + t.Parallel() + sub, err := (*FrameTable)(nil).Subset(Range{Start: 0, Length: 100}) + require.NoError(t, err) + assert.Nil(t, sub) + }) + + t.Run("zero length returns nil", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 0}) + require.NoError(t, err) + assert.Nil(t, sub) + }) + + t.Run("before StartAt errors", func(t *testing.T) { + t.Parallel() + sub := threeFrameFT(1<<20, 500_000) + _, err := sub.Subset(Range{Start: 0, Length: 1 << 20}) + assert.Error(t, err) + }) + + t.Run("beyond end errors", func(t *testing.T) { + t.Parallel() + _, err := ft.Subset(Range{Start: 4 << 20, Length: 1 << 20}) + assert.Error(t, err) + }) +} + +func TestFrameFor(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("first byte of each frame", func(t *testing.T) { + t.Parallel() + for i, wantU := range []int64{0, 1 << 20, 2 << 20} { + start, size, err := ft.FrameFor(wantU) + require.NoError(t, err, "frame %d", i) + assert.Equal(t, wantU, start.U) + assert.Equal(t, int32(1<<20), size.U) + } + }) + + t.Run("last byte of frame", func(t *testing.T) { + t.Parallel() + start, _, err := ft.FrameFor((1 << 20) - 1) + require.NoError(t, err) + assert.Equal(t, int64(0), start.U) + }) + + t.Run("returns correct C offset", func(t *testing.T) { + t.Parallel() + start, _, err := ft.FrameFor(2 << 20) + require.NoError(t, err) + assert.Equal(t, int64(1_100_000), start.C) // 500k + 600k + }) + + t.Run("beyond end errors", func(t *testing.T) { + t.Parallel() + _, _, err := ft.FrameFor(3 << 20) + assert.Error(t, err) + }) + + t.Run("nil table errors", func(t *testing.T) { + t.Parallel() + _, _, err := (*FrameTable)(nil).FrameFor(0) + assert.Error(t, err) + }) + + t.Run("respects StartAt", func(t *testing.T) { + t.Parallel() + sub := threeFrameFT(1<<20, 500_000) + start, _, err := sub.FrameFor(1 << 20) + require.NoError(t, err) + assert.Equal(t, int64(1<<20), start.U) + assert.Equal(t, int64(500_000), start.C) + + // Before StartAt — no frame should contain offset 0. + _, _, err = sub.FrameFor(0) + assert.Error(t, err) + }) +} + +func TestGetFetchRange(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("translates U-space to C-space", func(t *testing.T) { + t.Parallel() + r, err := ft.GetFetchRange(Range{Start: 1 << 20, Length: 1 << 20}) + require.NoError(t, err) + assert.Equal(t, int64(500_000), r.Start) + assert.Equal(t, 600_000, r.Length) + }) + + t.Run("range spanning multiple frames errors", func(t *testing.T) { + t.Parallel() + _, err := ft.GetFetchRange(Range{Start: 0, Length: 2 << 20}) + assert.Error(t, err) + }) + + t.Run("nil table returns input unchanged", func(t *testing.T) { + t.Parallel() + input := Range{Start: 42, Length: 100} + r, err := (*FrameTable)(nil).GetFetchRange(input) + require.NoError(t, err) + assert.Equal(t, input, r) + }) + + t.Run("uncompressed table returns input unchanged", func(t *testing.T) { + t.Parallel() + uncompressed := &FrameTable{CompressionType: CompressionNone} + input := Range{Start: 42, Length: 100} + r, err := uncompressed.GetFetchRange(input) + require.NoError(t, err) + assert.Equal(t, input, r) + }) +} + +func TestSize(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + u, c := ft.Size() + assert.Equal(t, int64(3<<20), u) + assert.Equal(t, int64(1_500_000), c) +} + +func TestIsCompressed(t *testing.T) { + t.Parallel() + assert.False(t, IsCompressed(nil)) + assert.False(t, IsCompressed(&FrameTable{CompressionType: CompressionNone})) + assert.True(t, IsCompressed(&FrameTable{CompressionType: CompressionLZ4})) + assert.True(t, IsCompressed(&FrameTable{CompressionType: CompressionZstd})) +} diff --git a/packages/shared/pkg/storage/gcp_multipart.go b/packages/shared/pkg/storage/gcp_multipart.go index 75324c16c1..45e8d95a6e 100644 --- a/packages/shared/pkg/storage/gcp_multipart.go +++ b/packages/shared/pkg/storage/gcp_multipart.go @@ -139,6 +139,53 @@ type MultipartUploader struct { client *retryablehttp.Client retryConfig RetryConfig baseURL string // Allow overriding for testing + + // Fields for PartUploader interface + uploadID string + mu sync.Mutex + parts []Part +} + +var _ PartUploader = (*MultipartUploader)(nil) + +// Start initiates the GCS multipart upload. +func (m *MultipartUploader) Start(ctx context.Context) error { + uploadID, err := m.initiateUpload(ctx) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + m.uploadID = uploadID + + return nil +} + +// UploadPart uploads a single part to GCS. Multiple data slices are hashed +// and uploaded without copying into a single contiguous buffer. +func (m *MultipartUploader) UploadPart(ctx context.Context, partIndex int, data ...[]byte) error { + etag, err := m.uploadPartSlices(ctx, m.uploadID, partIndex, data) + if err != nil { + return fmt.Errorf("failed to upload part %d: %w", partIndex, err) + } + + m.mu.Lock() + m.parts = append(m.parts, Part{ + PartNumber: partIndex, + ETag: etag, + }) + m.mu.Unlock() + + return nil +} + +// Complete finalizes the GCS multipart upload with all collected parts. +func (m *MultipartUploader) Complete(ctx context.Context) error { + m.mu.Lock() + parts := make([]Part, len(m.parts)) + copy(parts, m.parts) + m.mu.Unlock() + + return m.completeUpload(ctx, m.uploadID, parts) } func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, objectName string, retryConfig RetryConfig) (*MultipartUploader, error) { @@ -232,6 +279,60 @@ func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, par return etag, nil } +// uploadPartSlices uploads a part from multiple byte slices without concatenating them. +// It computes MD5 by hashing each slice and uses a ReaderFunc for retryable reads. +func (m *MultipartUploader) uploadPartSlices(ctx context.Context, uploadID string, partNumber int, slices [][]byte) (string, error) { + // Compute MD5 and total length without copying + hasher := md5.New() + totalLen := 0 + for _, s := range slices { + hasher.Write(s) + totalLen += len(s) + } + md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil)) + + url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s", + m.baseURL, m.objectName, partNumber, uploadID) + + // Use a ReaderFunc so the retryable client can replay the body on retries + bodyFn := func() (io.Reader, error) { + readers := make([]io.Reader, len(slices)) + for i, s := range slices { + readers[i] = bytes.NewReader(s) + } + + return io.MultiReader(readers...), nil + } + + req, err := retryablehttp.NewRequestWithContext(ctx, "PUT", url, retryablehttp.ReaderFunc(bodyFn)) + if err != nil { + return "", err + } + + req.Header.Set("Authorization", "Bearer "+m.token) + req.Header.Set("Content-Length", fmt.Sprintf("%d", totalLen)) + req.Header.Set("Content-MD5", md5Sum) + + resp, err := m.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return "", fmt.Errorf("failed to upload part %d (status %d): %s", partNumber, resp.StatusCode, string(body)) + } + + etag := resp.Header.Get("ETag") + if etag == "" { + return "", fmt.Errorf("no ETag returned for part %d", partNumber) + } + + return etag, nil +} + func (m *MultipartUploader) completeUpload(ctx context.Context, uploadID string, parts []Part) error { // Sort parts by part number sort.Slice(parts, func(i, j int) bool { diff --git a/packages/shared/pkg/storage/gcp_multipart_test.go b/packages/shared/pkg/storage/gcp_multipart_test.go index c0daaa6eef..c3a7e748fa 100644 --- a/packages/shared/pkg/storage/gcp_multipart_test.go +++ b/packages/shared/pkg/storage/gcp_multipart_test.go @@ -170,20 +170,18 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { err := os.WriteFile(testFile, []byte(testContent), 0o644) require.NoError(t, err) - var uploadID string var initiateCount, uploadPartCount, completeCount int32 - receivedParts := make(map[int]string) + var receivedParts sync.Map handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.URL.RawQuery == uploadsPath: // Initiate upload atomic.AddInt32(&initiateCount, 1) - uploadID = "test-upload-id-123" response := InitiateMultipartUploadResult{ Bucket: testBucketName, Key: testObjectName, - UploadID: uploadID, + UploadID: "test-upload-id-123", } xmlData, _ := xml.Marshal(response) w.Header().Set("Content-Type", "application/xml") @@ -194,7 +192,7 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { // Upload part partNum := atomic.AddInt32(&uploadPartCount, 1) body, _ := io.ReadAll(r.Body) - receivedParts[int(partNum)] = string(body) + receivedParts.Store(int(partNum), string(body)) w.Header().Set("ETag", fmt.Sprintf(`"etag%d"`, partNum)) w.WriteHeader(http.StatusOK) @@ -217,7 +215,9 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { // Verify all parts were uploaded and content matches var reconstructed strings.Builder for i := 1; i <= int(atomic.LoadInt32(&uploadPartCount)); i++ { - reconstructed.WriteString(receivedParts[i]) + part, ok := receivedParts.Load(i) + require.True(t, ok, "missing part %d", i) + reconstructed.WriteString(part.(string)) } require.Equal(t, testContent, reconstructed.String()) } @@ -522,7 +522,7 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { err := os.WriteFile(smallFile, []byte(smallContent), 0o644) require.NoError(t, err) - var receivedData string + var receivedParts sync.Map handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { @@ -538,7 +538,8 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { case strings.Contains(r.URL.RawQuery, "partNumber"): body, _ := io.ReadAll(r.Body) - receivedData = string(body) + partNum := r.URL.Query().Get("partNumber") + receivedParts.Store(partNum, string(body)) w.Header().Set("ETag", `"small-etag"`) w.WriteHeader(http.StatusOK) @@ -551,7 +552,18 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { uploader := createTestMultipartUploader(t, handler) _, err = uploader.UploadFileInParallel(t.Context(), smallFile, 10) // High concurrency for small file require.NoError(t, err) - require.Equal(t, smallContent, receivedData) + + // Small file should produce exactly one part + var partCount int + receivedParts.Range(func(_, _ any) bool { + partCount++ + + return true + }) + require.Equal(t, 1, partCount) + data, ok := receivedParts.Load("1") + require.True(t, ok) + require.Equal(t, smallContent, data.(string)) } type repeatReader struct { @@ -654,6 +666,7 @@ func TestMultipartUploader_BoundaryConditions_ExactChunkSize(t *testing.T) { err := os.WriteFile(testFile, []byte(testContent), 0o644) require.NoError(t, err) + var mu sync.Mutex var partSizes []int handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -670,7 +683,9 @@ func TestMultipartUploader_BoundaryConditions_ExactChunkSize(t *testing.T) { case strings.Contains(r.URL.RawQuery, "partNumber"): body, _ := io.ReadAll(r.Body) + mu.Lock() partSizes = append(partSizes, len(body)) + mu.Unlock() partNum := strings.Split(strings.Split(r.URL.RawQuery, "partNumber=")[1], "&")[0] w.Header().Set("ETag", fmt.Sprintf(`"boundary-etag-%s"`, partNum)) @@ -687,8 +702,9 @@ func TestMultipartUploader_BoundaryConditions_ExactChunkSize(t *testing.T) { // Should have exactly 2 parts, each of ChunkSize require.Len(t, partSizes, 2) - require.Equal(t, gcpMultipartUploadChunkSize, partSizes[0]) - require.Equal(t, gcpMultipartUploadChunkSize, partSizes[1]) + for _, size := range partSizes { + require.Equal(t, gcpMultipartUploadChunkSize, size) + } } func TestMultipartUploader_FileNotFound_Error(t *testing.T) { diff --git a/packages/shared/pkg/storage/header/header.go b/packages/shared/pkg/storage/header/header.go index 9a1f3008f5..f2e30bce69 100644 --- a/packages/shared/pkg/storage/header/header.go +++ b/packages/shared/pkg/storage/header/header.go @@ -5,10 +5,10 @@ import ( "fmt" "github.com/bits-and-blooms/bitset" - "github.com/google/uuid" "go.uber.org/zap" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) const NormalizeFixVersion = 3 @@ -47,12 +47,68 @@ func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) { startMap[block] = mapping } - return &Header{ + h := &Header{ blockStarts: intervals, Metadata: metadata, Mapping: mapping, startMap: startMap, - }, nil + } + + // Validate header integrity at creation time + if err := ValidateHeader(h); err != nil { + return nil, fmt.Errorf("header validation failed: %w", err) + } + + return h, nil +} + +func (t *Header) String() string { + if t == nil { + return "[nil Header]" + } + + return fmt.Sprintf("[Header: version=%d, size=%d, blockSize=%d, generation=%d, buildId=%s, mappings=%d]", + t.Metadata.Version, + t.Metadata.Size, + t.Metadata.BlockSize, + t.Metadata.Generation, + t.Metadata.BuildId.String(), + len(t.Mapping), + ) +} + +func (t *Header) Mappings(all bool) string { + if t == nil { + return "[nil Header, no mappings]" + } + n := 0 + for _, m := range t.Mapping { + if all || m.BuildId == t.Metadata.BuildId { + n++ + } + } + result := fmt.Sprintf("All mappings: %d\n", n) + if !all { + result = fmt.Sprintf("Mappings for build %s: %d\n", t.Metadata.BuildId.String(), n) + } + for _, m := range t.Mapping { + if !all && m.BuildId != t.Metadata.BuildId { + continue + } + frames := 0 + if m.FrameTable != nil { + frames = len(m.FrameTable.Frames) + } + result += fmt.Sprintf(" - Offset: %#x, Length: %#x, BuildId: %s, BuildStorageOffset: %#x, numFrames: %d\n", + m.Offset, + m.Length, + m.BuildId.String(), + m.BuildStorageOffset, + frames, + ) + } + + return result } // IsNormalizeFixApplied is a helper method to soft fail for older versions of the header where fix for normalization was not applied. @@ -61,29 +117,34 @@ func (t *Header) IsNormalizeFixApplied() bool { return t.Metadata.Version >= NormalizeFixVersion } -func (t *Header) GetShiftedMapping(ctx context.Context, offset int64) (mappedOffset int64, mappedLength int64, buildID *uuid.UUID, err error) { +func (t *Header) GetShiftedMapping(ctx context.Context, offset int64) (mappedToBuild *BuildMap, err error) { mapping, shift, err := t.getMapping(ctx, offset) if err != nil { - return 0, 0, nil, err + return nil, err } + lengthInBuild := int64(mapping.Length) - shift - mappedOffset = int64(mapping.BuildStorageOffset) + shift - mappedLength = int64(mapping.Length) - shift - buildID = &mapping.BuildId + b := &BuildMap{ + Offset: mapping.BuildStorageOffset + uint64(shift), + Length: uint64(lengthInBuild), + BuildId: mapping.BuildId, + FrameTable: mapping.FrameTable, + } - if mappedLength < 0 { + if lengthInBuild < 0 { if t.IsNormalizeFixApplied() { - return 0, 0, nil, fmt.Errorf("mapped length for offset %d is negative: %d", offset, mappedLength) + return nil, fmt.Errorf("mapped length for offset %d is negative: %d", offset, lengthInBuild) } + b.Length = 0 logger.L().Warn(ctx, "mapped length is negative, but normalize fix is not applied", zap.Int64("offset", offset), - zap.Int64("mappedLength", mappedLength), + zap.Int64("mappedLength", lengthInBuild), logger.WithBuildID(mapping.BuildId.String()), ) } - return mappedOffset, mappedLength, buildID, nil + return b, nil } // TODO: Maybe we can optimize mapping by automatically assuming the mapping is uuid.Nil if we don't find it + stopping storing the nil mapping. @@ -143,3 +204,105 @@ func (t *Header) getMapping(ctx context.Context, offset int64) (*BuildMap, int64 return mapping, shift, nil } + +// ValidateHeader checks header integrity and returns an error if corruption is detected. +// This verifies: +// 1. Header and metadata are valid +// 2. Mappings cover the entire file [0, Size) with no gaps +// 3. Mappings don't extend beyond file size (with block alignment tolerance) +func ValidateHeader(h *Header) error { + if h == nil { + return fmt.Errorf("header is nil") + } + if h.Metadata == nil { + return fmt.Errorf("header metadata is nil") + } + if h.Metadata.BlockSize == 0 { + return fmt.Errorf("header has zero block size") + } + if h.Metadata.Size == 0 { + return fmt.Errorf("header has zero size") + } + if len(h.Mapping) == 0 { + return fmt.Errorf("header has no mappings") + } + + // Sort mappings by offset to check for gaps/overlaps + sortedMappings := make([]*BuildMap, len(h.Mapping)) + copy(sortedMappings, h.Mapping) + for i := range len(sortedMappings) - 1 { + for j := i + 1; j < len(sortedMappings); j++ { + if sortedMappings[j].Offset < sortedMappings[i].Offset { + sortedMappings[i], sortedMappings[j] = sortedMappings[j], sortedMappings[i] + } + } + } + + // Check that first mapping starts at 0 + if sortedMappings[0].Offset != 0 { + return fmt.Errorf("mappings don't start at 0: first mapping starts at %#x for buildId %s", + sortedMappings[0].Offset, h.Metadata.BuildId.String()) + } + + // Check for gaps and overlaps between consecutive mappings + for i := range len(sortedMappings) - 1 { + currentEnd := sortedMappings[i].Offset + sortedMappings[i].Length + nextStart := sortedMappings[i+1].Offset + + if currentEnd < nextStart { + return fmt.Errorf("gap in mappings: mapping[%d] ends at %#x but mapping[%d] starts at %#x (gap=%d bytes) for buildId %s", + i, currentEnd, i+1, nextStart, nextStart-currentEnd, h.Metadata.BuildId.String()) + } + if currentEnd > nextStart { + return fmt.Errorf("overlap in mappings: mapping[%d] ends at %#x but mapping[%d] starts at %#x (overlap=%d bytes) for buildId %s", + i, currentEnd, i+1, nextStart, currentEnd-nextStart, h.Metadata.BuildId.String()) + } + } + + // Check that last mapping covers up to (at least) Size + lastMapping := sortedMappings[len(sortedMappings)-1] + lastEnd := lastMapping.Offset + lastMapping.Length + if lastEnd < h.Metadata.Size { + return fmt.Errorf("mappings don't cover entire file: last mapping ends at %#x but file size is %#x (missing %d bytes) for buildId %s", + lastEnd, h.Metadata.Size, h.Metadata.Size-lastEnd, h.Metadata.BuildId.String()) + } + + // Allow last mapping to extend up to one block past size (for alignment) + if lastEnd > h.Metadata.Size+h.Metadata.BlockSize { + return fmt.Errorf("last mapping extends too far: ends at %#x but file size is %#x (overhang=%d bytes, max allowed=%d) for buildId %s", + lastEnd, h.Metadata.Size, lastEnd-h.Metadata.Size, h.Metadata.BlockSize, h.Metadata.BuildId.String()) + } + + // Validate individual mapping bounds + for i, m := range h.Mapping { + if m.Offset > h.Metadata.Size { + return fmt.Errorf("mapping[%d] has Offset %#x beyond header size %#x for buildId %s", + i, m.Offset, h.Metadata.Size, m.BuildId.String()) + } + if m.Length == 0 { + return fmt.Errorf("mapping[%d] has zero length at offset %#x for buildId %s", + i, m.Offset, m.BuildId.String()) + } + } + + return nil +} + +// AddFrames associates compression frame information with this header's mappings. +// +// Only mappings matching this header's BuildId will be updated. Returns nil if frameTable is nil. +func (t *Header) AddFrames(frameTable *storage.FrameTable) error { + if frameTable == nil { + return nil + } + + for _, mapping := range t.Mapping { + if mapping.BuildId == t.Metadata.BuildId { + if err := mapping.AddFrames(frameTable); err != nil { + return err + } + } + } + + return nil +} diff --git a/packages/shared/pkg/storage/header/mapping.go b/packages/shared/pkg/storage/header/mapping.go index 0802bb1fe8..096ffd3308 100644 --- a/packages/shared/pkg/storage/header/mapping.go +++ b/packages/shared/pkg/storage/header/mapping.go @@ -6,6 +6,8 @@ import ( "github.com/bits-and-blooms/bitset" "github.com/google/uuid" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) // Start, Length and SourceStart are in bytes of the data file @@ -13,10 +15,11 @@ import ( // The list of block mappings will be in order of increasing Start, covering the entire file type BuildMap struct { // Offset defines which block of the current layer this mapping starts at - Offset uint64 + Offset uint64 // in the memory space Length uint64 BuildId uuid.UUID BuildStorageOffset uint64 + FrameTable *storage.FrameTable } func (mapping *BuildMap) Copy() *BuildMap { @@ -25,9 +28,40 @@ func (mapping *BuildMap) Copy() *BuildMap { Length: mapping.Length, BuildId: mapping.BuildId, BuildStorageOffset: mapping.BuildStorageOffset, + FrameTable: mapping.FrameTable, // Preserve FrameTable for compressed data } } +// AddFrames associates compression frame information with this mapping. +// +// When a file is uploaded with compression, the compressor produces a FrameTable +// that describes how the compressed data is organized into frames. This method +// computes which compressed frames cover this mapping's data within the build's +// storage file based on BuildStorageOffset and Length. +// +// Returns nil if frameTable is nil. Returns an error if the mapping's range +// cannot be found in the frame table. +func (mapping *BuildMap) AddFrames(frameTable *storage.FrameTable) error { + if frameTable == nil { + return nil + } + + mappedRange := storage.Range{ + Start: int64(mapping.BuildStorageOffset), + Length: int(mapping.Length), + } + + subset, err := frameTable.Subset(mappedRange) + if err != nil { + return fmt.Errorf("mapping at virtual offset %#x (storage offset %#x, length %#x): %w", + mapping.Offset, mapping.BuildStorageOffset, mapping.Length, err) + } + + mapping.FrameTable = subset + + return nil +} + func CreateMapping( buildId *uuid.UUID, dirty *bitset.BitSet, @@ -160,6 +194,7 @@ func MergeMappings( // the build storage offset is the same as the base mapping BuildStorageOffset: base.BuildStorageOffset, } + leftBase.FrameTable, _ = base.FrameTable.Subset(storage.Range{Start: int64(leftBase.BuildStorageOffset), Length: int(leftBase.Length)}) mappings = append(mappings, leftBase) } @@ -178,6 +213,7 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset + uint64(rightBaseShift), } + rightBase.FrameTable, _ = base.FrameTable.Subset(storage.Range{Start: int64(rightBase.BuildStorageOffset), Length: int(rightBase.Length)}) baseMapping[baseIdx] = rightBase } else { @@ -205,6 +241,7 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset + uint64(rightBaseShift), } + rightBase.FrameTable, _ = base.FrameTable.Subset(storage.Range{Start: int64(rightBase.BuildStorageOffset), Length: int(rightBase.Length)}) baseMapping[baseIdx] = rightBase } else { @@ -226,6 +263,7 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset, } + leftBase.FrameTable, _ = base.FrameTable.Subset(storage.Range{Start: int64(leftBase.BuildStorageOffset), Length: int(leftBase.Length)}) mappings = append(mappings, leftBase) } @@ -245,6 +283,8 @@ func MergeMappings( } // NormalizeMappings joins adjacent mappings that have the same buildId. +// When merging mappings, FrameTables are also merged by extending the first +// mapping's FrameTable with frames from subsequent mappings. func NormalizeMappings(mappings []*BuildMap) []*BuildMap { if len(mappings) == 0 { return nil @@ -252,7 +292,7 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { result := make([]*BuildMap, 0, len(mappings)) - // Start with a copy of the first mapping + // Start with a copy of the first mapping (Copy() now includes FrameTable) current := mappings[0].Copy() for i := 1; i < len(mappings); i++ { @@ -260,10 +300,22 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { if mp.BuildId != current.BuildId { // BuildId changed, add the current map to results and start a new one result = append(result, current) - current = mp.Copy() // New copy + current = mp.Copy() // New copy (includes FrameTable) } else { - // Same BuildId, just add the length + // Same BuildId, merge: add the length and extend FrameTable current.Length += mp.Length + + // Extend FrameTable if the mapping being merged has one + if mp.FrameTable != nil { + if current.FrameTable == nil { + // Current has no FrameTable but merged one does - take it + current.FrameTable = mp.FrameTable + } else { + // Both have FrameTables - extend current's with mp's frames + // The frames are contiguous subsets, so we append non-overlapping frames + current.FrameTable = mergeFrameTables(current.FrameTable, mp.FrameTable) + } + } } } @@ -272,3 +324,63 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { return result } + +// mergeFrameTables extends ft1 with frames from ft2. The FrameTables are +// assumed to be contiguous subsets from the same original, so ft2's frames +// follow ft1's frames (with possible overlap at the boundary). this function +// returns either an reference to one of the input tables, unchanged, or a new +// FrameTable with frames from both tables. +func mergeFrameTables(ft1, ft2 *storage.FrameTable) *storage.FrameTable { + if ft1 == nil { + return ft2 + } + if ft2 == nil { + return ft1 + } + + // Calculate where ft1 ends (uncompressed offset) + ft1EndU := ft1.StartAt.U + for _, frame := range ft1.Frames { + ft1EndU += int64(frame.U) + } + + // Find where to start appending from ft2 (skip frames already covered by ft1) + ft2CurrentU := ft2.StartAt.U + startIdx := 0 + for i, frame := range ft2.Frames { + frameEndU := ft2CurrentU + int64(frame.U) + if frameEndU <= ft1EndU { + // This frame is already covered by ft1 + ft2CurrentU = frameEndU + startIdx = i + 1 + + continue + } + if ft2CurrentU < ft1EndU { + // This frame overlaps with ft1's last frame - it's the same frame, skip it + ft2CurrentU = frameEndU + startIdx = i + 1 + + continue + } + // This frame is beyond ft1's coverage + break + } + + // Append remaining frames from ft2 + if startIdx < len(ft2.Frames) { + // Create a new FrameTable with extended frames + newFrames := make([]storage.FrameSize, len(ft1.Frames), len(ft1.Frames)+len(ft2.Frames)-startIdx) + copy(newFrames, ft1.Frames) + newFrames = append(newFrames, ft2.Frames[startIdx:]...) + + return &storage.FrameTable{ + CompressionType: ft1.CompressionType, + StartAt: ft1.StartAt, + Frames: newFrames, + } + } + + // All of ft2's frames were already covered by ft1 + return ft1 +} diff --git a/packages/shared/pkg/storage/header/serialization.go b/packages/shared/pkg/storage/header/serialization.go index 6af71f832b..5abbac82cf 100644 --- a/packages/shared/pkg/storage/header/serialization.go +++ b/packages/shared/pkg/storage/header/serialization.go @@ -13,7 +13,12 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) -const metadataVersion = 3 +const ( + // metadataVersion is used by template-manager for uncompressed builds (V3 headers). + metadataVersion = 3 + // MetadataVersionCompressed is used by compress-build for compressed builds (V4 headers with FrameTables). + MetadataVersionCompressed = 4 +) type Metadata struct { Version uint64 @@ -25,6 +30,25 @@ type Metadata struct { BaseBuildId uuid.UUID } +type v3SerializableBuildMap struct { + Offset uint64 + Length uint64 + BuildId uuid.UUID + BuildStorageOffset uint64 +} + +type v4SerializableBuildMap struct { + Offset uint64 + Length uint64 + BuildId uuid.UUID + BuildStorageOffset uint64 + CompressionTypeNumFrames uint64 // CompressionType is stored as uint8 in the high byte, the low 24 bits are NumFrames + + // if CompressionType != CompressionNone and there are frames + // - followed by frames offset (16 bytes) + // - followed by frames... (16 bytes * NumFrames) +} + func NewTemplateMetadata(buildId uuid.UUID, blockSize, size uint64) *Metadata { return &Metadata{ Version: metadataVersion, @@ -55,11 +79,53 @@ func Serialize(metadata *Metadata, mappings []*BuildMap) ([]byte, error) { return nil, fmt.Errorf("failed to write metadata: %w", err) } + var v any for _, mapping := range mappings { - err := binary.Write(&buf, binary.LittleEndian, mapping) + var offset *storage.FrameOffset + var frames []storage.FrameSize + if metadata.Version <= 3 { + v = &v3SerializableBuildMap{ + Offset: mapping.Offset, + Length: mapping.Length, + BuildId: mapping.BuildId, + BuildStorageOffset: mapping.BuildStorageOffset, + } + } else { + v4 := &v4SerializableBuildMap{ + Offset: mapping.Offset, + Length: mapping.Length, + BuildId: mapping.BuildId, + BuildStorageOffset: mapping.BuildStorageOffset, + } + if mapping.FrameTable != nil { + v4.CompressionTypeNumFrames = uint64(mapping.FrameTable.CompressionType)<<24 | uint64(len(mapping.FrameTable.Frames)) + // Only write offset/frames when the packed value is non-zero, + // matching the deserializer's condition. A FrameTable with + // CompressionNone and zero frames produces a packed value of 0. + if v4.CompressionTypeNumFrames != 0 { + offset = &mapping.FrameTable.StartAt + frames = mapping.FrameTable.Frames + } + } + v = v4 + } + + err := binary.Write(&buf, binary.LittleEndian, v) if err != nil { return nil, fmt.Errorf("failed to write block mapping: %w", err) } + if offset != nil { + err := binary.Write(&buf, binary.LittleEndian, offset) + if err != nil { + return nil, fmt.Errorf("failed to write compression frames starting offset: %w", err) + } + } + for _, frame := range frames { + err := binary.Write(&buf, binary.LittleEndian, frame) + if err != nil { + return nil, fmt.Errorf("failed to write compression frame: %w", err) + } + } } return buf.Bytes(), nil @@ -75,8 +141,8 @@ func Deserialize(ctx context.Context, in storage.Blob) (*Header, error) { } func DeserializeBytes(data []byte) (*Header, error) { - reader := bytes.NewReader(data) var metadata Metadata + reader := bytes.NewReader(data) err := binary.Read(reader, binary.LittleEndian, &metadata) if err != nil { return nil, fmt.Errorf("failed to read metadata: %w", err) @@ -84,19 +150,90 @@ func DeserializeBytes(data []byte) (*Header, error) { mappings := make([]*BuildMap, 0) +MAPPINGS: for { var m BuildMap - err := binary.Read(reader, binary.LittleEndian, &m) - if errors.Is(err, io.EOF) { - break - } - if err != nil { - return nil, fmt.Errorf("failed to read block mapping: %w", err) + switch metadata.Version { + case 0, 1, 2, 3: + var v3 v3SerializableBuildMap + err = binary.Read(reader, binary.LittleEndian, &v3) + if errors.Is(err, io.EOF) { + break MAPPINGS + } + if err != nil { + return nil, fmt.Errorf("failed to read block mapping: %w", err) + } + + m.Offset = v3.Offset + m.Length = v3.Length + m.BuildId = v3.BuildId + m.BuildStorageOffset = v3.BuildStorageOffset + + case 4: + var v4 v4SerializableBuildMap + err = binary.Read(reader, binary.LittleEndian, &v4) + if errors.Is(err, io.EOF) { + break MAPPINGS + } + if err != nil { + return nil, fmt.Errorf("failed to read block mapping: %w", err) + } + + m.Offset = v4.Offset + m.Length = v4.Length + m.BuildId = v4.BuildId + m.BuildStorageOffset = v4.BuildStorageOffset + + if v4.CompressionTypeNumFrames != 0 { + m.FrameTable = &storage.FrameTable{ + CompressionType: storage.CompressionType((v4.CompressionTypeNumFrames >> 24) & 0xFF), + } + numFrames := v4.CompressionTypeNumFrames & 0xFFFFFF + + var startAt storage.FrameOffset + err = binary.Read(reader, binary.LittleEndian, &startAt) + if err != nil { + return nil, fmt.Errorf("failed to read compression frames starting offset: %w", err) + } + m.FrameTable.StartAt = startAt + + for range numFrames { + var frame storage.FrameSize + err = binary.Read(reader, binary.LittleEndian, &frame) + if err != nil { + return nil, fmt.Errorf("failed to read the expected compression frame: %w", err) + } + m.FrameTable.Frames = append(m.FrameTable.Frames, frame) + } + } } mappings = append(mappings, &m) } - return NewHeader(&metadata, mappings) + return newValidatedHeader(&metadata, mappings) +} + +// DeserializeV4 decompresses LZ4-block-compressed data and deserializes a v4 header with frame tables. +func DeserializeV4(data []byte) (*Header, error) { + decompressed, err := storage.DecompressLZ4(data, storage.MaxCompressedHeaderSize) + if err != nil { + return nil, fmt.Errorf("failed to decompress v4 header: %w", err) + } + + return DeserializeBytes(decompressed) +} + +func newValidatedHeader(metadata *Metadata, mappings []*BuildMap) (*Header, error) { + header, err := NewHeader(metadata, mappings) + if err != nil { + return nil, err + } + + if err := ValidateHeader(header); err != nil { + return nil, fmt.Errorf("header validation failed: %w", err) + } + + return header, nil } diff --git a/packages/shared/pkg/storage/header/serialization_test.go b/packages/shared/pkg/storage/header/serialization_test.go new file mode 100644 index 0000000000..d9a99db106 --- /dev/null +++ b/packages/shared/pkg/storage/header/serialization_test.go @@ -0,0 +1,358 @@ +package header + +import ( + "crypto/rand" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +func compressLZ4Block(t *testing.T, data []byte) []byte { + t.Helper() + compressed, err := storage.CompressLZ4(data) + require.NoError(t, err) + + return compressed +} + +func TestSerializeDeserialize_V3_RoundTrip(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 3, + BlockSize: 4096, + Size: 8192, + Generation: 7, + BuildId: buildID, + BaseBuildId: baseID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 123, + }, + } + + data, err := Serialize(metadata, mappings) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Equal(t, metadata, got.Metadata) + require.Len(t, got.Mapping, 2) + assert.Equal(t, uint64(0), got.Mapping[0].Offset) + assert.Equal(t, uint64(4096), got.Mapping[0].Length) + assert.Equal(t, buildID, got.Mapping[0].BuildId) + assert.Equal(t, uint64(0), got.Mapping[0].BuildStorageOffset) + + assert.Equal(t, uint64(4096), got.Mapping[1].Offset) + assert.Equal(t, uint64(4096), got.Mapping[1].Length) + assert.Equal(t, baseID, got.Mapping[1].BuildId) + assert.Equal(t, uint64(123), got.Mapping[1].BuildStorageOffset) +} + +func TestDeserialize_TruncatedMetadata(t *testing.T) { + t.Parallel() + + _, err := DeserializeBytes([]byte{0x01, 0x02, 0x03}) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read metadata") +} + +func TestSerializeDeserialize_EmptyMappings_Defaults(t *testing.T) { + t.Parallel() + + metadata := &Metadata{ + Version: 3, + BlockSize: 4096, + Size: 8192, + Generation: 0, + BuildId: uuid.New(), + BaseBuildId: uuid.New(), + } + + data, err := Serialize(metadata, nil) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + // NewHeader creates a default mapping when none provided + require.Len(t, got.Mapping, 1) + assert.Equal(t, uint64(0), got.Mapping[0].Offset) + assert.Equal(t, metadata.Size, got.Mapping[0].Length) + assert.Equal(t, metadata.BuildId, got.Mapping[0].BuildId) +} + +func TestDeserialize_BlockSizeZero(t *testing.T) { + t.Parallel() + + metadata := &Metadata{ + Version: 3, + BlockSize: 0, + Size: 4096, + Generation: 0, + BuildId: uuid.New(), + BaseBuildId: uuid.New(), + } + + data, err := Serialize(metadata, nil) + require.NoError(t, err) + + _, err = DeserializeBytes(data) + require.Error(t, err) + assert.Contains(t, err.Error(), "block size cannot be zero") +} + +func TestSerializeDeserialize_V4_WithFrameTable(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 8192, + Generation: 1, + BuildId: buildID, + BaseBuildId: baseID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + FrameTable: &storage.FrameTable{ + CompressionType: storage.CompressionLZ4, + StartAt: storage.FrameOffset{U: 0, C: 0}, + Frames: []storage.FrameSize{ + {U: 2048, C: 1024}, + {U: 2048, C: 900}, + }, + }, + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 0, + }, + } + + data, err := Serialize(metadata, mappings) + require.NoError(t, err) + + got, err := DeserializeV4(compressLZ4Block(t, data)) + require.NoError(t, err) + + require.Equal(t, uint64(4), got.Metadata.Version) + require.Len(t, got.Mapping, 2) + + // First mapping has FrameTable + m0 := got.Mapping[0] + assert.Equal(t, uint64(0), m0.Offset) + assert.Equal(t, uint64(4096), m0.Length) + assert.Equal(t, buildID, m0.BuildId) + require.NotNil(t, m0.FrameTable) + assert.Equal(t, storage.CompressionLZ4, m0.FrameTable.CompressionType) + assert.Equal(t, int64(0), m0.FrameTable.StartAt.U) + assert.Equal(t, int64(0), m0.FrameTable.StartAt.C) + require.Len(t, m0.FrameTable.Frames, 2) + assert.Equal(t, int32(2048), m0.FrameTable.Frames[0].U) + assert.Equal(t, int32(1024), m0.FrameTable.Frames[0].C) + assert.Equal(t, int32(2048), m0.FrameTable.Frames[1].U) + assert.Equal(t, int32(900), m0.FrameTable.Frames[1].C) + + // Second mapping has no FrameTable + m1 := got.Mapping[1] + assert.Equal(t, uint64(4096), m1.Offset) + assert.Equal(t, uint64(4096), m1.Length) + assert.Equal(t, baseID, m1.BuildId) + assert.Nil(t, m1.FrameTable) +} + +func TestSerializeDeserialize_V4_Zstd_NonZeroStartAt(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 8192, + FrameTable: &storage.FrameTable{ + CompressionType: storage.CompressionZstd, + StartAt: storage.FrameOffset{U: 8192, C: 4000}, + Frames: []storage.FrameSize{ + {U: 4096, C: 3500}, + }, + }, + }, + } + + data, err := Serialize(metadata, mappings) + require.NoError(t, err) + + got, err := DeserializeV4(compressLZ4Block(t, data)) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + m := got.Mapping[0] + require.NotNil(t, m.FrameTable) + assert.Equal(t, storage.CompressionZstd, m.FrameTable.CompressionType) + assert.Equal(t, int64(8192), m.FrameTable.StartAt.U) + assert.Equal(t, int64(4000), m.FrameTable.StartAt.C) + require.Len(t, m.FrameTable.Frames, 1) + assert.Equal(t, int32(4096), m.FrameTable.Frames[0].U) + assert.Equal(t, int32(3500), m.FrameTable.Frames[0].C) +} + +// TestSerializeDeserialize_V4_CompressionNone_EmptyFrames verifies that a +// FrameTable with CompressionNone and zero frames does not corrupt the stream. +// Before the fix, the serializer wrote a StartAt offset (16 bytes) but the +// deserializer skipped it because the packed value was 0. +func TestSerializeDeserialize_V4_CompressionNone_EmptyFrames(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 8192, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + // FrameTable with CompressionNone and no frames — packed value is 0. + FrameTable: &storage.FrameTable{ + CompressionType: storage.CompressionNone, + StartAt: storage.FrameOffset{U: 100, C: 50}, + Frames: nil, + }, + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 0, + }, + } + + data, err := Serialize(metadata, mappings) + require.NoError(t, err) + + got, err := DeserializeV4(compressLZ4Block(t, data)) + require.NoError(t, err) + + require.Len(t, got.Mapping, 2) + + // First mapping: FrameTable was effectively empty, deserializer should treat as nil. + assert.Nil(t, got.Mapping[0].FrameTable) + + // Second mapping must not be corrupted by stray StartAt bytes. + assert.Equal(t, uint64(4096), got.Mapping[1].Offset) + assert.Equal(t, uint64(4096), got.Mapping[1].Length) + assert.Equal(t, baseID, got.Mapping[1].BuildId) +} + +func TestCompressDecompressLZ4_RoundTrip(t *testing.T) { + t.Parallel() + + // Random data should round-trip through LZ4 compress/decompress. + data := make([]byte, 4096) + _, err := rand.Read(data) + require.NoError(t, err) + + compressed, err := storage.CompressLZ4(data) + require.NoError(t, err) + + decompressed, err := storage.DecompressLZ4(compressed, storage.MaxCompressedHeaderSize) + require.NoError(t, err) + assert.Equal(t, data, decompressed) +} + +func TestSerializeDeserialize_V4_ManyFrames(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + const numFrames = 1000 + frames := make([]storage.FrameSize, numFrames) + for i := range frames { + frames[i] = storage.FrameSize{U: 4096, C: int32(2000 + i)} + } + + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096 * numFrames, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096 * numFrames, + BuildId: buildID, + BuildStorageOffset: 0, + FrameTable: &storage.FrameTable{ + CompressionType: storage.CompressionLZ4, + StartAt: storage.FrameOffset{U: 0, C: 0}, + Frames: frames, + }, + }, + } + + data, err := Serialize(metadata, mappings) + require.NoError(t, err) + + got, err := DeserializeV4(compressLZ4Block(t, data)) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + require.NotNil(t, got.Mapping[0].FrameTable) + require.Len(t, got.Mapping[0].FrameTable.Frames, numFrames) + + // Spot-check first and last frame + assert.Equal(t, int32(4096), got.Mapping[0].FrameTable.Frames[0].U) + assert.Equal(t, int32(2000), got.Mapping[0].FrameTable.Frames[0].C) + assert.Equal(t, int32(4096), got.Mapping[0].FrameTable.Frames[numFrames-1].U) + assert.Equal(t, int32(2000+numFrames-1), got.Mapping[0].FrameTable.Frames[numFrames-1].C) +} diff --git a/packages/shared/pkg/storage/lz4.go b/packages/shared/pkg/storage/lz4.go new file mode 100644 index 0000000000..1adf5a6ada --- /dev/null +++ b/packages/shared/pkg/storage/lz4.go @@ -0,0 +1,43 @@ +package storage + +import ( + "fmt" + + "github.com/pierrec/lz4/v4" +) + +// MaxCompressedHeaderSize is the maximum allowed decompressed header size (64 MiB). +// Headers are typically a few hundred KiB; this is a safety bound. +const MaxCompressedHeaderSize = 64 << 20 + +// CompressLZ4 compresses data using LZ4 block compression. +// Returns an error if the data is incompressible (CompressBlock returns 0), +// since callers store the result as ".lz4" and DecompressLZ4 would fail on raw data. +func CompressLZ4(data []byte) ([]byte, error) { + bound := lz4.CompressBlockBound(len(data)) + dst := make([]byte, bound) + + n, err := lz4.CompressBlock(data, dst, nil) + if err != nil { + return nil, fmt.Errorf("lz4 compress: %w", err) + } + + if n == 0 { + return nil, fmt.Errorf("lz4 compress: data is incompressible (%d bytes)", len(data)) + } + + return dst[:n], nil +} + +// DecompressLZ4 decompresses LZ4-block-compressed data. +// maxSize is the maximum allowed decompressed size to prevent memory abuse. +func DecompressLZ4(data []byte, maxSize int) ([]byte, error) { + dst := make([]byte, maxSize) + + n, err := lz4.UncompressBlock(data, dst) + if err != nil { + return nil, fmt.Errorf("lz4 decompress: %w", err) + } + + return dst[:n], nil +} diff --git a/packages/shared/pkg/storage/mocks/mockobjectprovider.go b/packages/shared/pkg/storage/mock_blob_test.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockobjectprovider.go rename to packages/shared/pkg/storage/mock_blob_test.go index 6955ab4312..d65768339f 100644 --- a/packages/shared/pkg/storage/mocks/mockobjectprovider.go +++ b/packages/shared/pkg/storage/mock_blob_test.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( "context" diff --git a/packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go b/packages/shared/pkg/storage/mock_featureflagsclient_test.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go rename to packages/shared/pkg/storage/mock_featureflagsclient_test.go index d9d0706b51..dcd49bd977 100644 --- a/packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go +++ b/packages/shared/pkg/storage/mock_featureflagsclient_test.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( "context" diff --git a/packages/shared/pkg/storage/mock_framedfile_test.go b/packages/shared/pkg/storage/mock_framedfile_test.go new file mode 100644 index 0000000000..b7d7c32267 --- /dev/null +++ b/packages/shared/pkg/storage/mock_framedfile_test.go @@ -0,0 +1,268 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package storage + +import ( + "context" + + mock "github.com/stretchr/testify/mock" +) + +// NewMockFramedFile creates a new instance of MockFramedFile. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockFramedFile(t interface { + mock.TestingT + Cleanup(func()) +}) *MockFramedFile { + mock := &MockFramedFile{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockFramedFile is an autogenerated mock type for the FramedFile type +type MockFramedFile struct { + mock.Mock +} + +type MockFramedFile_Expecter struct { + mock *mock.Mock +} + +func (_m *MockFramedFile) EXPECT() *MockFramedFile_Expecter { + return &MockFramedFile_Expecter{mock: &_m.Mock} +} + +// GetFrame provides a mock function for the type MockFramedFile +func (_mock *MockFramedFile) GetFrame(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error) { + ret := _mock.Called(ctx, offsetU, frameTable, decompress, buf, readSize, onRead) + + if len(ret) == 0 { + panic("no return value specified for GetFrame") + } + + var r0 Range + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, *FrameTable, bool, []byte, int64, func(totalWritten int64)) (Range, error)); ok { + return returnFunc(ctx, offsetU, frameTable, decompress, buf, readSize, onRead) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, *FrameTable, bool, []byte, int64, func(totalWritten int64)) Range); ok { + r0 = returnFunc(ctx, offsetU, frameTable, decompress, buf, readSize, onRead) + } else { + r0 = ret.Get(0).(Range) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, int64, *FrameTable, bool, []byte, int64, func(totalWritten int64)) error); ok { + r1 = returnFunc(ctx, offsetU, frameTable, decompress, buf, readSize, onRead) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockFramedFile_GetFrame_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetFrame' +type MockFramedFile_GetFrame_Call struct { + *mock.Call +} + +// GetFrame is a helper method to define mock.On call +// - ctx context.Context +// - offsetU int64 +// - frameTable *FrameTable +// - decompress bool +// - buf []byte +// - readSize int64 +// - onRead func(totalWritten int64) +func (_e *MockFramedFile_Expecter) GetFrame(ctx interface{}, offsetU interface{}, frameTable interface{}, decompress interface{}, buf interface{}, readSize interface{}, onRead interface{}) *MockFramedFile_GetFrame_Call { + return &MockFramedFile_GetFrame_Call{Call: _e.mock.On("GetFrame", ctx, offsetU, frameTable, decompress, buf, readSize, onRead)} +} + +func (_c *MockFramedFile_GetFrame_Call) Run(run func(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64))) *MockFramedFile_GetFrame_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 int64 + if args[1] != nil { + arg1 = args[1].(int64) + } + var arg2 *FrameTable + if args[2] != nil { + arg2 = args[2].(*FrameTable) + } + var arg3 bool + if args[3] != nil { + arg3 = args[3].(bool) + } + var arg4 []byte + if args[4] != nil { + arg4 = args[4].([]byte) + } + var arg5 int64 + if args[5] != nil { + arg5 = args[5].(int64) + } + var arg6 func(totalWritten int64) + if args[6] != nil { + arg6 = args[6].(func(totalWritten int64)) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + arg6, + ) + }) + return _c +} + +func (_c *MockFramedFile_GetFrame_Call) Return(rangeParam Range, err error) *MockFramedFile_GetFrame_Call { + _c.Call.Return(rangeParam, err) + return _c +} + +func (_c *MockFramedFile_GetFrame_Call) RunAndReturn(run func(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error)) *MockFramedFile_GetFrame_Call { + _c.Call.Return(run) + return _c +} + +// Size provides a mock function for the type MockFramedFile +func (_mock *MockFramedFile) Size(ctx context.Context) (int64, error) { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Size") + } + + var r0 int64 + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context) (int64, error)); ok { + return returnFunc(ctx) + } + if returnFunc, ok := ret.Get(0).(func(context.Context) int64); ok { + r0 = returnFunc(ctx) + } else { + r0 = ret.Get(0).(int64) + } + if returnFunc, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = returnFunc(ctx) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockFramedFile_Size_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Size' +type MockFramedFile_Size_Call struct { + *mock.Call +} + +// Size is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockFramedFile_Expecter) Size(ctx interface{}) *MockFramedFile_Size_Call { + return &MockFramedFile_Size_Call{Call: _e.mock.On("Size", ctx)} +} + +func (_c *MockFramedFile_Size_Call) Run(run func(ctx context.Context)) *MockFramedFile_Size_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockFramedFile_Size_Call) Return(n int64, err error) *MockFramedFile_Size_Call { + _c.Call.Return(n, err) + return _c +} + +func (_c *MockFramedFile_Size_Call) RunAndReturn(run func(ctx context.Context) (int64, error)) *MockFramedFile_Size_Call { + _c.Call.Return(run) + return _c +} + +// StoreFile provides a mock function for the type MockFramedFile +func (_mock *MockFramedFile) StoreFile(ctx context.Context, path string, opts *FramedUploadOptions) (*FrameTable, error) { + ret := _mock.Called(ctx, path, opts) + + if len(ret) == 0 { + panic("no return value specified for StoreFile") + } + + var r0 *FrameTable + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *FramedUploadOptions) (*FrameTable, error)); ok { + return returnFunc(ctx, path, opts) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *FramedUploadOptions) *FrameTable); ok { + r0 = returnFunc(ctx, path, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*FrameTable) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, *FramedUploadOptions) error); ok { + r1 = returnFunc(ctx, path, opts) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockFramedFile_StoreFile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StoreFile' +type MockFramedFile_StoreFile_Call struct { + *mock.Call +} + +// StoreFile is a helper method to define mock.On call +// - ctx context.Context +// - path string +// - opts *FramedUploadOptions +func (_e *MockFramedFile_Expecter) StoreFile(ctx interface{}, path interface{}, opts interface{}) *MockFramedFile_StoreFile_Call { + return &MockFramedFile_StoreFile_Call{Call: _e.mock.On("StoreFile", ctx, path, opts)} +} + +func (_c *MockFramedFile_StoreFile_Call) Run(run func(ctx context.Context, path string, opts *FramedUploadOptions)) *MockFramedFile_StoreFile_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 *FramedUploadOptions + if args[2] != nil { + arg2 = args[2].(*FramedUploadOptions) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockFramedFile_StoreFile_Call) Return(frameTable *FrameTable, err error) *MockFramedFile_StoreFile_Call { + _c.Call.Return(frameTable, err) + return _c +} + +func (_c *MockFramedFile_StoreFile_Call) RunAndReturn(run func(ctx context.Context, path string, opts *FramedUploadOptions) (*FrameTable, error)) *MockFramedFile_StoreFile_Call { + _c.Call.Return(run) + return _c +} diff --git a/packages/shared/pkg/storage/mocks/mockioreader.go b/packages/shared/pkg/storage/mock_ioreader_test.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockioreader.go rename to packages/shared/pkg/storage/mock_ioreader_test.go index 5497bc53c5..9adb02421e 100644 --- a/packages/shared/pkg/storage/mocks/mockioreader.go +++ b/packages/shared/pkg/storage/mock_ioreader_test.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( mock "github.com/stretchr/testify/mock" diff --git a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go b/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go deleted file mode 100644 index 3931f6b349..0000000000 --- a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go +++ /dev/null @@ -1,302 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package storagemocks - -import ( - "context" - "io" - - mock "github.com/stretchr/testify/mock" -) - -// NewMockSeekable creates a new instance of MockSeekable. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockSeekable(t interface { - mock.TestingT - Cleanup(func()) -}) *MockSeekable { - mock := &MockSeekable{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockSeekable is an autogenerated mock type for the Seekable type -type MockSeekable struct { - mock.Mock -} - -type MockSeekable_Expecter struct { - mock *mock.Mock -} - -func (_m *MockSeekable) EXPECT() *MockSeekable_Expecter { - return &MockSeekable_Expecter{mock: &_m.Mock} -} - -// OpenRangeReader provides a mock function for the type MockSeekable -func (_mock *MockSeekable) OpenRangeReader(ctx context.Context, off int64, length int64) (io.ReadCloser, error) { - ret := _mock.Called(ctx, off, length) - - if len(ret) == 0 { - panic("no return value specified for OpenRangeReader") - } - - var r0 io.ReadCloser - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) (io.ReadCloser, error)); ok { - return returnFunc(ctx, off, length) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) io.ReadCloser); ok { - r0 = returnFunc(ctx, off, length) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(io.ReadCloser) - } - } - if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok { - r1 = returnFunc(ctx, off, length) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockSeekable_OpenRangeReader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OpenRangeReader' -type MockSeekable_OpenRangeReader_Call struct { - *mock.Call -} - -// OpenRangeReader is a helper method to define mock.On call -// - ctx context.Context -// - off int64 -// - length int64 -func (_e *MockSeekable_Expecter) OpenRangeReader(ctx interface{}, off interface{}, length interface{}) *MockSeekable_OpenRangeReader_Call { - return &MockSeekable_OpenRangeReader_Call{Call: _e.mock.On("OpenRangeReader", ctx, off, length)} -} - -func (_c *MockSeekable_OpenRangeReader_Call) Run(run func(ctx context.Context, off int64, length int64)) *MockSeekable_OpenRangeReader_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 int64 - if args[1] != nil { - arg1 = args[1].(int64) - } - var arg2 int64 - if args[2] != nil { - arg2 = args[2].(int64) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockSeekable_OpenRangeReader_Call) Return(readCloser io.ReadCloser, err error) *MockSeekable_OpenRangeReader_Call { - _c.Call.Return(readCloser, err) - return _c -} - -func (_c *MockSeekable_OpenRangeReader_Call) RunAndReturn(run func(ctx context.Context, off int64, length int64) (io.ReadCloser, error)) *MockSeekable_OpenRangeReader_Call { - _c.Call.Return(run) - return _c -} - -// ReadAt provides a mock function for the type MockSeekable -func (_mock *MockSeekable) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - ret := _mock.Called(ctx, buffer, off) - - if len(ret) == 0 { - panic("no return value specified for ReadAt") - } - - var r0 int - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) (int, error)); ok { - return returnFunc(ctx, buffer, off) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) int); ok { - r0 = returnFunc(ctx, buffer, off) - } else { - r0 = ret.Get(0).(int) - } - if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64) error); ok { - r1 = returnFunc(ctx, buffer, off) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockSeekable_ReadAt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadAt' -type MockSeekable_ReadAt_Call struct { - *mock.Call -} - -// ReadAt is a helper method to define mock.On call -// - ctx context.Context -// - buffer []byte -// - off int64 -func (_e *MockSeekable_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}) *MockSeekable_ReadAt_Call { - return &MockSeekable_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off)} -} - -func (_c *MockSeekable_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64)) *MockSeekable_ReadAt_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 []byte - if args[1] != nil { - arg1 = args[1].([]byte) - } - var arg2 int64 - if args[2] != nil { - arg2 = args[2].(int64) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockSeekable_ReadAt_Call) Return(n int, err error) *MockSeekable_ReadAt_Call { - _c.Call.Return(n, err) - return _c -} - -func (_c *MockSeekable_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64) (int, error)) *MockSeekable_ReadAt_Call { - _c.Call.Return(run) - return _c -} - -// Size provides a mock function for the type MockSeekable -func (_mock *MockSeekable) Size(ctx context.Context) (int64, error) { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Size") - } - - var r0 int64 - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context) (int64, error)); ok { - return returnFunc(ctx) - } - if returnFunc, ok := ret.Get(0).(func(context.Context) int64); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Get(0).(int64) - } - if returnFunc, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = returnFunc(ctx) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockSeekable_Size_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Size' -type MockSeekable_Size_Call struct { - *mock.Call -} - -// Size is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockSeekable_Expecter) Size(ctx interface{}) *MockSeekable_Size_Call { - return &MockSeekable_Size_Call{Call: _e.mock.On("Size", ctx)} -} - -func (_c *MockSeekable_Size_Call) Run(run func(ctx context.Context)) *MockSeekable_Size_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockSeekable_Size_Call) Return(n int64, err error) *MockSeekable_Size_Call { - _c.Call.Return(n, err) - return _c -} - -func (_c *MockSeekable_Size_Call) RunAndReturn(run func(ctx context.Context) (int64, error)) *MockSeekable_Size_Call { - _c.Call.Return(run) - return _c -} - -// StoreFile provides a mock function for the type MockSeekable -func (_mock *MockSeekable) StoreFile(ctx context.Context, path string) error { - ret := _mock.Called(ctx, path) - - if len(ret) == 0 { - panic("no return value specified for StoreFile") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = returnFunc(ctx, path) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockSeekable_StoreFile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StoreFile' -type MockSeekable_StoreFile_Call struct { - *mock.Call -} - -// StoreFile is a helper method to define mock.On call -// - ctx context.Context -// - path string -func (_e *MockSeekable_Expecter) StoreFile(ctx interface{}, path interface{}) *MockSeekable_StoreFile_Call { - return &MockSeekable_StoreFile_Call{Call: _e.mock.On("StoreFile", ctx, path)} -} - -func (_c *MockSeekable_StoreFile_Call) Run(run func(ctx context.Context, path string)) *MockSeekable_StoreFile_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 string - if args[1] != nil { - arg1 = args[1].(string) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockSeekable_StoreFile_Call) Return(err error) *MockSeekable_StoreFile_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockSeekable_StoreFile_Call) RunAndReturn(run func(ctx context.Context, path string) error) *MockSeekable_StoreFile_Call { - _c.Call.Return(run) - return _c -} diff --git a/packages/shared/pkg/storage/storage.go b/packages/shared/pkg/storage/storage.go index 12f5ed95ed..9a3e4e6613 100644 --- a/packages/shared/pkg/storage/storage.go +++ b/packages/shared/pkg/storage/storage.go @@ -39,13 +39,8 @@ const ( MemoryChunkSize = 4 * 1024 * 1024 // 4 MB ) -type SeekableObjectType int - -const ( - UnknownSeekableObjectType SeekableObjectType = iota - MemfileObjectType - RootFSObjectType -) +// rangeReadFunc is a callback for reading a byte range from storage. +type rangeReadFunc func(ctx context.Context, offset int64, length int) (io.ReadCloser, error) type ObjectType int @@ -62,8 +57,8 @@ const ( type StorageProvider interface { DeleteObjectsWithPrefix(ctx context.Context, prefix string) error UploadSignedURL(ctx context.Context, path string, ttl time.Duration) (string, error) - OpenBlob(ctx context.Context, path string, objectType ObjectType) (Blob, error) - OpenSeekable(ctx context.Context, path string, seekableObjectType SeekableObjectType) (Seekable, error) + OpenBlob(ctx context.Context, path string) (Blob, error) + OpenFramedFile(ctx context.Context, path string) (FramedFile, error) GetDetails() string } @@ -73,26 +68,26 @@ type Blob interface { Exists(ctx context.Context) (bool, error) } -type SeekableReader interface { - // Random slice access, off and buffer length must be aligned to block size - ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) +// FramedFile represents a storage object that supports frame-based reads. +// The object knows its own path; callers do not need to supply it. +type FramedFile interface { + // GetFrame reads a single frame from storage into buf. When frameTable is + // nil (uncompressed data), reads directly without frame translation. When + // onRead is non-nil, data is written in readSize-aligned chunks and onRead + // is called after each chunk with the cumulative byte count written so far. + // When readSize <= 0, MemoryChunkSize is used as the default. + GetFrame(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, + buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error) + + // Size returns the uncompressed size of the object. For compressed objects + // with metadata, this returns the original uncompressed size. Size(ctx context.Context) (int64, error) -} - -// StreamingReader supports progressive reads via a streaming range reader. -type StreamingReader interface { - OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) -} -type SeekableWriter interface { - // Store entire file - StoreFile(ctx context.Context, path string) error -} - -type Seekable interface { - SeekableReader - SeekableWriter - StreamingReader + // StoreFile uploads the local file at path, as a multipart upload. When + // opts is non-nil with a compression type, compresses the data and returns + // the FrameTable describing the compressed frames. When opts is nil, + // performs a simple uncompressed upload (returns nil FrameTable). + StoreFile(ctx context.Context, path string, opts *FramedUploadOptions) (*FrameTable, error) } func GetTemplateStorageProvider(ctx context.Context, limiter *limit.Limiter) (StorageProvider, error) { @@ -158,3 +153,136 @@ func GetBlob(ctx context.Context, b Blob) ([]byte, error) { return buf.Bytes(), nil } + +// LoadBlob opens a blob by path and reads its contents. +func LoadBlob(ctx context.Context, s StorageProvider, path string) ([]byte, error) { + blob, err := s.OpenBlob(ctx, path) + if err != nil { + return nil, fmt.Errorf("failed to open blob %s: %w", path, err) + } + + return GetBlob(ctx, blob) +} + +// getFrame is the shared implementation for reading a single frame from storage. +// Each backend (GCP, AWS, FS) calls this with their own rangeRead callback. +// +// When onRead is non-nil, the output is written to buf in readSize-aligned +// blocks and onRead is called after each block with the cumulative bytes +// written. This pipelines network I/O with decompression — the LZ4/zstd reader +// pulls compressed bytes from the HTTP stream on demand, so fetch and decompress +// overlap naturally. When readSize <= 0, MemoryChunkSize is used. +func getFrame(ctx context.Context, rangeRead rangeReadFunc, storageDetails string, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error) { + // Handle uncompressed data (nil frameTable) - read directly without frame translation + if !IsCompressed(frameTable) { + return getFrameUncompressed(ctx, rangeRead, storageDetails, offsetU, buf, readSize, onRead) + } + + // Get the frame info: translate U offset -> C offset for fetching + frameStart, frameSize, err := frameTable.FrameFor(offsetU) + if err != nil { + return Range{}, fmt.Errorf("get frame for offset %#x, %s: %w", offsetU, storageDetails, err) + } + + // Validate buffer size + expectedSize := int(frameSize.C) + if decompress { + expectedSize = int(frameSize.U) + } + if len(buf) < expectedSize { + return Range{}, fmt.Errorf("buffer too small: got %d bytes, need %d bytes for frame", len(buf), expectedSize) + } + + // Fetch the compressed data from storage + respBody, err := rangeRead(ctx, frameStart.C, int(frameSize.C)) + if err != nil { + return Range{}, fmt.Errorf("getting frame at %#x from %s: %w", frameStart.C, storageDetails, err) + } + defer respBody.Close() + + var from io.Reader = respBody + totalSize := int(frameSize.C) + + if decompress { + totalSize = int(frameSize.U) + + switch frameTable.CompressionType { + case CompressionZstd: + dec, err := getZstdDecoder(respBody) + if err != nil { + return Range{}, fmt.Errorf("failed to create zstd decoder: %w", err) + } + defer putZstdDecoder(dec) + from = dec + + case CompressionLZ4: + rd := getLZ4Reader(respBody) + defer putLZ4Reader(rd) + from = rd + + default: + return Range{}, fmt.Errorf("unsupported compression type: %s", frameTable.CompressionType) + } + } + + // Progressive mode: read in readSize blocks, call onRead after each. + if onRead != nil { + return readProgressive(from, buf, totalSize, frameStart.C, readSize, onRead) + } + + n, err := io.ReadFull(from, buf[:totalSize]) + + return Range{Start: frameStart.C, Length: n}, err +} + +// readProgressive reads from src into buf in readSize-aligned blocks, +// calling onRead after each block with the cumulative bytes written. +// When readSize <= 0, MemoryChunkSize is used as the default. +func readProgressive(src io.Reader, buf []byte, totalSize int, rangeStart int64, readSize int64, onRead func(totalWritten int64)) (Range, error) { + if readSize <= 0 { + readSize = MemoryChunkSize + } + + var total int64 + + for total < int64(totalSize) { + end := min(total+readSize, int64(totalSize)) + n, err := io.ReadFull(src, buf[total:end]) + total += int64(n) + + if int64(n) > 0 { + onRead(total) + } + + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + + if err != nil { + return Range{}, fmt.Errorf("progressive read error after %d bytes: %w", total, err) + } + } + + return Range{Start: rangeStart, Length: int(total)}, nil +} + +// getFrameUncompressed reads uncompressed data directly from storage. +// When onRead is non-nil, uses readProgressive for progressive delivery. +func getFrameUncompressed(ctx context.Context, rangeRead rangeReadFunc, storageDetails string, offset int64, buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error) { + respBody, err := rangeRead(ctx, offset, len(buf)) + if err != nil { + return Range{}, fmt.Errorf("getting uncompressed data at %#x from %s: %w", offset, storageDetails, err) + } + defer respBody.Close() + + if onRead != nil { + return readProgressive(respBody, buf, len(buf), offset, readSize, onRead) + } + + n, err := io.ReadFull(respBody, buf) + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) { + return Range{}, fmt.Errorf("reading uncompressed data from %s: %w", storageDetails, err) + } + + return Range{Start: offset, Length: n}, nil +} diff --git a/packages/shared/pkg/storage/storage_aws.go b/packages/shared/pkg/storage/storage_aws.go index 189e1cd501..20f18633fe 100644 --- a/packages/shared/pkg/storage/storage_aws.go +++ b/packages/shared/pkg/storage/storage_aws.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os" + "strconv" "strings" "time" @@ -41,8 +42,8 @@ type awsObject struct { } var ( - _ Seekable = (*awsObject)(nil) - _ Blob = (*awsObject)(nil) + _ FramedFile = (*awsObject)(nil) + _ Blob = (*awsObject)(nil) ) func newAWSStorage(ctx context.Context, bucketName string) (*awsStorage, error) { @@ -127,7 +128,7 @@ func (s *awsStorage) UploadSignedURL(ctx context.Context, path string, ttl time. return resp.URL, nil } -func (s *awsStorage) OpenSeekable(_ context.Context, path string, _ SeekableObjectType) (Seekable, error) { +func (s *awsStorage) OpenFramedFile(_ context.Context, path string) (FramedFile, error) { return &awsObject{ client: s.client, bucketName: s.bucketName, @@ -135,7 +136,7 @@ func (s *awsStorage) OpenSeekable(_ context.Context, path string, _ SeekableObje }, nil } -func (s *awsStorage) OpenBlob(_ context.Context, path string, _ ObjectType) (Blob, error) { +func (s *awsStorage) OpenBlob(_ context.Context, path string) (Blob, error) { return &awsObject{ client: s.client, bucketName: s.bucketName, @@ -162,13 +163,17 @@ func (o *awsObject) WriteTo(ctx context.Context, dst io.Writer) (int64, error) { return io.Copy(dst, resp.Body) } -func (o *awsObject) StoreFile(ctx context.Context, path string) error { +func (o *awsObject) StoreFile(ctx context.Context, path string, opts *FramedUploadOptions) (*FrameTable, error) { + if opts != nil && opts.CompressionType != CompressionNone { + return nil, fmt.Errorf("compressed uploads are not supported on AWS (builds target GCP only)") + } + ctx, cancel := context.WithTimeout(ctx, awsWriteTimeout) defer cancel() f, err := os.Open(path) if err != nil { - return fmt.Errorf("failed to open file %s: %w", path, err) + return nil, fmt.Errorf("failed to open file %s: %w", path, err) } defer f.Close() @@ -189,7 +194,7 @@ func (o *awsObject) StoreFile(ctx context.Context, path string) error { }, ) - return err + return nil, err } func (o *awsObject) Put(ctx context.Context, data []byte) error { @@ -211,8 +216,8 @@ func (o *awsObject) Put(ctx context.Context, data []byte) error { return nil } -func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - readRange := aws.String(fmt.Sprintf("bytes=%d-%d", off, off+length-1)) +func (o *awsObject) openRangeReader(ctx context.Context, off int64, length int) (io.ReadCloser, error) { + readRange := aws.String(fmt.Sprintf("bytes=%d-%d", off, off+int64(length)-1)) resp, err := o.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(o.bucketName), Key: aws.String(o.path), @@ -230,37 +235,6 @@ func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64) (io. return resp.Body, nil } -func (o *awsObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, err error) { - ctx, cancel := context.WithTimeout(ctx, awsReadTimeout) - defer cancel() - - readRange := aws.String(fmt.Sprintf("bytes=%d-%d", off, off+int64(len(buff))-1)) - resp, err := o.client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: aws.String(o.bucketName), - Key: aws.String(o.path), - Range: readRange, - }) - if err != nil { - var nsk *types.NoSuchKey - if errors.As(err, &nsk) { - return 0, ErrObjectNotExist - } - - return 0, err - } - - defer resp.Body.Close() - - // When the object is smaller than requested range there will be unexpected EOF, - // but backend expects to return EOF in this case. - n, err = io.ReadFull(resp.Body, buff) - if errors.Is(err, io.ErrUnexpectedEOF) { - err = io.EOF - } - - return n, err -} - func (o *awsObject) Size(ctx context.Context) (int64, error) { ctx, cancel := context.WithTimeout(ctx, awsOperationTimeout) defer cancel() @@ -276,6 +250,13 @@ func (o *awsObject) Size(ctx context.Context) (int64, error) { return 0, err } + if v, ok := resp.Metadata["uncompressed-size"]; ok { + parsed, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr == nil { + return parsed, nil + } + } + return *resp.ContentLength, nil } @@ -306,3 +287,7 @@ func ignoreNotExists(err error) error { return err } + +func (o *awsObject) GetFrame(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error) { + return getFrame(ctx, o.openRangeReader, "S3:"+o.path, offsetU, frameTable, decompress, buf, readSize, onRead) +} diff --git a/packages/shared/pkg/storage/storage_cache.go b/packages/shared/pkg/storage/storage_cache.go index 2f5f05f43c..d314e5b4fc 100644 --- a/packages/shared/pkg/storage/storage_cache.go +++ b/packages/shared/pkg/storage/storage_cache.go @@ -68,8 +68,8 @@ func (c cache) UploadSignedURL(ctx context.Context, path string, ttl time.Durati return c.inner.UploadSignedURL(ctx, path, ttl) } -func (c cache) OpenBlob(ctx context.Context, path string, objectType ObjectType) (Blob, error) { - innerObject, err := c.inner.OpenBlob(ctx, path, objectType) +func (c cache) OpenBlob(ctx context.Context, path string) (Blob, error) { + innerObject, err := c.inner.OpenBlob(ctx, path) if err != nil { return nil, fmt.Errorf("failed to open object: %w", err) } @@ -88,8 +88,8 @@ func (c cache) OpenBlob(ctx context.Context, path string, objectType ObjectType) }, nil } -func (c cache) OpenSeekable(ctx context.Context, path string, objectType SeekableObjectType) (Seekable, error) { - innerObject, err := c.inner.OpenSeekable(ctx, path, objectType) +func (c cache) OpenFramedFile(ctx context.Context, path string) (FramedFile, error) { + innerObject, err := c.inner.OpenFramedFile(ctx, path) if err != nil { return nil, fmt.Errorf("failed to open object: %w", err) } @@ -99,7 +99,7 @@ func (c cache) OpenSeekable(ctx context.Context, path string, objectType Seekabl return nil, fmt.Errorf("failed to create cache directory: %w", err) } - return &cachedSeekable{ + return &cachedFramedFile{ path: localPath, chunkSize: c.chunkSize, inner: innerObject, diff --git a/packages/shared/pkg/storage/storage_cache_blob.go b/packages/shared/pkg/storage/storage_cache_blob.go index 33cdcbaac0..696a66126d 100644 --- a/packages/shared/pkg/storage/storage_cache_blob.go +++ b/packages/shared/pkg/storage/storage_cache_blob.go @@ -45,12 +45,12 @@ func (b *cachedBlob) WriteTo(ctx context.Context, dst io.Writer) (n int64, e err bytesRead, err := b.copyFullFileFromCache(ctx, dst) if err == nil { - recordCacheRead(ctx, true, bytesRead, cacheTypeObject, cacheOpWriteTo) + recordCacheRead(ctx, true, bytesRead, cacheTypeBlob, cacheOpWriteTo) return bytesRead, nil } - recordCacheReadError(ctx, cacheTypeObject, cacheOpWriteTo, err) + recordCacheReadError(ctx, cacheTypeBlob, cacheOpWriteTo, err) // This is semi-arbitrary. this code path is called for files that tend to be less than 1 MB (headers, metadata, etc), // so 2 MB allows us to read the file without needing to allocate more memory, with some room for growth. If the @@ -72,13 +72,13 @@ func (b *cachedBlob) WriteTo(ctx context.Context, dst io.Writer) (n int64, e err count, err := b.writeFileToCache(ctx, buffer) if err != nil { - recordCacheWriteError(ctx, cacheTypeObject, cacheOpWriteTo, err) + recordCacheWriteError(ctx, cacheTypeBlob, cacheOpWriteTo, err) recordError(span, err) return } - recordCacheWrite(ctx, count, cacheTypeObject, cacheOpWriteTo) + recordCacheWrite(ctx, count, cacheTypeBlob, cacheOpWriteTo) }) written, err := dst.Write(data) @@ -86,7 +86,7 @@ func (b *cachedBlob) WriteTo(ctx context.Context, dst io.Writer) (n int64, e err return int64(written), fmt.Errorf("failed to write object: %w", err) } - recordCacheRead(ctx, false, int64(written), cacheTypeObject, cacheOpWriteTo) + recordCacheRead(ctx, false, int64(written), cacheTypeBlob, cacheOpWriteTo) return int64(written), err // in case err == EOF } @@ -108,9 +108,9 @@ func (b *cachedBlob) Put(ctx context.Context, data []byte) (e error) { count, err := b.writeFileToCache(ctx, bytes.NewReader(data)) if err != nil { recordError(span, err) - recordCacheWriteError(ctx, cacheTypeObject, cacheOpWrite, err) + recordCacheWriteError(ctx, cacheTypeBlob, cacheOpPut, err) } else { - recordCacheWrite(ctx, count, cacheTypeObject, cacheOpWrite) + recordCacheWrite(ctx, count, cacheTypeBlob, cacheOpPut) } }) } diff --git a/packages/shared/pkg/storage/storage_cache_blob_test.go b/packages/shared/pkg/storage/storage_cache_blob_test.go index 1054a05d36..1d226bd7ca 100644 --- a/packages/shared/pkg/storage/storage_cache_blob_test.go +++ b/packages/shared/pkg/storage/storage_cache_blob_test.go @@ -13,8 +13,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace/noop" - - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" ) var noopTracer = noop.TracerProvider{}.Tracer("") @@ -32,12 +30,12 @@ func TestCachedObjectProvider_Put(t *testing.T) { err := os.MkdirAll(cacheDir, os.ModePerm) require.NoError(t, err) - inner := storagemocks.NewMockBlob(t) + inner := NewMockBlob(t) inner.EXPECT(). Put(mock.Anything, mock.Anything). Return(nil) - featureFlags := storagemocks.NewMockFeatureFlagsClient(t) + featureFlags := NewMockFeatureFlagsClient(t) featureFlags.EXPECT().BoolFlag(mock.Anything, mock.Anything).Return(true) c := cachedBlob{path: cacheDir, inner: inner, chunkSize: 1024, flags: featureFlags, tracer: noopTracer} @@ -68,7 +66,7 @@ func TestCachedObjectProvider_Put(t *testing.T) { const dataSize = 10 * megabyte actualData := generateData(t, dataSize) - inner := storagemocks.NewMockBlob(t) + inner := NewMockBlob(t) inner.EXPECT(). WriteTo(mock.Anything, mock.Anything). RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { @@ -101,7 +99,7 @@ func TestCachedObjectProvider_WriteFileToCache(t *testing.T) { tracer: noopTracer, } errTarget := errors.New("find me") - reader := storagemocks.NewMockReader(t) + reader := NewMockReader(t) reader.EXPECT().Read(mock.Anything).Return(4, nil).Once() reader.EXPECT().Read(mock.Anything).Return(0, errTarget).Once() diff --git a/packages/shared/pkg/storage/storage_cache_metrics.go b/packages/shared/pkg/storage/storage_cache_metrics.go index 037bc7ed06..7fd659ec7e 100644 --- a/packages/shared/pkg/storage/storage_cache_metrics.go +++ b/packages/shared/pkg/storage/storage_cache_metrics.go @@ -28,21 +28,19 @@ var ( type cacheOp string const ( - cacheOpWriteTo cacheOp = "write_to" - cacheOpReadAt cacheOp = "read_at" - cacheOpSize cacheOp = "size" + cacheOpWriteTo cacheOp = "write_to" + cacheOpGetFrame cacheOp = "get_frame" + cacheOpSize cacheOp = "size" - cacheOpOpenRangeReader cacheOp = "open_range_reader" - - cacheOpWrite cacheOp = "write" - cacheOpWriteFromFileSystem cacheOp = "write_from_filesystem" + cacheOpPut cacheOp = "put" + cacheOpStoreFile cacheOp = "store_file" ) type cacheType string const ( - cacheTypeObject cacheType = "object" - cacheTypeSeekable cacheType = "seekable" + cacheTypeBlob cacheType = "blob" + cacheTypeFramedFile cacheType = "framed_file" ) func recordCacheRead(ctx context.Context, isHit bool, bytesRead int64, t cacheType, op cacheOp) { diff --git a/packages/shared/pkg/storage/storage_cache_seekable.go b/packages/shared/pkg/storage/storage_cache_seekable.go index 47d65ae94d..51869a6658 100644 --- a/packages/shared/pkg/storage/storage_cache_seekable.go +++ b/packages/shared/pkg/storage/storage_cache_seekable.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "strconv" + "strings" "sync" "github.com/google/uuid" @@ -24,6 +25,12 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) +const ( + nfsCacheOperationAttr = "operation" + nfsCacheOperationAttrGetFrame = "GetFrame" + nfsCacheOperationAttrSize = "Size" +) + var ( ErrOffsetUnaligned = errors.New("offset must be a multiple of chunk size") ErrBufferTooSmall = errors.New("buffer is too small") @@ -31,12 +38,6 @@ var ( ErrBufferTooLarge = errors.New("buffer is too large") ) -const ( - nfsCacheOperationAttr = "operation" - nfsCacheOperationAttrReadAt = "ReadAt" - nfsCacheOperationAttrSize = "Size" -) - var ( cacheSlabReadTimerFactory = utils.Must(telemetry.NewTimerFactory(meter, "orchestrator.storage.slab.nfs.read", @@ -57,160 +58,275 @@ type featureFlagsClient interface { IntFlag(ctx context.Context, flag featureflags.IntFlag, ldctx ...ldcontext.Context) int } -type cachedSeekable struct { +type cachedFramedFile struct { path string chunkSize int64 - inner Seekable + inner FramedFile flags featureFlagsClient tracer trace.Tracer wg sync.WaitGroup } -var ( - _ Seekable = (*cachedSeekable)(nil) - _ StreamingReader = (*cachedSeekable)(nil) -) +var _ FramedFile = (*cachedFramedFile)(nil) -func (c *cachedSeekable) ReadAt(ctx context.Context, buff []byte, offset int64) (n int, err error) { - ctx, span := c.tracer.Start(ctx, "read object at offset", trace.WithAttributes( - attribute.Int64("offset", offset), - attribute.Int("buff_len", len(buff)), +// GetFrame reads a single frame from storage with NFS caching. +// +// Compressed path (ft != nil): cache key is the compressed frame file (.frm). +// Cache hit → read compressed bytes from NFS → decompress if requested. +// Cache miss → inner.GetFrame(decompress=false) → async write-back → decompress. +// +// Uncompressed path (ft == nil): cache key is the chunk file (.bin). +// Cache hit → read from NFS chunk file → deliver. +// Cache miss → inner.GetFrame → async write-back. +func (c *cachedFramedFile) GetFrame(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error) { + if err := c.validateGetFrameParams(offsetU, len(buf), frameTable, decompress); err != nil { + return Range{}, err + } + + if IsCompressed(frameTable) { + return c.getFrameCompressed(ctx, offsetU, frameTable, decompress, buf, readSize, onRead) + } + + return c.getFrameUncompressed(ctx, offsetU, buf, readSize, onRead) +} + +func (c *cachedFramedFile) getFrameCompressed(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64)) (_ Range, e error) { + ctx, span := c.tracer.Start(ctx, "get_frame at offset", trace.WithAttributes( + attribute.Int64("offset", offsetU), + attribute.Int("buf_len", len(buf)), + attribute.Bool("compressed", true), )) defer func() { - recordError(span, err) + recordError(span, e) span.End() }() - if err := c.validateReadAtParams(int64(len(buff)), offset); err != nil { - return 0, err + frameStart, frameSize, err := frameTable.FrameFor(offsetU) + if err != nil { + return Range{}, fmt.Errorf("cache GetFrame: frame lookup for offset %#x: %w", offsetU, err) } - // try to read from cache first - chunkPath := c.makeChunkFilename(offset) + framePath := makeFrameFilename(c.path, frameStart, frameSize) + + // Try NFS cache + readTimer := cacheSlabReadTimerFactory.Begin(attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrGetFrame)) + compressedBuf := make([]byte, frameSize.C) + n, readErr := readCacheFile(framePath, compressedBuf) + + if readErr == nil { + // Cache hit + readTimer.Success(ctx, int64(n)) + recordCacheRead(ctx, true, int64(n), cacheTypeFramedFile, cacheOpGetFrame) + } else { + readTimer.Failure(ctx, 0) + + if !os.IsNotExist(readErr) { + recordCacheReadError(ctx, cacheTypeFramedFile, cacheOpGetFrame, readErr) + } + + // Cache miss: fetch compressed data from inner + _, err = c.inner.GetFrame(ctx, offsetU, frameTable, false, compressedBuf, readSize, nil) + if err != nil { + return Range{}, fmt.Errorf("cache GetFrame: inner fetch for offset %#x: %w", offsetU, err) + } - readTimer := cacheSlabReadTimerFactory.Begin(attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrReadAt)) - count, err := c.readAtFromCache(ctx, chunkPath, buff) - if ignoreEOF(err) == nil { - recordCacheRead(ctx, true, int64(count), cacheTypeSeekable, cacheOpReadAt) - readTimer.Success(ctx, int64(count)) + n = int(frameSize.C) + recordCacheRead(ctx, false, int64(n), cacheTypeFramedFile, cacheOpGetFrame) - return count, err // return `err` in case it's io.EOF + // Async write-back + dataCopy := make([]byte, n) + copy(dataCopy, compressedBuf[:n]) + + c.goCtx(ctx, func(ctx context.Context) { + if err := c.writeFrameToCache(ctx, framePath, dataCopy); err != nil { + recordCacheWriteError(ctx, cacheTypeFramedFile, cacheOpGetFrame, err) + } + }) } - readTimer.Failure(ctx, int64(count)) - if !os.IsNotExist(err) { - recordCacheReadError(ctx, cacheTypeSeekable, cacheOpReadAt, err) + if !decompress { + copy(buf, compressedBuf[:n]) + if onRead != nil { + onRead(int64(n)) + } + + return Range{Start: frameStart.C, Length: n}, nil + } + + // Decompress: stream compressed data through a pooled decoder into buf + decompN, err := decompressInto(frameTable.CompressionType, compressedBuf[:n], buf, readSize, onRead) + if err != nil { + return Range{}, fmt.Errorf("cache GetFrame: decompress for offset %#x: %w", offsetU, err) } - logger.L().Debug(ctx, "failed to read cached chunk, falling back to remote read", + return Range{Start: frameStart.C, Length: decompN}, nil +} + +func (c *cachedFramedFile) getFrameUncompressed(ctx context.Context, offsetU int64, buf []byte, readSize int64, onRead func(totalWritten int64)) (_ Range, e error) { + ctx, span := c.tracer.Start(ctx, "get_frame at offset", trace.WithAttributes( + attribute.Int64("offset", offsetU), + attribute.Int("buf_len", len(buf)), + attribute.Bool("compressed", false), + )) + defer func() { + recordError(span, e) + span.End() + }() + + chunkPath := c.makeChunkFilename(offsetU) + + readTimer := cacheSlabReadTimerFactory.Begin(attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrGetFrame)) + n, readErr := readCacheFile(chunkPath, buf) + + if readErr == nil { + // Cache hit + readTimer.Success(ctx, int64(n)) + recordCacheRead(ctx, true, int64(n), cacheTypeFramedFile, cacheOpGetFrame) + + if onRead != nil { + onRead(int64(n)) + } + + return Range{Start: offsetU, Length: n}, nil + } + readTimer.Failure(ctx, 0) + + if !os.IsNotExist(readErr) { + recordCacheReadError(ctx, cacheTypeFramedFile, cacheOpGetFrame, readErr) + } + + logger.L().Debug(ctx, "cache miss for uncompressed chunk, falling back to remote read", zap.String("chunk_path", chunkPath), - zap.Int64("offset", offset), - zap.Error(err)) + zap.Int64("offset", offsetU), + zap.Error(readErr)) - // read remote file - readCount, err := c.inner.ReadAt(ctx, buff, offset) - if ignoreEOF(err) != nil { - return readCount, fmt.Errorf("failed to perform uncached read: %w", err) + // Cache miss: fetch from inner + r, err := c.inner.GetFrame(ctx, offsetU, nil, false, buf, readSize, onRead) + if err != nil { + return Range{}, fmt.Errorf("cache GetFrame uncompressed: inner fetch at %#x: %w", offsetU, err) } - shadowBuff := make([]byte, readCount) - copy(shadowBuff, buff[:readCount]) + recordCacheRead(ctx, false, int64(r.Length), cacheTypeFramedFile, cacheOpGetFrame) - c.goCtx(ctx, func(ctx context.Context) { - ctx, span := c.tracer.Start(ctx, "write chunk at offset back to cache") - defer span.End() + // Async write-back + dataCopy := make([]byte, r.Length) + copy(dataCopy, buf[:r.Length]) - if err := c.writeChunkToCache(ctx, offset, chunkPath, shadowBuff); err != nil { - recordError(span, err) - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpReadAt, err) + c.goCtx(ctx, func(ctx context.Context) { + if err := c.writeChunkToCache(ctx, offsetU, chunkPath, dataCopy); err != nil { + recordCacheWriteError(ctx, cacheTypeFramedFile, cacheOpGetFrame, err) } }) - recordCacheRead(ctx, false, int64(readCount), cacheTypeSeekable, cacheOpReadAt) - - return readCount, err + return r, nil } -func (c *cachedSeekable) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - // Try NFS cache file first - chunkPath := c.makeChunkFilename(off) +// decompressInto decompresses src into dst using pooled decoders. +// If onRead is non-nil, calls it progressively in readSize chunks. +func decompressInto(ct CompressionType, src, dst []byte, readSize int64, onRead func(int64)) (int, error) { + r := bytes.NewReader(src) - fp, err := os.Open(chunkPath) - if err == nil { - recordCacheRead(ctx, true, length, cacheTypeSeekable, cacheOpOpenRangeReader) + switch ct { + case CompressionZstd: + dec, err := getZstdDecoder(r) + if err != nil { + return 0, fmt.Errorf("zstd decoder: %w", err) + } + defer putZstdDecoder(dec) + + return readIntoWithCallback(dec, dst, readSize, onRead) + + case CompressionLZ4: + rd := getLZ4Reader(r) + defer putLZ4Reader(rd) + + return readIntoWithCallback(rd, dst, readSize, onRead) - return &fsRangeReadCloser{ - Reader: io.NewSectionReader(fp, 0, length), - file: fp, - }, nil + default: + return 0, fmt.Errorf("unsupported compression type: %s", ct) } +} + +// readIntoWithCallback reads from src into dst. If onRead is non-nil, +// delivers data in readSize-aligned chunks with progressive callbacks. +func readIntoWithCallback(src io.Reader, dst []byte, readSize int64, onRead func(int64)) (int, error) { + if onRead == nil { + n, err := io.ReadFull(src, dst) + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) { + return n, err + } - if !os.IsNotExist(err) { - recordCacheReadError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) + return n, nil } - // Cache miss: delegate to the inner backend (Seekable embeds StreamingReader). - inner, err := c.inner.OpenRangeReader(ctx, off, length) - if err != nil { - return nil, fmt.Errorf("failed to open inner range reader: %w", err) + if readSize <= 0 { + readSize = MemoryChunkSize } - recordCacheRead(ctx, false, length, cacheTypeSeekable, cacheOpOpenRangeReader) + var total int64 + totalSize := int64(len(dst)) - // Wrap in a write-through reader that caches data on Close - return &cacheWriteThroughReader{ - inner: inner, - buf: bytes.NewBuffer(make([]byte, 0, length)), - cache: c, - ctx: ctx, - off: off, - chunkPath: chunkPath, - }, nil -} + for total < totalSize { + end := min(total+readSize, totalSize) + n, err := io.ReadFull(src, dst[total:end]) + total += int64(n) + + if n > 0 { + onRead(total) + } + + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + + if err != nil { + return int(total), fmt.Errorf("progressive decompress error after %d bytes: %w", total, err) + } + } -// cacheWriteThroughReader wraps an inner reader, buffering all data read through it. -// On Close, it asynchronously writes the buffered data to the NFS cache. -type cacheWriteThroughReader struct { - inner io.ReadCloser - buf *bytes.Buffer - cache *cachedSeekable - ctx context.Context //nolint:containedctx // needed for async cache write-back in Close - off int64 - chunkPath string + return int(total), nil } -func (r *cacheWriteThroughReader) Read(p []byte) (int, error) { - n, err := r.inner.Read(p) - if n > 0 { - r.buf.Write(p[:n]) +// readCacheFile reads a cache file into buf. Returns bytes read and error. +func readCacheFile(path string, buf []byte) (int, error) { + f, err := os.Open(path) + if err != nil { + return 0, err } + defer f.Close() - return n, err + n, err := io.ReadFull(f, buf) + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) { + return n, err + } + + return n, nil } -func (r *cacheWriteThroughReader) Close() error { - closeErr := r.inner.Close() +// writeFrameToCache writes compressed frame data to the NFS cache. +func (c *cachedFramedFile) writeFrameToCache(ctx context.Context, framePath string, data []byte) error { + writeTimer := cacheSlabWriteTimerFactory.Begin() - if r.buf.Len() > 0 { - data := make([]byte, r.buf.Len()) - copy(data, r.buf.Bytes()) + dir := filepath.Dir(framePath) + if err := os.MkdirAll(dir, cacheDirPermissions); err != nil { + writeTimer.Failure(ctx, 0) - r.cache.goCtx(r.ctx, func(ctx context.Context) { - ctx, span := r.cache.tracer.Start(ctx, "write range reader chunk back to cache") - defer span.End() + return fmt.Errorf("failed to create frame cache dir: %w", err) + } - if err := r.cache.writeChunkToCache(ctx, r.off, r.chunkPath, data); err != nil { - recordError(span, err) - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) - } - }) + if err := os.WriteFile(framePath, data, cacheFilePermissions); err != nil { + writeTimer.Failure(ctx, int64(len(data))) + + return fmt.Errorf("failed to write frame to cache: %w", err) } - return closeErr + writeTimer.Success(ctx, int64(len(data))) + + return nil } -func (c *cachedSeekable) Size(ctx context.Context) (n int64, e error) { +func (c *cachedFramedFile) Size(ctx context.Context) (size int64, e error) { ctx, span := c.tracer.Start(ctx, "get size of object") defer func() { recordError(span, e) @@ -219,38 +335,43 @@ func (c *cachedSeekable) Size(ctx context.Context) (n int64, e error) { readTimer := cacheSlabReadTimerFactory.Begin(attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrSize)) - size, err := c.readLocalSize(ctx) + u, err := c.readLocalSize(ctx) if err == nil { - recordCacheRead(ctx, true, 0, cacheTypeSeekable, cacheOpSize) + recordCacheRead(ctx, true, 0, cacheTypeFramedFile, cacheOpSize) readTimer.Success(ctx, 0) - return size, nil + return u, nil } readTimer.Failure(ctx, 0) - recordCacheReadError(ctx, cacheTypeSeekable, cacheOpSize, err) + recordCacheReadError(ctx, cacheTypeFramedFile, cacheOpSize, err) - size, err = c.inner.Size(ctx) + u, err = c.inner.Size(ctx) if err != nil { - return size, err + return 0, err } + finalU := u c.goCtx(ctx, func(ctx context.Context) { ctx, span := c.tracer.Start(ctx, "write size of object to cache") defer span.End() - if err := c.writeLocalSize(ctx, size); err != nil { + if err := c.writeLocalSize(ctx, finalU); err != nil { recordError(span, err) - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpSize, err) + recordCacheWriteError(ctx, cacheTypeFramedFile, cacheOpSize, err) } }) - recordCacheRead(ctx, false, 0, cacheTypeSeekable, cacheOpSize) + recordCacheRead(ctx, false, 0, cacheTypeFramedFile, cacheOpSize) - return size, nil + return u, nil } -func (c *cachedSeekable) StoreFile(ctx context.Context, path string) (e error) { +func (c *cachedFramedFile) StoreFile(ctx context.Context, path string, opts *FramedUploadOptions) (_ *FrameTable, e error) { + if opts != nil && opts.CompressionType != CompressionNone { + return c.storeFileCompressed(ctx, path, opts) + } + ctx, span := c.tracer.Start(ctx, "write object from file system", trace.WithAttributes(attribute.String("path", path)), ) @@ -259,9 +380,6 @@ func (c *cachedSeekable) StoreFile(ctx context.Context, path string) (e error) { span.End() }() - // write the file to the disk and the remote system at the same time. - // this opens the file twice, but the API makes it difficult to use a MultiWriter - if c.flags.BoolFlag(ctx, featureflags.EnableWriteThroughCacheFlag) { c.goCtx(ctx, func(ctx context.Context) { ctx, span := c.tracer.Start(ctx, "write cache object from file system", @@ -271,112 +389,146 @@ func (c *cachedSeekable) StoreFile(ctx context.Context, path string) (e error) { size, err := c.createCacheBlocksFromFile(ctx, path) if err != nil { recordError(span, err) - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpWriteFromFileSystem, fmt.Errorf("failed to create cache blocks: %w", err)) + recordCacheWriteError(ctx, cacheTypeFramedFile, cacheOpStoreFile, fmt.Errorf("failed to create cache blocks: %w", err)) return } - recordCacheWrite(ctx, size, cacheTypeSeekable, cacheOpWriteFromFileSystem) + recordCacheWrite(ctx, size, cacheTypeFramedFile, cacheOpStoreFile) if err := c.writeLocalSize(ctx, size); err != nil { recordError(span, err) - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpWriteFromFileSystem, fmt.Errorf("failed to write local file size: %w", err)) + recordCacheWriteError(ctx, cacheTypeFramedFile, cacheOpStoreFile, fmt.Errorf("failed to write local file size: %w", err)) } }) } - return c.inner.StoreFile(ctx, path) + return c.inner.StoreFile(ctx, path, nil) +} + +// storeFileCompressed wraps the inner StoreFile with an OnFrameReady callback +// that writes each compressed frame to the NFS cache. +func (c *cachedFramedFile) storeFileCompressed(ctx context.Context, localPath string, opts *FramedUploadOptions) (*FrameTable, error) { + // Copy opts so we don't mutate the caller's value + modifiedOpts := *opts + modifiedOpts.OnFrameReady = func(offset FrameOffset, size FrameSize, data []byte) error { + framePath := makeFrameFilename(c.path, offset, size) + + dir := filepath.Dir(framePath) + if err := os.MkdirAll(dir, cacheDirPermissions); err != nil { + logger.L().Warn(ctx, "failed to create cache directory for compressed frame", + zap.String("dir", dir), + zap.Error(err)) + + return nil // non-fatal: cache write failures should not block uploads + } + + if err := os.WriteFile(framePath, data, cacheFilePermissions); err != nil { + logger.L().Warn(ctx, "failed to write compressed frame to cache", + zap.String("path", framePath), + zap.Error(err)) + + return nil // non-fatal + } + + return nil + } + + // Chain the original callback if present + if opts.OnFrameReady != nil { + origCallback := opts.OnFrameReady + wrappedCallback := modifiedOpts.OnFrameReady + modifiedOpts.OnFrameReady = func(offset FrameOffset, size FrameSize, data []byte) error { + if err := origCallback(offset, size, data); err != nil { + return err + } + + return wrappedCallback(offset, size, data) + } + } + + return c.inner.StoreFile(ctx, localPath, &modifiedOpts) +} + +// makeFrameFilename returns the NFS cache path for a compressed frame. +// Format: {cacheBasePath}/{016xC}-{xC}.frm +func makeFrameFilename(cacheBasePath string, offset FrameOffset, size FrameSize) string { + return fmt.Sprintf("%s/%016x-%x.frm", cacheBasePath, offset.C, size.C) } -func (c *cachedSeekable) goCtx(ctx context.Context, fn func(context.Context)) { +func (c *cachedFramedFile) goCtx(ctx context.Context, fn func(context.Context)) { c.wg.Go(func() { fn(context.WithoutCancel(ctx)) }) } -func (c *cachedSeekable) makeChunkFilename(offset int64) string { +func (c *cachedFramedFile) makeChunkFilename(offset int64) string { return fmt.Sprintf("%s/%012d-%d.bin", c.path, offset/c.chunkSize, c.chunkSize) } -func (c *cachedSeekable) makeTempChunkFilename(offset int64) string { +func (c *cachedFramedFile) makeTempChunkFilename(offset int64) string { tempFilename := uuid.NewString() return fmt.Sprintf("%s/.temp.%012d-%d.bin.%s", c.path, offset/c.chunkSize, c.chunkSize, tempFilename) } -func (c *cachedSeekable) readAtFromCache(ctx context.Context, chunkPath string, buff []byte) (n int, e error) { - ctx, span := c.tracer.Start(ctx, "read chunk at offset from cache") - defer func() { - recordError(span, e) - span.End() - }() - - fp, err := os.Open(chunkPath) - if err != nil { - return 0, fmt.Errorf("failed to open file: %w", err) - } - - defer utils.Cleanup(ctx, "failed to close chunk", fp.Close) - - count, err := fp.Read(buff) - if ignoreEOF(err) != nil { - return 0, fmt.Errorf("failed to read from chunk: %w", err) - } - - return count, err // return `err` in case it's io.EOF -} - -func (c *cachedSeekable) sizeFilename() string { +func (c *cachedFramedFile) sizeFilename() string { return filepath.Join(c.path, "size.txt") } -func (c *cachedSeekable) readLocalSize(context.Context) (int64, error) { +func (c *cachedFramedFile) readLocalSize(context.Context) (int64, error) { filename := c.sizeFilename() - content, err := os.ReadFile(filename) - if err != nil { - return 0, fmt.Errorf("failed to read cached size: %w", err) + content, readErr := os.ReadFile(filename) + if readErr != nil { + return 0, fmt.Errorf("failed to read cached size: %w", readErr) } - size, err := strconv.ParseInt(string(content), 10, 64) - if err != nil { - return 0, fmt.Errorf("failed to parse cached size: %w", err) + parts := strings.Fields(string(content)) + if len(parts) == 0 { + return 0, fmt.Errorf("empty cached size file") + } + + u, parseErr := strconv.ParseInt(parts[0], 10, 64) + if parseErr != nil { + return 0, fmt.Errorf("failed to parse cached uncompressed size: %w", parseErr) } - return size, nil + return u, nil } -func (c *cachedSeekable) validateReadAtParams(buffSize, offset int64) error { - if buffSize == 0 { +func (c *cachedFramedFile) validateGetFrameParams(off int64, length int, frameTable *FrameTable, _ bool) error { + if length == 0 { return ErrBufferTooSmall } - if buffSize > c.chunkSize { - return ErrBufferTooLarge + + // Compressed reads: the frame table handles alignment, no chunk checks needed. + if IsCompressed(frameTable) { + return nil } - if offset%c.chunkSize != 0 { - return ErrOffsetUnaligned + + // Uncompressed reads: enforce chunk alignment and bounds. + if off%c.chunkSize != 0 { + return fmt.Errorf("offset %#x is not aligned to chunk size %#x: %w", off, c.chunkSize, ErrOffsetUnaligned) } - if (offset%c.chunkSize)+buffSize > c.chunkSize { - return ErrMultipleChunks + + if int64(length) > c.chunkSize { + return fmt.Errorf("buffer length %d exceeds chunk size %d: %w", length, c.chunkSize, ErrBufferTooLarge) } return nil } -func (c *cachedSeekable) writeChunkToCache(ctx context.Context, offset int64, chunkPath string, bytes []byte) error { +func (c *cachedFramedFile) writeChunkToCache(ctx context.Context, offset int64, chunkPath string, bytes []byte) error { writeTimer := cacheSlabWriteTimerFactory.Begin() - // Try to acquire lock for this chunk write to NFS cache lockFile, err := lock.TryAcquireLock(ctx, chunkPath) if err != nil { - // failed to acquire lock, which is a different category of failure than "write failed" - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpReadAt, err) - + recordCacheWriteError(ctx, cacheTypeFramedFile, cacheOpGetFrame, err) writeTimer.Failure(ctx, 0) return nil } - // Release lock after write completes defer func() { err := lock.ReleaseLock(ctx, lockFile) if err != nil { @@ -408,16 +560,14 @@ func (c *cachedSeekable) writeChunkToCache(ctx context.Context, offset int64, ch return nil } -func (c *cachedSeekable) writeLocalSize(ctx context.Context, size int64) error { +func (c *cachedFramedFile) writeLocalSize(ctx context.Context, size int64) error { finalFilename := c.sizeFilename() - // Try to acquire lock for this chunk write to NFS cache lockFile, err := lock.TryAcquireLock(ctx, finalFilename) if err != nil { return fmt.Errorf("failed to acquire lock for local size: %w", err) } - // Release lock after write completes defer func() { err := lock.ReleaseLock(ctx, lockFile) if err != nil { @@ -443,7 +593,7 @@ func (c *cachedSeekable) writeLocalSize(ctx context.Context, size int64) error { return nil } -func (c *cachedSeekable) createCacheBlocksFromFile(ctx context.Context, inputPath string) (count int64, err error) { +func (c *cachedFramedFile) createCacheBlocksFromFile(ctx context.Context, inputPath string) (count int64, err error) { ctx, span := c.tracer.Start(ctx, "create cache blocks from filesystem") defer func() { recordError(span, err) @@ -486,10 +636,7 @@ func (c *cachedSeekable) createCacheBlocksFromFile(ctx context.Context, inputPat return totalSize, err } -// writeChunkFromFile writes a piece of a local file. It does not need to worry about race conditions, as it will only -// be called in the build layer, which cannot be built on multiple machines at the same time, or multiple times on the -// same machine.. -func (c *cachedSeekable) writeChunkFromFile(ctx context.Context, offset int64, input *os.File) (err error) { +func (c *cachedFramedFile) writeChunkFromFile(ctx context.Context, offset int64, input *os.File) (err error) { _, span := c.tracer.Start(ctx, "write chunk from file at offset", trace.WithAttributes( attribute.Int64("offset", offset), )) diff --git a/packages/shared/pkg/storage/storage_cache_seekable_test.go b/packages/shared/pkg/storage/storage_cache_seekable_test.go index b9179f3127..24de463855 100644 --- a/packages/shared/pkg/storage/storage_cache_seekable_test.go +++ b/packages/shared/pkg/storage/storage_cache_seekable_test.go @@ -2,7 +2,6 @@ package storage import ( "context" - "errors" "io" "os" "path/filepath" @@ -11,19 +10,17 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" ) -func TestCachedFileObjectProvider_MakeChunkFilename(t *testing.T) { +func TestCachedFramedFile_MakeChunkFilename(t *testing.T) { t.Parallel() - c := cachedSeekable{path: "/a/b/c", chunkSize: 1024, tracer: noopTracer} + c := cachedFramedFile{path: "/a/b/c", chunkSize: 1024, tracer: noopTracer} filename := c.makeChunkFilename(1024 * 4) assert.Equal(t, "/a/b/c/000000000004-1024.bin", filename) } -func TestCachedFileObjectProvider_Size(t *testing.T) { +func TestCachedFramedFile_Size(t *testing.T) { t.Parallel() t.Run("can be cached successfully", func(t *testing.T) { @@ -31,10 +28,10 @@ func TestCachedFileObjectProvider_Size(t *testing.T) { const expectedSize int64 = 1024 - inner := storagemocks.NewMockSeekable(t) + inner := NewMockFramedFile(t) inner.EXPECT().Size(mock.Anything).Return(expectedSize, nil) - c := cachedSeekable{path: t.TempDir(), inner: inner, tracer: noopTracer} + c := cachedFramedFile{path: t.TempDir(), inner: inner, tracer: noopTracer} // first call will write to cache size, err := c.Size(t.Context()) @@ -53,7 +50,7 @@ func TestCachedFileObjectProvider_Size(t *testing.T) { }) } -func TestCachedFileObjectProvider_WriteFromFileSystem(t *testing.T) { +func TestCachedFramedFile_WriteFromFileSystem(t *testing.T) { t.Parallel() t.Run("can be cached successfully", func(t *testing.T) { @@ -70,19 +67,19 @@ func TestCachedFileObjectProvider_WriteFromFileSystem(t *testing.T) { err = os.WriteFile(tempFilename, data, 0o644) require.NoError(t, err) - inner := storagemocks.NewMockSeekable(t) + inner := NewMockFramedFile(t) inner.EXPECT(). - StoreFile(mock.Anything, mock.Anything). - Return(nil) + StoreFile(mock.Anything, mock.Anything, mock.Anything). + Return(nil, nil) - featureFlags := storagemocks.NewMockFeatureFlagsClient(t) + featureFlags := NewMockFeatureFlagsClient(t) featureFlags.EXPECT().BoolFlag(mock.Anything, mock.Anything).Return(true) featureFlags.EXPECT().IntFlag(mock.Anything, mock.Anything).Return(10) - c := cachedSeekable{path: cacheDir, inner: inner, chunkSize: 1024, flags: featureFlags, tracer: noopTracer} + c := cachedFramedFile{path: cacheDir, inner: inner, chunkSize: 1024, flags: featureFlags, tracer: noopTracer} // write temp file - err = c.StoreFile(t.Context(), tempFilename) + _, err = c.StoreFile(t.Context(), tempFilename, nil) require.NoError(t, err) // file is written asynchronously, wait for it to finish @@ -94,26 +91,18 @@ func TestCachedFileObjectProvider_WriteFromFileSystem(t *testing.T) { size, err := c.Size(t.Context()) require.NoError(t, err) assert.Equal(t, int64(len(data)), size) - - // verify that the size has been cached - buff := make([]byte, len(data)) - bytesRead, err := c.ReadAt(t.Context(), buff, 0) - require.NoError(t, err) - assert.Equal(t, data, buff) - assert.Equal(t, len(data), bytesRead) }) } -func TestCachedFileObjectProvider_WriteTo(t *testing.T) { +func TestCachedFramedFile_GetFrame_Uncompressed(t *testing.T) { t.Parallel() - t.Run("read from cache when the file exists", func(t *testing.T) { + t.Run("cache hit from chunk file", func(t *testing.T) { t.Parallel() tempDir := t.TempDir() - tempPath := filepath.Join(tempDir, "a", "b", "c") - c := cachedSeekable{path: tempPath, chunkSize: 3, tracer: noopTracer} + c := cachedFramedFile{path: tempPath, chunkSize: 3, tracer: noopTracer} // create cache file cacheFilename := c.makeChunkFilename(0) @@ -124,62 +113,94 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { require.NoError(t, err) buffer := make([]byte, 3) - read, err := c.ReadAt(t.Context(), buffer, 0) + r, err := c.GetFrame(t.Context(), 0, nil, false, buffer, 0, nil) require.NoError(t, err) assert.Equal(t, []byte{1, 2, 3}, buffer) - assert.Equal(t, 3, read) + assert.Equal(t, 3, r.Length) }) - t.Run("consecutive ReadAt calls should cache", func(t *testing.T) { + t.Run("cache miss then write-back", func(t *testing.T) { t.Parallel() fakeData := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - fakeStorageObjectProvider := storagemocks.NewMockSeekable(t) - - fakeStorageObjectProvider.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, off int64) (int, error) { - start := off - end := off + int64(len(buff)) - end = min(end, int64(len(fakeData))) - copy(buff, fakeData[start:end]) - - return int(end - start), nil + inner := NewMockFramedFile(t) + inner.EXPECT(). + GetFrame(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + RunAndReturn(func(_ context.Context, offsetU int64, _ *FrameTable, _ bool, buf []byte, _ int64, onRead func(int64)) (Range, error) { + end := min(int(offsetU)+len(buf), len(fakeData)) + n := copy(buf, fakeData[offsetU:end]) + if onRead != nil { + onRead(int64(n)) + } + + return Range{Start: offsetU, Length: n}, nil }) tempDir := t.TempDir() - c := cachedSeekable{ + c := cachedFramedFile{ path: tempDir, chunkSize: 3, - inner: fakeStorageObjectProvider, + inner: inner, tracer: noopTracer, } // first read goes to source buffer := make([]byte, 3) - read, err := c.ReadAt(t.Context(), buffer, 3) + r, err := c.GetFrame(t.Context(), 3, nil, false, buffer, 0, nil) require.NoError(t, err) - assert.Equal(t, []byte{4, 5, 6}, buffer) - assert.Equal(t, 3, read) + assert.Equal(t, []byte{4, 5, 6}, buffer[:r.Length]) - // we write asynchronously, so let's wait until we're done + // wait for write-back c.wg.Wait() - // second read pulls from cache - c.inner = nil // prevent remote reads, force cache read + // second read from cache + c.inner = nil buffer = make([]byte, 3) - read, err = c.ReadAt(t.Context(), buffer, 3) + r, err = c.GetFrame(t.Context(), 3, nil, false, buffer, 0, nil) require.NoError(t, err) - assert.Equal(t, []byte{4, 5, 6}, buffer) - assert.Equal(t, 3, read) + assert.Equal(t, []byte{4, 5, 6}, buffer[:r.Length]) + }) +} + +func TestCachedFramedFile_GetFrame_Uncompressed_Validation(t *testing.T) { + t.Parallel() + + c := cachedFramedFile{path: "/tmp/test", chunkSize: 1024, tracer: noopTracer} + + t.Run("rejects empty buffer", func(t *testing.T) { + t.Parallel() + + buf := make([]byte, 0) + _, err := c.GetFrame(t.Context(), 0, nil, false, buf, 0, nil) + assert.ErrorIs(t, err, ErrBufferTooSmall) + }) + + t.Run("rejects unaligned offset", func(t *testing.T) { + t.Parallel() + + buf := make([]byte, 512) + _, err := c.GetFrame(t.Context(), 100, nil, false, buf, 0, nil) + assert.ErrorIs(t, err, ErrOffsetUnaligned) + }) + + t.Run("rejects oversized buffer", func(t *testing.T) { + t.Parallel() + + buf := make([]byte, 2048) + _, err := c.GetFrame(t.Context(), 0, nil, false, buf, 0, nil) + assert.ErrorIs(t, err, ErrBufferTooLarge) }) +} + +func TestCachedFramedFile_WriteTo(t *testing.T) { + t.Parallel() t.Run("WriteTo calls should read from cache", func(t *testing.T) { t.Parallel() fakeData := []byte{1, 2, 3} - fakeStorageObjectProvider := storagemocks.NewMockBlob(t) + fakeStorageObjectProvider := NewMockBlob(t) fakeStorageObjectProvider.EXPECT(). WriteTo(mock.Anything, mock.Anything). RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { @@ -211,76 +232,3 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { assert.Equal(t, fakeData, data) }) } - -func TestCachedFileObjectProvider_validateReadAtParams(t *testing.T) { - t.Parallel() - - testcases := map[string]struct { - chunkSize, bufferSize, offset int64 - expected error - }{ - "buffer is empty": { - chunkSize: 1, - bufferSize: 0, - offset: 0, - expected: ErrBufferTooSmall, - }, - "buffer is smaller than chunk size": { - chunkSize: 10, - bufferSize: 5, - offset: 0, - }, - "offset is unaligned": { - chunkSize: 10, - bufferSize: 10, - offset: 3, - expected: ErrOffsetUnaligned, - }, - "buffer is too large (unaligned)": { - chunkSize: 10, - bufferSize: 11, - expected: ErrBufferTooLarge, - }, - "buffer is too large (aligned)": { - chunkSize: 10, - bufferSize: 20, - expected: ErrBufferTooLarge, - }, - } - - for name, tc := range testcases { - t.Run(name, func(t *testing.T) { - t.Parallel() - - c := cachedSeekable{ - chunkSize: tc.chunkSize, - tracer: noopTracer, - } - err := c.validateReadAtParams(tc.bufferSize, tc.offset) - if tc.expected == nil { - require.NoError(t, err) - } else { - require.ErrorIs(t, err, tc.expected) - } - }) - } -} - -func TestCachedSeekableObjectProvider_ReadAt(t *testing.T) { - t.Parallel() - - t.Run("failed but returns count on short read", func(t *testing.T) { - t.Parallel() - - c := cachedSeekable{chunkSize: 10, tracer: noopTracer} - errTarget := errors.New("find me") - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT().ReadAt(mock.Anything, mock.Anything, mock.Anything).Return(5, errTarget) - c.inner = mockSeeker - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.ErrorIs(t, err, errTarget) - assert.Equal(t, 5, count) - }) -} diff --git a/packages/shared/pkg/storage/storage_fs.go b/packages/shared/pkg/storage/storage_fs.go index c02ef84948..249ad5498c 100644 --- a/packages/shared/pkg/storage/storage_fs.go +++ b/packages/shared/pkg/storage/storage_fs.go @@ -7,6 +7,8 @@ import ( "io" "os" "path/filepath" + "strconv" + "strings" "time" ) @@ -22,9 +24,8 @@ type fsObject struct { } var ( - _ Seekable = (*fsObject)(nil) - _ Blob = (*fsObject)(nil) - _ StreamingReader = (*fsObject)(nil) + _ FramedFile = (*fsObject)(nil) + _ Blob = (*fsObject)(nil) ) type fsRangeReadCloser struct { @@ -58,7 +59,7 @@ func (s *fsStorage) UploadSignedURL(_ context.Context, _ string, _ time.Duration return "", fmt.Errorf("file system storage does not support signed URLs") } -func (s *fsStorage) OpenSeekable(_ context.Context, path string, _ SeekableObjectType) (Seekable, error) { +func (s *fsStorage) OpenFramedFile(_ context.Context, path string) (FramedFile, error) { dir := filepath.Dir(s.getPath(path)) if err := os.MkdirAll(dir, 0o755); err != nil { return nil, err @@ -69,7 +70,7 @@ func (s *fsStorage) OpenSeekable(_ context.Context, path string, _ SeekableObjec }, nil } -func (s *fsStorage) OpenBlob(_ context.Context, path string, _ ObjectType) (Blob, error) { +func (s *fsStorage) OpenBlob(_ context.Context, path string) (Blob, error) { dir := filepath.Dir(s.getPath(path)) if err := os.MkdirAll(dir, 0o755); err != nil { return nil, err @@ -107,47 +108,58 @@ func (o *fsObject) Put(_ context.Context, data []byte) error { return err } -func (o *fsObject) StoreFile(_ context.Context, path string) error { +func (o *fsObject) StoreFile(ctx context.Context, path string, opts *FramedUploadOptions) (*FrameTable, error) { + if opts != nil && opts.CompressionType != CompressionNone { + return o.storeFileCompressed(ctx, path, opts) + } + r, err := os.Open(path) if err != nil { - return fmt.Errorf("failed to open file %s: %w", path, err) + return nil, fmt.Errorf("failed to open file %s: %w", path, err) } defer r.Close() handle, err := o.getHandle(false) if err != nil { - return err + return nil, err } defer handle.Close() _, err = io.Copy(handle, r) if err != nil { - return err + return nil, err } - return nil + return nil, nil } -func (o *fsObject) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - f, err := o.getHandle(true) +func (o *fsObject) storeFileCompressed(ctx context.Context, localPath string, opts *FramedUploadOptions) (*FrameTable, error) { + file, err := os.Open(localPath) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to open local file %s: %w", localPath, err) } + defer file.Close() - return &fsRangeReadCloser{ - Reader: io.NewSectionReader(f, off, length), - file: f, - }, nil + uploader := &fsPartUploader{fullPath: o.path} + + ft, err := CompressStream(ctx, file, opts, uploader) + if err != nil { + return nil, fmt.Errorf("failed to compress and upload %s: %w", localPath, err) + } + + return ft, nil } -func (o *fsObject) ReadAt(_ context.Context, buff []byte, off int64) (n int, err error) { - handle, err := o.getHandle(true) +func (o *fsObject) openRangeReader(_ context.Context, off int64, length int) (io.ReadCloser, error) { + f, err := o.getHandle(true) if err != nil { - return 0, err + return nil, err } - defer handle.Close() - return handle.ReadAt(buff, off) + return &fsRangeReadCloser{ + Reader: io.NewSectionReader(f, off, int64(length)), + file: f, + }, nil } func (o *fsObject) Exists(_ context.Context) (bool, error) { @@ -171,6 +183,14 @@ func (o *fsObject) Size(_ context.Context) (int64, error) { return 0, err } + // Check for .uncompressed-size sidecar file + sidecarPath := o.path + ".uncompressed-size" + if sidecarData, sidecarErr := os.ReadFile(sidecarPath); sidecarErr == nil { + if parsed, parseErr := strconv.ParseInt(strings.TrimSpace(string(sidecarData)), 10, 64); parseErr == nil { + return parsed, nil + } + } + return fileInfo.Size(), nil } @@ -201,3 +221,42 @@ func (o *fsObject) getHandle(checkExistence bool) (*os.File, error) { return handle, nil } + +// fsPartUploader implements PartUploader for local filesystem. +type fsPartUploader struct { + fullPath string + file *os.File +} + +func (u *fsPartUploader) Start(_ context.Context) error { + if err := os.MkdirAll(filepath.Dir(u.fullPath), 0o755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + f, err := os.OpenFile(u.fullPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + + u.file = f + + return nil +} + +func (u *fsPartUploader) UploadPart(_ context.Context, _ int, data ...[]byte) error { + for _, d := range data { + if _, err := u.file.Write(d); err != nil { + return fmt.Errorf("failed to write part: %w", err) + } + } + + return nil +} + +func (u *fsPartUploader) Complete(_ context.Context) error { + return u.file.Close() +} + +func (o *fsObject) GetFrame(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error) { + return getFrame(ctx, o.openRangeReader, "FS:"+o.path, offsetU, frameTable, decompress, buf, readSize, onRead) +} diff --git a/packages/shared/pkg/storage/storage_fs_test.go b/packages/shared/pkg/storage/storage_fs_test.go index a57b982177..732b533248 100644 --- a/packages/shared/pkg/storage/storage_fs_test.go +++ b/packages/shared/pkg/storage/storage_fs_test.go @@ -24,7 +24,7 @@ func TestOpenObject_Write_Exists_WriteTo(t *testing.T) { p := newTempProvider(t) ctx := t.Context() - obj, err := p.OpenBlob(ctx, filepath.Join("sub", "file.txt"), MetadataObjectType) + obj, err := p.OpenBlob(ctx, filepath.Join("sub", "file.txt")) require.NoError(t, err) contents := []byte("hello world") @@ -53,7 +53,7 @@ func TestFSPut(t *testing.T) { const payload = "copy me please" require.NoError(t, os.WriteFile(srcPath, []byte(payload), 0o600)) - obj, err := p.OpenBlob(ctx, "copy/dst.txt", UnknownObjectType) + obj, err := p.OpenBlob(ctx, "copy/dst.txt") require.NoError(t, err) require.NoError(t, obj.Put(t.Context(), []byte(payload))) @@ -68,7 +68,7 @@ func TestDelete(t *testing.T) { p := newTempProvider(t) ctx := t.Context() - obj, err := p.OpenBlob(ctx, "to/delete.txt", 0) + obj, err := p.OpenBlob(ctx, "to/delete.txt") require.NoError(t, err) err = obj.Put(t.Context(), []byte("bye")) @@ -98,7 +98,7 @@ func TestDeleteObjectsWithPrefix(t *testing.T) { "data/sub/c.txt", } for _, pth := range paths { - obj, err := p.OpenBlob(ctx, pth, UnknownObjectType) + obj, err := p.OpenBlob(ctx, pth) require.NoError(t, err) err = obj.Put(t.Context(), []byte("x")) require.NoError(t, err) @@ -119,7 +119,7 @@ func TestWriteToNonExistentObject(t *testing.T) { p := newTempProvider(t) ctx := t.Context() - obj, err := p.OpenBlob(ctx, "missing/file.txt", UnknownObjectType) + obj, err := p.OpenBlob(ctx, "missing/file.txt") require.NoError(t, err) _, err = GetBlob(t.Context(), obj) diff --git a/packages/shared/pkg/storage/storage_google.go b/packages/shared/pkg/storage/storage_google.go index 837e036dca..17ab2b6ee7 100644 --- a/packages/shared/pkg/storage/storage_google.go +++ b/packages/shared/pkg/storage/storage_google.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "os" + "strconv" "time" "cloud.google.com/go/storage" @@ -38,12 +39,12 @@ const ( gcloudDefaultUploadConcurrency = 16 gcsOperationAttr = "operation" - gcsOperationAttrReadAt = "ReadAt" gcsOperationAttrWrite = "Write" gcsOperationAttrWriteFromFileSystem = "WriteFromFileSystem" gcsOperationAttrWriteFromFileSystemOneShot = "WriteFromFileSystemOneShot" gcsOperationAttrWriteTo = "WriteTo" gcsOperationAttrSize = "Size" + gcsOperationAttrGetFrame = "GetFrame" ) var ( @@ -79,9 +80,8 @@ type gcpObject struct { } var ( - _ Seekable = (*gcpObject)(nil) - _ Blob = (*gcpObject)(nil) - _ StreamingReader = (*gcpObject)(nil) + _ FramedFile = (*gcpObject)(nil) + _ Blob = (*gcpObject)(nil) ) func NewGCP(ctx context.Context, bucketName string, limiter *limit.Limiter) (StorageProvider, error) { @@ -148,7 +148,7 @@ func (s *gcpStorage) UploadSignedURL(_ context.Context, path string, ttl time.Du return url, nil } -func (s *gcpStorage) OpenSeekable(_ context.Context, path string, _ SeekableObjectType) (Seekable, error) { +func (s *gcpStorage) OpenFramedFile(_ context.Context, path string) (FramedFile, error) { handle := s.bucket.Object(path).Retryer( storage.WithMaxAttempts(googleMaxAttempts), storage.WithPolicy(storage.RetryAlways), @@ -170,7 +170,7 @@ func (s *gcpStorage) OpenSeekable(_ context.Context, path string, _ SeekableObje }, nil } -func (s *gcpStorage) OpenBlob(_ context.Context, path string, _ ObjectType) (Blob, error) { +func (s *gcpStorage) OpenBlob(_ context.Context, path string) (Blob, error) { handle := s.bucket.Object(path).Retryer( storage.WithMaxAttempts(googleMaxAttempts), storage.WithPolicy(storage.RetryAlways), @@ -229,13 +229,20 @@ func (o *gcpObject) Size(ctx context.Context) (int64, error) { timer.Success(ctx, 0) + if v, ok := attrs.Metadata["uncompressed-size"]; ok { + parsed, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr == nil { + return parsed, nil + } + } + return attrs.Size, nil } -func (o *gcpObject) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { +func (o *gcpObject) openRangeReader(ctx context.Context, off int64, length int) (io.ReadCloser, error) { ctx, cancel := context.WithTimeout(ctx, googleReadTimeout) - reader, err := o.handle.NewRangeReader(ctx, off, length) + reader, err := o.handle.NewRangeReader(ctx, off, int64(length)) if err != nil { cancel() @@ -259,44 +266,6 @@ func (r *cancelOnCloseReader) Close() error { return r.ReadCloser.Close() } -func (o *gcpObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, err error) { - timer := googleReadTimerFactory.Begin(attribute.String(gcsOperationAttr, gcsOperationAttrReadAt)) - - ctx, cancel := context.WithTimeout(ctx, googleReadTimeout) - defer cancel() - - // The file should not be gzip compressed - reader, err := o.handle.NewRangeReader(ctx, off, int64(len(buff))) - if err != nil { - timer.Failure(ctx, int64(n)) - - return 0, fmt.Errorf("failed to create GCS reader for %q: %w", o.path, err) - } - - defer reader.Close() - - for reader.Remain() > 0 { - nr, err := reader.Read(buff[n:]) - n += nr - - if err == nil { - continue - } - - if errors.Is(err, io.EOF) { - break - } - - timer.Failure(ctx, int64(n)) - - return n, fmt.Errorf("failed to read %q: %w", o.path, err) - } - - timer.Success(ctx, int64(n)) - - return n, nil -} - func (o *gcpObject) Put(ctx context.Context, data []byte) (e error) { timer := googleWriteTimerFactory.Begin(attribute.String(gcsOperationAttr, gcsOperationAttrWrite)) @@ -351,7 +320,11 @@ func (o *gcpObject) WriteTo(ctx context.Context, dst io.Writer) (int64, error) { return n, nil } -func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { +func (o *gcpObject) StoreFile(ctx context.Context, path string, opts *FramedUploadOptions) (_ *FrameTable, e error) { + if opts != nil && opts.CompressionType != CompressionNone { + return o.storeFileCompressed(ctx, path, opts) + } + ctx, span := tracer.Start(ctx, "write to gcp from file system") defer func() { recordError(span, e) @@ -363,7 +336,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { fileInfo, err := os.Stat(path) if err != nil { - return fmt.Errorf("failed to get file size: %w", err) + return nil, fmt.Errorf("failed to get file size: %w", err) } // If the file is too small, the overhead of writing in parallel isn't worth the effort. @@ -377,19 +350,19 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { if err != nil { timer.Failure(ctx, 0) - return fmt.Errorf("failed to read file: %w", err) + return nil, fmt.Errorf("failed to read file: %w", err) } err = o.Put(ctx, data) if err != nil { timer.Failure(ctx, int64(len(data))) - return fmt.Errorf("failed to write file (%d bytes): %w", len(data), err) + return nil, fmt.Errorf("failed to write file (%d bytes): %w", len(data), err) } timer.Success(ctx, int64(len(data))) - return nil + return nil, nil } timer := googleWriteTimerFactory.Begin( @@ -404,7 +377,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { if semaphoreErr != nil { timer.Failure(ctx, 0) - return fmt.Errorf("failed to acquire semaphore: %w", semaphoreErr) + return nil, fmt.Errorf("failed to acquire semaphore: %w", semaphoreErr) } defer uploadLimiter.Release(1) } @@ -421,7 +394,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { if err != nil { timer.Failure(ctx, 0) - return fmt.Errorf("failed to create multipart uploader: %w", err) + return nil, fmt.Errorf("failed to create multipart uploader: %w", err) } start := time.Now() @@ -429,7 +402,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { if err != nil { timer.Failure(ctx, count) - return fmt.Errorf("failed to upload file in parallel: %w", err) + return nil, fmt.Errorf("failed to upload file in parallel: %w", err) } logger.L().Debug(ctx, "Uploaded file in parallel", @@ -443,7 +416,32 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { timer.Success(ctx, count) - return nil + return nil, nil +} + +func (o *gcpObject) storeFileCompressed(ctx context.Context, localPath string, opts *FramedUploadOptions) (*FrameTable, error) { + file, err := os.Open(localPath) + if err != nil { + return nil, fmt.Errorf("failed to open local file %s: %w", localPath, err) + } + defer file.Close() + + uploader, err := NewMultipartUploaderWithRetryConfig( + ctx, + o.storage.bucket.BucketName(), + o.path, + DefaultRetryConfig(), + ) + if err != nil { + return nil, fmt.Errorf("failed to create multipart uploader: %w", err) + } + + ft, err := CompressStream(ctx, file, opts, uploader) + if err != nil { + return nil, fmt.Errorf("failed to compress and upload %s: %w", localPath, err) + } + + return ft, nil } type gcpServiceToken struct { @@ -464,3 +462,18 @@ func parseServiceAccountBase64(serviceAccount string) (*gcpServiceToken, error) return &sa, nil } + +func (o *gcpObject) GetFrame(ctx context.Context, offsetU int64, frameTable *FrameTable, decompress bool, buf []byte, readSize int64, onRead func(totalWritten int64)) (Range, error) { + timer := googleReadTimerFactory.Begin(attribute.String(gcsOperationAttr, gcsOperationAttrGetFrame)) + + r, err := getFrame(ctx, o.openRangeReader, "GCS:"+o.path, offsetU, frameTable, decompress, buf, readSize, onRead) + if err != nil { + timer.Failure(ctx, int64(r.Length)) + + return r, err + } + + timer.Success(ctx, int64(r.Length)) + + return r, nil +} diff --git a/packages/shared/pkg/storage/template.go b/packages/shared/pkg/storage/template.go index bdd03fff01..47ab615c46 100644 --- a/packages/shared/pkg/storage/template.go +++ b/packages/shared/pkg/storage/template.go @@ -13,6 +13,13 @@ const ( MetadataName = "metadata.json" HeaderSuffix = ".header" + + // v4Prefix is prepended to the base filename for all v4 compressed assets. + v4Prefix = "v4." + + // v4HeaderSuffix is the suffix after the base filename for v4 headers. + // V4 headers are always LZ4-block-compressed. + v4HeaderSuffix = ".header.lz4" ) type TemplateFiles struct { @@ -51,3 +58,49 @@ func (t TemplateFiles) StorageSnapfilePath() string { func (t TemplateFiles) StorageMetadataPath() string { return fmt.Sprintf("%s/%s", t.StorageDir(), MetadataName) } + +// HeaderPath returns the header storage path for a given file name within this build. +func (t TemplateFiles) HeaderPath(fileName string) string { + return fmt.Sprintf("%s/%s%s", t.StorageDir(), fileName, HeaderSuffix) +} + +// V4DataName returns the v4 data filename: "v4.memfile.lz4". +func V4DataName(fileName string, ct CompressionType) string { + return v4Prefix + fileName + ct.Suffix() +} + +// V4HeaderName returns the v4 header filename: "v4.memfile.header.lz4". +func V4HeaderName(fileName string) string { + return v4Prefix + fileName + v4HeaderSuffix +} + +// V4DataPath transforms a base object path (e.g. "buildId/memfile") into +// the v4 compressed data path (e.g. "buildId/v4.memfile.lz4"). +func V4DataPath(basePath string, ct CompressionType) string { + dir, file := splitPath(basePath) + + return dir + V4DataName(file, ct) +} + +// splitPath splits "dir/file" into ("dir/", "file"). If there's no slash, +// dir is empty. +func splitPath(p string) (dir, file string) { + for i := len(p) - 1; i >= 0; i-- { + if p[i] == '/' { + return p[:i+1], p[i+1:] + } + } + + return "", p +} + +// CompressedDataPath returns the v4 compressed data path for a given file name. +// Example: "{buildId}/v4.memfile.lz4" +func (t TemplateFiles) CompressedDataPath(fileName string, ct CompressionType) string { + return fmt.Sprintf("%s/%s", t.StorageDir(), V4DataName(fileName, ct)) +} + +// CompressedHeaderPath returns the v4 header path: "{buildId}/v4.{fileName}.header.lz4". +func (t TemplateFiles) CompressedHeaderPath(fileName string) string { + return fmt.Sprintf("%s/%s", t.StorageDir(), V4HeaderName(fileName)) +} diff --git a/packages/shared/pkg/telemetry/meters.go b/packages/shared/pkg/telemetry/meters.go index 1f372b9ce8..9726ce6cd5 100644 --- a/packages/shared/pkg/telemetry/meters.go +++ b/packages/shared/pkg/telemetry/meters.go @@ -366,6 +366,12 @@ const ( resultTypeFailure = "failure" ) +var ( + // pre-allocated + Success = attribute.String(resultAttr, resultTypeSuccess) + Failure = attribute.String(resultAttr, resultTypeFailure) +) + func (t Stopwatch) Success(ctx context.Context, total int64, kv ...attribute.KeyValue) { t.end(ctx, resultTypeSuccess, total, kv...) } @@ -379,7 +385,24 @@ func (t Stopwatch) end(ctx context.Context, result string, total int64, kv ...at kv = append(t.kv, kv...) amount := time.Since(t.start).Milliseconds() - t.histogram.Record(ctx, amount, metric.WithAttributes(kv...)) - t.sum.Add(ctx, total, metric.WithAttributes(kv...)) - t.count.Add(ctx, 1, metric.WithAttributes(kv...)) + opt := metric.WithAttributeSet(attribute.NewSet(kv...)) + t.histogram.Record(ctx, amount, opt) + t.sum.Add(ctx, total, opt) + t.count.Add(ctx, 1, opt) +} + +// PrecomputeAttrs builds a reusable MeasurementOption from the given attribute +// key-values. The option must include all attributes (including "result"). +// Use with Stopwatch.Record to avoid per-call attribute allocation. +func PrecomputeAttrs(kv ...attribute.KeyValue) metric.MeasurementOption { + return metric.WithAttributeSet(attribute.NewSet(kv...)) +} + +// FastOK records an operation using a precomputed attribute +// option. Zero-allocation alternative to Success for hot paths. +func (t Stopwatch) Record(ctx context.Context, total int64, precomputedAttrs metric.MeasurementOption) { + amount := time.Since(t.start).Milliseconds() + t.histogram.Record(ctx, amount, precomputedAttrs) + t.sum.Add(ctx, total, precomputedAttrs) + t.count.Add(ctx, 1, precomputedAttrs) } diff --git a/tests/integration/Makefile b/tests/integration/Makefile index 00349fcfd4..1f2495378a 100644 --- a/tests/integration/Makefile +++ b/tests/integration/Makefile @@ -40,9 +40,9 @@ test/%: *.go:*) \ BASE=$${TEST_PATH%%:*}; \ TEST_FN=$${TEST_PATH#*:}; \ - go tool gotestsum --rerun-fails=1 --packages="$$BASE" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -run "$${TEST_FN}" ;; \ - *.go) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 ;; \ - *) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH/..." --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 ;; \ + go tool gotestsum --rerun-fails=1 --packages="$$BASE" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=2 -run "$${TEST_FN}" ;; \ + *.go) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=2 ;; \ + *) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH/..." --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=2 ;; \ esac .PHONY: connect-orchestrator