From 966fb5eee802aa0e32a22047113a4be3eacf90c2 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sat, 20 Jun 2026 23:47:54 +0200 Subject: [PATCH] Validate per-host infeed batch scaling --- lingvo/core/batch_utils.py | 8 +++++++- lingvo/core/batch_utils_test.py | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/lingvo/core/batch_utils.py b/lingvo/core/batch_utils.py index 1842e3dd8..29646d49a 100644 --- a/lingvo/core/batch_utils.py +++ b/lingvo/core/batch_utils.py @@ -55,7 +55,13 @@ def scale_global_to_infeed(global_batch_size, use_per_host_infeed): if not py_utils.use_tpu(): raise ValueError('Scaling to TPU hosts without TPUs. {}'.format( cluster.num_tpu_hosts)) - return global_batch_size // cluster.num_tpu_hosts + infeed_batch_size, remainder = divmod(global_batch_size, + cluster.num_tpu_hosts) + if remainder: + raise ValueError( + f'global_batch_size {global_batch_size} did not divide evenly by ' + f'{cluster.num_tpu_hosts} TPU hosts.') + return infeed_batch_size else: return global_batch_size diff --git a/lingvo/core/batch_utils_test.py b/lingvo/core/batch_utils_test.py index 3fb594d07..1721ac7a8 100644 --- a/lingvo/core/batch_utils_test.py +++ b/lingvo/core/batch_utils_test.py @@ -52,6 +52,27 @@ def testScaleInfeedToGlobalTPU(self, use_per_host_infeed, num_tpu_hosts): batch_utils.scale_infeed_to_global(1024, use_per_host_infeed), 1024 * num_infeeds) + @parameterized.parameters( + (False, 16, 4, 16), + (True, 16, 4, 4), + (True, 10, 1, 10), + ) + def testScaleGlobalToInfeed(self, use_per_host_infeed, global_batch_size, + num_tpu_hosts, expected_infeed_batch_size): + with flagsaver.flagsaver(xla_device='tpu', enable_asserts=False): + with cluster_factory.ForTestingWorker( + tpus=128, num_tpu_hosts=num_tpu_hosts): + self.assertEqual( + batch_utils.scale_global_to_infeed(global_batch_size, + use_per_host_infeed), + expected_infeed_batch_size) + + def testScaleGlobalToInfeedRejectsNonDivisiblePerHostBatch(self): + with flagsaver.flagsaver(xla_device='tpu', enable_asserts=False): + with cluster_factory.ForTestingWorker(tpus=128, num_tpu_hosts=4): + with self.assertRaisesRegex(ValueError, 'did not divide evenly'): + batch_utils.scale_global_to_infeed(10, use_per_host_infeed=True) + @parameterized.parameters( itertools.product( (False, True), # use_per_host_infeed