Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion lingvo/core/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ def scale_global_to_infeed(global_batch_size, use_per_host_infeed):
if not py_utils.use_tpu():
raise ValueError('Scaling to TPU hosts without TPUs. {}'.format(
cluster.num_tpu_hosts))
return global_batch_size // cluster.num_tpu_hosts
infeed_batch_size, remainder = divmod(global_batch_size,
cluster.num_tpu_hosts)
if remainder:
raise ValueError(
f'global_batch_size {global_batch_size} did not divide evenly by '
f'{cluster.num_tpu_hosts} TPU hosts.')
return infeed_batch_size
else:
return global_batch_size

Expand Down
21 changes: 21 additions & 0 deletions lingvo/core/batch_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ def testScaleInfeedToGlobalTPU(self, use_per_host_infeed, num_tpu_hosts):
batch_utils.scale_infeed_to_global(1024, use_per_host_infeed),
1024 * num_infeeds)

@parameterized.parameters(
(False, 16, 4, 16),
(True, 16, 4, 4),
(True, 10, 1, 10),
)
def testScaleGlobalToInfeed(self, use_per_host_infeed, global_batch_size,
num_tpu_hosts, expected_infeed_batch_size):
with flagsaver.flagsaver(xla_device='tpu', enable_asserts=False):
with cluster_factory.ForTestingWorker(
tpus=128, num_tpu_hosts=num_tpu_hosts):
self.assertEqual(
batch_utils.scale_global_to_infeed(global_batch_size,
use_per_host_infeed),
expected_infeed_batch_size)

def testScaleGlobalToInfeedRejectsNonDivisiblePerHostBatch(self):
with flagsaver.flagsaver(xla_device='tpu', enable_asserts=False):
with cluster_factory.ForTestingWorker(tpus=128, num_tpu_hosts=4):
with self.assertRaisesRegex(ValueError, 'did not divide evenly'):
batch_utils.scale_global_to_infeed(10, use_per_host_infeed=True)

@parameterized.parameters(
itertools.product(
(False, True), # use_per_host_infeed
Expand Down