diff --git a/pkg/cmd/roachtest/tests/vecindex.go b/pkg/cmd/roachtest/tests/vecindex.go index 3403d2d97eea..50b12a57426b 100644 --- a/pkg/cmd/roachtest/tests/vecindex.go +++ b/pkg/cmd/roachtest/tests/vecindex.go @@ -253,7 +253,6 @@ func runVectorIndex(ctx context.Context, t test.Test, c cluster.Cluster, opts ve t.L().Printf("Loading dataset %s", opts.dataset) loader := vecann.DatasetLoader{ DatasetName: opts.dataset, - ResetCache: true, OnProgress: func(ctx context.Context, format string, args ...any) { t.L().Printf(format, args...) }, diff --git a/pkg/workload/vecann/datasets.go b/pkg/workload/vecann/datasets.go index ae0d2a78462b..2994e598470e 100644 --- a/pkg/workload/vecann/datasets.go +++ b/pkg/workload/vecann/datasets.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "os" + "path/filepath" "strings" "time" @@ -90,8 +91,6 @@ type DatasetLoader struct { // CacheFolder is the path to the temporary folder where datasets will be // cached. It defaults to ~/.cache/workload-datasets. CacheFolder string - // ResetCache indicates that the cache should be re-populated. - ResetCache bool // OnProgress logs the progress of the loading process. OnProgress func(ctx context.Context, format string, args ...any) @@ -130,12 +129,12 @@ func (dl *DatasetLoader) loadFiles(ctx context.Context) error { neighbors := fmt.Sprintf("%s/%s-neighbors-%s.ibin", baseDir, baseName, metric) // Download test and neighbors files if missing. - if dl.ResetCache || !fileExists(test) { + if !fileExists(test) { if err := dl.downloadAndUnzip(ctx, baseName, baseName+"-test.fbin.zip", test); err != nil { return err } } - if dl.ResetCache || !fileExists(neighbors) { + if !fileExists(neighbors) { fileName := baseName + "-neighbors-" + metric + ".ibin.zip" if err := dl.downloadAndUnzip(ctx, baseName, fileName, neighbors); err != nil { return err @@ -181,7 +180,7 @@ func (dl *DatasetLoader) downloadTrainFiles( // First, check for files in the cache. onlyFileName := fmt.Sprintf("%s/%s.fbin", baseDir, baseName) firstPartName := fmt.Sprintf("%s/%s-1.fbin", baseDir, baseName) - if dl.ResetCache || (!fileExists(onlyFileName) && !fileExists(firstPartName)) { + if !fileExists(onlyFileName) && !fileExists(firstPartName) { // No files in cache, download them. partNum := 0 for { @@ -253,23 +252,21 @@ func (dl *DatasetLoader) downloadTrainFiles( } // downloadAndUnzip downloads a zip file from GCP and extracts the contained -// file to destPath. +// file to destPath. All intermediate files use process-unique temp paths so +// that concurrent processes downloading the same dataset do not corrupt each +// other's files. The final extraction is installed via atomic rename, ensuring +// that destPath is either absent or contains a complete file. func (dl *DatasetLoader) downloadAndUnzip( ctx context.Context, baseName, objectFile, destPath string, -) (err error) { +) error { objectName := fmt.Sprintf("%s/%s/%s", bucketDirName, baseName, objectFile) - tempZipFile := destPath + ".zip" - defer func() { - err = errors.CombineErrors(err, os.Remove(tempZipFile)) - }() + destDir := filepath.Dir(destPath) client, err := storage.NewClient(ctx) if err != nil { return errors.Wrapf(err, "creating GCS client") } - defer func() { - err = errors.CombineErrors(err, client.Close()) - }() + defer func() { _ = client.Close() }() bucket := client.Bucket(bucketName) object := bucket.Object(objectName) @@ -281,53 +278,74 @@ func (dl *DatasetLoader) downloadAndUnzip( // Only report progress once we know file exists. dl.OnProgress(ctx, "Downloading %s from %s", objectName, bucketName) - tempZip, err := os.Create(tempZipFile) + // Download to a unique temp file to avoid races with concurrent downloaders + // sharing the same cache directory. + tempZip, err := os.CreateTemp(destDir, ".dl-*.zip") if err != nil { - return errors.Wrapf(err, "creating temp zip file %s", tempZipFile) + return errors.Wrapf(err, "creating temp zip file in %s", destDir) } - defer func() { - err = errors.CombineErrors(err, tempZip.Close()) - }() + tempZipPath := tempZip.Name() + defer func() { _ = os.Remove(tempZipPath) }() reader, err := object.NewReader(ctx) if err != nil { + _ = tempZip.Close() return errors.Wrapf(err, "creating reader for %s/%s", bucketName, objectName) } - defer func() { - err = errors.CombineErrors(err, reader.Close()) - }() writer := makeProgressWriter(tempZip, attrs.Size) writer.OnProgress = dl.OnDownloadProgress - if _, err := io.Copy(&writer, reader); err != nil { - return errors.Wrapf(err, "downloading to file %s", tempZipFile) - } - - // Unzip the file - zipR, err := zip.OpenReader(tempZipFile) + _, copyErr := io.Copy(&writer, reader) + // reader is a GCS read stream; any data errors are already surfaced by + // io.Copy, and Close just releases the HTTP connection. + _ = reader.Close() + // tempZip.Close flushes buffered writes to disk, so a failure here means + // the zip file may be incomplete. + closeErr := tempZip.Close() + if copyErr != nil { + return errors.Wrapf(copyErr, "downloading %s", objectName) + } + if closeErr != nil { + return errors.Wrapf(closeErr, "closing temp zip file %s", tempZipPath) + } + + // Unzip the downloaded file. + zipR, err := zip.OpenReader(tempZipPath) if err != nil { - return errors.Wrapf(err, "opening zip file %s", tempZipFile) + return errors.Wrapf(err, "opening zip file %s", tempZipPath) } - defer func() { - err = errors.CombineErrors(err, zipR.Close()) - }() + defer func() { _ = zipR.Close() }() if len(zipR.File) == 0 { - return errors.Newf("zip file %s is empty", tempZipFile) + return errors.Newf("zip file %s is empty", tempZipPath) } zfile := zipR.File[0] zreader, err := zfile.Open() if err != nil { return errors.Wrapf(err, "opening zipped file %s", zfile.Name) } - defer zreader.Close() - out, err := os.Create(destPath) + defer func() { _ = zreader.Close() }() + + // Extract to a unique temp file, then atomically rename to the destination. + // This ensures the cached file is either absent or complete, never a + // truncated partial write. + tempOut, err := os.CreateTemp(destDir, ".extract-*.tmp") if err != nil { - return errors.Wrapf(err, "creating output file %s", destPath) + return errors.Wrapf(err, "creating temp output file in %s", destDir) } - defer out.Close() - if _, err := io.Copy(out, zreader); err != nil { - return errors.Wrapf(err, "extracting to %s", destPath) + tempOutPath := tempOut.Name() + defer func() { _ = os.Remove(tempOutPath) }() + + if _, err := io.Copy(tempOut, zreader); err != nil { + _ = tempOut.Close() + return errors.Wrapf(err, "extracting %s", zfile.Name) + } + if err := tempOut.Close(); err != nil { + return errors.Wrapf(err, "closing temp output file %s", tempOutPath) + } + + if err := os.Rename(tempOutPath, destPath); err != nil { + return errors.Wrapf(err, "moving extracted file to %s", destPath) } return nil }