diff --git a/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py b/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py index 2f5c5ec2..44768497 100644 --- a/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py +++ b/init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py @@ -22,7 +22,6 @@ https://github.com/mlcommons/algorithmic-efficiency/blob/main/datasets/dataset_setup.py. """ -# import tensorflow.compat.v2 as tf import os from absl import logging from ml_collections.config_dict import config_dict @@ -35,8 +34,7 @@ SHUFFLE_BUFFER_SIZE = 100_000 VOCAB_SIZE = 50_257 -PAD_ID = tf.constant(-1, dtype=tf.int64) -# PAD_ID = -1 +PAD_ID = -1 AUTOTUNE = tf.data.experimental.AUTOTUNE @@ -61,6 +59,8 @@ def batch_with_padding( # tf.data.Dataset.padded.batch pads elements in the batch so we call it # again with batch_size=1 to pad each element in original batch. + if isinstance(padding_id, int): + padding_id = tf.constant(padding_id, dtype=tf.int64) padded_batched_dataset = batched_dataset.padded_batch( 1, padded_shapes=padded_shapes, padding_values=padding_id ) @@ -95,7 +95,6 @@ def get_fineweb_edu_dataset( train_path = os.path.join(DATA_DIR, TRAIN_DIR) val_path = os.path.join(DATA_DIR, VAL_DIR) - # Load datasets and cast to int32. train_dataset = tf.data.Dataset.load(train_path) val_dataset = tf.data.Dataset.load(val_path) diff --git a/init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py b/init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py index 97ab105b..734df008 100644 --- a/init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py +++ b/init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py @@ -26,7 +26,7 @@ class FinewebEdu10bInputPipelineTest(absltest.TestCase): def test_batch_with_padding(self): """Test batching with padding.""" - arr = np.arange(18, dtype=np.int32) + arr = np.arange(18, dtype=np.int64) ds = tf.data.Dataset.from_tensor_slices(arr) ds = ds.batch( 6, diff --git a/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py b/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py index f31d2083..91f4bbad 100644 --- a/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py +++ b/init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py @@ -133,7 +133,7 @@ def test_eval_batch_padding_applied(self): self.assertLen(batches, 2) padded_batch = batches[1] - pad_id = int(input_pipeline.PAD_ID.numpy()) + pad_id = input_pipeline.PAD_ID # The second row of the padded batch should be all PAD_ID. np.testing.assert_array_equal(padded_batch['inputs'][1], np.full(4, pad_id)) @@ -157,7 +157,7 @@ def test_eval_batch_padding_not_in_full_batches(self): batches = list(valid_ds.as_numpy_iterator()) full_batch = batches[0] - pad_id = int(input_pipeline.PAD_ID.numpy()) + pad_id = input_pipeline.PAD_ID # No element in the full batch should be PAD_ID. self.assertTrue(np.all(full_batch['inputs'] != pad_id))