Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
https://github.com/mlcommons/algorithmic-efficiency/blob/main/datasets/dataset_setup.py.
"""

# import tensorflow.compat.v2 as tf
import os
from absl import logging
from ml_collections.config_dict import config_dict
Expand All @@ -35,8 +34,7 @@
SHUFFLE_BUFFER_SIZE = 100_000
VOCAB_SIZE = 50_257

PAD_ID = tf.constant(-1, dtype=tf.int64)
# PAD_ID = -1
PAD_ID = -1

AUTOTUNE = tf.data.experimental.AUTOTUNE

Expand All @@ -61,6 +59,8 @@ def batch_with_padding(

# tf.data.Dataset.padded.batch pads elements in the batch so we call it
# again with batch_size=1 to pad each element in original batch.
if isinstance(padding_id, int):
padding_id = tf.constant(padding_id, dtype=tf.int64)
padded_batched_dataset = batched_dataset.padded_batch(
1, padded_shapes=padded_shapes, padding_values=padding_id
)
Expand Down Expand Up @@ -95,7 +95,6 @@ def get_fineweb_edu_dataset(
train_path = os.path.join(DATA_DIR, TRAIN_DIR)
val_path = os.path.join(DATA_DIR, VAL_DIR)

# Load datasets and cast to int32.
train_dataset = tf.data.Dataset.load(train_path)
val_dataset = tf.data.Dataset.load(val_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class FinewebEdu10bInputPipelineTest(absltest.TestCase):

def test_batch_with_padding(self):
"""Test batching with padding."""
arr = np.arange(18, dtype=np.int32)
arr = np.arange(18, dtype=np.int64)
ds = tf.data.Dataset.from_tensor_slices(arr)
ds = ds.batch(
6,
Expand Down
4 changes: 2 additions & 2 deletions init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_eval_batch_padding_applied(self):
self.assertLen(batches, 2)

padded_batch = batches[1]
pad_id = int(input_pipeline.PAD_ID.numpy())
pad_id = input_pipeline.PAD_ID

# The second row of the padded batch should be all PAD_ID.
np.testing.assert_array_equal(padded_batch['inputs'][1], np.full(4, pad_id))
Expand All @@ -157,7 +157,7 @@ def test_eval_batch_padding_not_in_full_batches(self):

batches = list(valid_ds.as_numpy_iterator())
full_batch = batches[0]
pad_id = int(input_pipeline.PAD_ID.numpy())
pad_id = input_pipeline.PAD_ID

# No element in the full batch should be PAD_ID.
self.assertTrue(np.all(full_batch['inputs'] != pad_id))
Expand Down