diff --git a/lingvo/core/BUILD b/lingvo/core/BUILD index 718edb9f1..a77cc4d03 100644 --- a/lingvo/core/BUILD +++ b/lingvo/core/BUILD @@ -2383,6 +2383,16 @@ py_library( ], ) +py_test( + name = "program_utils_test", + srcs = ["program_utils_test.py"], + deps = [ + ":program_utils", + ":test_utils", + "//lingvo:compat", + ], +) + py_library( name = "program_lib", srcs = ["program.py"], diff --git a/lingvo/core/program_utils.py b/lingvo/core/program_utils.py index 5bc2d47b3..abb320a08 100644 --- a/lingvo/core/program_utils.py +++ b/lingvo/core/program_utils.py @@ -67,23 +67,34 @@ def __init__(self, program_dir): if content: self.ckpt_key = content[0] if len(content) > 1: - self.decoded_datasets = content[1:] + for dataset_name in content[1:]: + self._AddDecodedDataset(dataset_name) + + def _AddDecodedDataset(self, dataset_name): + if dataset_name and dataset_name not in self.decoded_datasets: + self.decoded_datasets.append(dataset_name) + + def _WriteStatusFile(self): + content = '\n'.join([self.ckpt_key] + self.decoded_datasets) + '\n' + tmp_status_file = self.status_file + '.tmp' + with tf.io.gfile.GFile(tmp_status_file, 'w') as f: + f.write(content) + tf.io.gfile.rename(tmp_status_file, self.status_file, overwrite=True) def UpdateCkpt(self, ckpt_key): """Update checkpoint key in the status.""" if ckpt_key != self.ckpt_key: self.ckpt_key = ckpt_key self.decoded_datasets = [] - with tf.io.gfile.GFile(self.status_file, 'w') as f: - f.write(self.ckpt_key) + self._WriteStatusFile() def UpdateDataset(self, dataset_name, summaries): """Update decoded dataset in the status.""" cache_file = os.path.join(self.cache_dir, f'{dataset_name}.csv') with tf.io.gfile.GFile(cache_file, 'w') as f: f.write(SummaryToCsv(summaries)) - with tf.io.gfile.GFile(self.status_file, 'w+') as f: - f.write(f.read().strip() + '\n' + dataset_name) + self._AddDecodedDataset(dataset_name) + self._WriteStatusFile() def TryLoadCache(self, ckpt_key, dataset_name): """Try load summary cache for ckpt_key, dataset_name. @@ -102,8 +113,6 @@ def TryLoadCache(self, ckpt_key, dataset_name): return None with tf.io.gfile.GFile(cache_file, 'r') as f: summaries = CsvToSummary(f.read()) - with tf.io.gfile.GFile(self.status_file, 'w+') as f: - f.write(f.read().strip() + '\n' + dataset_name) return summaries return None diff --git a/lingvo/core/program_utils_test.py b/lingvo/core/program_utils_test.py new file mode 100644 index 000000000..f2072a7f6 --- /dev/null +++ b/lingvo/core/program_utils_test.py @@ -0,0 +1,57 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for program_utils.""" + +import os + +import lingvo.compat as tf +from lingvo.core import program_utils +from lingvo.core import test_utils + + +class DecodeStatusCacheTest(test_utils.TestCase): + + def _Summary(self, tag, value): + return tf.Summary( + value=[tf.Summary.Value(tag=tag, simple_value=value)]) + + def testUpdateDatasetPreservesCheckpointAndDecodedDatasets(self): + program_dir = self.create_tempdir().full_path + cache = program_utils.DecodeStatusCache(program_dir) + + cache.UpdateCkpt('ckpt-123') + cache.UpdateDataset('Dev', {'acc': self._Summary('acc', 0.5)}) + cache.UpdateDataset('Test', {'loss': self._Summary('loss', 1.25)}) + cache.UpdateDataset('Dev', {'acc': self._Summary('acc', 0.75)}) + + reloaded_cache = program_utils.DecodeStatusCache(program_dir) + + self.assertEqual('ckpt-123', reloaded_cache.ckpt_key) + self.assertEqual(['Dev', 'Test'], reloaded_cache.decoded_datasets) + + summaries = reloaded_cache.TryLoadCache('ckpt-123', 'Dev') + self.assertIsNotNone(summaries) + self.assertIn('acc', summaries) + self.assertAlmostEqual(0.75, summaries['acc'].value[0].simple_value) + self.assertEqual(['Dev', 'Test'], reloaded_cache.decoded_datasets) + + status_file = os.path.join(program_dir, 'decoded_datasets.txt') + with tf.io.gfile.GFile(status_file, 'r') as f: + self.assertEqual(['ckpt-123', 'Dev', 'Test'], + [line.strip() for line in f.readlines()]) + + +if __name__ == '__main__': + test_utils.main()