diff --git a/init2winit/dataset_lib/criteo_terabyte_dataset.py b/init2winit/dataset_lib/criteo_terabyte_dataset.py index 51513917..e24d566c 100644 --- a/init2winit/dataset_lib/criteo_terabyte_dataset.py +++ b/init2winit/dataset_lib/criteo_terabyte_dataset.py @@ -34,9 +34,9 @@ import tensorflow as tf import tensorflow_datasets as tfds - # Change to the path to raw dataset files. RAW_CRITEO1TB_FILE_PATH = '' +PREPROCESSED_CRITEO1TB_FILE_PATH = '' # pylint: disable=invalid-name CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( input_shape=(13 + 26,), @@ -157,6 +157,94 @@ def criteo_tsv_reader( return ds +_ARRAYRECORD_FEATURE_SPEC = { + 'inputs': tf.io.FixedLenFeature([13 + 26], tf.float32), + 'targets': tf.io.FixedLenFeature([1], tf.float32), +} + + +@tf.function +def _parse_arrayrecord_example_fn(serialized_examples): + """Parse a batch of serialized tf.train.Examples from ArrayRecord.""" + parsed = tf.io.parse_example(serialized_examples, _ARRAYRECORD_FEATURE_SPEC) + return { + 'inputs': parsed['inputs'], + 'targets': tf.squeeze(parsed['targets'], axis=-1), + } + + +def criteo_arrayrecord_reader( + split, shuffle_rng, file_path, batch_size, num_batches_to_prefetch +): + """Input reader for preprocessed Criteo ArrayRecord data. + + Args: + split: one of {'train', 'eval_train', 'validation', 'test'}. + shuffle_rng: jax.random.PRNGKey for shuffling (train). + file_path: glob pattern for .array_record files. + batch_size: per-host batch size. + num_batches_to_prefetch: number of batches to prefetch. + + Returns: + A tf.data.Dataset object. + """ + # Import here to avoid hard dependency for TSV-only users. + if split not in ['train', 'eval_train', 'validation', 'test']: + raise ValueError(f'Invalid split name {split}.') + data_shuffle_seed = None + + is_training = split == 'train' + if is_training: + _, data_shuffle_seed = jax.random.split(shuffle_rng, 2) + data_shuffle_seed = data_utils.convert_jax_to_tf_random_seed( + data_shuffle_seed + ) + + # Discover all matching files. + all_files = sorted(tf.io.gfile.glob(file_path)) + if not all_files: + raise ValueError(f'No ArrayRecord files found matching: {file_path}') + + # Shard files across hosts. + index = jax.process_index() + num_hosts = jax.process_count() + host_files = all_files[index::num_hosts] + + # Interleave per-file datasets, with batch+parse inside each file's + # sub-pipeline. This is critical for performance: interleaving dense float + # tensors (post-parse) is much faster than interleaving raw byte strings + # and batching them later. + file_ds = tf.data.Dataset.from_tensor_slices(host_files) + if is_training: + file_ds = file_ds.repeat() + file_ds = file_ds.shuffle( + buffer_size=2 * len(host_files), seed=data_shuffle_seed + ) + + ds = file_ds.interleave( + lambda f: ( + ar_dataset.ArrayRecordDataset([f]) + .batch( + batch_size, + drop_remainder=is_training, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + ) + .map( + _parse_arrayrecord_example_fn, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + ) + ), + cycle_length=64, + block_length=batch_size // 8, + num_parallel_calls=64, + deterministic=False, + ) + ds = ds.prefetch(num_batches_to_prefetch) + return ds + + def _eval_numpy_iterator( num_batches, per_host_eval_batch_size, tf_dataset, split_size ): @@ -222,14 +310,145 @@ def get_criteo1tb(shuffle_rng, per_host_eval_batch_size = eval_batch_size // process_count per_host_batch_size = batch_size // process_count + use_raw_tsv = hps.get('use_raw_tsv', False) + num_batches_to_prefetch = ( + hps.num_tf_data_prefetches + if hps.num_tf_data_prefetches > 0 + else tf.data.AUTOTUNE + ) + num_device_prefetches = hps.get('num_device_prefetches', 0) + + if use_raw_tsv: + return _get_criteo1tb_tsv( + shuffle_rng, + per_host_batch_size, + per_host_eval_batch_size, + hps, + num_batches_to_prefetch, + num_device_prefetches, + ) + else: + return _get_criteo1tb_arrayrecord( + shuffle_rng, + per_host_batch_size, + per_host_eval_batch_size, + hps, + num_batches_to_prefetch, + num_device_prefetches, + ) + + +def _get_criteo1tb_arrayrecord( + shuffle_rng, + per_host_batch_size, + per_host_eval_batch_size, + hps, + num_batches_to_prefetch, + num_device_prefetches, +): + """Load Criteo 1TB from preprocessed ArrayRecord files.""" + base = hps.get('preprocessed_data_path', PREPROCESSED_CRITEO1TB_FILE_PATH) + train_file_path = os.path.join(base, 'train', '*') + validation_file_path = os.path.join( + base, 'val_set_second_half_of_day23_not_used', '*' + ) + test_file_path = os.path.join(base, 'eval', '*') + + train_dataset = criteo_arrayrecord_reader( + split='train', + shuffle_rng=shuffle_rng, + file_path=train_file_path, + batch_size=per_host_batch_size, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + data_utils.log_rss('train arrayrecord dataset created') + + if num_device_prefetches > 0: + train_iterator_fn = lambda: data_utils.prefetch_iterator( + tfds.as_numpy(train_dataset), num_device_prefetches + ) + data_utils.log_rss( + f'using prefetching with {num_device_prefetches} in the train dataset' + ) + else: + train_iterator_fn = lambda: tfds.as_numpy(train_dataset) + + eval_train_dataset = criteo_arrayrecord_reader( + split='eval_train', + shuffle_rng=None, + file_path=train_file_path, + batch_size=per_host_eval_batch_size, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + eval_train_iterator_fn = functools.partial( + _eval_numpy_iterator, + per_host_eval_batch_size=per_host_eval_batch_size, + tf_dataset=eval_train_dataset, + split_size=hps.train_size, + ) + data_utils.log_rss('eval_train arrayrecord dataset created') + + validation_dataset = criteo_arrayrecord_reader( + split='validation', + shuffle_rng=None, + file_path=validation_file_path, + batch_size=per_host_eval_batch_size, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + validation_iterator_fn = functools.partial( + _eval_numpy_iterator, + per_host_eval_batch_size=per_host_eval_batch_size, + tf_dataset=validation_dataset, + split_size=hps.valid_size, + ) + data_utils.log_rss('validation arrayrecord dataset created') + + test_dataset = criteo_arrayrecord_reader( + split='test', + shuffle_rng=None, + file_path=test_file_path, + batch_size=per_host_eval_batch_size, + num_batches_to_prefetch=num_batches_to_prefetch, + ) + test_iterator_fn = functools.partial( + _eval_numpy_iterator, + per_host_eval_batch_size=per_host_eval_batch_size, + tf_dataset=test_dataset, + split_size=hps.test_size, + ) + data_utils.log_rss('test arrayrecord dataset created') + + eval_train_iterator_fn = data_utils.CachedIteratorFactory( + eval_train_iterator_fn(None), 'eval_train' + ) + validation_iterator_fn = data_utils.CachedIteratorFactory( + validation_iterator_fn(None), 'validation' + ) + test_iterator_fn = data_utils.CachedIteratorFactory( + test_iterator_fn(None), 'test' + ) + + return data_utils.Dataset( + train_iterator_fn, + eval_train_iterator_fn, + validation_iterator_fn, + test_iterator_fn, + ) + + +def _get_criteo1tb_tsv( + shuffle_rng, + per_host_batch_size, + per_host_eval_batch_size, + hps, + num_batches_to_prefetch, + num_device_prefetches, +): + """Load Criteo 1TB from raw TSV files (legacy path).""" train_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'train/*/*') validation_file_path = os.path.join( RAW_CRITEO1TB_FILE_PATH, 'val_set_second_half_of_day23_not_used/*') test_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'eval/day_23/*') - num_batches_to_prefetch = (hps.num_tf_data_prefetches - if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE) - - num_device_prefetches = hps.get('num_device_prefetches', 0) train_dataset = criteo_tsv_reader( split='train',