Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pkg/cmd/roachtest/tests/vecindex.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ func runVectorIndex(ctx context.Context, t test.Test, c cluster.Cluster, opts ve
t.L().Printf("Loading dataset %s", opts.dataset)
loader := vecann.DatasetLoader{
DatasetName: opts.dataset,
ResetCache: true,
OnProgress: func(ctx context.Context, format string, args ...any) {
t.L().Printf(format, args...)
},
Expand Down
96 changes: 57 additions & 39 deletions pkg/workload/vecann/datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"

Expand Down Expand Up @@ -90,8 +91,6 @@ type DatasetLoader struct {
// CacheFolder is the path to the temporary folder where datasets will be
// cached. It defaults to ~/.cache/workload-datasets.
CacheFolder string
// ResetCache indicates that the cache should be re-populated.
ResetCache bool

// OnProgress logs the progress of the loading process.
OnProgress func(ctx context.Context, format string, args ...any)
Expand Down Expand Up @@ -130,12 +129,12 @@ func (dl *DatasetLoader) loadFiles(ctx context.Context) error {
neighbors := fmt.Sprintf("%s/%s-neighbors-%s.ibin", baseDir, baseName, metric)

// Download test and neighbors files if missing.
if dl.ResetCache || !fileExists(test) {
if !fileExists(test) {
if err := dl.downloadAndUnzip(ctx, baseName, baseName+"-test.fbin.zip", test); err != nil {
return err
}
}
if dl.ResetCache || !fileExists(neighbors) {
if !fileExists(neighbors) {
fileName := baseName + "-neighbors-" + metric + ".ibin.zip"
if err := dl.downloadAndUnzip(ctx, baseName, fileName, neighbors); err != nil {
return err
Expand Down Expand Up @@ -181,7 +180,7 @@ func (dl *DatasetLoader) downloadTrainFiles(
// First, check for files in the cache.
onlyFileName := fmt.Sprintf("%s/%s.fbin", baseDir, baseName)
firstPartName := fmt.Sprintf("%s/%s-1.fbin", baseDir, baseName)
if dl.ResetCache || (!fileExists(onlyFileName) && !fileExists(firstPartName)) {
if !fileExists(onlyFileName) && !fileExists(firstPartName) {
// No files in cache, download them.
partNum := 0
for {
Expand Down Expand Up @@ -253,23 +252,21 @@ func (dl *DatasetLoader) downloadTrainFiles(
}

// downloadAndUnzip downloads a zip file from GCP and extracts the contained
// file to destPath.
// file to destPath. All intermediate files use process-unique temp paths so
// that concurrent processes downloading the same dataset do not corrupt each
// other's files. The final extraction is installed via atomic rename, ensuring
// that destPath is either absent or contains a complete file.
func (dl *DatasetLoader) downloadAndUnzip(
ctx context.Context, baseName, objectFile, destPath string,
) (err error) {
) error {
objectName := fmt.Sprintf("%s/%s/%s", bucketDirName, baseName, objectFile)
tempZipFile := destPath + ".zip"
defer func() {
err = errors.CombineErrors(err, os.Remove(tempZipFile))
}()
destDir := filepath.Dir(destPath)

client, err := storage.NewClient(ctx)
if err != nil {
return errors.Wrapf(err, "creating GCS client")
}
defer func() {
err = errors.CombineErrors(err, client.Close())
}()
defer func() { _ = client.Close() }()

bucket := client.Bucket(bucketName)
object := bucket.Object(objectName)
Expand All @@ -281,53 +278,74 @@ func (dl *DatasetLoader) downloadAndUnzip(
// Only report progress once we know file exists.
dl.OnProgress(ctx, "Downloading %s from %s", objectName, bucketName)

tempZip, err := os.Create(tempZipFile)
// Download to a unique temp file to avoid races with concurrent downloaders
// sharing the same cache directory.
tempZip, err := os.CreateTemp(destDir, ".dl-*.zip")
if err != nil {
return errors.Wrapf(err, "creating temp zip file %s", tempZipFile)
return errors.Wrapf(err, "creating temp zip file in %s", destDir)
}
defer func() {
err = errors.CombineErrors(err, tempZip.Close())
}()
tempZipPath := tempZip.Name()
defer func() { _ = os.Remove(tempZipPath) }()

reader, err := object.NewReader(ctx)
if err != nil {
_ = tempZip.Close()
return errors.Wrapf(err, "creating reader for %s/%s", bucketName, objectName)
}
defer func() {
err = errors.CombineErrors(err, reader.Close())
}()

writer := makeProgressWriter(tempZip, attrs.Size)
writer.OnProgress = dl.OnDownloadProgress
if _, err := io.Copy(&writer, reader); err != nil {
return errors.Wrapf(err, "downloading to file %s", tempZipFile)
}

// Unzip the file
zipR, err := zip.OpenReader(tempZipFile)
_, copyErr := io.Copy(&writer, reader)
// reader is a GCS read stream; any data errors are already surfaced by
// io.Copy, and Close just releases the HTTP connection.
_ = reader.Close()
// tempZip.Close flushes buffered writes to disk, so a failure here means
// the zip file may be incomplete.
closeErr := tempZip.Close()
if copyErr != nil {
return errors.Wrapf(copyErr, "downloading %s", objectName)
}
if closeErr != nil {
return errors.Wrapf(closeErr, "closing temp zip file %s", tempZipPath)
}

// Unzip the downloaded file.
zipR, err := zip.OpenReader(tempZipPath)
if err != nil {
return errors.Wrapf(err, "opening zip file %s", tempZipFile)
return errors.Wrapf(err, "opening zip file %s", tempZipPath)
}
defer func() {
err = errors.CombineErrors(err, zipR.Close())
}()
defer func() { _ = zipR.Close() }()

if len(zipR.File) == 0 {
return errors.Newf("zip file %s is empty", tempZipFile)
return errors.Newf("zip file %s is empty", tempZipPath)
}
zfile := zipR.File[0]
zreader, err := zfile.Open()
if err != nil {
return errors.Wrapf(err, "opening zipped file %s", zfile.Name)
}
defer zreader.Close()
out, err := os.Create(destPath)
defer func() { _ = zreader.Close() }()

// Extract to a unique temp file, then atomically rename to the destination.
// This ensures the cached file is either absent or complete, never a
// truncated partial write.
tempOut, err := os.CreateTemp(destDir, ".extract-*.tmp")
if err != nil {
return errors.Wrapf(err, "creating output file %s", destPath)
return errors.Wrapf(err, "creating temp output file in %s", destDir)
}
defer out.Close()
if _, err := io.Copy(out, zreader); err != nil {
return errors.Wrapf(err, "extracting to %s", destPath)
tempOutPath := tempOut.Name()
defer func() { _ = os.Remove(tempOutPath) }()

if _, err := io.Copy(tempOut, zreader); err != nil {
_ = tempOut.Close()
return errors.Wrapf(err, "extracting %s", zfile.Name)
}
if err := tempOut.Close(); err != nil {
return errors.Wrapf(err, "closing temp output file %s", tempOutPath)
}

if err := os.Rename(tempOutPath, destPath); err != nil {
return errors.Wrapf(err, "moving extracted file to %s", destPath)
}
return nil
}
Expand Down