From da5f85a7c878f0399c7b8a5d2fcfb9d729e567ea Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 11 Mar 2025 15:46:49 +0100 Subject: [PATCH 01/98] first LM commit --- algoperf/workloads/lm/__init__.py | 0 algoperf/workloads/lm/dev/data_pytorch.py | 42 ++++++++++ algoperf/workloads/lm/input_pipeline.py | 82 ++++++++++++++++++++ algoperf/workloads/lm/lm_pytorch/__init__.py | 0 algoperf/workloads/lm/lm_pytorch/workload.py | 36 +++++++++ algoperf/workloads/lm/test_01.py | 22 ++++++ algoperf/workloads/lm/test_input_pipeline.py | 68 ++++++++++++++++ algoperf/workloads/lm/workload.py | 66 ++++++++++++++++ 8 files changed, 316 insertions(+) create mode 100644 algoperf/workloads/lm/__init__.py create mode 100644 algoperf/workloads/lm/dev/data_pytorch.py create mode 100644 algoperf/workloads/lm/input_pipeline.py create mode 100644 algoperf/workloads/lm/lm_pytorch/__init__.py create mode 100644 algoperf/workloads/lm/lm_pytorch/workload.py create mode 100644 algoperf/workloads/lm/test_01.py create mode 100644 algoperf/workloads/lm/test_input_pipeline.py create mode 100644 algoperf/workloads/lm/workload.py diff --git a/algoperf/workloads/lm/__init__.py b/algoperf/workloads/lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py new file mode 100644 index 000000000..d0081a75d --- /dev/null +++ b/algoperf/workloads/lm/dev/data_pytorch.py @@ -0,0 +1,42 @@ + +import torch + +from datasets import Dataset, load_from_disk +from torch.utils.data import DataLoader + +trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" +vocab_size = 50280 +seq_len = 2048 +sampler = 'sequential' +sampler_seed = None +num_workers = 4 + +train_set = load_from_disk(trainset_path) # + +""" +>>> type(train_set) + + +>>> len(train_set) +7501407 + +>>> train_set[0] +{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} + +>>> type(train_set[0]['input_ids']) + + +# In PyTorch we do: +trainloader = DataLoader( + train_set, + sampler = ..., + batch_size = ..., + num_workers = ..., + pin_memory = ..., + ) + +# PyTorch’s DataLoader expects an iterable dataset, +# which means it calls __getitem__() and __len__() on train_set. + +""" + diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py new file mode 100644 index 000000000..7424dd6d5 --- /dev/null +++ b/algoperf/workloads/lm/input_pipeline.py @@ -0,0 +1,82 @@ +"""Input pipeline for a LM dataset.""" +import functools +import os + +from datasets import Dataset, load_from_disk +from typing import Dict, List, Optional, Union + +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from algoperf import data_utils +from algoperf.pytorch_utils import pytorch_setup + +RANK = pytorch_setup()[1] +# Avoid multithreading in all processes but the first (rank 0). +AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None + + +def get_lm_dataset(data_rng, + split: str, + data_dir: str, + is_training: bool, + vocab_size: int, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + vocab_path: Optional[str] = None): + """Load HF dataset and return a TF dataset.""" + + dataset_path = os.path.join(data_dir, split) + dataset = load_from_disk(dataset_path) # Loads HF arrow dataset + + is_training = split == "train" + shuffle = split in ['train', 'eval_train'] + + def tf_generator(): + """Generates data in a TensorFlow-friendly format.""" + for example in dataset: + yield { + "inputs": tf.convert_to_tensor(example["input_ids"][:-1], dtype=tf.int32), + "targets": tf.convert_to_tensor(example["input_ids"][1:], dtype=tf.int32), + } + + # Create a TensorFlow dataset from the generator function + ds = tf.data.Dataset.from_generator( + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), + } + ) + + # Avoid creating too many threads when using PyTorch DDP. + if RANK != 0: + options = tf.data.Options() + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) + + if shuffle: + print(f"Shuffling dataset with seed: {data_rng[0]}, type={type(data_rng[0])}") + ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) + + if is_training: + ds = ds.repeat() + + # Batch the dataset, ensuring the last batch is dropped if not full during training + ds = ds.batch(global_batch_size, drop_remainder=is_training) + ds = ds.prefetch(AUTOTUNE) + + # Limit the dataset to a fixed number of batches if `num_batches` is specified + if num_batches: + ds = ds.take(num_batches) + + # Shard the dataset across multiple GPUs/TPUs if necessary + ds = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) + + return ds \ No newline at end of file diff --git a/algoperf/workloads/lm/lm_pytorch/__init__.py b/algoperf/workloads/lm/lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py new file mode 100644 index 000000000..904657b1d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -0,0 +1,36 @@ +"""LM workload implemented in PyTorch.""" + +import contextlib +from typing import Any, Dict, Optional, Tuple + +from absl import logging +import jax +import tensorflow as tf +import torch +import torch.distributed as dist +from torch.nn import DataParallel as DP +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn(): + pass + + def model_fn(): + pass + + def _build_input_queue(): + pass + + def eval_step(): + pass diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/test_01.py new file mode 100644 index 000000000..e33ddf3e7 --- /dev/null +++ b/algoperf/workloads/lm/test_01.py @@ -0,0 +1,22 @@ +import os +import tensorflow as tf +import torch +from datasets import load_from_disk + +from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" +BATCH_SIZE = 2 +SEED = 42 # Fixed random seed for reproducibility + +tf_seed = SEED + +# Load the dataset +ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, +) diff --git a/algoperf/workloads/lm/test_input_pipeline.py b/algoperf/workloads/lm/test_input_pipeline.py new file mode 100644 index 000000000..47c11969f --- /dev/null +++ b/algoperf/workloads/lm/test_input_pipeline.py @@ -0,0 +1,68 @@ +import os +import tensorflow as tf +import torch +from datasets import load_from_disk + +from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" +BATCH_SIZE = 2 +SEED = 42 # Fixed random seed for reproducibility + + +def test_tf_dataset(): + """Tests if get_lm_dataset correctly loads the HF dataset as a TensorFlow dataset.""" + + print(f"Loading dataset from: {DATASET_PATH}") + + tf_seed = SEED + + # Load the dataset + ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, + ) + + print("Testing TensorFlow Dataset Output...") + for batch in ds.take(2): # Take two batches to test + print("Inputs:", batch["inputs"].numpy()) # Convert to NumPy for inspection + print("Targets:", batch["targets"].numpy()) + +def test_pytorch_dataloader(): + """Tests if the TensorFlow dataset can be converted to PyTorch format correctly.""" + + # Use the same TensorFlow-compatible seed + tf_seed = tf.constant(SEED, dtype=tf.int64) + + # Load the dataset + ds = get_lm_dataset( + data_rng=[tf_seed], # Ensure correct seed type + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, + global_batch_size=BATCH_SIZE, + ) + + def _input_queue_generator(): + """Generator that converts TF dataset batches to PyTorch tensors.""" + for batch in iter(ds): + batch = {k: torch.tensor(v.numpy()) for k, v in batch.items()} # Convert to PyTorch tensors + yield batch + + dataloader = _input_queue_generator() + + print("\nTesting PyTorch DataLoader Output...") + for _ in range(2): # Take two batches + batch = next(dataloader) + print("Inputs:", batch["inputs"]) + print("Targets:", batch["targets"]) + +# Run tests +if __name__ == "__main__": + test_tf_dataset() + test_pytorch_dataloader() \ No newline at end of file diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py new file mode 100644 index 000000000..d070cabec --- /dev/null +++ b/algoperf/workloads/lm/workload.py @@ -0,0 +1,66 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Any, Dict, Optional, Tuple + +import jax +import numpy as np +import torch + +from algoperf import spec +from algoperf.workloads.lm import input_pipeline + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +class BaseLmWorkload(spec.Workload): + """A LM workload.""" + + _vocab_size: int = 32000 + + def __init__(self) -> None: + super().__init__() + self._tokenizer = None + + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + is_training = split == 'train' + ds, self._tokenizer = input_pipeline.get_lm_dataset( + data_rng, + split, + data_dir, + is_training=is_training, + vocab_size=self._vocab_size, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + + for batch in iter(ds): + yield batch + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the loss function at (label_batch, logits_batch).""" + pass \ No newline at end of file From a12a36404ce907c8e50e67c8e4a5eb25baa9a2f3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 12 Mar 2025 15:49:04 +0100 Subject: [PATCH 02/98] lm data pipeline --- algoperf/workloads/lm/input_pipeline.py | 11 +-- algoperf/workloads/lm/test_01.py | 96 +++++++++++++++++++++---- datasets/dataset_setup.py | 96 +++++++++++++++++++++++++ datasets/lm_preprocess.py | 0 4 files changed, 185 insertions(+), 18 deletions(-) create mode 100644 datasets/lm_preprocess.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 7424dd6d5..a14cebeda 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,6 +5,7 @@ from datasets import Dataset, load_from_disk from typing import Dict, List, Optional, Union +import jax import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -17,7 +18,7 @@ AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None -def get_lm_dataset(data_rng, +def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, is_training: bool, @@ -37,11 +38,12 @@ def get_lm_dataset(data_rng, def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: + input_ids = example["input_ids"].numpy().astype(np.int32) # torch tensor TODO: remove numpy conversion yield { - "inputs": tf.convert_to_tensor(example["input_ids"][:-1], dtype=tf.int32), - "targets": tf.convert_to_tensor(example["input_ids"][1:], dtype=tf.int32), + "inputs": tf.convert_to_tensor(input_ids[:-1], dtype=tf.int32), + "targets": tf.convert_to_tensor(input_ids[1:], dtype=tf.int32), } - + # Create a TensorFlow dataset from the generator function ds = tf.data.Dataset.from_generator( tf_generator, @@ -58,7 +60,6 @@ def tf_generator(): ds = ds.with_options(options) if shuffle: - print(f"Shuffling dataset with seed: {data_rng[0]}, type={type(data_rng[0])}") ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) if is_training: diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/test_01.py index e33ddf3e7..977fae11a 100644 --- a/algoperf/workloads/lm/test_01.py +++ b/algoperf/workloads/lm/test_01.py @@ -1,22 +1,92 @@ + import os +import numpy as np import tensorflow as tf import torch + from datasets import load_from_disk +from absl import app +from absl import flags +from absl import logging + +from algoperf.profiler import PassThroughProfiler +from algoperf import random_utils as prng +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.lm.input_pipeline import get_lm_dataset + +tf.config.set_visible_devices([], 'GPU') + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +# (nico) +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +flags.DEFINE_enum( + 'framework', + None, + enum_values=['jax', 'pytorch'], + help='Whether to use Jax or Pytorch for the submission. Controls among ' + 'other things if the Jax or Numpy RNG library is used for RNG.') + +FLAGS = flags.FLAGS +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" BATCH_SIZE = 2 -SEED = 42 # Fixed random seed for reproducibility - -tf_seed = SEED - -# Load the dataset -ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, -) +RNG_SEED = 1996 # Fixed random seed for reproducibility + + +def main(_): + profiler = PassThroughProfiler() + if FLAGS.framework == 'pytorch': + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + + rng = prng.PRNGKey(RNG_SEED) + data_rng, _, _, _ = prng.split(rng, 4) + + print(f"data_rng = {data_rng}") + + # Load the dataset + ds = get_lm_dataset( + data_rng=data_rng, + split="train", + data_dir=DATASET_PATH, + is_training=True, + vocab_size=0, # Not needed but kept for function signature + global_batch_size=BATCH_SIZE, + ) + # Check if `ds` acts as a generator + if hasattr(ds, '__iter__'): + print("Dataset is an iterable/generator.") + + # Fetch first batch + try: + first_batch = next(iter(ds)) + print(f"Successfully retrieved first batch.") + except Exception as e: + print(f"Error retrieving first batch: {e}") + return + + # Print structure of a batch + print(f"First batch keys: {first_batch.keys()}") + print(f"First batch shapes:") + for key, value in first_batch.items(): + print(f" - {key}: {value.shape} (dtype: {value.dtype})") + + # Validate batch dimensions + assert "inputs" in first_batch and "targets" in first_batch, "Missing expected keys!" + assert first_batch["inputs"].shape[0] == BATCH_SIZE, "Batch size mismatch!" + assert first_batch["inputs"].shape == first_batch["targets"].shape, "Inputs and targets should have the same shape!" + + print(f"Dataset is correctly batched and structured.") + print(f"Test completed successfully.") + +if __name__ == '__main__': + flags.mark_flag_as_required('framework') + app.run(main) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index efe923dbe..14dd24545 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -76,13 +76,21 @@ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer +from datasets import lm_preprocess +import datasets as hf_datasets +# from datasets import load_dataset, Dataset +from transformers import AutoTokenizer + +import math import functools +import itertools import os import shutil import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -126,6 +134,9 @@ flags.DEFINE_boolean('librispeech', False, 'If --all=false, whether or not to download LibriSpeech.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('mnist', False, 'If --all=false, whether or not to download MNIST.') @@ -699,6 +710,86 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) +def download_finewebedu(data_dir, tmp_dir): + """Download FineWebEdu-10B.""" + + # data_dir = "/fast/najroldi/data" + + tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") + data_dir = os.path.join(data_dir, 'finewebedu') + + _maybe_mkdir(tmp_dir) + _maybe_mkdir(data_dir) + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + # cache_dir=tmp_dir + ) + + ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size + + seq_len = 2048 + max_seq_length = seq_len+1 + map_setup = dict(batched=True, batch_size=1024, num_proc=8) + + # Tokenize + tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of tokenizer = {len(tokenizer)}") + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] + return tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False + ) + + tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + tokenized_dataset = ds.map( + tokenize, + remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', + 'language_score', 'token_count', 'score', 'int_score'], + **map_setup + ) + tokenizer.model_max_length = seq_len + + # Concat in chunks of max_seq_len + def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """Concatenate text and generate chunks of max_seq_length""" + concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + if total_length >= max_seq_length: + total_length = (total_length // max_seq_length) * max_seq_length + result = { + k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] + for k, t in concatenated_examples.items() + } + return result + + lm_dataset = tokenized_dataset.map( + concat_chunck, + **map_setup + ) + + n_tokens = len(lm_dataset) * max_seq_length + logging.info(f"Number of tokens in dataset: {n_tokens:_}") + + # Split dataset into training and validation sets + # TODO: avoid (single doc) contamination between train and val + VAL_TOKENS = 10_000_000 + val_samples = VAL_TOKENS // max_seq_length + 1 + val_dataset = lm_dataset.select(range(val_samples)) + train_dataset = lm_dataset.select(range(val_samples, len(lm_dataset))) + logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length :_}") + logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length :_}") + + # Save datasets + train_dataset.save_to_disk(os.path.join(data_dir, f"train")) + val_dataset.save_to_disk(os.path.join(data_dir, f"val")) + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -781,6 +872,11 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + if not FLAGS.skip_download: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/lm_preprocess.py b/datasets/lm_preprocess.py new file mode 100644 index 000000000..e69de29bb From ca83ab8954a9e164dc538cb4749847812ee0e032 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 14 Mar 2025 11:31:08 +0100 Subject: [PATCH 03/98] testing --- algoperf/workloads/lm/{ => dev}/test_01.py | 0 .../lm/{ => dev}/test_input_pipeline.py | 0 algoperf/workloads/lm/input_pipeline.py | 37 +++++---- .../workloads/lm/lm_jax/__init__.py | 0 algoperf/workloads/lm/lm_jax/workload.py | 20 +++++ algoperf/workloads/lm/lm_pytorch/workload.py | 56 ++++++++++++- algoperf/workloads/lm/test.py | 37 +++++++++ algoperf/workloads/lm/workload.py | 80 ++++++++++++++----- datasets/dataset_setup.py | 25 ++++-- 9 files changed, 211 insertions(+), 44 deletions(-) rename algoperf/workloads/lm/{ => dev}/test_01.py (100%) rename algoperf/workloads/lm/{ => dev}/test_input_pipeline.py (100%) rename datasets/lm_preprocess.py => algoperf/workloads/lm/lm_jax/__init__.py (100%) create mode 100644 algoperf/workloads/lm/lm_jax/workload.py create mode 100644 algoperf/workloads/lm/test.py diff --git a/algoperf/workloads/lm/test_01.py b/algoperf/workloads/lm/dev/test_01.py similarity index 100% rename from algoperf/workloads/lm/test_01.py rename to algoperf/workloads/lm/dev/test_01.py diff --git a/algoperf/workloads/lm/test_input_pipeline.py b/algoperf/workloads/lm/dev/test_input_pipeline.py similarity index 100% rename from algoperf/workloads/lm/test_input_pipeline.py rename to algoperf/workloads/lm/dev/test_input_pipeline.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index a14cebeda..f0024e4a6 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -15,6 +15,10 @@ RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). +# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# automatic optimization (AUTOTUNE), while other processes disable it (None). +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine the optimal +# number of elements to prefetch or parallelize for dataset operations, improving performance. AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None @@ -30,34 +34,36 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, """Load HF dataset and return a TF dataset.""" dataset_path = os.path.join(data_dir, split) - dataset = load_from_disk(dataset_path) # Loads HF arrow dataset + dataset = load_from_disk(dataset_path) is_training = split == "train" shuffle = split in ['train', 'eval_train'] + dataset.set_format("tensorflow") # tf.int64 + def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: - input_ids = example["input_ids"].numpy().astype(np.int32) # torch tensor TODO: remove numpy conversion yield { - "inputs": tf.convert_to_tensor(input_ids[:-1], dtype=tf.int32), - "targets": tf.convert_to_tensor(input_ids[1:], dtype=tf.int32), + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], } - # Create a TensorFlow dataset from the generator function + # Create a TensorFlow dataset ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), - } - ) + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + } + ) # Avoid creating too many threads when using PyTorch DDP. - if RANK != 0: + # Limits TensorFlow's threading for non-primary processes (RANK != 0) + if RANK != 0: options = tf.data.Options() - options.threading.private_threadpool_size = 1 - ds = ds.with_options(options) + options.threading.private_threadpool_size = 1 # restrict dataset operations to a single thread + ds = ds.with_options(options) # apply threading restrictions if shuffle: ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) @@ -66,6 +72,9 @@ def tf_generator(): ds = ds.repeat() # Batch the dataset, ensuring the last batch is dropped if not full during training + # i.e. it groups consecutive elements into fixed-size chunks. + # Instead of processing individual elements, the dataset yields batches (tensors with multiple elements), + # improving efficiency and parallelism in training ds = ds.batch(global_batch_size, drop_remainder=is_training) ds = ds.prefetch(AUTOTUNE) diff --git a/datasets/lm_preprocess.py b/algoperf/workloads/lm/lm_jax/__init__.py similarity index 100% rename from datasets/lm_preprocess.py rename to algoperf/workloads/lm/lm_jax/__init__.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py new file mode 100644 index 000000000..4cdb42409 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -0,0 +1,20 @@ +"""LM workload implemented in Jax.""" + +import functools +from typing import Dict, Optional, Tuple + +from flax import jax_utils +import jax +import jax.numpy as jnp +import numpy as np + +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.lm.workload import BaseLmWorkload + + +class LmWorkload(BaseLmWorkload): + + @property + def eval_batch_size(self) -> int: + return 131_072 diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 904657b1d..9ee21ccb6 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -29,8 +29,58 @@ def init_model_fn(): def model_fn(): pass - def _build_input_queue(): - pass - + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + per_device_batch_size = int(global_batch_size / N_GPUS) + + # Only create and iterate over tf input pipeline in one Python process to + # avoid creating too many threads. + if RANK == 0: + np_iter = super()._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + while True: + if RANK == 0: + batch = next(np_iter) + inputs = torch.as_tensor( + batch['inputs'], dtype=torch.float32, device=DEVICE) + targets = torch.as_tensor( + batch['targets'], dtype=torch.float32, device=DEVICE) + # Send batch to other devices when using DDP. + if USE_PYTORCH_DDP: + dist.broadcast(inputs, src=0) + inputs = inputs[0] # TODO: check + dist.broadcast(targets, src=0) + targets = targets[0] # TODO: check + else: + batch = {} + inputs = torch.empty((N_GPUS, per_device_batch_size, 39), + dtype=torch.float32, + device=DEVICE) + dist.broadcast(inputs, src=0) + inputs = inputs[RANK] + targets = torch.empty((N_GPUS, per_device_batch_size, 1), + dtype=torch.float32, + device=DEVICE) + dist.broadcast(targets, src=0) + targets = targets[RANK] + + batch = { + 'inputs': inputs, + 'targets': targets, + # 'weights': weights, + } + yield batch + + def eval_step(): pass diff --git a/algoperf/workloads/lm/test.py b/algoperf/workloads/lm/test.py new file mode 100644 index 000000000..7e693d0af --- /dev/null +++ b/algoperf/workloads/lm/test.py @@ -0,0 +1,37 @@ +""" +Test data pipaline in JAX and PyTorch. + +Instantiate a workload and loops over the input queue. +""" + +import jax +import numpy as np +import torch + +import algoperf.workloads.lm.lm_jax.workload as lm_jax +# import algoperf.workloads.lm.lm_pytorch.workload as lm_pytorch + + +data_rng = jax.random.PRNGKey(0) +split = 'train' +data_dir = "/fast/najroldi/data/finewebedu" +global_batch_size = 8 +num_batches = 10 +repeat_final_dataset = False + +# ------------------------------------------------------------------------------ +# JAX +# ------------------------------------------------------------------------------ + +# 1 GPU +workload = lm_jax.LmWorkload() + +input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) + +next(input_queue) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index d070cabec..63d2c707e 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -32,7 +32,7 @@ def _build_input_queue(self, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): is_training = split == 'train' - ds, self._tokenizer = input_pipeline.get_lm_dataset( + ds = input_pipeline.get_lm_dataset( data_rng, split, data_dir, @@ -41,26 +41,66 @@ def _build_input_queue(self, global_batch_size=global_batch_size, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) - + for batch in iter(ds): yield batch - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - - def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the loss function at (label_batch, logits_batch).""" + def _eval_model_on_split(): + pass + + def eval_period_time_sec(): + pass + + def has_reached_test_target(): + pass + + def has_reached_validation_target(): + pass + + def init_model_fn(): + pass + + def is_output_params(): + pass + + def loss_fn(): + pass + + def loss_type(): + pass + + def max_allowed_runtime_sec(): + pass + + def model_fn(): + pass + + def num_eval_train_examples(): + pass + + def num_test_examples(): + pass + + def num_train_examples(): + pass + + def num_validation_examples(): + pass + + def step_hint(): + pass + + def test_target_value(): + pass + + def train_mean(): + pass + + def train_stddev(): + pass + + def validation_target_value(): + pass + + def target_metric_name(): pass \ No newline at end of file diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 14dd24545..aab793832 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -76,10 +76,8 @@ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer -from datasets import lm_preprocess import datasets as hf_datasets -# from datasets import load_dataset, Dataset from transformers import AutoTokenizer import math @@ -721,6 +719,9 @@ def download_finewebedu(data_dir, tmp_dir): _maybe_mkdir(tmp_dir) _maybe_mkdir(data_dir) + # Use local disk instead of NFS for temp storage + os.environ["TMPDIR"] = tmp_dir + ds = hf_datasets.load_dataset( 'HuggingFaceFW/fineweb-edu', name='sample-10BT', @@ -745,7 +746,6 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return_special_tokens_mask=False, return_attention_mask=False ) - tokenizer.model_max_length = 1e30 # prevent truncation during tokenization tokenized_dataset = ds.map( tokenize, @@ -754,8 +754,21 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: **map_setup ) tokenizer.model_max_length = seq_len + + tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) + from datasets import load_from_disk + tokenized_dataset = load_from_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) # Concat in chunks of max_seq_len + # TODO: this might take to much memory + # TODO: bug fix: Python's shutil.rmtree tried to delete a .nfs* file, but it was still in use (OSError: [Errno 16] Device or resource busy + # TODO: bug fix: I am losing tokens in the concat-chunk: num_tokens before split: 9_944_182_212 + # (1) loss happening because of batched=True: potentially losing the last tokens in the last batch of the 1024 batched examples + # NOTE: the current approach leads to data loss at batch boundaries, + # but concatenation *cannot* happen if batched=False, + # because concat_chunck relies on processing multiple examples at once. + # (2) loss happening because of nproc>1: potentially losing the last tokens in each process + # TODO: this does not allow to later change the seq_len... not a problem in AlgoPerf, but bad in plainLM def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -767,13 +780,11 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map( - concat_chunck, + concat_chunck,\ **map_setup ) - - n_tokens = len(lm_dataset) * max_seq_length + n_tokens = len(lm_dataset) * max_seq_length # 9_944_182_212 logging.info(f"Number of tokens in dataset: {n_tokens:_}") # Split dataset into training and validation sets From e3e78dc6443c5485af64bfe986951f72d9754f99 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 17 Mar 2025 11:18:41 +0100 Subject: [PATCH 04/98] LM workload tested torch pipeline --- algoperf/data_utils.py | 2 +- .../lm/dev/test_build_input_queue_torch.py | 80 +++++++++++++++++++ .../workloads/lm/{test.py => dev/test_jax.py} | 19 ++++- algoperf/workloads/lm/input_pipeline.py | 3 +- algoperf/workloads/lm/lm_jax/workload.py | 5 +- algoperf/workloads/lm/lm_pytorch/workload.py | 68 +++++++++------- algoperf/workloads/lm/workload.py | 7 +- submission_runner.py | 2 +- 8 files changed, 146 insertions(+), 40 deletions(-) create mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_torch.py rename algoperf/workloads/lm/{test.py => dev/test_jax.py} (63%) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..068c21c03 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree.map(_prepare, batch) + return jax.tree_util.tree_map(_prepare, batch) def pad(tensor: np.ndarray, diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py new file mode 100644 index 000000000..86b1ca6b7 --- /dev/null +++ b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py @@ -0,0 +1,80 @@ + +import jax +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + +n_gpus = max(N_GPUS, jax.local_device_count()) + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + # batch = next(input_queue) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # Start test. + for _ in range(100): + + batch = next(input_queue) + assert type(batch) == dict + + assert 'inputs' in batch + assert 'targets' in batch + + assert type(batch['inputs']) == torch.Tensor + assert type(batch['targets']) == torch.Tensor + + assert batch['inputs'].dtype == dtype + assert batch['targets'].dtype == dtype + + assert batch['inputs'].shape == (local_batch_size, seq_len) + assert batch['targets'].shape == (local_batch_size, seq_len) + + sync_ddp() + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() + diff --git a/algoperf/workloads/lm/test.py b/algoperf/workloads/lm/dev/test_jax.py similarity index 63% rename from algoperf/workloads/lm/test.py rename to algoperf/workloads/lm/dev/test_jax.py index 7e693d0af..4ba3de631 100644 --- a/algoperf/workloads/lm/test.py +++ b/algoperf/workloads/lm/dev/test_jax.py @@ -15,6 +15,7 @@ data_rng = jax.random.PRNGKey(0) split = 'train' data_dir = "/fast/najroldi/data/finewebedu" +seq_len = 2048 global_batch_size = 8 num_batches = 10 repeat_final_dataset = False @@ -34,4 +35,20 @@ num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) -next(input_queue) +batch = next(input_queue) +assert type(batch) == dict + +assert 'inputs' in batch +assert 'targets' in batch + +assert type(batch['inputs']) == np.ndarray +assert type(batch['targets']) == np.ndarray + +assert batch['inputs'].dtype == np.int64 +assert batch['targets'].dtype == np.int64 + +assert batch['inputs'].shape == (1, global_batch_size, seq_len) +assert batch['targets'].shape == (1, global_batch_size, seq_len) + +print(f"JAX devices = {jax.devices()}") +print("1") diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index f0024e4a6..e74490a16 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -25,7 +25,6 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - is_training: bool, vocab_size: int, global_batch_size: int, num_batches: Optional[int] = None, @@ -39,7 +38,7 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, is_training = split == "train" shuffle = split in ['train', 'eval_train'] - dataset.set_format("tensorflow") # tf.int64 + dataset.set_format("tensorflow") # tf.int64 # TODO: is this needed? def tf_generator(): """Generates data in a TensorFlow-friendly format.""" diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 4cdb42409..773f8c54c 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -14,7 +14,4 @@ class LmWorkload(BaseLmWorkload): - - @property - def eval_batch_size(self) -> int: - return 131_072 + pass diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 9ee21ccb6..0ff7884c7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,7 +1,7 @@ """LM workload implemented in PyTorch.""" import contextlib -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Iterator, Optional, Tuple from absl import logging import jax @@ -22,12 +22,6 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" - - def init_model_fn(): - pass - - def model_fn(): - pass def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -35,8 +29,12 @@ def _build_input_queue(self, data_dir: str, global_batch_size: int, num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) + + seq_len = 2048 # TODO: define it somewehere else + DTYPE = torch.int32 # TODO: decide between int32 and int64. # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. @@ -48,36 +46,50 @@ def _build_input_queue(self, global_batch_size=global_batch_size, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) + weights = None + while True: + # Only iterate over tf input pipeline in one Python process to + # avoid creating too many threads. if RANK == 0: - batch = next(np_iter) - inputs = torch.as_tensor( - batch['inputs'], dtype=torch.float32, device=DEVICE) - targets = torch.as_tensor( - batch['targets'], dtype=torch.float32, device=DEVICE) + batch = next(np_iter) # pylint: disable=stop-iteration-return + inputs = torch.as_tensor(batch['inputs'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + targets = torch.as_tensor(batch['targets'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: - dist.broadcast(inputs, src=0) - inputs = inputs[0] # TODO: check - dist.broadcast(targets, src=0) - targets = targets[0] # TODO: check + if not_train: + # During eval, the batch size of the remainder might be different. + per_device_batch_size = torch.tensor(len(targets[0]), dtype=DTYPE, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + # We don't broadcast the shard for RANK 0. + dist.broadcast(inputs[1:], src=0) + dist.broadcast(targets[1:], src=0) + + # RANK 0 extracts his shard. If not DDP, this just flattens. + inputs, targets = inputs[0], targets[0] + else: - batch = {} - inputs = torch.empty((N_GPUS, per_device_batch_size, 39), - dtype=torch.float32, - device=DEVICE) + # Receive batch from rank 0. + if not_train: + # During eval, the batch size of the remainder might be different. + per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE) + dist.broadcast(per_device_batch_size, src=0) + + # N_GPUS - 1 since we don't broadcast the shard for RANK 0. + inputs = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) + targets = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) dist.broadcast(inputs, src=0) - inputs = inputs[RANK] - targets = torch.empty((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) dist.broadcast(targets, src=0) - targets = targets[RANK] - + # RANK - 1 since we don't broadcast the shard for RANK 0. + inputs, targets = inputs[RANK-1], targets[RANK-1] + + if weights is None: + weights = torch.ones(per_device_batch_size, device=DEVICE) batch = { 'inputs': inputs, 'targets': targets, - # 'weights': weights, + 'weights': weights, } yield batch diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 63d2c707e..7b1313dd7 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -31,12 +31,10 @@ def _build_input_queue(self, global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): - is_training = split == 'train' ds = input_pipeline.get_lm_dataset( data_rng, split, data_dir, - is_training=is_training, vocab_size=self._vocab_size, global_batch_size=global_batch_size, num_batches=num_batches, @@ -103,4 +101,7 @@ def validation_target_value(): pass def target_metric_name(): - pass \ No newline at end of file + pass + + def eval_batch_size(): + pass diff --git a/submission_runner.py b/submission_runner.py index a2521e77b..6fac50d99 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -234,7 +234,7 @@ def train_once( dropout_rate = hyperparameters.dropout_rate if hasattr(hyperparameters, 'aux_dropout_rate'): aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( + model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ From e6194950fc524793906127f09b330a8329ad079f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 17 Mar 2025 11:34:10 +0100 Subject: [PATCH 05/98] LM workload - fix torch tests --- .../lm/dev/test_build_input_queue_torch.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py index 86b1ca6b7..66205d091 100644 --- a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py @@ -41,30 +41,33 @@ def test_dataloader_torch(): data_dir=data_dir, global_batch_size=global_batch_size) - # batch = next(input_queue) - print(f"RANK {RANK} of {N_GPUS}") sync_ddp() # Start test. for _ in range(100): - + batch = next(input_queue) - assert type(batch) == dict + assert type(batch) == dict assert 'inputs' in batch assert 'targets' in batch - assert type(batch['inputs']) == torch.Tensor - assert type(batch['targets']) == torch.Tensor + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype - assert batch['inputs'].dtype == dtype - assert batch['targets'].dtype == dtype + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) - assert batch['inputs'].shape == (local_batch_size, seq_len) - assert batch['targets'].shape == (local_batch_size, seq_len) - - sync_ddp() + assert torch.equal(inputs[:,1:], targets[:,:-1]) print(f"=== ALL TEST PASSED ===") From d8e9c56738de817e561e79cffee638ab7197eaed Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:44:36 +0100 Subject: [PATCH 06/98] add LM tests, remove dev files --- algoperf/workloads/lm/dev/data_pytorch.py | 42 --------- algoperf/workloads/lm/dev/test_01.py | 92 ------------------- .../lm/dev/test_build_input_queue_torch.py | 83 ----------------- .../workloads/lm/dev/test_input_pipeline.py | 68 -------------- algoperf/workloads/lm/dev/test_jax.py | 54 ----------- 5 files changed, 339 deletions(-) delete mode 100644 algoperf/workloads/lm/dev/data_pytorch.py delete mode 100644 algoperf/workloads/lm/dev/test_01.py delete mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_torch.py delete mode 100644 algoperf/workloads/lm/dev/test_input_pipeline.py delete mode 100644 algoperf/workloads/lm/dev/test_jax.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py deleted file mode 100644 index d0081a75d..000000000 --- a/algoperf/workloads/lm/dev/data_pytorch.py +++ /dev/null @@ -1,42 +0,0 @@ - -import torch - -from datasets import Dataset, load_from_disk -from torch.utils.data import DataLoader - -trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" -vocab_size = 50280 -seq_len = 2048 -sampler = 'sequential' -sampler_seed = None -num_workers = 4 - -train_set = load_from_disk(trainset_path) # - -""" ->>> type(train_set) - - ->>> len(train_set) -7501407 - ->>> train_set[0] -{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} - ->>> type(train_set[0]['input_ids']) - - -# In PyTorch we do: -trainloader = DataLoader( - train_set, - sampler = ..., - batch_size = ..., - num_workers = ..., - pin_memory = ..., - ) - -# PyTorch’s DataLoader expects an iterable dataset, -# which means it calls __getitem__() and __len__() on train_set. - -""" - diff --git a/algoperf/workloads/lm/dev/test_01.py b/algoperf/workloads/lm/dev/test_01.py deleted file mode 100644 index 977fae11a..000000000 --- a/algoperf/workloads/lm/dev/test_01.py +++ /dev/null @@ -1,92 +0,0 @@ - -import os -import numpy as np -import tensorflow as tf -import torch - -from datasets import load_from_disk - -from absl import app -from absl import flags -from absl import logging - -from algoperf.profiler import PassThroughProfiler -from algoperf import random_utils as prng -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.input_pipeline import get_lm_dataset - - -tf.config.set_visible_devices([], 'GPU') - -# Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' -# (nico) -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - -flags.DEFINE_enum( - 'framework', - None, - enum_values=['jax', 'pytorch'], - help='Whether to use Jax or Pytorch for the submission. Controls among ' - 'other things if the Jax or Numpy RNG library is used for RNG.') - -FLAGS = flags.FLAGS -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - - -DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" -BATCH_SIZE = 2 -RNG_SEED = 1996 # Fixed random seed for reproducibility - - -def main(_): - profiler = PassThroughProfiler() - if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - - rng = prng.PRNGKey(RNG_SEED) - data_rng, _, _, _ = prng.split(rng, 4) - - print(f"data_rng = {data_rng}") - - # Load the dataset - ds = get_lm_dataset( - data_rng=data_rng, - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, - ) - # Check if `ds` acts as a generator - if hasattr(ds, '__iter__'): - print("Dataset is an iterable/generator.") - - # Fetch first batch - try: - first_batch = next(iter(ds)) - print(f"Successfully retrieved first batch.") - except Exception as e: - print(f"Error retrieving first batch: {e}") - return - - # Print structure of a batch - print(f"First batch keys: {first_batch.keys()}") - print(f"First batch shapes:") - for key, value in first_batch.items(): - print(f" - {key}: {value.shape} (dtype: {value.dtype})") - - # Validate batch dimensions - assert "inputs" in first_batch and "targets" in first_batch, "Missing expected keys!" - assert first_batch["inputs"].shape[0] == BATCH_SIZE, "Batch size mismatch!" - assert first_batch["inputs"].shape == first_batch["targets"].shape, "Inputs and targets should have the same shape!" - - print(f"Dataset is correctly batched and structured.") - print(f"Test completed successfully.") - -if __name__ == '__main__': - flags.mark_flag_as_required('framework') - app.run(main) diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py b/algoperf/workloads/lm/dev/test_build_input_queue_torch.py deleted file mode 100644 index 66205d091..000000000 --- a/algoperf/workloads/lm/dev/test_build_input_queue_torch.py +++ /dev/null @@ -1,83 +0,0 @@ - -import jax -import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload - -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - -n_gpus = max(N_GPUS, jax.local_device_count()) - -def sync_ddp(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -def test_dataloader_torch(): - # Test config. - rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' - split = 'train' - global_batch_size = 8 - dtype = torch.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - print(f"RANK {RANK} of {N_GPUS}") - sync_ddp() - - # Start test. - for _ in range(100): - - batch = next(input_queue) - - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:,1:], targets[:,:-1]) - - print(f"=== ALL TEST PASSED ===") - - -def main(): - profiler = PassThroughProfiler() - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - test_dataloader_torch() - - -if __name__ == '__main__': - main() - diff --git a/algoperf/workloads/lm/dev/test_input_pipeline.py b/algoperf/workloads/lm/dev/test_input_pipeline.py deleted file mode 100644 index 47c11969f..000000000 --- a/algoperf/workloads/lm/dev/test_input_pipeline.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import tensorflow as tf -import torch -from datasets import load_from_disk - -from algoperf.workloads.lm.input_pipeline import get_lm_dataset - -DATASET_PATH = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens" -BATCH_SIZE = 2 -SEED = 42 # Fixed random seed for reproducibility - - -def test_tf_dataset(): - """Tests if get_lm_dataset correctly loads the HF dataset as a TensorFlow dataset.""" - - print(f"Loading dataset from: {DATASET_PATH}") - - tf_seed = SEED - - # Load the dataset - ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, # Not needed but kept for function signature - global_batch_size=BATCH_SIZE, - ) - - print("Testing TensorFlow Dataset Output...") - for batch in ds.take(2): # Take two batches to test - print("Inputs:", batch["inputs"].numpy()) # Convert to NumPy for inspection - print("Targets:", batch["targets"].numpy()) - -def test_pytorch_dataloader(): - """Tests if the TensorFlow dataset can be converted to PyTorch format correctly.""" - - # Use the same TensorFlow-compatible seed - tf_seed = tf.constant(SEED, dtype=tf.int64) - - # Load the dataset - ds = get_lm_dataset( - data_rng=[tf_seed], # Ensure correct seed type - split="train", - data_dir=DATASET_PATH, - is_training=True, - vocab_size=0, - global_batch_size=BATCH_SIZE, - ) - - def _input_queue_generator(): - """Generator that converts TF dataset batches to PyTorch tensors.""" - for batch in iter(ds): - batch = {k: torch.tensor(v.numpy()) for k, v in batch.items()} # Convert to PyTorch tensors - yield batch - - dataloader = _input_queue_generator() - - print("\nTesting PyTorch DataLoader Output...") - for _ in range(2): # Take two batches - batch = next(dataloader) - print("Inputs:", batch["inputs"]) - print("Targets:", batch["targets"]) - -# Run tests -if __name__ == "__main__": - test_tf_dataset() - test_pytorch_dataloader() \ No newline at end of file diff --git a/algoperf/workloads/lm/dev/test_jax.py b/algoperf/workloads/lm/dev/test_jax.py deleted file mode 100644 index 4ba3de631..000000000 --- a/algoperf/workloads/lm/dev/test_jax.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Test data pipaline in JAX and PyTorch. - -Instantiate a workload and loops over the input queue. -""" - -import jax -import numpy as np -import torch - -import algoperf.workloads.lm.lm_jax.workload as lm_jax -# import algoperf.workloads.lm.lm_pytorch.workload as lm_pytorch - - -data_rng = jax.random.PRNGKey(0) -split = 'train' -data_dir = "/fast/najroldi/data/finewebedu" -seq_len = 2048 -global_batch_size = 8 -num_batches = 10 -repeat_final_dataset = False - -# ------------------------------------------------------------------------------ -# JAX -# ------------------------------------------------------------------------------ - -# 1 GPU -workload = lm_jax.LmWorkload() - -input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) - -batch = next(input_queue) -assert type(batch) == dict - -assert 'inputs' in batch -assert 'targets' in batch - -assert type(batch['inputs']) == np.ndarray -assert type(batch['targets']) == np.ndarray - -assert batch['inputs'].dtype == np.int64 -assert batch['targets'].dtype == np.int64 - -assert batch['inputs'].shape == (1, global_batch_size, seq_len) -assert batch['targets'].shape == (1, global_batch_size, seq_len) - -print(f"JAX devices = {jax.devices()}") -print("1") From 6b4ff12356c5f41b01ce703801b556a11079d354 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:44:58 +0100 Subject: [PATCH 07/98] add LM tests, remove dev files --- algoperf/workloads/lm/dev/data_pytorch.py | 42 ++++++ .../lm/dev/test_build_input_queue_jax.py | 127 ++++++++++++++++++ .../lm/tests/test_build_input_queue_torch.py | 87 ++++++++++++ 3 files changed, 256 insertions(+) create mode 100644 algoperf/workloads/lm/dev/data_pytorch.py create mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_jax.py create mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_torch.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py new file mode 100644 index 000000000..d0081a75d --- /dev/null +++ b/algoperf/workloads/lm/dev/data_pytorch.py @@ -0,0 +1,42 @@ + +import torch + +from datasets import Dataset, load_from_disk +from torch.utils.data import DataLoader + +trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" +vocab_size = 50280 +seq_len = 2048 +sampler = 'sequential' +sampler_seed = None +num_workers = 4 + +train_set = load_from_disk(trainset_path) # + +""" +>>> type(train_set) + + +>>> len(train_set) +7501407 + +>>> train_set[0] +{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} + +>>> type(train_set[0]['input_ids']) + + +# In PyTorch we do: +trainloader = DataLoader( + train_set, + sampler = ..., + batch_size = ..., + num_workers = ..., + pin_memory = ..., + ) + +# PyTorch’s DataLoader expects an iterable dataset, +# which means it calls __getitem__() and __len__() on train_set. + +""" + diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py new file mode 100644 index 000000000..08354be74 --- /dev/null +++ b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py @@ -0,0 +1,127 @@ + +# TODO: redo with pmap!! + +import os +import jax +import tensorflow as tf +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_jax.workload import LmWorkload + +# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.set_visible_devices([], 'GPU') + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' + + +N_GPUS = jax.local_device_count() + +print(f"jax.local_devices() = {jax.local_devices()}") +print(f"jax.local_device_count() = {jax.local_device_count()}") + +print(f"N_GPUS = {N_GPUS}") + +def check_batch(batch): + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:,1:], targets[:,:-1]) + + +def process_shard(batch): + inputs, targets = batch['inputs'], batch['targets'] + jax.debug.print("Processing on GPU with inputs: {shape}", shape=inputs.shape) + jax.debug.print("inputs {inputs}", inputs=inputs) + jax.debug.callback(check_batch, batch) + return inputs, targets + +# Apply process_batch across devices, sharding batch across devices +pmap_process = jax.pmap(process_shard, axis_name='batch') + + +def test_dataloader_jax(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = np.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + batch = next(input_queue) + + inputs, targets = batch['inputs'], batch['targets'] + print(f"Processing on GPU with inputs: {inputs.shape}") + + inputs, targets = pmap_process(batch) + print(f"Processing on GPU with inputs: {inputs.shape}") + print(f"Processing on GPU with inputs: {inputs}") + + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs[0]: {inputs[0]}") + # print(f"inputs[1]: {inputs[1]}") + + # for device_id in range(2): + # # Access the sharded data for each GPU + # print(inputs.shape) + # device_inputs = inputs[device_id] + # print(f" GPU {device_id} Inputs: {device_inputs.shape}") + + # @jax.pmap + # def process_batch(batch): + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + + # return inputs, targets + + # inputs, targets = batch['inputs'], batch['targets'] #process_batch(batch) + # print(f"inputs: {inputs[0]}") + + + +def main(): + test_dataloader_jax() + + +if __name__ == '__main__': + main() + diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py new file mode 100644 index 000000000..83a18ec15 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -0,0 +1,87 @@ + +import jax +import torch +import pdb +import numpy as np + +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() + + +def sync_ddp(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def test_dataloader_torch(): + # Test config. + rng_seed = 1996 + data_dir = '/fast/najroldi/data/finewebedu' + split = 'train' + global_batch_size = 8 + dtype = torch.int32 + seq_len = 2048 + + local_batch_size = global_batch_size // N_GPUS + + workload = LmWorkload() + + data_rng = jax.random.PRNGKey(rng_seed) + + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + print(f"RANK {RANK} of {N_GPUS}") + sync_ddp() + + # batch = next(input_queue) + # inputs, targets = batch['inputs'], batch['targets'] + # print(f"inputs.shape: {inputs.shape}") + # print(f"inputs: {inputs}") + + # Start test. + for _ in range(100): + + batch = next(input_queue) + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + + assert type(inputs) == torch.Tensor + assert type(targets) == torch.Tensor + + assert inputs.device == DEVICE + assert targets.device == DEVICE + + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (local_batch_size, seq_len) + assert targets.shape == (local_batch_size, seq_len) + + assert torch.equal(inputs[:,1:], targets[:,:-1]) + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + test_dataloader_torch() + + +if __name__ == '__main__': + main() + From 3c5c847eb1489fa11a65c98c0f3327bd3c23c088 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:45:41 +0100 Subject: [PATCH 08/98] Stop tracking .gitignore --- .gitignore | 28 ---------------------------- 1 file changed, 28 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 7d35f0ccc..000000000 --- a/.gitignore +++ /dev/null @@ -1,28 +0,0 @@ -__pycache__/* -__pycache__ -*egg-info -*eggs -.vscode/ -env/ -venv/ -workdir/ -makefile -*.out -*.sh -*.swp -*/data/ -*events.out.tfevents* -algoperf/workloads/librispeech_conformer/data_dir -algoperf/workloads/librispeech_conformer/work_dir -*.flac -*.npy -*.csv -*.vocab -wandb/ -*.txt -scoring/plots/ - -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv - -algoperf/_version.py From 20d841b1932408bc905051dc2e188f3a43e0d749 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 09:47:55 +0100 Subject: [PATCH 09/98] Remove dev/ from repo, keep locally --- algoperf/workloads/lm/dev/data_pytorch.py | 42 ------ .../lm/dev/test_build_input_queue_jax.py | 127 ------------------ 2 files changed, 169 deletions(-) delete mode 100644 algoperf/workloads/lm/dev/data_pytorch.py delete mode 100644 algoperf/workloads/lm/dev/test_build_input_queue_jax.py diff --git a/algoperf/workloads/lm/dev/data_pytorch.py b/algoperf/workloads/lm/dev/data_pytorch.py deleted file mode 100644 index d0081a75d..000000000 --- a/algoperf/workloads/lm/dev/data_pytorch.py +++ /dev/null @@ -1,42 +0,0 @@ - -import torch - -from datasets import Dataset, load_from_disk -from torch.utils.data import DataLoader - -trainset_path = "/fast/najroldi/data/lm/slim_pajama/new_sp_15B_tokens/train" -vocab_size = 50280 -seq_len = 2048 -sampler = 'sequential' -sampler_seed = None -num_workers = 4 - -train_set = load_from_disk(trainset_path) # - -""" ->>> type(train_set) - - ->>> len(train_set) -7501407 - ->>> train_set[0] -{'input_ids': tensor([ 5166, 20, 1639, ..., 275, 253, 19992])} - ->>> type(train_set[0]['input_ids']) - - -# In PyTorch we do: -trainloader = DataLoader( - train_set, - sampler = ..., - batch_size = ..., - num_workers = ..., - pin_memory = ..., - ) - -# PyTorch’s DataLoader expects an iterable dataset, -# which means it calls __getitem__() and __len__() on train_set. - -""" - diff --git a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py b/algoperf/workloads/lm/dev/test_build_input_queue_jax.py deleted file mode 100644 index 08354be74..000000000 --- a/algoperf/workloads/lm/dev/test_build_input_queue_jax.py +++ /dev/null @@ -1,127 +0,0 @@ - -# TODO: redo with pmap!! - -import os -import jax -import tensorflow as tf -import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_jax.workload import LmWorkload - -# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.set_visible_devices([], 'GPU') - -# Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - - -N_GPUS = jax.local_device_count() - -print(f"jax.local_devices() = {jax.local_devices()}") -print(f"jax.local_device_count() = {jax.local_device_count()}") - -print(f"N_GPUS = {N_GPUS}") - -def check_batch(batch): - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:,1:], targets[:,:-1]) - - -def process_shard(batch): - inputs, targets = batch['inputs'], batch['targets'] - jax.debug.print("Processing on GPU with inputs: {shape}", shape=inputs.shape) - jax.debug.print("inputs {inputs}", inputs=inputs) - jax.debug.callback(check_batch, batch) - return inputs, targets - -# Apply process_batch across devices, sharding batch across devices -pmap_process = jax.pmap(process_shard, axis_name='batch') - - -def test_dataloader_jax(): - # Test config. - rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' - split = 'train' - global_batch_size = 8 - dtype = np.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - batch = next(input_queue) - - inputs, targets = batch['inputs'], batch['targets'] - print(f"Processing on GPU with inputs: {inputs.shape}") - - inputs, targets = pmap_process(batch) - print(f"Processing on GPU with inputs: {inputs.shape}") - print(f"Processing on GPU with inputs: {inputs}") - - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - # print(f"inputs[0]: {inputs[0]}") - # print(f"inputs[1]: {inputs[1]}") - - # for device_id in range(2): - # # Access the sharded data for each GPU - # print(inputs.shape) - # device_inputs = inputs[device_id] - # print(f" GPU {device_id} Inputs: {device_inputs.shape}") - - # @jax.pmap - # def process_batch(batch): - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - - # return inputs, targets - - # inputs, targets = batch['inputs'], batch['targets'] #process_batch(batch) - # print(f"inputs: {inputs[0]}") - - - -def main(): - test_dataloader_jax() - - -if __name__ == '__main__': - main() - From f3ba0593d955c657b6da8a07eede425509dbc6b9 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 10:00:44 +0100 Subject: [PATCH 10/98] fix comments --- algoperf/workloads/lm/input_pipeline.py | 2 +- datasets/dataset_setup.py | 27 +++++++------------------ 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e74490a16..bae1f5e45 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -38,7 +38,7 @@ def get_lm_dataset(data_rng: jax.random.PRNGKey, is_training = split == "train" shuffle = split in ['train', 'eval_train'] - dataset.set_format("tensorflow") # tf.int64 # TODO: is this needed? + dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? def tf_generator(): """Generates data in a TensorFlow-friendly format.""" diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index aab793832..8299133c1 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -711,8 +711,6 @@ def download_wmt(data_dir): def download_finewebedu(data_dir, tmp_dir): """Download FineWebEdu-10B.""" - # data_dir = "/fast/najroldi/data" - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") data_dir = os.path.join(data_dir, 'finewebedu') @@ -726,7 +724,7 @@ def download_finewebedu(data_dir, tmp_dir): 'HuggingFaceFW/fineweb-edu', name='sample-10BT', split='train', - # cache_dir=tmp_dir + cache_dir=tmp_dir ) ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size @@ -756,19 +754,11 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - from datasets import load_from_disk - tokenized_dataset = load_from_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) # Concat in chunks of max_seq_len - # TODO: this might take to much memory - # TODO: bug fix: Python's shutil.rmtree tried to delete a .nfs* file, but it was still in use (OSError: [Errno 16] Device or resource busy - # TODO: bug fix: I am losing tokens in the concat-chunk: num_tokens before split: 9_944_182_212 - # (1) loss happening because of batched=True: potentially losing the last tokens in the last batch of the 1024 batched examples - # NOTE: the current approach leads to data loss at batch boundaries, - # but concatenation *cannot* happen if batched=False, - # because concat_chunck relies on processing multiple examples at once. - # (2) loss happening because of nproc>1: potentially losing the last tokens in each process - # TODO: this does not allow to later change the seq_len... not a problem in AlgoPerf, but bad in plainLM + # TODO (nico): this might take to much memory + # TODO (nico): bug fix: Python's shutil.rmtree tried to delete .nfs file, but it was still in use (OSError: [Errno 16] Device or resource busy + # TODO (nico): make it sequential or increase batch_size in the map_setup def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -780,15 +770,12 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map( - concat_chunck,\ - **map_setup - ) - n_tokens = len(lm_dataset) * max_seq_length # 9_944_182_212 + lm_dataset = tokenized_dataset.map(concat_chunck, **map_setup) + n_tokens = len(lm_dataset) * max_seq_length logging.info(f"Number of tokens in dataset: {n_tokens:_}") # Split dataset into training and validation sets - # TODO: avoid (single doc) contamination between train and val + # TODO (nico): avoid (single doc) contamination, by splitting before concatenation VAL_TOKENS = 10_000_000 val_samples = VAL_TOKENS // max_seq_length + 1 val_dataset = lm_dataset.select(range(val_samples)) From 381451f04a34e4a78a5256f92e1e7c092e0eadeb Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 10:46:45 +0100 Subject: [PATCH 11/98] add class specifications --- algoperf/workloads/lm/lm_jax/workload.py | 36 +++- algoperf/workloads/lm/lm_pytorch/workload.py | 26 ++- algoperf/workloads/lm/workload.py | 201 +++++++++++++------ datasets/dataset_setup.py | 6 +- 4 files changed, 199 insertions(+), 70 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 773f8c54c..84377b4bc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,17 +1,47 @@ """LM workload implemented in Jax.""" import functools -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple +from absl import logging from flax import jax_utils +from flax import linen as nn +from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np +import optax from algoperf import param_utils +from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload - class LmWorkload(BaseLmWorkload): - pass + """LM JAX workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """aux_dropout_rate is used as attention_dropout_rate.""" + pass + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + pass + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + pass diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 0ff7884c7..404dc2532 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -23,6 +23,24 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + """aux_dropout_rate is used as attention_dropout_rate.""" + pass + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + pass + def _build_input_queue(self, data_rng: jax.random.PRNGKey, split: str, @@ -93,6 +111,10 @@ def _build_input_queue(self, } yield batch - - def eval_step(): + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" pass diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 7b1313dd7..e36d54625 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -5,6 +5,9 @@ import os from typing import Any, Dict, Optional, Tuple +from absl import flags +import torch.distributed as dist + import jax import numpy as np import torch @@ -12,17 +15,98 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +FLAGS = flags.FLAGS + USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseLmWorkload(spec.Workload): - """A LM workload.""" + """LM workload.""" _vocab_size: int = 32000 def __init__(self) -> None: super().__init__() - self._tokenizer = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + pass + + def has_reached_test_target(self, eval_result: float) -> bool: + return eval_result['test/ppl'] > self.test_target_value + + @property + def test_target_value(self) -> float: + pass + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + pass + + @property + def num_eval_train_examples(self) -> int: + pass + + @property + def num_validation_examples(self) -> int: + pass + + @property + def num_test_examples(self) -> int: + pass + + @property + def eval_batch_size(self) -> int: + pass + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + pass + + @property + def eval_period_time_sec(self) -> int: + pass + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + pass + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -43,65 +127,58 @@ def _build_input_queue(self, for batch in iter(ds): yield batch - def _eval_model_on_split(): - pass - - def eval_period_time_sec(): - pass - - def has_reached_test_target(): - pass - - def has_reached_validation_target(): - pass - - def init_model_fn(): - pass - - def is_output_params(): - pass - - def loss_fn(): - pass - - def loss_type(): - pass - - def max_allowed_runtime_sec(): - pass - - def model_fn(): - pass - - def num_eval_train_examples(): - pass - - def num_test_examples(): - pass - - def num_train_examples(): - pass - - def num_validation_examples(): - pass - - def step_hint(): - pass - - def test_target_value(): - pass - - def train_mean(): - pass - - def train_stddev(): - pass - - def validation_target_value(): - pass - - def target_metric_name(): - pass + @abc.abstractmethod + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True) + + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {'loss': mean_loss} - def eval_batch_size(): + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ pass + + + diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 8299133c1..fb8701f4d 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -711,11 +711,11 @@ def download_wmt(data_dir): def download_finewebedu(data_dir, tmp_dir): """Download FineWebEdu-10B.""" - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser("~/.cache/huggingface/datasets") data_dir = os.path.join(data_dir, 'finewebedu') - - _maybe_mkdir(tmp_dir) + tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None \ + else os.path.expanduser("~/.cache/huggingface/datasets") _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) # Use local disk instead of NFS for temp storage os.environ["TMPDIR"] = tmp_dir From f111d2e8baada7af619504a87974fa78f3e34d55 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 11:29:37 +0100 Subject: [PATCH 12/98] add workload LM info --- algoperf/workloads/workloads.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4712f4e25..6b99a25a6 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -114,6 +114,7 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, + 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' }, @@ -150,6 +151,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', + 'lm', 'ogbg', 'wmt' ] From 808d398ee2cf78e92cea29e2d0696eb6ce592929 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 11:32:48 +0100 Subject: [PATCH 13/98] restore data_utils.py tree map --- algoperf/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 068c21c03..37d1bd20f 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree_util.tree_map(_prepare, batch) + return jax.tree.map(_prepare, batch) def pad(tensor: np.ndarray, From 35f8f8942cb993628f1b20c3d29346e4d7b40e95 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 14:38:41 +0100 Subject: [PATCH 14/98] fixed NFS bug --- datasets/dataset_setup.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index fb8701f4d..a68da3ff5 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -708,26 +708,28 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) -def download_finewebedu(data_dir, tmp_dir): +def download_finewebedu(data_dir, tmp_dir=None): """Download FineWebEdu-10B.""" data_dir = os.path.join(data_dir, 'finewebedu') - tmp_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None \ - else os.path.expanduser("~/.cache/huggingface/datasets") + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets') + _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) - # Use local disk instead of NFS for temp storage os.environ["TMPDIR"] = tmp_dir ds = hf_datasets.load_dataset( 'HuggingFaceFW/fineweb-edu', name='sample-10BT', split='train', - cache_dir=tmp_dir + cache_dir=cache_dir ) - ds = ds.shuffle(seed=1996) # shuffle so that multiproc has shards of similar size + # Shuffle so that multiproc has shards of similar size. + ds = ds.shuffle(seed=1996) seq_len = 2048 max_seq_length = seq_len+1 @@ -754,11 +756,8 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - + # Concat in chunks of max_seq_len - # TODO (nico): this might take to much memory - # TODO (nico): bug fix: Python's shutil.rmtree tried to delete .nfs file, but it was still in use (OSError: [Errno 16] Device or resource busy - # TODO (nico): make it sequential or increase batch_size in the map_setup def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} From cbb6ee67c6eb4828b574987d45fde508e5f1db67 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Tue, 18 Mar 2025 15:02:27 +0100 Subject: [PATCH 15/98] train/val split before concat --- datasets/dataset_setup.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index a68da3ff5..5e27211e8 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -756,8 +756,21 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenizer.model_max_length = seq_len tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - - # Concat in chunks of max_seq_len + + # Find how many entries to take from dataset to have VAL_TOKENS in validation set. + VAL_TOKENS = 10_000_000 + tokens_accumulated, num_examples_for_val = 0, 0 + for example in tokenized_dataset: + tokens_accumulated += len(example['input_ids']) + num_examples_for_val += 1 + if tokens_accumulated >= VAL_TOKENS: + break + # Split in train and valid. + val_dataset = tokenized_dataset.select(range(num_examples_for_val)) + train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset))) + + # Concat in chunks of max_seq_len. + # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} @@ -769,18 +782,11 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: for k, t in concatenated_examples.items() } return result - lm_dataset = tokenized_dataset.map(concat_chunck, **map_setup) - n_tokens = len(lm_dataset) * max_seq_length - logging.info(f"Number of tokens in dataset: {n_tokens:_}") - - # Split dataset into training and validation sets - # TODO (nico): avoid (single doc) contamination, by splitting before concatenation - VAL_TOKENS = 10_000_000 - val_samples = VAL_TOKENS // max_seq_length + 1 - val_dataset = lm_dataset.select(range(val_samples)) - train_dataset = lm_dataset.select(range(val_samples, len(lm_dataset))) - logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length :_}") - logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length :_}") + # Concat text in validation and train sets. + val_dataset = val_dataset.map(concat_chunck, **map_setup) + train_dataset = train_dataset.map(concat_chunck, **map_setup) + logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") + logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}") # Save datasets train_dataset.save_to_disk(os.path.join(data_dir, f"train")) From 868987c2fd72ced8107048e20de44a7e303074e8 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 19 Mar 2025 09:41:05 +0100 Subject: [PATCH 16/98] renamed datasets to avoid conflict with HF --- {datasets => datasets_algoperf}/README.md | 0 .../dataset_setup.py | 17 ++++++++++------- .../librispeech_preprocess.py | 2 +- .../librispeech_tokenizer.py | 0 4 files changed, 11 insertions(+), 8 deletions(-) rename {datasets => datasets_algoperf}/README.md (100%) rename {datasets => datasets_algoperf}/dataset_setup.py (98%) rename {datasets => datasets_algoperf}/librispeech_preprocess.py (98%) rename {datasets => datasets_algoperf}/librispeech_tokenizer.py (100%) diff --git a/datasets/README.md b/datasets_algoperf/README.md similarity index 100% rename from datasets/README.md rename to datasets_algoperf/README.md diff --git a/datasets/dataset_setup.py b/datasets_algoperf/dataset_setup.py similarity index 98% rename from datasets/dataset_setup.py rename to datasets_algoperf/dataset_setup.py index 5e27211e8..21811e729 100644 --- a/datasets/dataset_setup.py +++ b/datasets_algoperf/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 datasets_algoperf/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -126,15 +126,15 @@ flags.DEFINE_boolean('fastmri', False, 'If --all=false, whether or not to download FastMRI.') +flags.DEFINE_boolean('finewebedu', + False, + 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('imagenet', False, 'If --all=false, whether or not to download Imagenet.') flags.DEFINE_boolean('librispeech', False, 'If --all=false, whether or not to download LibriSpeech.') -flags.DEFINE_boolean('finewebedu', - False, - 'If --all=false, whether or not to download FineWebEdu.') flags.DEFINE_boolean('mnist', False, 'If --all=false, whether or not to download MNIST.') @@ -727,6 +727,8 @@ def download_finewebedu(data_dir, tmp_dir=None): split='train', cache_dir=cache_dir ) + # TODO (nico): maybe save intermediate dataset to avoid re-downloading + # and allow re-chunking with different seq_len? # Shuffle so that multiproc has shards of similar size. ds = ds.shuffle(seed=1996) @@ -747,6 +749,7 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return_attention_mask=False ) tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info(f"Tokenizing...") tokenized_dataset = ds.map( tokenize, remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', @@ -783,6 +786,7 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: } return result # Concat text in validation and train sets. + logging.info(f"Concatenating and chunking...") val_dataset = val_dataset.map(concat_chunck, **map_setup) train_dataset = train_dataset.map(concat_chunck, **map_setup) logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") @@ -876,9 +880,8 @@ def main(_): download_wmt(data_dir) if FLAGS.all or FLAGS.finewebedu: - if not FLAGS.skip_download: - logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir) + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu(data_dir, tmp_dir) # pylint: enable=logging-format-interpolation diff --git a/datasets/librispeech_preprocess.py b/datasets_algoperf/librispeech_preprocess.py similarity index 98% rename from datasets/librispeech_preprocess.py rename to datasets_algoperf/librispeech_preprocess.py index a8c5cae1d..cd291e5b3 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets_algoperf/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets import librispeech_tokenizer +from datasets_algoperf import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/datasets_algoperf/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to datasets_algoperf/librispeech_tokenizer.py From dd59dedc97f99e994221775b1e980d845bfb908c Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Wed, 19 Mar 2025 09:55:11 +0100 Subject: [PATCH 17/98] renamed datasets to dataset --- {datasets_algoperf => dataset}/README.md | 0 {datasets_algoperf => dataset}/dataset_setup.py | 6 +++--- {datasets_algoperf => dataset}/librispeech_preprocess.py | 2 +- {datasets_algoperf => dataset}/librispeech_tokenizer.py | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename {datasets_algoperf => dataset}/README.md (100%) rename {datasets_algoperf => dataset}/dataset_setup.py (99%) rename {datasets_algoperf => dataset}/librispeech_preprocess.py (98%) rename {datasets_algoperf => dataset}/librispeech_tokenizer.py (100%) diff --git a/datasets_algoperf/README.md b/dataset/README.md similarity index 100% rename from datasets_algoperf/README.md rename to dataset/README.md diff --git a/datasets_algoperf/dataset_setup.py b/dataset/dataset_setup.py similarity index 99% rename from datasets_algoperf/dataset_setup.py rename to dataset/dataset_setup.py index 21811e729..0c7f33de6 100644 --- a/datasets_algoperf/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets_algoperf/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -74,8 +74,8 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import \ normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer import datasets as hf_datasets from transformers import AutoTokenizer diff --git a/datasets_algoperf/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 98% rename from datasets_algoperf/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index cd291e5b3..b96881332 100644 --- a/datasets_algoperf/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -15,7 +15,7 @@ from pydub import AudioSegment import tensorflow as tf -from datasets_algoperf import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets_algoperf/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets_algoperf/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py From 496b9c31f0bdd9a50e18a6907146969fd98e73cf Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 10:52:54 +0100 Subject: [PATCH 18/98] fix style --- .gitignore | 28 +++++++++++ algoperf/workloads/lm/input_pipeline.py | 50 ++++++++----------- algoperf/workloads/lm/lm_jax/workload.py | 15 +----- algoperf/workloads/lm/lm_pytorch/workload.py | 46 +++++++++-------- .../lm/tests/test_build_input_queue_torch.py | 18 +++---- algoperf/workloads/lm/workload.py | 12 ++--- 6 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..916a29ff4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +__pycache__/* +__pycache__ +*egg-info +*eggs +.vscode/ +env/ +venv/ +workdir/ +makefile +*.out +*.sh +*.swp +*/data/ +*events.out.tfevents* +algoperf/workloads/librispeech_conformer/data_dir +algoperf/workloads/librispeech_conformer/work_dir +*.flac +*.npy +*.csv +*.vocab +wandb/ +*.txt +scoring/plots/ + +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv + +algoperf/_version.py \ No newline at end of file diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index bae1f5e45..53fe79276 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -1,24 +1,22 @@ """Input pipeline for a LM dataset.""" import functools import os +from typing import Optional -from datasets import Dataset, load_from_disk -from typing import Dict, List, Optional, Union - +from datasets import load_from_disk import jax -import numpy as np import tensorflow as tf -import tensorflow_datasets as tfds from algoperf import data_utils from algoperf.pytorch_utils import pytorch_setup RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). -# This ensures that only the primary process (RANK == 0) uses TensorFlow's +# This ensures that only the primary process (RANK == 0) uses TensorFlow's # automatic optimization (AUTOTUNE), while other processes disable it (None). -# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine the optimal -# number of elements to prefetch or parallelize for dataset operations, improving performance. +# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine +# the optimal number of elements to prefetch or parallelize for dataset +# operations, improving performance. AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None @@ -44,25 +42,24 @@ def tf_generator(): """Generates data in a TensorFlow-friendly format.""" for example in dataset: yield { - "inputs": example["input_ids"][:-1], - "targets": example["input_ids"][1:], + "inputs": example["input_ids"][:-1], + "targets": example["input_ids"][1:], } # Create a TensorFlow dataset ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), - } - ) + tf_generator, + output_signature={ + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + }) # Avoid creating too many threads when using PyTorch DDP. # Limits TensorFlow's threading for non-primary processes (RANK != 0) - if RANK != 0: + if RANK != 0: options = tf.data.Options() - options.threading.private_threadpool_size = 1 # restrict dataset operations to a single thread - ds = ds.with_options(options) # apply threading restrictions + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) if shuffle: ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) @@ -70,10 +67,7 @@ def tf_generator(): if is_training: ds = ds.repeat() - # Batch the dataset, ensuring the last batch is dropped if not full during training - # i.e. it groups consecutive elements into fixed-size chunks. - # Instead of processing individual elements, the dataset yields batches (tensors with multiple elements), - # improving efficiency and parallelism in training + # Batch the dataset, grouping consecutive elements into fixed-size chunks. ds = ds.batch(global_batch_size, drop_remainder=is_training) ds = ds.prefetch(AUTOTUNE) @@ -83,9 +77,9 @@ def tf_generator(): # Shard the dataset across multiple GPUs/TPUs if necessary ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, + global_batch_size=global_batch_size), + ds) - return ds \ No newline at end of file + return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 84377b4bc..64d538dda 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,22 +1,11 @@ """LM workload implemented in Jax.""" -import functools -from typing import Any, Dict, Iterator, Optional, Tuple +from typing import Dict, Optional, Tuple -from absl import logging -from flax import jax_utils -from flax import linen as nn -from flax.training import common_utils -import jax -import jax.numpy as jnp -import numpy as np -import optax - -from algoperf import param_utils -from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload + class LmWorkload(BaseLmWorkload): """LM JAX workload.""" diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 404dc2532..e57d26390 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -3,16 +3,10 @@ import contextlib from typing import Dict, Iterator, Optional, Tuple -from absl import logging import jax -import tensorflow as tf import torch import torch.distributed as dist -from torch.nn import DataParallel as DP -import torch.nn.functional as F -from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload @@ -41,16 +35,17 @@ def model_fn( update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: pass - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) - + seq_len = 2048 # TODO: define it somewehere else DTYPE = torch.int32 # TODO: decide between int32 and int64. @@ -65,20 +60,25 @@ def _build_input_queue(self, num_batches=num_batches, repeat_final_dataset=repeat_final_dataset) weights = None - + while True: # Only iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return - inputs = torch.as_tensor(batch['inputs'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - targets = torch.as_tensor(batch['targets'], dtype=DTYPE, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + inputs = torch.as_tensor( + batch['inputs'], dtype=DTYPE, + device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + targets = torch.as_tensor( + batch['targets'], dtype=DTYPE, + device=DEVICE) # (N_GPUS, global_batch_size, seq_len) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.tensor(len(targets[0]), dtype=DTYPE, device=DEVICE) + per_device_batch_size = torch.tensor( + len(targets[0]), dtype=DTYPE, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # We don't broadcast the shard for RANK 0. dist.broadcast(inputs[1:], src=0) @@ -95,12 +95,16 @@ def _build_input_queue(self, dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the shard for RANK 0. - inputs = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) - targets = torch.empty((N_GPUS-1, per_device_batch_size, seq_len), dtype=DTYPE, device=DEVICE) + inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), + dtype=DTYPE, + device=DEVICE) + targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), + dtype=DTYPE, + device=DEVICE) dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) # RANK - 1 since we don't broadcast the shard for RANK 0. - inputs, targets = inputs[RANK-1], targets[RANK-1] + inputs, targets = inputs[RANK - 1], targets[RANK - 1] if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py index 83a18ec15..639e71491 100644 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -1,11 +1,6 @@ - import jax import torch -import pdb -import numpy as np - -from algoperf import random_utils as prng -from algoperf import spec + from algoperf.profiler import PassThroughProfiler from algoperf.pytorch_utils import pytorch_init from algoperf.pytorch_utils import pytorch_setup @@ -29,20 +24,20 @@ def test_dataloader_torch(): seq_len = 2048 local_batch_size = global_batch_size // N_GPUS - + workload = LmWorkload() data_rng = jax.random.PRNGKey(rng_seed) - + input_queue = workload._build_input_queue( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=global_batch_size) - + print(f"RANK {RANK} of {N_GPUS}") sync_ddp() - + # batch = next(input_queue) # inputs, targets = batch['inputs'], batch['targets'] # print(f"inputs.shape: {inputs.shape}") @@ -71,7 +66,7 @@ def test_dataloader_torch(): assert inputs.shape == (local_batch_size, seq_len) assert targets.shape == (local_batch_size, seq_len) - assert torch.equal(inputs[:,1:], targets[:,:-1]) + assert torch.equal(inputs[:, 1:], targets[:, :-1]) print(f"=== ALL TEST PASSED ===") @@ -84,4 +79,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index e36d54625..3d04be3c5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -3,14 +3,11 @@ import abc import math import os -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Optional from absl import flags -import torch.distributed as dist - import jax -import numpy as np -import torch +import torch.distributed as dist from algoperf import spec from algoperf.workloads.lm import input_pipeline @@ -155,7 +152,7 @@ def _eval_model_on_split(self, global_batch_size, num_batches, repeat_final_dataset=True) - + for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) loss += self._eval_batch(params, eval_batch) @@ -179,6 +176,3 @@ def loss_fn( (not synced across devices). """ pass - - - From 50989eb6a8a54c43225a4243f770a4419d431a81 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 10:57:06 +0100 Subject: [PATCH 19/98] fix formatting --- algoperf/workloads/lm/lm_pytorch/workload.py | 1 - submission_runner.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e57d26390..be6c94c46 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,6 +1,5 @@ """LM workload implemented in PyTorch.""" -import contextlib from typing import Dict, Iterator, Optional, Tuple import jax diff --git a/submission_runner.py b/submission_runner.py index d7df006bb..f8a66452d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -234,7 +234,7 @@ def train_once( dropout_rate = hyperparameters.dropout_rate if hasattr(hyperparameters, 'aux_dropout_rate'): aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( + model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ @@ -384,8 +384,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 5af0fdc1437d924e2e162de5100e66782d01a7e5 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:02:22 +0100 Subject: [PATCH 20/98] fix style --- algoperf/workloads/lm/lm_pytorch/workload.py | 16 ++++++++-------- algoperf/workloads/lm/workload.py | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index be6c94c46..606f16ad7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -45,8 +45,8 @@ def _build_input_queue( not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) - seq_len = 2048 # TODO: define it somewehere else - DTYPE = torch.int32 # TODO: decide between int32 and int64. + seq_len = self._seq_len # TODO: define it somewehere else? + dtype = torch.int32 # TODO: decide between int32 and int64. # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. @@ -66,10 +66,10 @@ def _build_input_queue( if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( - batch['inputs'], dtype=DTYPE, + batch['inputs'], dtype=dtype, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) targets = torch.as_tensor( - batch['targets'], dtype=DTYPE, + batch['targets'], dtype=dtype, device=DEVICE) # (N_GPUS, global_batch_size, seq_len) # Send batch to other devices when using DDP. @@ -77,7 +77,7 @@ def _build_input_queue( if not_train: # During eval, the batch size of the remainder might be different. per_device_batch_size = torch.tensor( - len(targets[0]), dtype=DTYPE, device=DEVICE) + len(targets[0]), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # We don't broadcast the shard for RANK 0. dist.broadcast(inputs[1:], src=0) @@ -90,15 +90,15 @@ def _build_input_queue( # Receive batch from rank 0. if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), dtype=DTYPE, device=DEVICE) + per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the shard for RANK 0. inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=DTYPE, + dtype=dtype, device=DEVICE) targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=DTYPE, + dtype=dtype, device=DEVICE) dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 3d04be3c5..aa6d188b3 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -21,6 +21,7 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 32000 + _seq_len: int = 2048 def __init__(self) -> None: super().__init__() From 26830999b92d26c729171cae141ee7abb3409463 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:32:47 +0100 Subject: [PATCH 21/98] fix style --- algoperf/workloads/lm/workload.py | 2 +- dataset/dataset_setup.py | 91 +++++++++++++++++++------------ 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index aa6d188b3..4eb6c74a5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -24,7 +24,7 @@ class BaseLmWorkload(spec.Workload): _seq_len: int = 2048 def __init__(self) -> None: - super().__init__() + pass @property def target_metric_name(self) -> str: diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 0c7f33de6..8f0b09ab7 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -80,7 +80,6 @@ import datasets as hf_datasets from transformers import AutoTokenizer -import math import functools import itertools import os @@ -713,7 +712,9 @@ def download_finewebedu(data_dir, tmp_dir=None): data_dir = os.path.join(data_dir, 'finewebedu') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser('~/.cache/huggingface/datasets') + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) @@ -722,75 +723,93 @@ def download_finewebedu(data_dir, tmp_dir=None): os.environ["TMPDIR"] = tmp_dir ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir - ) - # TODO (nico): maybe save intermediate dataset to avoid re-downloading + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + # TODO (nico): maybe save intermediate dataset to avoid re-downloading # and allow re-chunking with different seq_len? # Shuffle so that multiproc has shards of similar size. ds = ds.shuffle(seed=1996) seq_len = 2048 - max_seq_length = seq_len+1 + max_seq_length = seq_len + 1 map_setup = dict(batched=True, batch_size=1024, num_proc=8) # Tokenize - tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of tokenizer = {len(tokenizer)}") + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - return tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False - ) - tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization logging.info(f"Tokenizing...") tokenized_dataset = ds.map( - tokenize, - remove_columns=['text', 'id', 'dump', 'url', 'file_path', 'language', - 'language_score', 'token_count', 'score', 'int_score'], - **map_setup - ) - tokenizer.model_max_length = seq_len - + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ], + **map_setup) + lm_tokenizer.model_max_length = seq_len + tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - # Find how many entries to take from dataset to have VAL_TOKENS in validation set. - VAL_TOKENS = 10_000_000 + # Find how many entries to take from dataset to have val_tokens in validation set. + val_tokens = 10_000_000 # TODO: decide this value. tokens_accumulated, num_examples_for_val = 0, 0 for example in tokenized_dataset: tokens_accumulated += len(example['input_ids']) num_examples_for_val += 1 - if tokens_accumulated >= VAL_TOKENS: - break + if tokens_accumulated >= val_tokens: + break # Split in train and valid. val_dataset = tokenized_dataset.select(range(num_examples_for_val)) - train_dataset = tokenized_dataset.select(range(num_examples_for_val, len(tokenized_dataset))) + train_dataset = tokenized_dataset.select( + range(num_examples_for_val, len(tokenized_dataset))) # Concat in chunks of max_seq_len. # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: """Concatenate text and generate chunks of max_seq_length""" - concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + concatenated_examples = { + k: list(itertools.chain(*examples[k])) for k in examples.keys() + } total_length = len(concatenated_examples[list(examples.keys())[0]]) if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length + total_length = (total_length // max_seq_length) * max_seq_length result = { - k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] - for k, t in concatenated_examples.items() + k: [ + t[i:i + max_seq_length] + for i in range(0, total_length, max_seq_length) + ] for k, t in concatenated_examples.items() } return result + # Concat text in validation and train sets. logging.info(f"Concatenating and chunking...") val_dataset = val_dataset.map(concat_chunck, **map_setup) train_dataset = train_dataset.map(concat_chunck, **map_setup) - logging.info(f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") - logging.info(f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}") + logging.info( + f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") + logging.info( + f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}" + ) # Save datasets train_dataset.save_to_disk(os.path.join(data_dir, f"train")) From 6b7ee29684ee9bf1f9564032f65c09373212c4a4 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:36:27 +0100 Subject: [PATCH 22/98] fix yapf --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index f8a66452d..468a04c7c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -384,8 +384,8 @@ def train_once( train_step_end_time - train_state['last_step_end_time']) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). if prepare_for_eval is not None: From 46b645b2ac4a4f4b93fe4ee6324b07f412fb81b3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 20 Mar 2025 11:38:40 +0100 Subject: [PATCH 23/98] fix style --- dataset/dataset_setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 8f0b09ab7..6587f1439 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -797,7 +797,8 @@ def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: k: [ t[i:i + max_seq_length] for i in range(0, total_length, max_seq_length) - ] for k, t in concatenated_examples.items() + ] for k, + t in concatenated_examples.items() } return result From b3ae6474be93f07c578f885bae484773b8a65515 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 27 Mar 2025 15:56:25 +0000 Subject: [PATCH 24/98] HF datasets pipeline --- algoperf/workloads/lm/input_pipeline.py | 75 ++++++++++- .../lm/tests/test_hf_input_pipeline.py | 116 ++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 algoperf/workloads/lm/tests/test_hf_input_pipeline.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 53fe79276..ea4cb9d63 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -3,12 +3,17 @@ import os from typing import Optional -from datasets import load_from_disk import jax +import jax.numpy as jnp import tensorflow as tf +import torch +import torch.nn.functional as F +from transformers import GPT2Tokenizer from algoperf import data_utils from algoperf.pytorch_utils import pytorch_setup +from datasets import load_dataset +from datasets import load_from_disk RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). @@ -20,6 +25,74 @@ AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None +def get_hf_dataloader(cache_dir: str, + data_rng: jax.random.PRNGKey, + batch_size: int = 8, + seq_len: int = 32, + framework: str = "torch", + split="train"): + """ + Create a data loader from HuggingFace's FineWeb dataset. + + Args: + cache_dir: Directory to cache the dataset + batch_size: Number of sequences per batch + seq_len: Length of each sequence + framework: Either "torch" or "jax" to specify output tensor type + split: Dataset split to load + """ + # Initialize tokenizer and get vocab size + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + # Load the FineWeb dataset in streaming mode + fw = load_dataset( + "HuggingFaceFW/fineweb-edu", + name="sample-10BT", + split=split, + streaming=True, + cache_dir=cache_dir) + fw = fw.batch(batch_size=batch_size, drop_last_batch=True) + if split in ['train', 'eval_train']: + fw = fw.shuffle(seed=int(data_rng[-1])) + + def _tokenize(x): + """Tokenize and pad text to seq_len+1 tokens.""" + if framework == "torch": + tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) + elif framework == "jax": + tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() + pad_length = seq_len - tokens.shape[0] + if pad_length > 0: + tokens = jnp.pad( + tokens, + pad_length, + mode="constant", + constant_values=tokenizer.pad_token_id) + return tokens[:seq_len + 1] + + def batch_iterator(): + for doc in fw: + if framework == "torch": + token_ids = torch.stack([_tokenize(x) for x in doc['text']]) + # Take first seq_len+1 tokens and convert to one-hot + tokens = F.one_hot(token_ids, num_classes=vocab_size).float() + # Split into input/target + inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] + inputs, targets = inputs.to("cuda"), targets.to("cuda") + elif framework == "jax": + token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) + tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) + inputs, targets = tokens[:, :-1], tokens[:, 1:] + devices = jax.devices("gpu") + inputs, targets = jax.device_put(inputs), jax.device_put(targets) + yield inputs, targets + + return batch_iterator() + + def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, diff --git a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py new file mode 100644 index 000000000..36bab0d02 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py @@ -0,0 +1,116 @@ +"""Tests for LM HuggingFace input pipeline.""" +import os + +import jax +import jax.numpy as jnp +import torch +from transformers import GPT2Tokenizer + +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + +def main(): + # Setup test environment + cache_dir = "/home/ak4605/data" + if not os.path.exists(cache_dir): + raise FileNotFoundError(f"Cache directory {cache_dir} not found") + + data_rng = jax.random.PRNGKey(42) + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + vocab_size = tokenizer.vocab_size + + print("Running JAX output shapes and types test...") + batch_size = 8 + seq_len = 32 + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == jnp.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == jnp.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert jnp.all(jnp.sum(inputs, axis=-1) == 1), "Inputs should be one-hot encoded" + assert jnp.all(jnp.sum(targets, axis=-1) == 1), "Targets should be one-hot encoded" + print("✓ JAX test passed") + + print("\nRunning Torch output shapes and types test...") + loader = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="torch", + split="train", + data_rng=data_rng) + inputs, targets = next(loader) + assert inputs.shape == (batch_size, seq_len, vocab_size), \ + f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" + assert targets.shape == (batch_size, seq_len, vocab_size), \ + f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" + assert inputs.dtype == torch.float32, \ + f"Expected inputs dtype float32, got {inputs.dtype}" + assert targets.dtype == torch.float32, \ + f"Expected targets dtype float32, got {targets.dtype}" + assert torch.all(torch.sum(inputs, dim=-1) == 1), "Inputs should be one-hot encoded" + assert torch.all(torch.sum(targets, dim=-1) == 1), "Targets should be one-hot encoded" + print("✓ Torch test passed") + + print("\nTesting consistent batching with same seed...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="train", + data_rng=jax.random.PRNGKey(42)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Input batches should be identical with same seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Target batches should be identical with same seed" + print("✓ Consistent batching test passed") + + print("\nTesting eval split doesn't shuffle...") + loader1 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(42)) + batch1 = next(loader1) + + loader2 = get_hf_dataloader( + cache_dir=cache_dir, + batch_size=batch_size, + seq_len=seq_len, + framework="jax", + split="eval", + data_rng=jax.random.PRNGKey(999)) + batch2 = next(loader2) + + assert jnp.array_equal(batch1[0], batch2[0]), "Eval inputs should be identical regardless of seed" + assert jnp.array_equal(batch1[1], batch2[1]), "Eval targets should be identical regardless of seed" + print("✓ Eval no shuffling test passed") + + print("\nAll tests passed successfully!") + + +if __name__ == "__main__": + main() From f095d4b167dabc0e1aeb925b871f32f427fc22c8 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 27 Mar 2025 17:03:05 +0000 Subject: [PATCH 25/98] Testing with linear model --- algoperf/workloads/lm/input_pipeline.py | 1 - algoperf/workloads/lm/lm_jax/models.py | 18 +++++++++ algoperf/workloads/lm/lm_jax/workload.py | 26 +++++++++++-- algoperf/workloads/lm/lm_pytorch/models.py | 18 +++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 32 +++++++++++++-- .../workloads/lm/tests/test_linear_model.py | 39 +++++++++++++++++++ algoperf/workloads/lm/workload.py | 17 ++------ 7 files changed, 129 insertions(+), 22 deletions(-) create mode 100644 algoperf/workloads/lm/lm_jax/models.py create mode 100644 algoperf/workloads/lm/lm_pytorch/models.py create mode 100644 algoperf/workloads/lm/tests/test_linear_model.py diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index ea4cb9d63..cc658501e 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -86,7 +86,6 @@ def batch_iterator(): token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] - devices = jax.devices("gpu") inputs, targets = jax.device_put(inputs), jax.device_put(targets) yield inputs, targets diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py new file mode 100644 index 000000000..edfc102fa --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -0,0 +1,18 @@ +from flax import linen as nn +import jax.numpy as jnp + +class LinearModel(nn.Module): + vocab_size: int + + @nn.compact + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense( + 512, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(inputs) + return nn.Dense( + self.vocab_size, + kernel_init=nn.initializers.normal(0.02), + bias_init=nn.initializers.zeros + )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 64d538dda..30b0c7867 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,8 +2,12 @@ from typing import Dict, Optional, Tuple +import jax.numpy as jnp +from flax import jax_utils +from algoperf import param_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_jax.models import LinearModel class LmWorkload(BaseLmWorkload): @@ -14,18 +18,32 @@ def init_model_fn( rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - pass + + model = LinearModel(vocab_size=self._vocab_size) + input_shape = (1, self._seq_len, self._vocab_size) + variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) + model_state, params = variables.pop('params') + + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + model_state = jax_utils.replicate(model_state) + params = jax_utils.replicate(params) + + return params, model_state def model_fn( self, params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - pass + + del mode, rng, update_batch_norm # Not used for linear model + inputs = batch['inputs'] + logits = self._model.apply({'params': params, **model_state}, inputs) + return logits, model_state def _eval_batch(self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py new file mode 100644 index 000000000..545763924 --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/models.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + +class LinearLayer(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.bottleneck = nn.Linear(vocab_size, 512) + self.output = nn.Linear(512, vocab_size) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.bottleneck.weight, std=0.02) + nn.init.zeros_(self.bottleneck.bias) + nn.init.normal_(self.output.weight, std=0.02) + nn.init.zeros_(self.output.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output(self.bottleneck(x)) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 606f16ad7..3395aa08f 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -5,10 +5,13 @@ import jax import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.lm_pytorch.models import LinearLayer USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -21,18 +24,39 @@ def init_model_fn( rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - pass + + if hasattr(self, '_model'): + self._model.reset_parameters() + return self._model, None + + torch.manual_seed(rng[0]) + self._model = LinearLayer(vocab_size=self._vocab_size) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None def model_fn( self, params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - pass + + del model_state, rng, update_batch_norm # Not used for linear model + model = params + inputs = batch['inputs'].float() # Convert one-hot to float + logits = model(inputs) + return logits, None def _build_input_queue( self, diff --git a/algoperf/workloads/lm/tests/test_linear_model.py b/algoperf/workloads/lm/tests/test_linear_model.py new file mode 100644 index 000000000..31cd1d577 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_linear_model.py @@ -0,0 +1,39 @@ +import jax +import jax.numpy as jnp +import torch + +TEST_SEQ_LEN = 512 + +def test_pytorch_linear(): + from algoperf.workloads.lm.lm_pytorch.models import LinearLayer + vocab_size = 32000 + model = LinearLayer(vocab_size) + + batch_size = 8 + seq_len = TEST_SEQ_LEN + inputs = torch.randn(batch_size, seq_len, vocab_size) + outputs = model(inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not torch.isnan(outputs).any() + +def test_jax_linear(): + from algoperf.workloads.lm.lm_jax.models import LinearModel + + vocab_size = 32000 + seq_len = TEST_SEQ_LEN + batch_size = 8 + model = LinearModel(vocab_size) + rng = jax.random.PRNGKey(0) + params = model.init(rng, jnp.ones((1, seq_len, vocab_size))) + + inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size)) + outputs = model.apply(params, inputs) + + assert outputs.shape == (batch_size, seq_len, vocab_size) + assert not jnp.isnan(outputs).any() + +if __name__ == '__main__': + test_pytorch_linear() + test_jax_linear() + print("All tests passed!") diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 4eb6c74a5..a06b17fdc 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -20,8 +20,8 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" - _vocab_size: int = 32000 - _seq_len: int = 2048 + _vocab_size: int = 50257 + _seq_len: int = 512 def __init__(self) -> None: pass @@ -106,6 +106,7 @@ def activation(self) -> str: def glu(self) -> bool: return True + @abc.abstractmethod def _build_input_queue(self, data_rng: jax.random.PRNGKey, split: str, @@ -113,17 +114,7 @@ def _build_input_queue(self, global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False): - ds = input_pipeline.get_lm_dataset( - data_rng, - split, - data_dir, - vocab_size=self._vocab_size, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) - - for batch in iter(ds): - yield batch + """Build an input queue for the given split.""" @abc.abstractmethod def _eval_batch(self, From 0c22f3df420968cf820cbcc826f84a61751f95f5 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 12:28:05 -0400 Subject: [PATCH 26/98] lm workload with linear model --- .../workloads/cifar/cifar_jax/workload.py | 11 -- algoperf/workloads/lm/input_pipeline.py | 2 +- algoperf/workloads/lm/lm_jax/models.py | 5 +- algoperf/workloads/lm/lm_jax/workload.py | 82 +++++++++-- algoperf/workloads/lm/lm_pytorch/workload.py | 129 ++++++++++-------- algoperf/workloads/lm/workload.py | 59 ++++---- pyproject.toml | 3 +- .../nesterov/jax/submission.py | 8 +- submission_runner.py | 6 +- 9 files changed, 187 insertions(+), 118 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -71,17 +71,6 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cc658501e..440de64c1 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -87,7 +87,7 @@ def batch_iterator(): tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield inputs, targets + yield {'inputs': inputs, 'targets': targets} return batch_iterator() diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index edfc102fa..72ee5bd83 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -7,12 +7,13 @@ class LinearModel(nn.Module): @nn.compact def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: x = nn.Dense( - 512, + 10, kernel_init=nn.initializers.normal(0.02), bias_init=nn.initializers.zeros )(inputs) return nn.Dense( self.vocab_size, kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros + bias_init=nn.initializers.zeros, + name="output" )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 30b0c7867..7cb50302f 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,33 +2,57 @@ from typing import Dict, Optional, Tuple +import jax import jax.numpy as jnp +import optax from flax import jax_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using HuggingFace FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="jax", + split=split) + return loader + def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - model = LinearModel(vocab_size=self._vocab_size) + self._model = LinearModel(vocab_size=self._vocab_size) input_shape = (1, self._seq_len, self._vocab_size) - variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - + params_rng, init_rng = jax.random.split(rng) + print(params_rng) + # variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32)) + variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32)) + params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) - + params = sharding_utils.shard_replicated(params) + model_state = None return params, model_state def model_fn( @@ -40,15 +64,51 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm # Not used for linear model + del mode, rng, update_batch_norm, model_state inputs = batch['inputs'] - logits = self._model.apply({'params': params, **model_state}, inputs) - return logits, model_state + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1)) + else: + # Dense labels + loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch] + + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3395aa08f..0d0281690 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -66,68 +66,38 @@ def _build_input_queue( global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: - not_train = split != 'train' - per_device_batch_size = int(global_batch_size / N_GPUS) - - seq_len = self._seq_len # TODO: define it somewehere else? - dtype = torch.int32 # TODO: decide between int32 and int64. - - # Only create and iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len weights = None - - while True: - # Only iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - batch = next(np_iter) # pylint: disable=stop-iteration-return - inputs = torch.as_tensor( - batch['inputs'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - targets = torch.as_tensor( - batch['targets'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) - - # Send batch to other devices when using DDP. - if USE_PYTORCH_DDP: - if not_train: - # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.tensor( - len(targets[0]), dtype=dtype, device=DEVICE) - dist.broadcast(per_device_batch_size, src=0) - # We don't broadcast the shard for RANK 0. - dist.broadcast(inputs[1:], src=0) - dist.broadcast(targets[1:], src=0) - - # RANK 0 extracts his shard. If not DDP, this just flattens. - inputs, targets = inputs[0], targets[0] - - else: - # Receive batch from rank 0. - if not_train: - # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), dtype=dtype, device=DEVICE) + + dtype = torch.long + is_train = split == 'train' + + for batch in loader: + inputs, targets = batch + + if USE_PYTORCH_DDP: + if not is_train: + # During eval, the batch size of the remainder might be different + per_device_batch_size = torch.tensor( + len(targets[0]), dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) - - # N_GPUS - 1 since we don't broadcast the shard for RANK 0. - inputs = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=dtype, - device=DEVICE) - targets = torch.empty((N_GPUS - 1, per_device_batch_size, seq_len), - dtype=dtype, - device=DEVICE) + + # Broadcast to all devices dist.broadcast(inputs, src=0) dist.broadcast(targets, src=0) - # RANK - 1 since we don't broadcast the shard for RANK 0. - inputs, targets = inputs[RANK - 1], targets[RANK - 1] + + if weights is None: + weights = torch.ones(inputs.shape[0], device=DEVICE) if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) @@ -138,10 +108,51 @@ def _build_input_queue( } yield batch + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'output.weight' in param_name or 'output.bias' in param_name + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(targets * log_probs) + return loss + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in PyTorch.""" + vocab_size = logits_batch.shape[-1] + + if len(label_batch.shape) == len(logits_batch.shape): + # One-hot labels + log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) + loss = -torch.sum(label_batch * log_probs, dim=-1) + else: + # Dense labels + loss = torch.nn.functional.cross_entropy( + logits_batch, + label_batch, + reduction='none') + + if mask_batch is not None: + loss = loss * mask_batch + + n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + return { + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss + } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index a06b17fdc..c10bf13e8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -11,6 +11,7 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS @@ -21,10 +22,13 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 512 + _seq_len: int = 5 + warmup_factor: float = 0.1 def __init__(self) -> None: - pass + super().__init__() + self._param_shapes = None + self._param_types = None @property def target_metric_name(self) -> str: @@ -36,14 +40,14 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - pass + return 20.0 # Target perplexity - def has_reached_test_target(self, eval_result: float) -> bool: - return eval_result['test/ppl'] > self.test_target_value + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return eval_result['test/ppl'] <= self.test_target_value @property def test_target_value(self) -> float: - pass + return 20.0 # Target perplexity @property def loss_type(self) -> spec.LossType: @@ -51,23 +55,23 @@ def loss_type(self) -> spec.LossType: @property def num_train_examples(self) -> int: - pass + return 1000000 # Example size @property def num_eval_train_examples(self) -> int: - pass + return 10000 # Subset for evaluation @property def num_validation_examples(self) -> int: - pass + return 50000 @property def num_test_examples(self) -> int: - pass + return 50000 @property def eval_batch_size(self) -> int: - pass + return 8 @property def train_mean(self): @@ -79,16 +83,16 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - pass + return 3600 * 4 # 4 hours @property def eval_period_time_sec(self) -> int: - pass + return 600 # 10 minutes @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - pass + return 100000 @property def pre_ln(self) -> bool: @@ -116,13 +120,22 @@ def _build_input_queue(self, repeat_final_dataset: bool = False): """Build an input queue for the given split.""" - @abc.abstractmethod def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + + loss_dict = self.loss_fn(batch['targets'], logits) + return loss_dict['summed'] def _eval_model_on_split(self, split: str, @@ -145,9 +158,10 @@ def _eval_model_on_split(self, num_batches, repeat_final_dataset=True) + loss = 0.0 for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch) + loss += self._eval_batch(params, eval_batch, model_state, rng) if USE_PYTORCH_DDP: dist.all_reduce(loss) mean_loss = loss.item() / num_examples @@ -155,16 +169,11 @@ def _eval_model_on_split(self, # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. + @abc.abstractmethod def loss_fn( self, - label_batch: spec.Tensor, # Dense or one-hot labels. + label_batch: spec.Tensor, logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - pass + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" diff --git a/pyproject.toml b/pyproject.toml index f4ebdaee3..745c6c680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ version_file = "algoperf/_version.py" [project.optional-dependencies] # All workloads full = [ - "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", ] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] @@ -96,6 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] +lm = ["transformers", "datasets"] # Frameworks jax_core_deps = [ diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -90,12 +90,6 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -# @functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4)) def train_step(workload, opt_update_fn, model_state, @@ -272,6 +266,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index fa300916e..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -250,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -712,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From 99c7b9b70a374a25d6ac29c4f9a0f7c95e57c1aa Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 12:46:53 -0400 Subject: [PATCH 27/98] add nanodo model --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 345 ++++++++++++++++++ algoperf/workloads/lm/lm_jax/workload.py | 56 ++- .../paper_baselines/adamw/jax/submission.py | 4 +- 3 files changed, 386 insertions(+), 19 deletions(-) create mode 100644 algoperf/workloads/lm/lm_jax/nanodo_model.py diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py new file mode 100644 index 000000000..d21fd5090 --- /dev/null +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -0,0 +1,345 @@ +# Self-contained version of the DecoderOnly Transformer from NanoDO + +import dataclasses +from functools import partial + +from flax import linen as nn +import jax +import jax.numpy as jnp + +# =========== Transformer Decoder-only Model ========== + + + +@dataclasses.dataclass +class DoConfig: + """Hyper-parameters for Transformer decoder-only.""" + + D: int # model/embed dim = qkv dim + H: int # num attention heads + L: int # max context/sequence length + N: int # number of transformer block layers + V: int # vocab size + F: int # FF inner dimension + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", out_axis=0 + ) + dtype: jnp.dtype = jnp.float32 + rmsnorm_epsilon: float = 1e-6 + multiple_of: int = 256 + tie_embeddings: bool = True # Whether to tie input and output embeddings + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: DoConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + # Use Xavier uniform initialization explicitly + xavier_init = nn.initializers.xavier_uniform() + linear = partial( + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + ) + hidden_dim = cfg.multiple_of * ( + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = linear(cfg.D)(x_BxLxF) + return x_BxLxD + +@partial(jax.jit, static_argnums=(0,1,2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack([ + jnp.cos(freqs)[None, :, None, :], + jnp.sin(freqs)[None, :, None, :] + ], axis=3) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack([ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] + ], axis=-1) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: DoConfig + + def setup(self): + cfg = self.cfg + assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" + self.Dh = cfg.D // cfg.H + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.H, self.Dh), + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + self.multilinear_query = self.multilinear(name="query") + self.multilinear_key = self.multilinear(name="key") + self.multilinear_value = self.multilinear(name="value") + self.output_projection = nn.DenseGeneral( + features=cfg.D, + name="attn_out_proj", + # axis=(-2, -1), # + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Scale queries + q_BxLxHxDh /= self.Dh**0.5 + + # Compute attention scores + att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: DoConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: DoConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.V, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name="output_proj" + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + batch_size = y_BxL.shape[0] + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.L: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + model = TransformerDo(cfg) + + # Print model info + print(f"\nModel Configuration:") + print(f" - Model dimension (D): {cfg.D}") + print(f" - Number of heads (H): {cfg.H}") + print(f" - Max sequence length (L): {cfg.L}") + print(f" - Number of layers (N): {cfg.N}") + print(f" - Vocabulary size (V): {cfg.V}") + print(f" - Feed forward dimension (F): {cfg.F}") + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + ) + + # Initialize model parameters + print("\nInitializing model parameters...") + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f"Total parameters: {param_count:,}") + + # Make a prediction (forward pass) + print("\nRunning forward pass...") + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") + print(f"Output data type: {logits.dtype}") + + # Print sample logits (first 5 positions of the first sequence) + print("\nSample logits (first sequence, first 5 positions, first 5 values):") + for position in range(min(5, L)): + print(f" Position {position}: {logits[0, position, :5]}") + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + # Test the predict function + print("\nTesting predict function...") + # Use a shorter + short_seq = x_BxL[:, :10] + print(f"Input sequence shape: {short_seq.shape}") + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print("\nPredicted token IDs (first sequence, first 10 positions):") + print(predictions[0, :10]) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 7cb50302f..9fdfe6f60 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -10,7 +10,8 @@ from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + TransformerDo, DoConfig, init_rope, apply_rope) from algoperf.workloads.lm.input_pipeline import get_hf_dataloader @@ -42,12 +43,22 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - self._model = LinearModel(vocab_size=self._vocab_size) - input_shape = (1, self._seq_len, self._vocab_size) + # Initialize NanoDO transformer model + cfg = DoConfig( + D=512, # model dim + H=8, # num heads + L=self._seq_len, + N=6, # num layers + V=self._vocab_size, + F=2048, # feedforward dim + dtype=jnp.float32 + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + params_rng, init_rng = jax.random.split(rng) - print(params_rng) - # variables = model.init(init_rng, jnp.ones(input_shape, jnp.float32)) - variables = jax.jit(self._model.init)({'params': params_rng}, jnp.ones(input_shape, jnp.float32)) + variables = jax.jit(self._model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.int32)) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -66,6 +77,11 @@ def model_fn( del mode, rng, update_batch_norm, model_state inputs = batch['inputs'] + + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) return logits, None @@ -76,23 +92,29 @@ def loss_fn( mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in JAX.""" - vocab_size = logits_batch.shape[-1] + # Convert one-hot labels to token IDs if needed + if len(label_batch.shape) == len(logits_batch.shape): # one-hot + label_batch = jnp.argmax(label_batch, axis=-1) - if len(label_batch.shape) == len(logits_batch.shape): - # One-hot labels - loss = -jnp.sum(label_batch * jax.nn.log_softmax(logits_batch, axis=-1)) - else: - # Dense labels - loss = -jax.nn.log_softmax(logits_batch)[jnp.arange(label_batch.shape[0]), label_batch] + # Reshape for sequence modeling + logits = logits_batch.reshape(-1, logits_batch.shape[-1]) + labels = label_batch.reshape(-1) + + # Compute cross-entropy loss + loss = -jnp.sum( + jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) if mask_batch is not None: - loss = loss * mask_batch + mask = mask_batch.reshape(-1) + loss = loss * mask + n_valid = mask.sum() + else: + n_valid = labels.shape[0] - n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] return { - 'summed': loss.sum(), + 'summed': loss, 'n_valid_examples': n_valid, - 'per_example': loss + 'per_example': loss / n_valid # Return per-token loss } def is_output_params(self, param_name: str) -> bool: diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 6c6d19ef8..dca9a6b95 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -75,7 +75,6 @@ def _loss_fn(params): spec.ForwardPassMode.TRAIN, rng, update_batch_norm=True,) - jax.debug.print("logits: {logits}", logits=logits) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -163,7 +162,6 @@ def update_params( replicated, # loss replicated # grad_norm )) - # print(batch) new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, @@ -229,6 +227,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 706d9f74046a0f1c90256ae584b45e30a38e4349 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 3 Apr 2025 13:26:15 -0400 Subject: [PATCH 28/98] torch model --- algoperf/param_utils.py | 2 + .../workloads/lm/lm_pytorch/plainlm_model.py | 298 ++++++++++++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 57 ++-- .../adamw/pytorch/submission.py | 2 + 4 files changed, 341 insertions(+), 18 deletions(-) create mode 100644 algoperf/workloads/lm/lm_pytorch/plainlm_model.py diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 05d882404..24f981546 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -43,6 +43,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py new file mode 100644 index 000000000..627a0e16d --- /dev/null +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -0,0 +1,298 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from dataclasses import dataclass +from typing import Tuple + + + +@dataclass +class ModelConfig: + vocab_size: int + seq_len: int + dim: int + expand: float + n_layers: int + n_heads: int + rmsnorm_eps: float = 1e-6 + tie_embeddings: bool = False + + +class MLP(nn.Module): + + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + + # Initialize with Xavier uniform + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0, + condense_ratio: int = 1): + inv_freqs = 1.0 / (theta**(torch.arange( + 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) + t = torch.arange(end, dtype=torch.float32, + device=inv_freqs.device) / condense_ratio + freqs = torch.outer(t, inv_freqs).float() + return torch.stack([ + torch.cos(freqs)[None, :, None, :], + torch.sin(freqs)[None, :, None, :] + ], + dim=4) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, + freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - + qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.dim % cfg.n_heads == 0 + self.dim = cfg.dim + self.n_heads = cfg.n_heads + self.head_dim = cfg.dim // cfg.n_heads + + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + k = k.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + v = v.view(bsz, seqlen, self.n_heads, + self.head_dim) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) + + out = out.transpose(1, 2).contiguous().view(bsz, seqlen, + d) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, cfg): + super().__init__() + self.n_layers = cfg.n_layers + self.cfg = cfg + head_dim = cfg.dim // cfg.n_heads + assert cfg.dim % cfg.n_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.n_layers)]) + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer('freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], + persistent=False) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.dim // self.cfg.n_heads + new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) + self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + # For debugging + predictions = [] + + batch_size = x.shape[0] + seq_len = x.shape[1] + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Print top 5 tokens for debugging + if i == 0: + print("\nPyTorch detailed prediction:") + top5_values, top5_indices = torch.topk(next_token_logits[0], 5) + for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): + prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() + print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + predictions.append(next_token.item()) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + print(f" Full predictions step by step: {predictions}") + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith("fc2.weight"): # mlp/glu output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + if n.endswith("w_out.weight"): # attn output layer + torch.nn.init.normal_(p, + mean=0.0, + std=0.02 / math.sqrt(2 * self.n_layers)) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print("Initializing transformer model and running forward pass...") + + seq_length = 512 + + # Define model configuration + config = ModelConfig( + vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + dim=768, # Embedding dimension + expand=4.0, # MLP expansion factor + n_layers=12, # Number of transformer layers + n_heads=12, # Number of attention heads + rmsnorm_eps=1e-6, # RMSNorm epsilon + tie_embeddings=True # Tie embedding and output weights + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if (not self.lm_head.weight + is self.embed_tokens.weight): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 0d0281690..45ad0828f 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -11,7 +11,7 @@ from algoperf import pytorch_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_pytorch.models import LinearLayer +from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -26,11 +26,23 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: if hasattr(self, '_model'): - self._model.reset_parameters() + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() return self._model, None torch.manual_seed(rng[0]) - self._model = LinearLayer(vocab_size=self._vocab_size) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + dim=512, # Model dimension + expand=4, # MLP expansion factor + n_layers=6, # Number of transformer layers + n_heads=8, # Number of attention heads + rmsnorm_eps=1e-6, + tie_embeddings=True + ) + self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -46,15 +58,20 @@ def init_model_fn( def model_fn( self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del model_state, rng, update_batch_norm # Not used for linear model + del model_state, rng, update_batch_norm model = params - inputs = batch['inputs'].float() # Convert one-hot to float + + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + logits = model(inputs) return logits, None @@ -83,13 +100,14 @@ def _build_input_queue( is_train = split == 'train' for batch in loader: - inputs, targets = batch + inputs = batch['inputs'] + targets = batch['targets'] if USE_PYTORCH_DDP: if not is_train: # During eval, the batch size of the remainder might be different per_device_batch_size = torch.tensor( - len(targets[0]), dtype=dtype, device=DEVICE) + targets.shape[0], dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) # Broadcast to all devices @@ -97,10 +115,8 @@ def _build_input_queue( dist.broadcast(targets, src=0) if weights is None: - weights = torch.ones(inputs.shape[0], device=DEVICE) - - if weights is None: - weights = torch.ones(per_device_batch_size, device=DEVICE) + batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item() + weights = torch.ones((batch_size, seq_len), device=DEVICE) batch = { 'inputs': inputs, 'targets': targets, @@ -110,7 +126,7 @@ def _build_input_queue( def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" - return 'output.weight' in param_name or 'output.bias' in param_name + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name def _eval_batch(self, params: spec.ParameterContainer, @@ -121,11 +137,17 @@ def _eval_batch(self, model = params logits, _ = self.model_fn( model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - targets = batch['targets'] - # Calculate cross-entropy loss - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - loss = -torch.sum(targets * log_probs) + # Handle both one-hot and token ID targets + targets = batch['targets'] + if targets.dim() == 3: # one-hot + loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) + else: # token IDs + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction='sum' + ) return loss def loss_fn( self, @@ -146,7 +168,6 @@ def loss_fn( logits_batch, label_batch, reduction='none') - if mask_batch is not None: loss = loss * mask_batch diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 21d9b6b57..bdeaaf95b 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -173,6 +173,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 4 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From c335e341913dc6b1a747f2d3407e71a8d8e66ab6 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 29 May 2025 14:22:50 +0000 Subject: [PATCH 29/98] lm workload dataset integration in jax --- .../workloads/cifar/cifar_jax/workload.py | 11 - algoperf/workloads/lm/input_pipeline.py | 12 +- algoperf/workloads/lm/lm_jax/models.py | 3 +- algoperf/workloads/lm/lm_jax/workload.py | 68 +++- algoperf/workloads/lm/lm_pytorch/workload.py | 49 +-- algoperf/workloads/lm/workload.py | 313 +++++++++--------- .../nesterov/jax/submission.py | 8 +- submission_runner.py | 6 +- 8 files changed, 261 insertions(+), 209 deletions(-) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f827fac87..fd990eeaa 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -71,17 +71,6 @@ def _build_input_queue( cache, repeat_final_dataset) - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics - # and we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cc658501e..8f68fcb55 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -87,19 +87,19 @@ def batch_iterator(): tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) inputs, targets = tokens[:, :-1], tokens[:, 1:] inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield inputs, targets - + batch = { + "inputs": inputs, + "targets": targets, + } + yield batch return batch_iterator() def get_lm_dataset(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - vocab_size: int, global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False, - vocab_path: Optional[str] = None): + num_batches: Optional[int] = None): """Load HF dataset and return a TF dataset.""" dataset_path = os.path.join(data_dir, split) diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index edfc102fa..7913f2c67 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -14,5 +14,6 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: return nn.Dense( self.vocab_size, kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros + bias_init=nn.initializers.zeros, + name="output" )(x) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 30b0c7867..6ad0e7d3d 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -2,16 +2,36 @@ from typing import Dict, Optional, Tuple +import jax import jax.numpy as jnp +import optax from flax import jax_utils from algoperf import param_utils +from algoperf import sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + def _build_input_queue(self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False): + """Build an input queue using pre-cached FineWeb dataset.""" + del num_batches + del repeat_final_dataset + loader = get_lm_dataset( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + return loader def init_model_fn( self, @@ -21,14 +41,15 @@ def init_model_fn( model = LinearModel(vocab_size=self._vocab_size) input_shape = (1, self._seq_len, self._vocab_size) - variables = model.init(rng, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(model.init)({'params': params_rng}, + jnp.ones(input_shape, jnp.float32)) + params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) - + params = sharding_utils.shard_replicated(params) + model_state = None + self._model = model return params, model_state def model_fn( @@ -40,15 +61,40 @@ def model_fn( rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm # Not used for linear model - inputs = batch['inputs'] - logits = self._model.apply({'params': params, **model_state}, inputs) - return logits, model_state + del mode, rng, update_batch_norm, model_state + inputs = jax.nn.one_hot(batch['inputs'], self._vocab_size, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, # One-hot labels. + logits_batch: spec.Tensor, # Dense logits. + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: Optional[float] = 0.0) -> Dict[str, spec.Tensor]: + del mask_batch, label_smoothing + logits_flat = logits_batch.reshape(-1, self._vocab_size) + targets = jax.nn.one_hot(label_batch, self._vocab_size, axis=-1) + targets_flat = targets.reshape(-1, self._vocab_size) + # Cross-entropy loss + loss = -jnp.sum(targets_flat * jax.nn.log_softmax(logits_flat, axis=-1)) + n_valid_examples = logits_flat.shape[0] + return {'summed': loss, 'n_valid_examples': n_valid_examples} + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) + return loss diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3395aa08f..2c6862160 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -66,35 +66,30 @@ def _build_input_queue( global_batch_size: int, num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: - not_train = split != 'train' - per_device_batch_size = int(global_batch_size / N_GPUS) - - seq_len = self._seq_len # TODO: define it somewehere else? - dtype = torch.int32 # TODO: decide between int32 and int64. - - # Only create and iterate over tf input pipeline in one Python process to - # avoid creating too many threads. - if RANK == 0: - np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + """Build an input queue for the given split.""" + from algoperf.workloads.lm.input_pipeline import get_hf_dataloader + + loader = get_hf_dataloader( + cache_dir=data_dir, + data_rng=data_rng, + batch_size=global_batch_size, + seq_len=self._seq_len, + framework="torch", + split=split) + seq_len = self._seq_len weights = None - + while True: # Only iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: - batch = next(np_iter) # pylint: disable=stop-iteration-return + batch = next(dataset_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( batch['inputs'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len) targets = torch.as_tensor( batch['targets'], dtype=dtype, - device=DEVICE) # (N_GPUS, global_batch_size, seq_len) + device=DEVICE) # (N_GPUS, per_device_batch_size, seq_len) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: @@ -138,10 +133,22 @@ def _build_input_queue( } yield batch + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'output.weight' in param_name or 'output.bias' in param_name + def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" - pass + model = params + logits, _ = self.model_fn( + model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + targets = batch['targets'] + + # Calculate cross-entropy loss + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(targets * log_probs) + return loss diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index a06b17fdc..e6b33e3e4 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -11,160 +11,171 @@ from algoperf import spec from algoperf.workloads.lm import input_pipeline +from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS -USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ class BaseLmWorkload(spec.Workload): - """LM workload.""" - - _vocab_size: int = 50257 - _seq_len: int = 512 - - def __init__(self) -> None: - pass - - @property - def target_metric_name(self) -> str: - """The name of the target metric (useful for scoring/processing code).""" - return 'ppl' - - def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] > self.validation_target_value - - @property - def validation_target_value(self) -> float: - pass - - def has_reached_test_target(self, eval_result: float) -> bool: - return eval_result['test/ppl'] > self.test_target_value - - @property - def test_target_value(self) -> float: - pass - - @property - def loss_type(self) -> spec.LossType: - return spec.LossType.SOFTMAX_CROSS_ENTROPY - - @property - def num_train_examples(self) -> int: - pass - - @property - def num_eval_train_examples(self) -> int: - pass - - @property - def num_validation_examples(self) -> int: - pass - - @property - def num_test_examples(self) -> int: - pass - - @property - def eval_batch_size(self) -> int: - pass - - @property - def train_mean(self): - raise NotImplementedError - - @property - def train_stddev(self): - raise NotImplementedError - - @property - def max_allowed_runtime_sec(self) -> int: - pass - - @property - def eval_period_time_sec(self) -> int: - pass - - @property - def step_hint(self) -> int: - """Approx. steps the baseline can do in the allowed runtime budget.""" - pass - - @property - def pre_ln(self) -> bool: - return True - - @property - def attention_temp(self) -> float: - return 1.0 - - @property - def activation(self) -> str: - return 'silu' - - @property - def glu(self) -> bool: - return True - - @abc.abstractmethod - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): - """Build an input queue for the given split.""" - - @abc.abstractmethod - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - num_batches = int(math.ceil(num_examples / global_batch_size)) - if split not in self._eval_iters: - # These iterators will repeat indefinitely. - self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) - - for _ in range(num_batches): - eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch) - if USE_PYTORCH_DDP: - dist.all_reduce(loss) - mean_loss = loss.item() / num_examples - return {'loss': mean_loss} - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. - def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - pass + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 512 + + def __init__(self) -> None: + pass + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return "ppl" + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result["validation/ppl"] > self.validation_target_value + + @property + def validation_target_value(self) -> float: + pass + + def has_reached_test_target(self, eval_result: float) -> bool: + return eval_result["test/ppl"] > self.test_target_value + + @property + def test_target_value(self) -> float: + pass + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + pass + + @property + def num_eval_train_examples(self) -> int: + pass + + @property + def num_validation_examples(self) -> int: + pass + + @property + def num_test_examples(self) -> int: + pass + + @property + def eval_batch_size(self) -> int: + return 8 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + pass + + @property + def eval_period_time_sec(self) -> int: + pass + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + # FIXME: should replace this with a real value later. + return 10000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return "silu" + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): + """Build an input queue for the given split.""" + + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) + + loss = 0.0 + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + loss += self._eval_batch(params, eval_batch, model_state, rng) + if USE_PYTORCH_DDP: + dist.all_reduce(loss) + mean_loss = loss.item() / num_examples + return {"loss": mean_loss} + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ + pass diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 49e46109b..c570e382b 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -90,12 +90,6 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -# @functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4)) def train_step(workload, opt_update_fn, model_state, @@ -272,6 +266,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'lm': + return 8 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/submission_runner.py b/submission_runner.py index fa300916e..fd1eb8259 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -250,7 +250,8 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_deepspeech', + 'lm' ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -712,7 +713,8 @@ def main(_): 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', - 'criteo1tb' + 'criteo1tb', + 'lm' ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From af8cce4d61e7f79916d7293127121ebaa4a4d7ce Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 03:20:46 +0000 Subject: [PATCH 30/98] set package versions for transformers and datasets --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 745c6c680..5e9c21f47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] -lm = ["transformers", "datasets"] +lm = ["transformers==4.25.4", "datasets==3.6.0"] # Frameworks jax_core_deps = [ From d68c54e0aa023570abc94cea97f5757bfb0baca8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 5 Jun 2025 04:02:41 +0000 Subject: [PATCH 31/98] use train_test_split method to shuffle and split fineweb-edu dataset --- dataset/dataset_setup.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 6587f1439..7a83a03f6 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -770,18 +770,10 @@ def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) - # Find how many entries to take from dataset to have val_tokens in validation set. - val_tokens = 10_000_000 # TODO: decide this value. - tokens_accumulated, num_examples_for_val = 0, 0 - for example in tokenized_dataset: - tokens_accumulated += len(example['input_ids']) - num_examples_for_val += 1 - if tokens_accumulated >= val_tokens: - break # Split in train and valid. - val_dataset = tokenized_dataset.select(range(num_examples_for_val)) - train_dataset = tokenized_dataset.select( - range(num_examples_for_val, len(tokenized_dataset))) + dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) + train_dataset = dataset_split_dict['train'] + val_dataset = dataset_split_dict['test'] # Concat in chunks of max_seq_len. # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. From 9737367473f35b206333edc46f9c193ec8dda821 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 19:45:32 +0000 Subject: [PATCH 32/98] modifications to fwedu datasetup --- dataset/dataset_setup.py | 164 +++++++++++++++++---------------------- 1 file changed, 73 insertions(+), 91 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 7a83a03f6..584189c4a 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -191,6 +191,7 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean('skip_tokenization', False, 'Skip Fineweb-edu tokenization.') FLAGS = flags.FLAGS @@ -707,106 +708,87 @@ def download_wmt(data_dir): ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) -def download_finewebedu(data_dir, tmp_dir=None): +def download_finewebedu(data_dir, + tmp_dir=None, + skip_download=False, + skip_tokenization=False): """Download FineWebEdu-10B.""" - data_dir = os.path.join(data_dir, 'finewebedu') - tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, - 'lm') if tmp_dir is not None else os.path.expanduser( - '~/.cache/huggingface/datasets') - - _maybe_mkdir(data_dir) - _maybe_mkdir(tmp_dir) - _maybe_mkdir(cache_dir) - - os.environ["TMPDIR"] = tmp_dir - - ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir) - # TODO (nico): maybe save intermediate dataset to avoid re-downloading - # and allow re-chunking with different seq_len? - - # Shuffle so that multiproc has shards of similar size. - ds = ds.shuffle(seed=1996) - - seq_len = 2048 - max_seq_length = seq_len + 1 - map_setup = dict(batched=True, batch_size=1024, num_proc=8) - - # Tokenize - lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") - - def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - add_eos = lambda seq: (seq + lm_tokenizer.eos_token) if seq else seq - add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - return lm_tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False) - - lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization - logging.info(f"Tokenizing...") - tokenized_dataset = ds.map( - tokenize, - remove_columns=[ - 'text', - 'id', - 'dump', - 'url', - 'file_path', - 'language', - 'language_score', - 'token_count', - 'score', - 'int_score' - ], - **map_setup) - lm_tokenizer.model_max_length = seq_len - - tokenized_dataset.save_to_disk(os.path.join(data_dir, f"fwedu_10B_tokenized")) + if not skip_download: + data_dir = os.path.join(data_dir, 'finewebedu') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = os.path.join(tmp_dir, + 'lm') if tmp_dir is not None else os.path.expanduser( + '~/.cache/huggingface/datasets') + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ["TMPDIR"] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(tmp_dir, 'fwedu_10B_raw') + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples["text"]), + return_special_tokens_mask=False, + return_attention_mask=False) + + lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization + logging.info("Tokenizing...") + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score' + ],) + + tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + else: + tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split_dict['train'] val_dataset = dataset_split_dict['test'] - # Concat in chunks of max_seq_len. - # NOTE: expected token loss by batched concat_chunk. Truncates leftover tokens that don't fill a full max_seq_length chunk. - def concat_chunck(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - """Concatenate text and generate chunks of max_seq_length""" - concatenated_examples = { - k: list(itertools.chain(*examples[k])) for k in examples.keys() - } - total_length = len(concatenated_examples[list(examples.keys())[0]]) - if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length - result = { - k: [ - t[i:i + max_seq_length] - for i in range(0, total_length, max_seq_length) - ] for k, - t in concatenated_examples.items() - } - return result - - # Concat text in validation and train sets. - logging.info(f"Concatenating and chunking...") - val_dataset = val_dataset.map(concat_chunck, **map_setup) - train_dataset = train_dataset.map(concat_chunck, **map_setup) - logging.info( - f"Number of tokens in val_dataset: {len(val_dataset) * max_seq_length:_}") - logging.info( - f"Number of tokens in train_dataset: {len(train_dataset) * max_seq_length:_}" - ) + # Convert to tensorflow_datasets.Dataset objects + train_dataset = train_dataset.to_tf_dataset() + val_dataset = train_dataset.to_tf_dataset() # Save datasets - train_dataset.save_to_disk(os.path.join(data_dir, f"train")) - val_dataset.save_to_disk(os.path.join(data_dir, f"val")) + train_dataset.Save(os.path.join(data_dir, "train")) + val_dataset.save(os.path.join(data_dir, "val")) + + return def main(_): @@ -893,7 +875,7 @@ def main(_): if FLAGS.all or FLAGS.finewebedu: logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir, tmp_dir) + download_finewebedu(data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization) # pylint: enable=logging-format-interpolation From 1bf0750e094a695176e8e3bc45ffd979abe9e237 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 19:46:26 +0000 Subject: [PATCH 33/98] rename fwedu data dir --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 584189c4a..ae27aab18 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -715,7 +715,7 @@ def download_finewebedu(data_dir, """Download FineWebEdu-10B.""" if not skip_download: - data_dir = os.path.join(data_dir, 'finewebedu') + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' cache_dir = os.path.join(tmp_dir, 'lm') if tmp_dir is not None else os.path.expanduser( From a33339117b4c79d5fa946f4f7ed029087ab5a630 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 20:46:21 +0000 Subject: [PATCH 34/98] fix --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index ae27aab18..289a1faa6 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -734,7 +734,7 @@ def download_finewebedu(data_dir, cache_dir=cache_dir) ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) else: - ds = hf_datasets.load_from_disk(tmp_dir, 'fwedu_10B_raw') + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) if not skip_tokenization: # Tokenize From 05dc4dd7102670cebb8ac3a8875b34387d57b9b6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Jun 2025 21:22:57 +0000 Subject: [PATCH 35/98] add back batch mapping in tokenization for fwedu --- dataset/dataset_setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 289a1faa6..f50274615 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -769,7 +769,10 @@ def add_eos_batched(seqs): 'token_count', 'score', 'int_score' - ],) + ], + batched=True, + batch_size=1024, + num_proc=8) tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) else: From b374cf8db62e99e1594dea90b46a7f69a5bb04c6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:12:24 +0000 Subject: [PATCH 36/98] debugging --- dataset/dataset_setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index f50274615..2c46f4ebc 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -779,9 +779,11 @@ def add_eos_batched(seqs): tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. + print(type(tokenized_dataset)) dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split_dict['train'] val_dataset = dataset_split_dict['test'] + print(type(train_dataset)) # Convert to tensorflow_datasets.Dataset objects train_dataset = train_dataset.to_tf_dataset() From c0c1e3c32c46d65cb7511891b32429aeeb05f90c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:13:48 +0000 Subject: [PATCH 37/98] debugging --- dataset/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 2c46f4ebc..c18e72ea4 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -776,7 +776,7 @@ def add_eos_batched(seqs): tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) else: - tokenized_dataset.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) # Split in train and valid. print(type(tokenized_dataset)) From f76dc392fa83a1da25194d401aa03a9dd6dc9c6a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:23:24 +0000 Subject: [PATCH 38/98] debugging --- dataset/dataset_setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index c18e72ea4..414b78609 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -778,6 +778,7 @@ def add_eos_batched(seqs): else: tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset.to_tf_dataset() # Split in train and valid. print(type(tokenized_dataset)) dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) From e805fa7997daae83deea4e5336801af195270c1a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 10 Jun 2025 00:45:07 +0000 Subject: [PATCH 39/98] use tfds to shuffle and split dataset --- dataset/dataset_setup.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 414b78609..747d06d27 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -778,20 +778,18 @@ def add_eos_batched(seqs): else: tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) - tokenized_dataset.to_tf_dataset() - # Split in train and valid. - print(type(tokenized_dataset)) - dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42) - train_dataset = dataset_split_dict['train'] - val_dataset = dataset_split_dict['test'] - print(type(train_dataset)) - # Convert to tensorflow_datasets.Dataset objects - train_dataset = train_dataset.to_tf_dataset() - val_dataset = train_dataset.to_tf_dataset() + tokenized_dataset = tokenized_dataset.to_tf_dataset() - # Save datasets - train_dataset.Save(os.path.join(data_dir, "train")) + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, "train")) val_dataset.save(os.path.join(data_dir, "val")) return From c9e9abcdf0cc9c817c1683f7a40d94a9372752f3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Oct 2025 03:40:29 +0000 Subject: [PATCH 40/98] add command for fineweb-edu --- dataset/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dataset/README.md b/dataset/README.md index 1aeb83239..50ca11985 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -453,3 +453,13 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + +### Fineweb-EDU 10B +From `algorithmic-efficiency` run: + +```bash +python3 python3 datasets/dataset_setup.py \ +--data_dir $DATA_DIR \ +--temp_dir $DATA_DIR/tmp \ +--fineweb_edu +``` \ No newline at end of file From e4323deca83a86ad1d703f056157dfcb0e0b1650 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 2 Oct 2025 03:42:16 +0000 Subject: [PATCH 41/98] fix --- dataset/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/README.md b/dataset/README.md index 50ca11985..1bfd9bf73 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -458,7 +458,7 @@ python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_v From `algorithmic-efficiency` run: ```bash -python3 python3 datasets/dataset_setup.py \ +python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --fineweb_edu From f0c6e75ad70cb2c4242014c1522abb3b3bf9aa2e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 3 Oct 2025 06:23:26 +0000 Subject: [PATCH 42/98] update calls to sharing utils --- algoperf/workloads/lm/lm_jax/workload.py | 4 ++-- algoperf/workloads/lm/workload.py | 2 +- .../baselines/external_tuning/jax_nadamw_full_budget.py | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index e73a5bfaf..81dde95fc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -7,7 +7,7 @@ import optax from flax import jax_utils from algoperf import param_utils -from algoperf import sharding_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.lm.workload import BaseLmWorkload from algoperf.workloads.lm.lm_jax.models import LinearModel @@ -79,7 +79,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = sharding_utils.shard_replicated(params) + params = jax_sharding_utils.replicate(params) model_state = None return params, model_state diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 6b71c7952..2a9777354 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -92,7 +92,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 100000 + return 7000 @property def pre_ln(self) -> bool: diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..6e40cdab1 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -394,6 +394,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 + elif workload_name == 'lm': + return 128 elif workload_name == 'mnist': return 16 else: From f4ffbe709f6a867ea95ae55f4b47032caee98c4a Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 6 Oct 2025 17:09:11 +0000 Subject: [PATCH 43/98] Fix torch sharding issue, update input pipeline and workload classes to use int32 for tensor types and add dropout rate parameter --- algoperf/workloads/lm/input_pipeline.py | 4 +- algoperf/workloads/lm/lm_jax/workload.py | 5 ++- algoperf/workloads/lm/lm_pytorch/workload.py | 37 ++++++++++--------- .../lm/tests/test_build_input_queue_torch.py | 15 +++++--- algoperf/workloads/lm/workload.py | 3 +- 5 files changed, 37 insertions(+), 27 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index db345700e..c010b32af 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -119,8 +119,8 @@ def tf_generator(): ds = tf.data.Dataset.from_generator( tf_generator, output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int64), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int64), + "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), + "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), }) # Avoid creating too many threads when using PyTorch DDP. diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 81dde95fc..1f6b3c2b2 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -90,8 +90,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del mode, rng, update_batch_norm, model_state + update_batch_norm: bool, + dropout_rate: float) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed if inputs.ndim == 3: # one-hot encoded diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 36e441e7e..e5dafdd3c 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -6,7 +6,8 @@ import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP - +from itertools import islice +from algoperf import data_utils from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec @@ -84,19 +85,22 @@ def _build_input_queue( num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" - from algoperf.workloads.lm.input_pipeline import get_hf_dataloader - - loader = get_hf_dataloader( - cache_dir=data_dir, + from algoperf.workloads.lm.input_pipeline import get_lm_dataset + local_batch_size = global_batch_size // N_GPUS + + loader = get_lm_dataset( data_rng=data_rng, - batch_size=global_batch_size, - seq_len=self._seq_len, - framework="torch", - split=split) + split=split, + data_dir=data_dir, + global_batch_size=local_batch_size, + num_batches=num_batches + ) + if USE_PYTORCH_DDP: + loader = islice(loader, RANK, None, N_GPUS) seq_len = self._seq_len weights = None - dtype = torch.long + dtype = torch.int32 is_train = split == 'train' for batch in loader: @@ -109,17 +113,16 @@ def _build_input_queue( per_device_batch_size = torch.tensor( targets.shape[0], dtype=dtype, device=DEVICE) dist.broadcast(per_device_batch_size, src=0) - + local_batch_size = per_device_batch_size.item() # Broadcast to all devices - dist.broadcast(inputs, src=0) - dist.broadcast(targets, src=0) + #dist.broadcast(inputs, src=0) + #dist.broadcast(targets, src=0) if weights is None: - batch_size = targets.shape[0] if not USE_PYTORCH_DDP else per_device_batch_size.item() - weights = torch.ones((batch_size, seq_len), device=DEVICE) + weights = torch.ones((local_batch_size, seq_len), device=DEVICE) batch = { - 'inputs': inputs, - 'targets': targets, + 'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype), + 'targets': torch.tensor(targets, device=DEVICE, dtype=dtype), 'weights': weights, } yield batch diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py index 639e71491..827272037 100644 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py @@ -17,9 +17,9 @@ def sync_ddp(): def test_dataloader_torch(): # Test config. rng_seed = 1996 - data_dir = '/fast/najroldi/data/finewebedu' + data_dir = '/home/ak4605/data/finewebedu/' split = 'train' - global_batch_size = 8 + global_batch_size = 64 dtype = torch.int32 seq_len = 2048 @@ -44,35 +44,40 @@ def test_dataloader_torch(): # print(f"inputs: {inputs}") # Start test. - for _ in range(100): + for _ in range(1): batch = next(input_queue) + print(f"RANK {RANK} got batch") assert type(batch) == dict assert 'inputs' in batch assert 'targets' in batch inputs, targets = batch['inputs'], batch['targets'] - + print(f"RANK {RANK} inputs.shape: {inputs.shape}") + print(f"RANK {RANK} targets.shape: {targets.shape}") + print(f"RANK {RANK} type(inputs): {type(inputs)}") assert type(inputs) == torch.Tensor assert type(targets) == torch.Tensor assert inputs.device == DEVICE assert targets.device == DEVICE - assert inputs.dtype == dtype assert targets.dtype == dtype + print(local_batch_size, seq_len) assert inputs.shape == (local_batch_size, seq_len) assert targets.shape == (local_batch_size, seq_len) assert torch.equal(inputs[:, 1:], targets[:, :-1]) + print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") print(f"=== ALL TEST PASSED ===") def main(): profiler = PassThroughProfiler() + print(USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS) pytorch_init(USE_PYTORCH_DDP, RANK, profiler) test_dataloader_torch() diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 2a9777354..986a98297 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -132,7 +132,8 @@ def _eval_batch(self, model_state, spec.ForwardPassMode.EVAL, rng, - update_batch_norm=False) + update_batch_norm=False, + dropout_rate=None) loss_dict = self.loss_fn(batch['targets'], logits) return loss_dict['summed'] From 5c85c7e278ffa540d65b1d49f0bd1d0cad732052 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 6 Oct 2025 17:39:35 +0000 Subject: [PATCH 44/98] test working, lm workload training not working (debugging) --- algoperf/workloads/lm/lm_jax/workload.py | 3 +- .../lm/tests/test_build_input_queue_jax.py | 60 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_jax.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 1f6b3c2b2..5401ad240 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -33,9 +33,10 @@ def _build_input_queue(self, split=split, data_dir=data_dir, global_batch_size=global_batch_size) + loader = map(jax_sharding_utils.shard_along_batch_dim, loader) return loader - def _build_input_queue(self, + def _build_hf_input_queue(self, data_rng: jax.random.PRNGKey, split: str, data_dir: str, diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_jax.py b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py new file mode 100644 index 000000000..b9adc70d2 --- /dev/null +++ b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py @@ -0,0 +1,60 @@ +import jax +import jax.numpy as jnp + +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads.lm.lm_jax.workload import LmWorkload +import os + +RANK = os.environ.get('RANK', 0) + +def test_dataloader_jax(): + # Test config. + rng_seed = 1996 + data_dir = '/home/ak4605/data/finewebedu/' + split = 'train' + global_batch_size = 64 + dtype = jnp.int32 + seq_len = 2048 + + workload = LmWorkload() + data_rng = jax.random.PRNGKey(rng_seed) + input_queue = workload._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size) + + for _ in range(1): + + batch = next(input_queue) + print(f"RANK {RANK} got batch") + + assert type(batch) == dict + assert 'inputs' in batch + assert 'targets' in batch + + inputs, targets = batch['inputs'], batch['targets'] + print(f"RANK {RANK} inputs.shape: {inputs.shape}") + print(f"RANK {RANK} targets.shape: {targets.shape}") + print(f"RANK {RANK} type(inputs): {type(inputs)}") + + jax.debug.inspect_array_sharding(inputs, callback=print) + assert inputs.dtype == dtype + assert targets.dtype == dtype + + assert inputs.shape == (global_batch_size, seq_len) + assert targets.shape == (global_batch_size, seq_len) + + assert jnp.equal(inputs[:, 1:], targets[:, :-1]).all() + print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") + + print(f"=== ALL TEST PASSED ===") + + +def main(): + profiler = PassThroughProfiler() + test_dataloader_jax() + + +if __name__ == '__main__': + main() From a59dfda3a7ce87b5cad550f2332aaf049f59c8f6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 6 Oct 2025 18:33:29 +0000 Subject: [PATCH 45/98] updates to input_pipeline and model spec --- algoperf/workloads/lm/input_pipeline.py | 257 +++++++++---------- algoperf/workloads/lm/lm_jax/nanodo_model.py | 2 +- algoperf/workloads/lm/lm_jax/workload.py | 36 +-- algoperf/workloads/lm/lm_pytorch/workload.py | 5 +- algoperf/workloads/lm/workload.py | 98 +++---- 5 files changed, 187 insertions(+), 211 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index c010b32af..e674170e4 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -1,154 +1,129 @@ """Input pipeline for a LM dataset.""" + import functools import os from typing import Optional import jax -import jax.numpy as jnp import tensorflow as tf -import torch -import torch.nn.functional as F -from transformers import GPT2Tokenizer from algoperf import data_utils -from algoperf.pytorch_utils import pytorch_setup -from datasets import load_dataset -from datasets import load_from_disk - -RANK = pytorch_setup()[1] -# Avoid multithreading in all processes but the first (rank 0). -# This ensures that only the primary process (RANK == 0) uses TensorFlow's -# automatic optimization (AUTOTUNE), while other processes disable it (None). -# tf.data.AUTOTUNE is a constant that lets TensorFlow automatically determine -# the optimal number of elements to prefetch or parallelize for dataset -# operations, improving performance. -AUTOTUNE = tf.data.AUTOTUNE if RANK == 0 else None - - -def get_hf_dataloader(cache_dir: str, - data_rng: jax.random.PRNGKey, - batch_size: int = 8, - seq_len: int = 32, - framework: str = "torch", - split="train"): + +AUTOTUNE = tf.data.experimental.AUTOTUNE +PAD_ID = -1 + +TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} + +SEQUENCE_LENGTH = 2048 +MAX_CORPUS_CHARS = 1_000_000_000 +SHUFFLE_BUFFER_SIZE = 1_000_000 +VOCAB_SIZE = 50_257 + + +def batch_with_padding( + dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, +): + """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. + + Args: + dataset: tf.data.Dataset + batch_size: batch size of resulting batched dataset + padded_shapes: shapes of the padded batches + padding_id: value for padding, for elements in new batch + + Returns: """ - Create a data loader from HuggingFace's FineWeb dataset. - - Args: - cache_dir: Directory to cache the dataset - batch_size: Number of sequences per batch - seq_len: Length of each sequence - framework: Either "torch" or "jax" to specify output tensor type - split: Dataset split to load - """ - # Initialize tokenizer and get vocab size - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - vocab_size = tokenizer.vocab_size - # Load the FineWeb dataset in streaming mode - fw = load_dataset( - "HuggingFaceFW/fineweb-edu", - name="sample-10BT", - split=split, - streaming=True, - cache_dir=cache_dir) - fw = fw.batch(batch_size=batch_size, drop_last_batch=True) - if split in ['train', 'eval_train']: - fw = fw.shuffle(seed=int(data_rng[-1])) - - def _tokenize(x): - """Tokenize and pad text to seq_len+1 tokens.""" - if framework == "torch": - tokens = tokenizer(x, return_tensors="pt")["input_ids"].squeeze() - pad_length = seq_len - tokens.shape[0] - if pad_length > 0: - tokens = F.pad(tokens, pad_length, value=tokenizer.pad_token_id) - elif framework == "jax": - tokens = tokenizer(x, return_tensors="jax")["input_ids"].squeeze() - pad_length = seq_len - tokens.shape[0] - if pad_length > 0: - tokens = jnp.pad( - tokens, - pad_length, - mode="constant", - constant_values=tokenizer.pad_token_id) - return tokens[:seq_len + 1] - - def batch_iterator(): - for doc in fw: - if framework == "torch": - token_ids = torch.stack([_tokenize(x) for x in doc['text']]) - # Take first seq_len+1 tokens and convert to one-hot - tokens = F.one_hot(token_ids, num_classes=vocab_size).float() - # Split into input/target - inputs, targets = tokens[:, :-1, :], tokens[:, 1:, :] - inputs, targets = inputs.to("cuda"), targets.to("cuda") - elif framework == "jax": - token_ids = jnp.stack([_tokenize(x) for x in doc['text']]) - tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size) - inputs, targets = tokens[:, :-1], tokens[:, 1:] - inputs, targets = jax.device_put(inputs), jax.device_put(targets) - yield {'inputs': inputs, 'targets': targets} - - return batch_iterator() - - -def get_lm_dataset(data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None): + batched_dataset = dataset.batch(batch_size, drop_remainder=False) + + # tf.data.Dataset.padded.batch pads elements in the batch so we call it + # again with batch_size=1 to pad each element in original batch. + padded_batched_dataset = batched_dataset.padded_batch( + 1, padded_shapes=padded_shapes, padding_values=padding_id + ) + + # Remove extra dimension resulting from the batch_size=1. + padded_batched_dataset = padded_batched_dataset.unbatch() + + return padded_batched_dataset + + +def get_data_iter(data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None,): + + ds = get_lm_dataset(data_rng, split, data_dir, global_batch_size, num_batches) + + it = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) + + return iter(it) + +def get_lm_dataset( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, +): """Load HF dataset and return a TF dataset.""" - - dataset_path = os.path.join(data_dir, split) - dataset = load_from_disk(dataset_path) - - is_training = split == "train" - shuffle = split in ['train', 'eval_train'] - - dataset.set_format("tensorflow") # tf.int64 # TODO (nico): is this needed? - - def tf_generator(): - """Generates data in a TensorFlow-friendly format.""" - for example in dataset: - yield { - "inputs": example["input_ids"][:-1], - "targets": example["input_ids"][1:], - } - - # Create a TensorFlow dataset - ds = tf.data.Dataset.from_generator( - tf_generator, - output_signature={ - "inputs": tf.TensorSpec(shape=(None,), dtype=tf.int32), - "targets": tf.TensorSpec(shape=(None,), dtype=tf.int32), - }) - - # Avoid creating too many threads when using PyTorch DDP. - # Limits TensorFlow's threading for non-primary processes (RANK != 0) - if RANK != 0: - options = tf.data.Options() - options.threading.private_threadpool_size = 1 - ds = ds.with_options(options) - - if shuffle: - ds = ds.shuffle(buffer_size=1024, seed=data_rng[0]) - - if is_training: - ds = ds.repeat() - - # Batch the dataset, grouping consecutive elements into fixed-size chunks. - ds = ds.batch(global_batch_size, drop_remainder=is_training) - ds = ds.prefetch(AUTOTUNE) - - # Limit the dataset to a fixed number of batches if `num_batches` is specified - if num_batches: - ds = ds.take(num_batches) - - # Shard the dataset across multiple GPUs/TPUs if necessary - ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + if split not in TFDS_SPLIT_NAME: + raise NotImplementedError + + shuffle_seed = jax.random.randint(data_rng, (), -2**31, 2**31-1) + + data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) + tokens_ds = tf.data.Dataset.load(data_dir) + + # tokens + tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices) + + # sequences + sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True) + + # get inputs and outputs + sequences_ds = sequences_ds.map( + lambda x: { + 'inputs': x['input_ids'][:SEQUENCE_LENGTH], + 'targets': x['input_ids'][1:], + }, + num_parallel_calls=AUTOTUNE, + ) + + # batch + if split == 'train': + shuffled_sequences_ds = sequences_ds.shuffle( + SHUFFLE_BUFFER_SIZE, seed=shuffle_seed + ) + repeated_sequences_dataset = shuffled_sequences_ds.repeat() + ds = repeated_sequences_dataset.batch( + global_batch_size, drop_remainder=False + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'eval_train': + ds = batch_with_padding( + sequences_ds, + global_batch_size, + padded_shapes={ + 'inputs': (global_batch_size, None), + 'targets': (global_batch_size, None), + }, + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + elif split == 'validation': + ds = batch_with_padding( + sequences_ds, + global_batch_size, + padded_shapes={ + 'inputs': (global_batch_size, None), + 'targets': (global_batch_size, None), + }, + ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size return ds diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index d21fd5090..ed469e1bd 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -3,9 +3,9 @@ import dataclasses from functools import partial -from flax import linen as nn import jax import jax.numpy as jnp +from flax import linen as nn # =========== Transformer Decoder-only Model ========== diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 5401ad240..49547fcef 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -4,16 +4,14 @@ import jax import jax.numpy as jnp -import optax -from flax import jax_utils -from algoperf import param_utils -from algoperf import jax_sharding_utils -from algoperf import spec -from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_jax.models import LinearModel -from algoperf.workloads.lm.input_pipeline import get_hf_dataloader, get_lm_dataset + +from algoperf import jax_sharding_utils, param_utils, spec +from algoperf.workloads.lm.input_pipeline import get_data_iter from algoperf.workloads.lm.lm_jax.nanodo_model import ( - TransformerDo, DoConfig, init_rope, apply_rope) + DoConfig, + TransformerDo, +) +from algoperf.workloads.lm.workload import BaseLmWorkload class LmWorkload(BaseLmWorkload): @@ -28,7 +26,7 @@ def _build_input_queue(self, """Build an input queue using pre-cached FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_lm_dataset( + loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, @@ -46,14 +44,8 @@ def _build_hf_input_queue(self, """Build an input queue using HuggingFace FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_hf_dataloader( - cache_dir=data_dir, - data_rng=data_rng, - batch_size=global_batch_size, - seq_len=self._seq_len, - framework="jax", - split=split) - return loader + iter = get_data_iter(data_rng, split, data_dir, global_batch_size) + return iter def init_model_fn( self, @@ -63,10 +55,10 @@ def init_model_fn( # Initialize NanoDO transformer model cfg = DoConfig( - D=512, # model dim - H=8, # num heads + D=2048, # model dim + H=16, # num heads L=self._seq_len, - N=6, # num layers + N=12, # num layers V=self._vocab_size, F=2048, # feedforward dim dtype=jnp.float32 @@ -92,7 +84,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e5dafdd3c..5797de654 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -63,9 +63,10 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del model_state, rng, update_batch_norm + del model_state, rng, update_batch_norm, dropout_rate model = params # Convert one-hot inputs to token IDs if needed diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 986a98297..8f17553ff 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -1,21 +1,20 @@ """LM workload parent class.""" import abc +from absl import logging import math import os from typing import Dict, Optional -from absl import flags import jax import torch.distributed as dist +from absl import flags from algoperf import spec -from algoperf.workloads.lm import input_pipeline -from algoperf.workloads.lm.input_pipeline import get_hf_dataloader FLAGS = flags.FLAGS -USE_PYTORCH_DDP = "LOCAL_RANK" in os.environ +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseLmWorkload(spec.Workload): @@ -63,7 +62,7 @@ def num_eval_train_examples(self) -> int: @property def num_validation_examples(self) -> int: - return 50000 + return 50000 @property def num_test_examples(self) -> int: @@ -111,53 +110,60 @@ def glu(self) -> bool: return True @abc.abstractmethod - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): """Build an input queue for the given split.""" - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: """Evaluate the model on a single batch.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False, - dropout_rate=None) - + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) + loss_dict = self.loss_fn(batch['targets'], logits) return loss_dict['summed'] - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) loss = 0.0 for _ in range(num_batches): @@ -168,13 +174,15 @@ def _eval_model_on_split(self, mean_loss = loss.item() / num_examples return {'loss': mean_loss} + # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. @abc.abstractmethod def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling.""" From 1c3cb6649b26c87e4bd7afd9c83fac84af9372ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 6 Oct 2025 22:15:38 +0000 Subject: [PATCH 46/98] add defaults for lm workload --- algoperf/workloads/lm/lm_jax/workload.py | 10 +++++----- algoperf/workloads/lm/workload.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 49547fcef..76739b590 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -54,13 +54,13 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: # Initialize NanoDO transformer model - cfg = DoConfig( - D=2048, # model dim - H=16, # num heads + cfg = DoConfig(u + D=self._emb_dim, # embedding dim + H=self._n_heads, # num heads L=self._seq_len, - N=12, # num layers + N=self._n_layers, # num layers V=self._vocab_size, - F=2048, # feedforward dim + F=self._mlp_dim, # feedforward dim dtype=jnp.float32 ) self._model = TransformerDo(cfg) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 8f17553ff..5cc783dba 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -21,7 +21,11 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 5 + _seq_len: int = 2048 + _emb_dim: int = 1024 + _n_heads: int = 8 + _n_layers: int = 12 + _mlp_dim: int = 4096 warmup_factor: float = 0.1 def __init__(self) -> None: From af91b120b2d5bd055f486aabdb3a881e28f3d231 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 01:03:33 +0000 Subject: [PATCH 47/98] refactor eval pipeline and loss fn for lm --- algoperf/workloads/lm/input_pipeline.py | 8 +- algoperf/workloads/lm/lm_jax/workload.py | 92 +++++++++++-------- algoperf/workloads/lm/lm_pytorch/workload.py | 28 ++++-- algoperf/workloads/lm/workload.py | 52 +++++++---- .../external_tuning/jax_nadamw_full_budget.py | 4 +- submission_runner.py | 2 +- 6 files changed, 116 insertions(+), 70 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e674170e4..91d6ae53c 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -10,13 +10,13 @@ from algoperf import data_utils AUTOTUNE = tf.data.experimental.AUTOTUNE -PAD_ID = -1 +PAD_ID = tf.constant(-1, dtype=tf.int64) TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} -SEQUENCE_LENGTH = 2048 +SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 1_000_000 +SHUFFLE_BUFFER_SIZE = 1024 VOCAB_SIZE = 50_257 @@ -74,7 +74,7 @@ def get_lm_dataset( global_batch_size: int, num_batches: Optional[int] = None, ): - """Load HF dataset and return a TF dataset.""" + """Load preprocessed TF dataset.""" if split not in TFDS_SPLIT_NAME: raise NotImplementedError diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 76739b590..c3d84104b 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -1,9 +1,11 @@ """LM workload implemented in Jax.""" -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp +import optax +from flax.training import common_utils from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter @@ -54,7 +56,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: # Initialize NanoDO transformer model - cfg = DoConfig(u + cfg = DoConfig( D=self._emb_dim, # embedding dim H=self._n_heads, # num heads L=self._seq_len, @@ -84,7 +86,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed @@ -93,41 +95,58 @@ def model_fn( logits = self._model.apply({'params': params}, inputs) return logits, None - def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling in JAX.""" - # Convert one-hot labels to token IDs if needed - if len(label_batch.shape) == len(logits_batch.shape): # one-hot - label_batch = jnp.argmax(label_batch, axis=-1) - - # Reshape for sequence modeling - logits = logits_batch.reshape(-1, logits_batch.shape[-1]) - labels = label_batch.reshape(-1) - - # Compute cross-entropy loss - loss = -jnp.sum( - jax.nn.log_softmax(logits)[jnp.arange(labels.shape[0]), labels]) - - if mask_batch is not None: - mask = mask_batch.reshape(-1) - loss = loss * mask - n_valid = mask.sum() - else: - n_valid = labels.shape[0] + + def compute_weighted_cross_entropy( + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable + """Compute weighted cross entropy and entropy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: array of shape [batch, length]. + label_smoothing: label smoothing constant, used to determine the on and off + values. + + Returns: + {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + """ + if logits.ndim != targets.ndim + 1: + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) + smoothed_targets = optax.smooth_labels( + common_utils.onehot(targets, self._vocab_size), label_smoothing + ) + per_example_losses = -jnp.sum( + smoothed_targets * jax.nn.log_softmax(logits), axis=-1 + ) + if weights is None: + weights = jnp.ones_like(targets) + per_example_losses = jnp.where(weights, per_example_losses, 0.0) + summed_loss = per_example_losses.sum() + n_valid_examples = weights.sum() return { - 'summed': loss, - 'n_valid_examples': n_valid, - 'per_example': loss / n_valid # Return per-token loss + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def is_output_params(self, param_name: str) -> bool: - """Return whether the given parameter is an output parameter.""" - return param_name.contains('output') + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) + def _eval_batch(self, params: spec.ParameterContainer, @@ -140,5 +159,6 @@ def _eval_batch(self, targets = batch['targets'] # Calculate cross-entropy loss - loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) - return loss + # TODO(kasimbeg): add weights? + loss_metrics = self.compute_weighted_cross_entropy(logits, targets) + return loss_metrics diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 5797de654..ddf99204d 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,18 +1,19 @@ """LM workload implemented in PyTorch.""" -from typing import Dict, Iterator, Optional, Tuple +from itertools import islice +from typing import Any, Dict, Iterator, Optional, Tuple import jax import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from itertools import islice -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec + +from algoperf import data_utils, param_utils, pytorch_utils, spec +from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( + ModelConfig, + Transformer, +) from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.lm_pytorch.plainlm_model import Transformer, ModelConfig USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -153,6 +154,7 @@ def _eval_batch(self, reduction='sum' ) return loss + def loss_fn( self, label_batch: spec.Tensor, @@ -181,3 +183,15 @@ def loss_fn( 'n_valid_examples': n_valid, 'per_example': loss } + +def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 5cc783dba..b1fa3d2a8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -1,13 +1,11 @@ """LM workload parent class.""" import abc -from absl import logging import math import os -from typing import Dict, Optional +from typing import Any, Dict, Optional import jax -import torch.distributed as dist from absl import flags from algoperf import spec @@ -21,7 +19,7 @@ class BaseLmWorkload(spec.Workload): """LM workload.""" _vocab_size: int = 50257 - _seq_len: int = 2048 + _seq_len: int = 1024 _emb_dim: int = 1024 _n_heads: int = 8 _n_layers: int = 12 @@ -169,24 +167,38 @@ def _eval_model_on_split( repeat_final_dataset=True, ) - loss = 0.0 + eval_metrics = {} for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - loss += self._eval_batch(params, eval_batch, model_state, rng) - if USE_PYTORCH_DDP: - dist.all_reduce(loss) - mean_loss = loss.item() / num_examples - return {'loss': mean_loss} + metrics = self._eval_batch(params, eval_batch) + for metric_name, metric_value in metrics.items(): + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + + return eval_results - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. @abc.abstractmethod + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0, - ) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling.""" + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling in JAX.""" + return self.compute_weighted_cross_entropy( + logits_batch, + label_batch, + weights=mask_batch, + label_smoothing=label_smoothing + ) + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') \ No newline at end of file diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 6e40cdab1..9b4192de2 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -11,7 +11,7 @@ Tuple, Union, ) - +from absl import logging # isort: on import chex import jax @@ -395,7 +395,7 @@ def get_batch_size(workload_name): elif workload_name == 'wmt': return 128 elif workload_name == 'lm': - return 128 + return 64 elif workload_name == 'mnist': return 16 else: diff --git a/submission_runner.py b/submission_runner.py index 1c51ec58f..64a67e781 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -53,7 +53,7 @@ # Environment variables os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings. # disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false --xla_dump_to=/logs/xla_dump_jax_lm_10_06_bsz64_seq1028 --xla_dump_hlo_as_proto' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR From 6b55adf5a65184d09d62a734db8fd3b6c33fdce2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:41:09 +0000 Subject: [PATCH 48/98] refactor evaluation pipeline for lm --- algoperf/workloads/lm/input_pipeline.py | 15 ++++++++++--- algoperf/workloads/lm/lm_jax/workload.py | 28 +++++++----------------- algoperf/workloads/lm/workload.py | 26 ++++++++++++---------- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 91d6ae53c..3a2e46923 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,6 +5,7 @@ from typing import Optional import jax +import numpy as np import tensorflow as tf from algoperf import data_utils @@ -106,7 +107,7 @@ def get_lm_dataset( repeated_sequences_dataset = shuffled_sequences_ds.repeat() ds = repeated_sequences_dataset.batch( global_batch_size, drop_remainder=False - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) + ).prefetch(tf.data.experimental.AUTOTUNE) elif split == 'eval_train': ds = batch_with_padding( sequences_ds, @@ -115,7 +116,11 @@ def get_lm_dataset( 'inputs': (global_batch_size, None), 'targets': (global_batch_size, None), }, - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + ) + ds = ds.map(lambda x: {'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation elif split == 'validation': ds = batch_with_padding( sequences_ds, @@ -124,6 +129,10 @@ def get_lm_dataset( 'inputs': (global_batch_size, None), 'targets': (global_batch_size, None), }, - ).take(100).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size + ) + ds = ds.map(lambda x: {'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index c3d84104b..bb19d6c30 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -28,26 +28,13 @@ def _build_input_queue(self, """Build an input queue using pre-cached FineWeb dataset.""" del num_batches del repeat_final_dataset - loader = get_data_iter( + ds = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=global_batch_size) - loader = map(jax_sharding_utils.shard_along_batch_dim, loader) - return loader - - def _build_hf_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): - """Build an input queue using HuggingFace FineWeb dataset.""" - del num_batches - del repeat_final_dataset - iter = get_data_iter(data_rng, split, data_dir, global_batch_size) - return iter + ds = map(jax_sharding_utils.shard_along_batch_dim, ds) + return ds def init_model_fn( self, @@ -156,9 +143,10 @@ def _eval_batch(self, """Evaluate the model on a single batch.""" logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - targets = batch['targets'] - # Calculate cross-entropy loss # TODO(kasimbeg): add weights? - loss_metrics = self.compute_weighted_cross_entropy(logits, targets) - return loss_metrics + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b1fa3d2a8..b8e1ea144 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,6 +2,7 @@ import abc import math +import numpy as np import os from typing import Any, Dict, Optional @@ -44,11 +45,11 @@ def validation_target_value(self) -> float: return 20.0 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: - return eval_result['test/ppl'] <= self.test_target_value + return True # No test targets @property def test_target_value(self) -> float: - return 20.0 # Target perplexity + return None # No test targets @property def loss_type(self) -> spec.LossType: @@ -60,19 +61,19 @@ def num_train_examples(self) -> int: @property def num_eval_train_examples(self) -> int: - return 10000 # Subset for evaluation + return 500 # Subset for evaluation. # TODO(kasimbeg): update @property def num_validation_examples(self) -> int: - return 50000 + return 500 # TODO(kasimbeg update) @property def num_test_examples(self) -> int: - return 50000 + return 0 @property def eval_batch_size(self) -> int: - return 8 + return 32 @property def train_mean(self): @@ -84,7 +85,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 4 # 4 hours + return 3600 * 5 # 4 hours @property def eval_period_time_sec(self) -> int: @@ -93,7 +94,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 7000 + return 54000 @property def pre_ln(self) -> bool: @@ -141,7 +142,7 @@ def _eval_batch( ) loss_dict = self.loss_fn(batch['targets'], logits) - return loss_dict['summed'] + return loss_dict def _eval_model_on_split( self, @@ -170,12 +171,15 @@ def _eval_model_on_split( eval_metrics = {} for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - metrics = self._eval_batch(params, eval_batch) + metrics = self._eval_batch(params, eval_batch, model_state, rng) for metric_name, metric_value in metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + eval_results['ppl'] = np.exp(eval_results['loss']) + print(eval_results) return eval_results From 210d671fe7e78502cf321a52c0dfcafe6fa3580c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:43:42 +0000 Subject: [PATCH 49/98] remove temporary flag for hlo dumps --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 64a67e781..1c51ec58f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -53,7 +53,7 @@ # Environment variables os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings. # disable only for deepspeech if it works fine for other workloads -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false --xla_dump_to=/logs/xla_dump_jax_lm_10_06_bsz64_seq1028 --xla_dump_hlo_as_proto' +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR From 0ad7788302fdc8c5ea22379a0f15c047f75988af Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 7 Oct 2025 03:45:45 +0000 Subject: [PATCH 50/98] fix in workload target condition check --- algoperf/workloads/lm/workload.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b8e1ea144..374b91ce6 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -38,7 +38,7 @@ def target_metric_name(self) -> str: return 'ppl' def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] > self.validation_target_value + return eval_result['validation/ppl'] < self.validation_target_value @property def validation_target_value(self) -> float: @@ -178,9 +178,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']) - print(eval_results) - + eval_results['ppl'] = np.exp(eval_results['loss']) return eval_results @abc.abstractmethod From 01921d5f6d0068e1d92808ad224b50ab19b60b15 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 8 Oct 2025 23:36:28 +0000 Subject: [PATCH 51/98] fix in mlp for glu --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index ed469e1bd..bd7213620 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -44,6 +44,10 @@ def __call__(self, x_BxLxD: jax.Array): linear = partial( nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype ) + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters + hidden_dim = cfg.F * 2 / 3 hidden_dim = cfg.multiple_of * ( (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of ) From e42045083c1d28aba5fa5dd15f6993d4a8312880 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 10 Oct 2025 04:14:40 +0000 Subject: [PATCH 52/98] Fix OOM error in weighted cross entropy calculation --- algoperf/workloads/lm/lm_jax/workload.py | 44 +++++++++++-------- .../workloads/lm/lm_pytorch/plainlm_model.py | 2 +- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index bb19d6c30..c052794c8 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -84,21 +84,19 @@ def model_fn( def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1, - ) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. - Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. - Returns: {'summed': scalar summed loss, 'n_valid_examples': scalar number of valid examples in batch, 'per_example': 1-d array of per-example losses} @@ -108,18 +106,26 @@ def compute_weighted_cross_entropy( f'Incorrect shapes. Got shape {logits.shape} logits and ' f'{targets.shape} targets.' ) - smoothed_targets = optax.smooth_labels( - common_utils.onehot(targets, self._vocab_size), label_smoothing - ) - - per_example_losses = -jnp.sum( - smoothed_targets * jax.nn.log_softmax(logits), axis=-1 - ) - if weights is None: - weights = jnp.ones_like(targets) - per_example_losses = jnp.where(weights, per_example_losses, 0.0) + # Compute log probabilities + log_probs = jax.nn.log_softmax(logits, axis=-1) + # Extract log probability of the target class + # Shape: [batch, length] + target_log_probs = jnp.take_along_axis( + log_probs, + targets[..., None], + axis=-1 + ).squeeze(-1) + # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) + # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. + confidence = 1.0 - label_smoothing + smoothing_term = label_smoothing / self._vocab_size + per_example_losses = -1.0 * (confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1)) + if weights is not None: + per_example_losses = jnp.where(weights, per_example_losses, 0.0) + n_valid_examples = weights.sum() + else: + n_valid_examples = targets.shape[0] * targets.shape[1] summed_loss = per_example_losses.sum() - n_valid_examples = weights.sum() return { 'summed': summed_loss, 'n_valid_examples': n_valid_examples, diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 627a0e16d..225b98767 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -16,7 +16,7 @@ class ModelConfig: n_layers: int n_heads: int rmsnorm_eps: float = 1e-6 - tie_embeddings: bool = False + tie_embeddings: bool = True class MLP(nn.Module): From 3b31ad521d0037f80391de31582517cc291877be Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 10 Oct 2025 04:15:27 +0000 Subject: [PATCH 53/98] fix issue with checkpointing bool --- algoperf/checkpoint_utils.py | 47 ++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 2c8441d9c..00f05ba5d 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -5,7 +5,7 @@ """ import os -from typing import Sequence, Tuple +from typing import Sequence, Tuple, Optional import numpy as np import torch @@ -14,7 +14,8 @@ from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint from tensorflow.io import gfile # pytype: disable=import-error - +import orbax.checkpoint as ocp +from orbax.checkpoint.type_handlers import NumpyHandler from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -29,6 +30,48 @@ int, ] +class BoolHandler(NumpyHandler): + """ + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. + It works by treating the scalar as a 0-dimensional array. + """ + + def typestr(self) -> str: + """Unique string identifier for this handler.""" + return 'np.bool_' + + async def serialize( + self, + values: Sequence[np.bool_], + infos: Sequence, + args: Optional[Sequence[ocp.SaveArgs]] = None, + ): + """ + Serializes a sequence of np.bool_ scalars by first converting them + to 0-dim numpy arrays and then calling the parent NumpyHandler. + """ + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray + array_values = [np.asarray(v, dtype=np.bool_) for v in values] + # Use the parent class's robust serialization logic + return await super().serialize(array_values, infos, args) + + async def deserialize( + self, + infos: Sequence, + args: Optional[Sequence[ocp.RestoreArgs]] = None, + ) -> Sequence[np.bool_]: + """ + Deserializes into a sequence of np.bool_ scalars by calling the + parent handler and then converting the resulting 0-dim arrays. + """ + # Parent deserialize will return a sequence of 0-dimensional np.ndarray + results = await super().deserialize(infos, args) + + # Convert each 0-d array back to an np.bool_ scalar using .item() + scalar_results = [np.bool_(r.item()) for r in results] + return scalar_results + +ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) def maybe_restore_checkpoint( framework: str, From bbc114fe730e351d3a721d78f6165f343e4c25cb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:33:15 +0000 Subject: [PATCH 54/98] increase buffer size --- algoperf/workloads/lm/input_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 3a2e46923..2fd27113a 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -17,7 +17,7 @@ SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 1024 +SHUFFLE_BUFFER_SIZE = 100_000 VOCAB_SIZE = 50_257 From 2b162e8d87603ad7ae2ac5020a26fd8c2bce974d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:42:19 +0000 Subject: [PATCH 55/98] remove _eval_batch from jax workload --- algoperf/workloads/lm/lm_jax/workload.py | 17 ----------- algoperf/workloads/lm/workload.py | 36 +++++++++++------------- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index c052794c8..801b1e0b4 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -139,20 +139,3 @@ def _normalize_eval_metrics( del num_examples eval_denominator = total_metrics.pop('denominator') return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) - - - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss - # TODO(kasimbeg): add weights? - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - return { - 'loss': metrics['summed'], - 'denominator': metrics['n_valid_examples'], - } diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 374b91ce6..f5d2cda38 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -124,25 +124,6 @@ def _build_input_queue( ): """Build an input queue for the given split.""" - def _eval_batch( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - ) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False, - ) - - loss_dict = self.loss_fn(batch['targets'], logits) - return loss_dict def _eval_model_on_split( self, @@ -181,6 +162,23 @@ def _eval_model_on_split( eval_results['ppl'] = np.exp(eval_results['loss']) return eval_results + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } + + @abc.abstractmethod def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] From 617e1a3f3810bb73f15d998c25e54fa79ef04315 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Oct 2025 04:45:44 +0000 Subject: [PATCH 56/98] add todo for pytorch _eval_batch cleanup --- algoperf/workloads/lm/lm_pytorch/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index ddf99204d..71a8afd93 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -148,6 +148,7 @@ def _eval_batch(self, if targets.dim() == 3: # one-hot loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) else: # token IDs + # TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch. loss = torch.nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), From 64ea658c04a2d13db75ab0b8fd1204cfe43f8746 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:34:05 +0000 Subject: [PATCH 57/98] add target setting algorithm for fineweb edu lm workload --- .../jax_nadamw_target_setting.py | 427 ++++++++++++++++++ .../fineweb_edu_lm/tuning_search_space.json | 11 + 2 files changed, 438 insertions(+) create mode 100644 algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py create mode 100644 algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py new file mode 100644 index 000000000..9fa6823d5 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -0,0 +1,427 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +# isort: on +import chex +import jax +import jax.numpy as jnp +import optax + +from algoperf import jax_sharding_utils, spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + step_hint = 0.75 * step_hint + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) + optimizer_state = opt_init_fn(params_zeros_like) + + return optimizer_state, opt_update_fn + + +def train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, +): + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + dropout_rate=dropout_rate, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container + ) + # Compute mean loss and grad + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + dropout_rate = hyperparameters.dropout_rate + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = ( + jax_sharding_utils.get_batch_dim_sharding() + ) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated, # label_smoothing + replicated, # dropout_rate + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings, + ) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + jitted_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, + ) + ) + + # Log loss, grad_norm. + if global_step % 1 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step + ) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'lm': + return 64 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json new file mode 100644 index 000000000..e6945d69a --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -0,0 +1,11 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.0003955553491092581, + "one_minus_beta1": 0.06124602712, + "beta2": 0.9535169492059872, + "weight_decay": 0.03268700808664715, + "warmup_factor": 0.0375 + } +] \ No newline at end of file From b38ade083282348a5000220bf3ca11f79b5c9e9a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:34:49 +0000 Subject: [PATCH 58/98] update step hint for lm workload --- algoperf/workloads/lm/input_pipeline.py | 2 +- algoperf/workloads/lm/workload.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 2fd27113a..04bd90216 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -17,7 +17,7 @@ SEQUENCE_LENGTH = 1024 MAX_CORPUS_CHARS = 1_000_000_000 -SHUFFLE_BUFFER_SIZE = 100_000 +SHUFFLE_BUFFER_SIZE = 1000 VOCAB_SIZE = 50_257 diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index f5d2cda38..b9610f919 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -57,7 +57,7 @@ def loss_type(self) -> spec.LossType: @property def num_train_examples(self) -> int: - return 1000000 # Example size + return 8_749_870 # sequences of 1024 tokens each @property def num_eval_train_examples(self) -> int: @@ -94,7 +94,7 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 54000 + return 72000 @property def pre_ln(self) -> bool: @@ -159,7 +159,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']) + eval_results['ppl'] = np.exp(eval_results['loss']).item() return eval_results From 65369f239a3110748890473cef415dcb087fe6c0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 01:36:42 +0000 Subject: [PATCH 59/98] update target --- algoperf/workloads/lm/workload.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b9610f919..0bed0b34d 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -38,11 +38,11 @@ def target_metric_name(self) -> str: return 'ppl' def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/ppl'] < self.validation_target_value + return eval_result['validation/ppl'] <= self.validation_target_value @property def validation_target_value(self) -> float: - return 20.0 # Target perplexity + return 25.5477 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return True # No test targets @@ -73,7 +73,7 @@ def num_test_examples(self) -> int: @property def eval_batch_size(self) -> int: - return 32 + return 64 @property def train_mean(self): @@ -85,16 +85,16 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 5 # 4 hours + return 3600 * 5 # 4 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: - return 600 # 10 minutes + return 600 # 10 minutes TODO(kasimbeg): update @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 72000 + return 72_000 @property def pre_ln(self) -> bool: From 6171b2d2fb6a0243993b10d03f0c284eb2c86801 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Oct 2025 23:04:56 +0000 Subject: [PATCH 60/98] update eval split sizes for lm workload and target setting point --- algoperf/workloads/lm/input_pipeline.py | 4 ++-- algoperf/workloads/lm/workload.py | 8 ++++---- .../fineweb_edu_lm/jax_nadamw_target_setting.py | 4 ++-- .../fineweb_edu_lm/tuning_search_space.json | 12 ++++++------ 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 04bd90216..79fdfbbcb 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -120,7 +120,7 @@ def get_lm_dataset( ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) - ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size of validation + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'validation': ds = batch_with_padding( sequences_ds, @@ -133,6 +133,6 @@ def get_lm_dataset( ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) - ds = ds.take(1000).prefetch(tf.data.experimental.AUTOTUNE) # todo(kasimbeg): set final size + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) return ds diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 0bed0b34d..466769d96 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -61,11 +61,11 @@ def num_train_examples(self) -> int: @property def num_eval_train_examples(self) -> int: - return 500 # Subset for evaluation. # TODO(kasimbeg): update + return 10_000 # Subset for evaluation. @property def num_validation_examples(self) -> int: - return 500 # TODO(kasimbeg update) + return 100_000 # sequences @property def num_test_examples(self) -> int: @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 5 # 4 hours TODO(kasimbeg): update + return 3600 * 14 # 14 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: - return 600 # 10 minutes TODO(kasimbeg): update + return 1200 # 20 minutes TODO(kasimbeg): update @property def step_hint(self) -> int: diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py index 9fa6823d5..1fef611ac 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -170,8 +170,8 @@ def init_optimizer_state( del rng def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. step_hint = 0.75 * step_hint + # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( init_value=0.0, @@ -343,7 +343,7 @@ def update_params( ) # Log loss, grad_norm. - if global_step % 1 == 0 and workload.metrics_logger is not None: + if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step ) diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json index e6945d69a..ce0f75623 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -1,11 +1,11 @@ [ { "dropout_rate": 0.0, - "label_smoothing": 0.1, - "learning_rate": 0.0003955553491092581, - "one_minus_beta1": 0.06124602712, - "beta2": 0.9535169492059872, - "weight_decay": 0.03268700808664715, - "warmup_factor": 0.0375 + "label_smoothing": 0.0, + "learning_rate": 0.00038418421332238876, + "one_minus_beta1": 0.01564758865, + "beta2": 0.992362328914093, + "weight_decay": 0.25551270901641954, + "warmup_factor": 0.05 } ] \ No newline at end of file From d7a885cd7270dfbd8203f41276c3313ddbd63929 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 17 Oct 2025 04:01:11 +0000 Subject: [PATCH 61/98] Porting workload input pipeline to torch - Added `limit_tf_threads` parameter to `pytorch_init` to control TensorFlow threading based on workload type. Dataloader was going OOM otherwise. - Updated input pipeline to support "None" for weights (for memory). - Modified Transformer model's `forward` method to optionally return loss during training. Should be better to fuse the loss later. - Adjusted torch LM workload configuration for model dimensions and parameters to match jax. - Updated transformers version in `pyproject.toml`, older version seems unavailable. --- algoperf/pytorch_utils.py | 6 +- algoperf/workloads/lm/input_pipeline.py | 8 +- .../workloads/lm/lm_pytorch/plainlm_model.py | 74 ++++++----- algoperf/workloads/lm/lm_pytorch/workload.py | 118 ++++++------------ pyproject.toml | 2 +- submission_runner.py | 3 +- 6 files changed, 90 insertions(+), 121 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index af09e67fc..c7537a884 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -27,7 +27,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: return use_pytorch_ddp, rank, device, n_gpus -def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: +def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads = True) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. @@ -39,7 +39,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: if use_pytorch_ddp: # Avoid tf input pipeline creating too many threads. - if rank != 0: + if rank != 0 and limit_tf_threads: tf.config.threading.set_intra_op_parallelism_threads(1) tf.config.threading.set_inter_op_parallelism_threads(1) @@ -47,10 +47,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: profiler.set_local_rank(rank) # Only log once (for local rank == 0). if rank != 0: - def logging_pass(*args): pass - logging.info = logging_pass # Initialize the process group. dist.init_process_group('nccl') diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 04bd90216..ee54427e1 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -107,7 +107,13 @@ def get_lm_dataset( repeated_sequences_dataset = shuffled_sequences_ds.repeat() ds = repeated_sequences_dataset.batch( global_batch_size, drop_remainder=False - ).prefetch(tf.data.experimental.AUTOTUNE) + ) + ds = ds.map(lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': None, + }) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'eval_train': ds = batch_with_padding( sequences_ds, diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 225b98767..5de5bf310 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -159,7 +159,7 @@ def __init__(self, cfg): if cfg.tie_embeddings: self.tie_weights() - def forward(self, x): + def forward(self, x, targets=None): # x: (bsz, seqlen) x = self.embed_tokens(x) # (bsz, seqlen, dim) L = x.shape[1] @@ -178,7 +178,12 @@ def forward(self, x): for layer in self.layers: x = layer(x, freqs_cis) # (bsz, seqlen, dim) - return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: + loss = F.cross_entropy( + out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100) + return out, loss + return out def predict(self, x, k=1): """Generate k tokens autoregressively. @@ -190,11 +195,6 @@ def predict(self, x, k=1): Returns: Tuple of (input_ids, predicted_ids) """ - # For debugging - predictions = [] - - batch_size = x.shape[0] - seq_len = x.shape[1] # Store original input original_input = x.clone() @@ -202,6 +202,7 @@ def predict(self, x, k=1): # Generate k tokens autoregressively for i in range(k): + # Get logits for the entire sequence logits = self(generated_input) @@ -212,24 +213,20 @@ def predict(self, x, k=1): # This is a common issue - the model gets stuck repeating the last token last_token_id = generated_input[:, -1] next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - - # Print top 5 tokens for debugging - if i == 0: - print("\nPyTorch detailed prediction:") - top5_values, top5_indices = torch.topk(next_token_logits[0], 5) - for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): - prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() - print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") - + # Get the most likely token next_token = torch.argmax(next_token_logits, dim=-1) - predictions.append(next_token.item()) # Append the predicted token to the sequence next_token = next_token.unsqueeze(1) # Add sequence dimension generated_input = torch.cat([generated_input, next_token], dim=1) - print(f" Full predictions step by step: {predictions}") + # For debugging, print predictions for the first item in the batch + print("\nPyTorch detailed prediction (first item in batch):") + predicted_sequence = generated_input[0, -k:].tolist() + print(f" Predicted token IDs: {predicted_sequence}") + for i, token_id in enumerate(predicted_sequence): + print(f" Step {i+1}: Predicted token {token_id}") # Return all tokens, not just the last k return original_input, generated_input[:, -k:] @@ -269,30 +266,43 @@ def count_params(self, non_embedding=True): def main(): print("Initializing transformer model and running forward pass...") - seq_length = 512 + seq_length = 1024 # Define model configuration config = ModelConfig( - vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece + vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece seq_len=seq_length, # Maximum sequence length - dim=768, # Embedding dimension + dim=1024, # Embedding dimension expand=4.0, # MLP expansion factor n_layers=12, # Number of transformer layers - n_heads=12, # Number of attention heads + n_heads=8, # Number of attention heads rmsnorm_eps=1e-6, # RMSNorm epsilon tie_embeddings=True # Tie embedding and output weights ) - def tie_weights(self): - self.lm_head.weight = self.embed_tokens.weight + # Instantiate the model + model = Transformer(config) + print(f"Model has {model.count_params():,} parameters.") - def count_params(self, non_embedding=True): - n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: - n_params -= self.embed_tokens.weight.numel() - if (not self.lm_head.weight - is self.embed_tokens.weight): # if no weight tying - n_params -= self.lm_head.weight.numel() - return n_params + # Create some random input data + batch_size = 2 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + + # Move data to the same device as the model + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + + # Run a forward pass + print(f"Running forward pass with input shape: {input_ids.shape}") + logits = model(input_ids) + print(f"Output logits shape: {logits.shape}") + # Run prediction + print("Running prediction...") + original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) + print(f"Original input shape for prediction: {original_input.shape}") + print(f"Predicted IDs shape: {predicted_ids.shape}") + print(f"Predicted IDs: {predicted_ids}") +if __name__ == "__main__": + main() diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 71a8afd93..e4c03c4f5 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -14,6 +14,7 @@ Transformer, ) from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.lm.input_pipeline import get_data_iter USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -37,10 +38,11 @@ def init_model_fn( cfg = ModelConfig( vocab_size=self._vocab_size, seq_len=self._seq_len, - dim=512, # Model dimension - expand=4, # MLP expansion factor - n_layers=6, # Number of transformer layers - n_heads=8, # Number of attention heads + dim=self._emb_dim, # Model dimension + expand=self._mlp_dim // self._emb_dim, # MLP expansion factor + # FIXME(rka97): fix expansion factor + n_layers=self._n_layers, # Number of transformer layers + n_heads=self._n_heads, # Number of attention heads rmsnorm_eps=1e-6, tie_embeddings=True ) @@ -65,7 +67,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state, rng, update_batch_norm, dropout_rate model = params @@ -87,10 +89,8 @@ def _build_input_queue( num_batches: Optional[int] = None, repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" - from algoperf.workloads.lm.input_pipeline import get_lm_dataset local_batch_size = global_batch_size // N_GPUS - - loader = get_lm_dataset( + loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, @@ -99,33 +99,12 @@ def _build_input_queue( ) if USE_PYTORCH_DDP: loader = islice(loader, RANK, None, N_GPUS) - seq_len = self._seq_len - weights = None - dtype = torch.int32 - is_train = split == 'train' - for batch in loader: - inputs = batch['inputs'] - targets = batch['targets'] - - if USE_PYTORCH_DDP: - if not is_train: - # During eval, the batch size of the remainder might be different - per_device_batch_size = torch.tensor( - targets.shape[0], dtype=dtype, device=DEVICE) - dist.broadcast(per_device_batch_size, src=0) - local_batch_size = per_device_batch_size.item() - # Broadcast to all devices - #dist.broadcast(inputs, src=0) - #dist.broadcast(targets, src=0) - - if weights is None: - weights = torch.ones((local_batch_size, seq_len), device=DEVICE) batch = { - 'inputs': torch.tensor(inputs, device=DEVICE, dtype=dtype), - 'targets': torch.tensor(targets, device=DEVICE, dtype=dtype), - 'weights': weights, + 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), + 'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64), + 'weights': None, } yield batch @@ -133,66 +112,41 @@ def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - model = params - logits, _ = self.model_fn( - model, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - - # Handle both one-hot and token ID targets - targets = batch['targets'] - if targets.dim() == 3: # one-hot - loss = -torch.sum(targets * torch.nn.functional.log_softmax(logits, dim=-1)) - else: # token IDs - # TODO(kasimbeg): before deleting make sure we have defined self.weighted_cross_entropy so that we can call the shared workload _eval_batch. - loss = torch.nn.functional.cross_entropy( - logits.view(-1, logits.size(-1)), - targets.view(-1), - reduction='sum' - ) - return loss - - def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + # FIXME(rka97): Implement label smoothing + def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in PyTorch.""" - vocab_size = logits_batch.shape[-1] + vocab_size = logits.size(-1) - if len(label_batch.shape) == len(logits_batch.shape): + if len(labels.shape) == len(logits.shape): # One-hot labels - log_probs = torch.nn.functional.log_softmax(logits_batch, dim=-1) - loss = -torch.sum(label_batch * log_probs, dim=-1) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + loss = -torch.sum(labels * log_probs, dim=-1) else: # Dense labels loss = torch.nn.functional.cross_entropy( - logits_batch, - label_batch, + logits.view(-1, vocab_size), + labels.view(-1), reduction='none') - if mask_batch is not None: - loss = loss * mask_batch + loss = loss.view_as(labels) + + if weights is not None: + loss = loss * weights - n_valid = mask_batch.sum() if mask_batch is not None else label_batch.shape[0] + n_valid = weights.sum() if weights is not None else torch.tensor(labels.numel(), dtype=torch.float32, device=labels.device) return { 'summed': loss.sum(), 'n_valid_examples': n_valid, - 'per_example': loss + 'per_example': loss, } -def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, Any] - ) -> Dict[str, float]: - """Normalize eval metrics.""" - del num_examples - if USE_PYTORCH_DDP: - for metric in total_metrics.values(): - dist.all_reduce(metric) - total_metrics = {k: v.item() for k, v in total_metrics.items()} - eval_denominator = total_metrics.pop('denominator') - return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 76bcfb7ca..b93c9794e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] -lm = ["transformers==4.25.4", "datasets==3.6.0"] +lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ diff --git a/submission_runner.py b/submission_runner.py index 1c51ec58f..1c50cd6d9 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -784,7 +784,8 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + limit_tf_threads = (base_workload != 'lm') + pytorch_init(USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads) # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: From 1f0439aaf6bbb7f0670a4dc0564a41c86e509270 Mon Sep 17 00:00:00 2001 From: rka97 Date: Sat, 18 Oct 2025 06:41:33 +0000 Subject: [PATCH 62/98] Fix OOM bug in lm eval --- algoperf/random_utils.py | 4 +-- algoperf/workloads/lm/lm_pytorch/workload.py | 28 ++++++++++++++----- algoperf/workloads/lm/workload.py | 15 +++++++--- .../pytorch_nadamw_full_budget.py | 2 ++ 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index 1dc773e80..07efa2bdf 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index e4c03c4f5..b2ffac18e 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -1,5 +1,6 @@ """LM workload implemented in PyTorch.""" +import contextlib from itertools import islice from typing import Any, Dict, Iterator, Optional, Tuple @@ -8,7 +9,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import data_utils, param_utils, pytorch_utils, spec +from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( ModelConfig, Transformer, @@ -72,12 +73,23 @@ def model_fn( del model_state, rng, update_batch_norm, dropout_rate model = params - # Convert one-hot inputs to token IDs if needed - inputs = augmented_and_preprocessed_input_batch['inputs'] - if inputs.dim() == 3: # one-hot encoded + # Set model to eval or train mode based on the mode parameter + if mode == spec.ForwardPassMode.EVAL: + model.eval() + elif mode == spec.ForwardPassMode.TRAIN: + model.train() + contexts = { + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + } + with contexts[mode](): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded inputs = inputs.argmax(dim=-1) - logits = model(inputs) + logits = model(inputs) + return logits, None def _build_input_queue( @@ -90,12 +102,14 @@ def _build_input_queue( repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" local_batch_size = global_batch_size // N_GPUS + # In DDP mode, pass local_device_count=1 to prevent shard_and_maybe_pad_np + # from seeing all GPUs via torch.cuda.device_count() loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, global_batch_size=local_batch_size, - num_batches=num_batches + num_batches=num_batches, ) if USE_PYTORCH_DDP: loader = islice(loader, RANK, None, N_GPUS) @@ -104,7 +118,7 @@ def _build_input_queue( batch = { 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), 'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64), - 'weights': None, + 'weights': torch.tensor(batch['weights'], device=DEVICE, dtype=torch.float32) if batch['weights'] is not None else None, } yield batch diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 466769d96..73e784f3a 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -73,7 +73,7 @@ def num_test_examples(self) -> int: @property def eval_batch_size(self) -> int: - return 64 + return 256 @property def train_mean(self): @@ -138,6 +138,11 @@ def _eval_model_on_split( ) -> Dict[str, float]: """Run a full evaluation of the model.""" num_batches = int(math.ceil(num_examples / global_batch_size)) + + # Handle edge case where num_batches is 0 (e.g., test split with 0 examples) + if num_batches == 0: + return {'loss': 0.0, 'ppl': 1.0} + if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( @@ -159,7 +164,7 @@ def _eval_model_on_split( eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) - eval_results['ppl'] = np.exp(eval_results['loss']).item() + eval_results['ppl'] = np.exp(eval_results['loss']).item() return eval_results @@ -173,9 +178,11 @@ def _eval_batch(self, params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) # Calculate cross-entropy loss metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! return { - 'loss': metrics['summed'], - 'denominator': metrics['n_valid_examples'], + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), } diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..9b544e380 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -372,6 +372,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From b11c1938447c3cb68a9635ffa75648ec97c3e5d2 Mon Sep 17 00:00:00 2001 From: rka97 Date: Sat, 18 Oct 2025 20:42:14 +0000 Subject: [PATCH 63/98] repeat dataset --- algoperf/workloads/lm/input_pipeline.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 7a55e81fd..ab7c64479 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -98,14 +98,12 @@ def get_lm_dataset( }, num_parallel_calls=AUTOTUNE, ) - - # batch + sequences_ds = sequences_ds.repeat() if split == 'train': - shuffled_sequences_ds = sequences_ds.shuffle( + ds = sequences_ds.shuffle( SHUFFLE_BUFFER_SIZE, seed=shuffle_seed ) - repeated_sequences_dataset = shuffled_sequences_ds.repeat() - ds = repeated_sequences_dataset.batch( + ds = ds.batch( global_batch_size, drop_remainder=False ) ds = ds.map(lambda x: { From 42d1d1a5379257015ca93847d539ef710e307067 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Oct 2025 17:26:07 +0000 Subject: [PATCH 64/98] label smoothing default fix --- algoperf/workloads/lm/input_pipeline.py | 7 ++++--- algoperf/workloads/lm/lm_jax/workload.py | 4 +--- algoperf/workloads/lm/workload.py | 8 ++++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index 79fdfbbcb..1716399c0 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -119,7 +119,8 @@ def get_lm_dataset( ) ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], - 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) + }) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'validation': ds = batch_with_padding( @@ -132,7 +133,7 @@ def get_lm_dataset( ) ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], - 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) + }) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) - return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 801b1e0b4..91a2592b4 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -4,8 +4,6 @@ import jax import jax.numpy as jnp -import optax -from flax.training import common_utils from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter @@ -88,7 +86,7 @@ def compute_weighted_cross_entropy( logits: spec.Tensor, targets: spec.Tensor, weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1, + label_smoothing: float = 0.0, ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. Args: diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 466769d96..21a7e8fbb 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,11 +2,11 @@ import abc import math -import numpy as np import os from typing import Any, Dict, Optional import jax +import numpy as np from absl import flags from algoperf import spec @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 14 # 14 hours TODO(kasimbeg): update + return 3600 * 14 # 14 hours @property def eval_period_time_sec(self) -> int: - return 1200 # 20 minutes TODO(kasimbeg): update + return 1200 # 20 minutes @property def step_hint(self) -> int: @@ -172,7 +172,7 @@ def _eval_batch(self, logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) # Calculate cross-entropy loss - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + metrics = self.loss_fn(batch['targets'], logits, batch['weights']) return { 'loss': metrics['summed'], 'denominator': metrics['n_valid_examples'], From d95f2bfb6290a47c0a81580f4c9e90e84c6bbd53 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 00:05:02 +0000 Subject: [PATCH 65/98] Make sure to take the correct number of batches in lm --- algoperf/workloads/lm/input_pipeline.py | 26 +++++++++++--------- algoperf/workloads/lm/lm_jax/workload.py | 13 +++++----- algoperf/workloads/lm/lm_pytorch/workload.py | 10 ++++---- algoperf/workloads/lm/workload.py | 12 ++++----- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index ab7c64479..68cb54d1e 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -54,14 +54,14 @@ def batch_with_padding( def get_data_iter(data_rng: jax.random.PRNGKey, split: str, data_dir: str, - global_batch_size: int, + batch_size: int, num_batches: Optional[int] = None,): - ds = get_lm_dataset(data_rng, split, data_dir, global_batch_size, num_batches) + ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) it = map( functools.partial( - data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size ), ds, ) @@ -72,7 +72,7 @@ def get_lm_dataset( data_rng: jax.random.PRNGKey, split: str, data_dir: str, - global_batch_size: int, + batch_size: int, num_batches: Optional[int] = None, ): """Load preprocessed TF dataset.""" @@ -104,8 +104,9 @@ def get_lm_dataset( SHUFFLE_BUFFER_SIZE, seed=shuffle_seed ) ds = ds.batch( - global_batch_size, drop_remainder=False + batch_size, drop_remainder=False ) + ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.map(lambda x: { 'inputs': x['inputs'], 'targets': x['targets'], @@ -115,12 +116,13 @@ def get_lm_dataset( elif split == 'eval_train': ds = batch_with_padding( sequences_ds, - global_batch_size, + batch_size, padded_shapes={ - 'inputs': (global_batch_size, None), - 'targets': (global_batch_size, None), + 'inputs': (batch_size, None), + 'targets': (batch_size, None), }, ) + ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) @@ -128,15 +130,15 @@ def get_lm_dataset( elif split == 'validation': ds = batch_with_padding( sequences_ds, - global_batch_size, + batch_size, padded_shapes={ - 'inputs': (global_batch_size, None), - 'targets': (global_batch_size, None), + 'inputs': (batch_size, None), + 'targets': (batch_size, None), }, ) + ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)}) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) - return ds diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 801b1e0b4..760b87306 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -23,16 +23,17 @@ def _build_input_queue(self, split: str, data_dir: str, global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None): """Build an input queue using pre-cached FineWeb dataset.""" - del num_batches - del repeat_final_dataset + del cache, repeat_final_dataset ds = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, - global_batch_size=global_batch_size) + batch_size=global_batch_size, + num_batches=num_batches) ds = map(jax_sharding_utils.shard_along_batch_dim, ds) return ds @@ -73,7 +74,7 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index b2ffac18e..b5f93ce2e 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -98,17 +98,17 @@ def _build_input_queue( split: str, data_dir: str, global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" + del cache, repeat_final_dataset local_batch_size = global_batch_size // N_GPUS - # In DDP mode, pass local_device_count=1 to prevent shard_and_maybe_pad_np - # from seeing all GPUs via torch.cuda.device_count() loader = get_data_iter( data_rng=data_rng, split=split, data_dir=data_dir, - global_batch_size=local_batch_size, + batch_size=local_batch_size, num_batches=num_batches, ) if USE_PYTORCH_DDP: diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 73e784f3a..8f17fd930 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -4,7 +4,7 @@ import math import numpy as np import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Iterator import jax from absl import flags @@ -119,9 +119,10 @@ def _build_input_queue( split: str, data_dir: str, global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, num_batches: Optional[int] = None, - repeat_final_dataset: bool = False, - ): + ) -> Iterator[Dict[str, Any]]: """Build an input queue for the given split.""" @@ -150,8 +151,7 @@ def _eval_model_on_split( split, data_dir, global_batch_size, - num_batches, - repeat_final_dataset=True, + num_batches=num_batches ) eval_metrics = {} @@ -175,7 +175,7 @@ def _eval_batch(self, rng: spec.RandomState) -> spec.Tensor: """Evaluate the model on a single batch.""" logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False, 0.0) # Calculate cross-entropy loss metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) # CRITICAL: Detach tensors to free computation graph and activations From 0dc16db94c20973d2ae1f31231cfae91bef0801b Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 00:22:23 +0000 Subject: [PATCH 66/98] Properly handle repetition in LM training and evaluation splits --- algoperf/workloads/lm/input_pipeline.py | 4 +++- algoperf/workloads/lm/workload.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index aeeff80a9..e701d1bcb 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -98,7 +98,6 @@ def get_lm_dataset( }, num_parallel_calls=AUTOTUNE, ) - sequences_ds = sequences_ds.repeat() if split == 'train': ds = sequences_ds.shuffle( SHUFFLE_BUFFER_SIZE, seed=shuffle_seed @@ -107,6 +106,7 @@ def get_lm_dataset( batch_size, drop_remainder=False ) ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() ds = ds.map(lambda x: { 'inputs': x['inputs'], 'targets': x['targets'], @@ -123,6 +123,7 @@ def get_lm_dataset( }, ) ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) @@ -138,6 +139,7 @@ def get_lm_dataset( }, ) ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() ds = ds.map(lambda x: {'inputs': x['inputs'], 'targets': x['targets'], 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 1f966ca03..8f17fd930 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,11 +2,11 @@ import abc import math +import numpy as np import os from typing import Any, Dict, Optional, Iterator import jax -import numpy as np from absl import flags from algoperf import spec @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 14 # 14 hours + return 3600 * 14 # 14 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: - return 1200 # 20 minutes + return 1200 # 20 minutes TODO(kasimbeg): update @property def step_hint(self) -> int: From 7edb702c2f4a4eb8a88bd35a40ea7a255e6f09d8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Oct 2025 01:41:23 +0000 Subject: [PATCH 67/98] move eval_batch from shared class to framework specific classes since pytorch calls detatch --- algoperf/workloads/lm/lm_jax/workload.py | 17 +++++++++++++++++ algoperf/workloads/lm/lm_pytorch/workload.py | 18 ++++++++++++++++++ algoperf/workloads/lm/workload.py | 18 ------------------ 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 91a2592b4..3809c8258 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -130,6 +130,23 @@ def compute_weighted_cross_entropy( 'per_example': per_example_losses, } + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } + def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] ) -> Dict[str, float]: diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index b2ffac18e..4d87c5ba7 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -153,6 +153,24 @@ def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tenso 'per_example': loss, } + + def _eval_batch(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! + return { + 'loss': metrics['summed'].detatch(), + 'denominator': metrics['n_valid_examples'].detatch(), + } + def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] ) -> Dict[str, float]: diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index b4e7d7bb6..1c4c53fc8 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -168,24 +168,6 @@ def _eval_model_on_split( return eval_results - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - # CRITICAL: Detach tensors to free computation graph and activations - # Without this, all intermediate activations are kept in memory! - return { - 'loss': metrics['summed'].detach(), - 'denominator': metrics['n_valid_examples'].detach(), - } - - @abc.abstractmethod def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] From 73e3ea6679d9b7fc5ccf6d75a2e4b1c9021e8d22 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 02:13:04 +0000 Subject: [PATCH 68/98] Refactor imports and clean up unused code in LM workload and related modules --- algoperf/checkpoint_utils.py | 7 +- algoperf/workloads/lm/input_pipeline.py | 1 - algoperf/workloads/lm/lm_jax/models.py | 3 +- algoperf/workloads/lm/lm_jax/nanodo_model.py | 2 +- algoperf/workloads/lm/lm_pytorch/models.py | 1 + .../workloads/lm/lm_pytorch/plainlm_model.py | 9 +- algoperf/workloads/lm/lm_pytorch/workload.py | 9 +- .../lm/tests/test_build_input_queue_jax.py | 60 --------- .../lm/tests/test_build_input_queue_torch.py | 86 ------------- .../lm/tests/test_hf_input_pipeline.py | 116 ------------------ .../workloads/lm/tests/test_linear_model.py | 39 ------ algoperf/workloads/lm/workload.py | 4 +- .../external_tuning/jax_nadamw_full_budget.py | 2 +- dataset/dataset_setup.py | 1 - 14 files changed, 18 insertions(+), 322 deletions(-) delete mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_jax.py delete mode 100644 algoperf/workloads/lm/tests/test_build_input_queue_torch.py delete mode 100644 algoperf/workloads/lm/tests/test_hf_input_pipeline.py delete mode 100644 algoperf/workloads/lm/tests/test_linear_model.py diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 00f05ba5d..6d61e9d7f 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -5,17 +5,18 @@ """ import os -from typing import Sequence, Tuple, Optional +from typing import Optional, Sequence, Tuple import numpy as np +import orbax.checkpoint as ocp import torch from absl import logging from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -from tensorflow.io import gfile # pytype: disable=import-error -import orbax.checkpoint as ocp from orbax.checkpoint.type_handlers import NumpyHandler +from tensorflow.io import gfile # pytype: disable=import-error + from algoperf import spec from algoperf.pytorch_utils import pytorch_setup diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index e701d1bcb..cfa2f36cd 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -5,7 +5,6 @@ from typing import Optional import jax -import numpy as np import tensorflow as tf from algoperf import data_utils diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py index 72ee5bd83..ae8d935bf 100644 --- a/algoperf/workloads/lm/lm_jax/models.py +++ b/algoperf/workloads/lm/lm_jax/models.py @@ -1,5 +1,6 @@ -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn + class LinearModel(nn.Module): vocab_size: int diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index bd7213620..9126d31e8 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -284,7 +284,7 @@ def main(): model = TransformerDo(cfg) # Print model info - print(f"\nModel Configuration:") + print("\nModel Configuration:") print(f" - Model dimension (D): {cfg.D}") print(f" - Number of heads (H): {cfg.H}") print(f" - Max sequence length (L): {cfg.L}") diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py index 545763924..b88e457d8 100644 --- a/algoperf/workloads/lm/lm_pytorch/models.py +++ b/algoperf/workloads/lm/lm_pytorch/models.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + class LinearLayer(nn.Module): def __init__(self, vocab_size: int): super().__init__() diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 5de5bf310..9dc8be522 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -1,10 +1,10 @@ import math -import torch -import torch.nn.functional as F -from torch import nn from dataclasses import dataclass from typing import Tuple +import torch +import torch.nn.functional as F +from torch import nn @dataclass @@ -257,8 +257,7 @@ def count_params(self, non_embedding=True): n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.embed_tokens.weight.numel() - if (not self.lm_head.weight - is self.embed_tokens.weight): # if no weight tying + if (self.lm_head.weight is not self.embed_tokens.weight): # if no weight tying n_params -= self.lm_head.weight.numel() return n_params diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 6a59770bb..9713e84b0 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -10,12 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.lm.input_pipeline import get_data_iter from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( ModelConfig, Transformer, ) from algoperf.workloads.lm.workload import BaseLmWorkload -from algoperf.workloads.lm.input_pipeline import get_data_iter USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -162,13 +162,10 @@ def _eval_batch(self, """Evaluate the model on a single batch.""" logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - # CRITICAL: Detach tensors to free computation graph and activations - # Without this, all intermediate activations are kept in memory! return { - 'loss': metrics['summed'].detatch(), - 'denominator': metrics['n_valid_examples'].detatch(), + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), } def _normalize_eval_metrics( diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_jax.py b/algoperf/workloads/lm/tests/test_build_input_queue_jax.py deleted file mode 100644 index b9adc70d2..000000000 --- a/algoperf/workloads/lm/tests/test_build_input_queue_jax.py +++ /dev/null @@ -1,60 +0,0 @@ -import jax -import jax.numpy as jnp - -from algoperf.profiler import PassThroughProfiler -from algoperf.workloads.lm.lm_jax.workload import LmWorkload -import os - -RANK = os.environ.get('RANK', 0) - -def test_dataloader_jax(): - # Test config. - rng_seed = 1996 - data_dir = '/home/ak4605/data/finewebedu/' - split = 'train' - global_batch_size = 64 - dtype = jnp.int32 - seq_len = 2048 - - workload = LmWorkload() - data_rng = jax.random.PRNGKey(rng_seed) - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - for _ in range(1): - - batch = next(input_queue) - print(f"RANK {RANK} got batch") - - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - print(f"RANK {RANK} inputs.shape: {inputs.shape}") - print(f"RANK {RANK} targets.shape: {targets.shape}") - print(f"RANK {RANK} type(inputs): {type(inputs)}") - - jax.debug.inspect_array_sharding(inputs, callback=print) - assert inputs.dtype == dtype - assert targets.dtype == dtype - - assert inputs.shape == (global_batch_size, seq_len) - assert targets.shape == (global_batch_size, seq_len) - - assert jnp.equal(inputs[:, 1:], targets[:, :-1]).all() - print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") - - print(f"=== ALL TEST PASSED ===") - - -def main(): - profiler = PassThroughProfiler() - test_dataloader_jax() - - -if __name__ == '__main__': - main() diff --git a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py b/algoperf/workloads/lm/tests/test_build_input_queue_torch.py deleted file mode 100644 index 827272037..000000000 --- a/algoperf/workloads/lm/tests/test_build_input_queue_torch.py +++ /dev/null @@ -1,86 +0,0 @@ -import jax -import torch - -from algoperf.profiler import PassThroughProfiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.lm.lm_pytorch.workload import LmWorkload - -USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() - - -def sync_ddp(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -def test_dataloader_torch(): - # Test config. - rng_seed = 1996 - data_dir = '/home/ak4605/data/finewebedu/' - split = 'train' - global_batch_size = 64 - dtype = torch.int32 - seq_len = 2048 - - local_batch_size = global_batch_size // N_GPUS - - workload = LmWorkload() - - data_rng = jax.random.PRNGKey(rng_seed) - - input_queue = workload._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size) - - print(f"RANK {RANK} of {N_GPUS}") - sync_ddp() - - # batch = next(input_queue) - # inputs, targets = batch['inputs'], batch['targets'] - # print(f"inputs.shape: {inputs.shape}") - # print(f"inputs: {inputs}") - - # Start test. - for _ in range(1): - - batch = next(input_queue) - print(f"RANK {RANK} got batch") - - assert type(batch) == dict - assert 'inputs' in batch - assert 'targets' in batch - - inputs, targets = batch['inputs'], batch['targets'] - print(f"RANK {RANK} inputs.shape: {inputs.shape}") - print(f"RANK {RANK} targets.shape: {targets.shape}") - print(f"RANK {RANK} type(inputs): {type(inputs)}") - assert type(inputs) == torch.Tensor - assert type(targets) == torch.Tensor - - assert inputs.device == DEVICE - assert targets.device == DEVICE - assert inputs.dtype == dtype - assert targets.dtype == dtype - - print(local_batch_size, seq_len) - assert inputs.shape == (local_batch_size, seq_len) - assert targets.shape == (local_batch_size, seq_len) - - assert torch.equal(inputs[:, 1:], targets[:, :-1]) - print(f"RANK {RANK} inputs[0, :10]: {inputs[0, :10]}") - - print(f"=== ALL TEST PASSED ===") - - -def main(): - profiler = PassThroughProfiler() - print(USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS) - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) - test_dataloader_torch() - - -if __name__ == '__main__': - main() diff --git a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py b/algoperf/workloads/lm/tests/test_hf_input_pipeline.py deleted file mode 100644 index 36bab0d02..000000000 --- a/algoperf/workloads/lm/tests/test_hf_input_pipeline.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Tests for LM HuggingFace input pipeline.""" -import os - -import jax -import jax.numpy as jnp -import torch -from transformers import GPT2Tokenizer - -from algoperf.workloads.lm.input_pipeline import get_hf_dataloader - - -def main(): - # Setup test environment - cache_dir = "/home/ak4605/data" - if not os.path.exists(cache_dir): - raise FileNotFoundError(f"Cache directory {cache_dir} not found") - - data_rng = jax.random.PRNGKey(42) - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - vocab_size = tokenizer.vocab_size - - print("Running JAX output shapes and types test...") - batch_size = 8 - seq_len = 32 - loader = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="train", - data_rng=data_rng) - inputs, targets = next(loader) - assert inputs.shape == (batch_size, seq_len, vocab_size), \ - f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" - assert targets.shape == (batch_size, seq_len, vocab_size), \ - f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" - assert inputs.dtype == jnp.float32, \ - f"Expected inputs dtype float32, got {inputs.dtype}" - assert targets.dtype == jnp.float32, \ - f"Expected targets dtype float32, got {targets.dtype}" - assert jnp.all(jnp.sum(inputs, axis=-1) == 1), "Inputs should be one-hot encoded" - assert jnp.all(jnp.sum(targets, axis=-1) == 1), "Targets should be one-hot encoded" - print("✓ JAX test passed") - - print("\nRunning Torch output shapes and types test...") - loader = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="torch", - split="train", - data_rng=data_rng) - inputs, targets = next(loader) - assert inputs.shape == (batch_size, seq_len, vocab_size), \ - f"Expected inputs shape {(batch_size, seq_len, vocab_size)}, got {inputs.shape}" - assert targets.shape == (batch_size, seq_len, vocab_size), \ - f"Expected targets shape {(batch_size, seq_len, vocab_size)}, got {targets.shape}" - assert inputs.dtype == torch.float32, \ - f"Expected inputs dtype float32, got {inputs.dtype}" - assert targets.dtype == torch.float32, \ - f"Expected targets dtype float32, got {targets.dtype}" - assert torch.all(torch.sum(inputs, dim=-1) == 1), "Inputs should be one-hot encoded" - assert torch.all(torch.sum(targets, dim=-1) == 1), "Targets should be one-hot encoded" - print("✓ Torch test passed") - - print("\nTesting consistent batching with same seed...") - loader1 = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="train", - data_rng=jax.random.PRNGKey(42)) - batch1 = next(loader1) - - loader2 = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="train", - data_rng=jax.random.PRNGKey(42)) - batch2 = next(loader2) - - assert jnp.array_equal(batch1[0], batch2[0]), "Input batches should be identical with same seed" - assert jnp.array_equal(batch1[1], batch2[1]), "Target batches should be identical with same seed" - print("✓ Consistent batching test passed") - - print("\nTesting eval split doesn't shuffle...") - loader1 = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="eval", - data_rng=jax.random.PRNGKey(42)) - batch1 = next(loader1) - - loader2 = get_hf_dataloader( - cache_dir=cache_dir, - batch_size=batch_size, - seq_len=seq_len, - framework="jax", - split="eval", - data_rng=jax.random.PRNGKey(999)) - batch2 = next(loader2) - - assert jnp.array_equal(batch1[0], batch2[0]), "Eval inputs should be identical regardless of seed" - assert jnp.array_equal(batch1[1], batch2[1]), "Eval targets should be identical regardless of seed" - print("✓ Eval no shuffling test passed") - - print("\nAll tests passed successfully!") - - -if __name__ == "__main__": - main() diff --git a/algoperf/workloads/lm/tests/test_linear_model.py b/algoperf/workloads/lm/tests/test_linear_model.py deleted file mode 100644 index 31cd1d577..000000000 --- a/algoperf/workloads/lm/tests/test_linear_model.py +++ /dev/null @@ -1,39 +0,0 @@ -import jax -import jax.numpy as jnp -import torch - -TEST_SEQ_LEN = 512 - -def test_pytorch_linear(): - from algoperf.workloads.lm.lm_pytorch.models import LinearLayer - vocab_size = 32000 - model = LinearLayer(vocab_size) - - batch_size = 8 - seq_len = TEST_SEQ_LEN - inputs = torch.randn(batch_size, seq_len, vocab_size) - outputs = model(inputs) - - assert outputs.shape == (batch_size, seq_len, vocab_size) - assert not torch.isnan(outputs).any() - -def test_jax_linear(): - from algoperf.workloads.lm.lm_jax.models import LinearModel - - vocab_size = 32000 - seq_len = TEST_SEQ_LEN - batch_size = 8 - model = LinearModel(vocab_size) - rng = jax.random.PRNGKey(0) - params = model.init(rng, jnp.ones((1, seq_len, vocab_size))) - - inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size)) - outputs = model.apply(params, inputs) - - assert outputs.shape == (batch_size, seq_len, vocab_size) - assert not jnp.isnan(outputs).any() - -if __name__ == '__main__': - test_pytorch_linear() - test_jax_linear() - print("All tests passed!") diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index e0af589e3..f15e4b8a7 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -2,11 +2,11 @@ import abc import math -import numpy as np import os -from typing import Any, Dict, Optional, Iterator +from typing import Any, Dict, Iterator, Optional import jax +import numpy as np from absl import flags from algoperf import spec diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 9b4192de2..ccfa25360 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -11,7 +11,7 @@ Tuple, Union, ) -from absl import logging + # isort: on import chex import jax diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 872e2ef0b..8fecaf419 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -81,7 +81,6 @@ from transformers import AutoTokenizer import functools -import itertools import os import shutil import subprocess From 91988af436f452021e98a61e5144a89d14418e20 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 02:20:31 +0000 Subject: [PATCH 69/98] pass linter checks --- algoperf/checkpoint_utils.py | 73 +-- algoperf/pytorch_utils.py | 6 +- algoperf/workloads/lm/input_pipeline.py | 54 +- algoperf/workloads/lm/lm_jax/models.py | 20 - algoperf/workloads/lm/lm_jax/nanodo_model.py | 575 +++++++++--------- algoperf/workloads/lm/lm_jax/workload.py | 134 ++-- algoperf/workloads/lm/lm_pytorch/models.py | 19 - .../workloads/lm/lm_pytorch/plainlm_model.py | 556 +++++++++-------- algoperf/workloads/lm/lm_pytorch/workload.py | 182 +++--- algoperf/workloads/lm/workload.py | 33 +- algoperf/workloads/workloads.py | 286 ++++----- dataset/dataset_setup.py | 171 +++--- submission_runner.py | 28 +- 13 files changed, 1091 insertions(+), 1046 deletions(-) delete mode 100644 algoperf/workloads/lm/lm_jax/models.py delete mode 100644 algoperf/workloads/lm/lm_pytorch/models.py diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 6d61e9d7f..af05111cd 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -31,49 +31,52 @@ int, ] + class BoolHandler(NumpyHandler): + """ + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. + It works by treating the scalar as a 0-dimensional array. + """ + + def typestr(self) -> str: + """Unique string identifier for this handler.""" + return 'np.bool_' + + async def serialize( + self, + values: Sequence[np.bool_], + infos: Sequence, + args: Optional[Sequence[ocp.SaveArgs]] = None, + ): """ - An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. - It works by treating the scalar as a 0-dimensional array. + Serializes a sequence of np.bool_ scalars by first converting them + to 0-dim numpy arrays and then calling the parent NumpyHandler. """ + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray + array_values = [np.asarray(v, dtype=np.bool_) for v in values] + # Use the parent class's robust serialization logic + return await super().serialize(array_values, infos, args) + + async def deserialize( + self, + infos: Sequence, + args: Optional[Sequence[ocp.RestoreArgs]] = None, + ) -> Sequence[np.bool_]: + """ + Deserializes into a sequence of np.bool_ scalars by calling the + parent handler and then converting the resulting 0-dim arrays. + """ + # Parent deserialize will return a sequence of 0-dimensional np.ndarray + results = await super().deserialize(infos, args) - def typestr(self) -> str: - """Unique string identifier for this handler.""" - return 'np.bool_' + # Convert each 0-d array back to an np.bool_ scalar using .item() + scalar_results = [np.bool_(r.item()) for r in results] + return scalar_results - async def serialize( - self, - values: Sequence[np.bool_], - infos: Sequence, - args: Optional[Sequence[ocp.SaveArgs]] = None, - ): - """ - Serializes a sequence of np.bool_ scalars by first converting them - to 0-dim numpy arrays and then calling the parent NumpyHandler. - """ - # Convert each scalar np.bool_ to a 0-dimensional np.ndarray - array_values = [np.asarray(v, dtype=np.bool_) for v in values] - # Use the parent class's robust serialization logic - return await super().serialize(array_values, infos, args) - - async def deserialize( - self, - infos: Sequence, - args: Optional[Sequence[ocp.RestoreArgs]] = None, - ) -> Sequence[np.bool_]: - """ - Deserializes into a sequence of np.bool_ scalars by calling the - parent handler and then converting the resulting 0-dim arrays. - """ - # Parent deserialize will return a sequence of 0-dimensional np.ndarray - results = await super().deserialize(infos, args) - - # Convert each 0-d array back to an np.bool_ scalar using .item() - scalar_results = [np.bool_(r.item()) for r in results] - return scalar_results ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) + def maybe_restore_checkpoint( framework: str, optimizer_state: spec.OptimizerState, diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index c7537a884..e24b0f141 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -27,7 +27,9 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: return use_pytorch_ddp, rank, device, n_gpus -def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads = True) -> None: +def pytorch_init( + use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads=True +) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. @@ -47,8 +49,10 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_ profiler.set_local_rank(rank) # Only log once (for local rank == 0). if rank != 0: + def logging_pass(*args): pass + logging.info = logging_pass # Initialize the process group. dist.init_process_group('nccl') diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/lm/input_pipeline.py index cfa2f36cd..3007371fc 100644 --- a/algoperf/workloads/lm/input_pipeline.py +++ b/algoperf/workloads/lm/input_pipeline.py @@ -50,14 +50,15 @@ def batch_with_padding( return padded_batched_dataset -def get_data_iter(data_rng: jax.random.PRNGKey, +def get_data_iter( + data_rng: jax.random.PRNGKey, split: str, data_dir: str, batch_size: int, - num_batches: Optional[int] = None,): - + num_batches: Optional[int] = None, +): ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) - + it = map( functools.partial( data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size @@ -67,6 +68,7 @@ def get_data_iter(data_rng: jax.random.PRNGKey, return iter(it) + def get_lm_dataset( data_rng: jax.random.PRNGKey, split: str, @@ -78,7 +80,7 @@ def get_lm_dataset( if split not in TFDS_SPLIT_NAME: raise NotImplementedError - shuffle_seed = jax.random.randint(data_rng, (), -2**31, 2**31-1) + shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1) data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) tokens_ds = tf.data.Dataset.load(data_dir) @@ -98,19 +100,17 @@ def get_lm_dataset( num_parallel_calls=AUTOTUNE, ) if split == 'train': - ds = sequences_ds.shuffle( - SHUFFLE_BUFFER_SIZE, seed=shuffle_seed - ) - ds = ds.batch( - batch_size, drop_remainder=False - ) + ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) + ds = ds.batch(batch_size, drop_remainder=False) ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.repeat() - ds = ds.map(lambda x: { - 'inputs': x['inputs'], - 'targets': x['targets'], - 'weights': None, - }) + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': None, + } + ) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'eval_train': ds = batch_with_padding( @@ -123,10 +123,13 @@ def get_lm_dataset( ) ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.repeat() - ds = ds.map(lambda x: {'inputs': x['inputs'], - 'targets': x['targets'], - 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) - }) + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) elif split == 'validation': ds = batch_with_padding( @@ -139,9 +142,12 @@ def get_lm_dataset( ) ds = ds.take(num_batches) if num_batches is not None else ds ds = ds.repeat() - ds = ds.map(lambda x: {'inputs': x['inputs'], - 'targets': x['targets'], - 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0) - }) + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) return ds diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/lm/lm_jax/models.py deleted file mode 100644 index ae8d935bf..000000000 --- a/algoperf/workloads/lm/lm_jax/models.py +++ /dev/null @@ -1,20 +0,0 @@ -import jax.numpy as jnp -from flax import linen as nn - - -class LinearModel(nn.Module): - vocab_size: int - - @nn.compact - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: - x = nn.Dense( - 10, - kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros - )(inputs) - return nn.Dense( - self.vocab_size, - kernel_init=nn.initializers.normal(0.02), - bias_init=nn.initializers.zeros, - name="output" - )(x) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index 9126d31e8..a1644f569 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -1,4 +1,7 @@ -# Self-contained version of the DecoderOnly Transformer from NanoDO +""" +Originally based on code from the NanoDO repository under the Apache 2.0 license: +https://github.com/google-deepmind/nanodo +""" import dataclasses from functools import partial @@ -7,343 +10,345 @@ import jax.numpy as jnp from flax import linen as nn -# =========== Transformer Decoder-only Model ========== - - @dataclasses.dataclass class DoConfig: - """Hyper-parameters for Transformer decoder-only.""" - - D: int # model/embed dim = qkv dim - H: int # num attention heads - L: int # max context/sequence length - N: int # number of transformer block layers - V: int # vocab size - F: int # FF inner dimension - kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() - embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", out_axis=0 - ) - dtype: jnp.dtype = jnp.float32 - rmsnorm_epsilon: float = 1e-6 - multiple_of: int = 256 - tie_embeddings: bool = True # Whether to tie input and output embeddings + """Hyper-parameters for Transformer decoder-only.""" + + D: int # model/embed dim = qkv dim + H: int # num attention heads + L: int # max context/sequence length + N: int # number of transformer block layers + V: int # vocab size + F: int # FF inner dimension + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0 + ) + dtype: jnp.dtype = jnp.float32 + rmsnorm_epsilon: float = 1e-6 + multiple_of: int = 256 + tie_embeddings: bool = True # Whether to tie input and output embeddings class Mlp(nn.Module): - """Multilayer perceptron with GLU activation.""" - - cfg: DoConfig - - @nn.compact - def __call__(self, x_BxLxD: jax.Array): - cfg = self.cfg - # Use Xavier uniform initialization explicitly - xavier_init = nn.initializers.xavier_uniform() - linear = partial( - nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype - ) - # Adjust hidden dimension to keep the number of parameters invariant to - # the activation function used since the GLU MLP has 3 * hidden_dim * D - # parameters instead of 2 * hidden_dim * D parameters - hidden_dim = cfg.F * 2 / 3 - hidden_dim = cfg.multiple_of * ( - (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of - ) - # Double the hidden dimension for GLU - x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) - # Apply GLU activation - x_BxLxF = nn.glu(x_BxLx2F, axis=-1) - x_BxLxD = linear(cfg.D)(x_BxLxF) - return x_BxLxD - -@partial(jax.jit, static_argnums=(0,1,2)) -def init_rope(dim=256, seq_len=128, n_heads=4): - """Initialize rotary embeddings.""" - def precompute_freqs_cis_jax(dim, end, theta=10000.0): - inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) - t = jnp.arange(end) / 1.0 - freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) - return jnp.stack([ - jnp.cos(freqs)[None, :, None, :], - jnp.sin(freqs)[None, :, None, :] - ], axis=3) - - freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) - return freqs_cis.transpose(0, 1, 2, 4, 3) + """Multilayer perceptron with GLU activation.""" -@jax.jit -def apply_rope(q, k, freqs_cis): - """Apply rotary embeddings to Q and K.""" - def rotate_tensor(x): - # Split into real and imaginary parts - x_r2 = x.reshape(*x.shape[:-1], -1, 2) - L = x.shape[1] - freqs = freqs_cis[:, :L, :, :, :] + cfg: DoConfig - # Apply rotation - rotated_x_r2 = jnp.stack([ - x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], - x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] - ], axis=-1) + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + # Use Xavier uniform initialization explicitly + xavier_init = nn.initializers.xavier_uniform() + linear = partial( + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + ) + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters + hidden_dim = cfg.F * 2 / 3 + hidden_dim = cfg.multiple_of * ( + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = linear(cfg.D)(x_BxLxF) + return x_BxLxD - return rotated_x_r2.reshape(*x.shape) - # Apply rotation to Q and K separately - rotated_q = rotate_tensor(q) - rotated_k = rotate_tensor(k) +@partial(jax.jit, static_argnums=(0, 1, 2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack( + [jnp.cos(freqs)[None, :, None, :], jnp.sin(freqs)[None, :, None, :]], + axis=3, + ) - return rotated_q, rotated_k + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) -class CausalAttn(nn.Module): - """Causal attention layer with rotary embeddings.""" +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack( + [ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1], + ], + axis=-1, + ) - cfg: DoConfig + return rotated_x_r2.reshape(*x.shape) - def setup(self): - cfg = self.cfg - assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" - self.Dh = cfg.D // cfg.H + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) - # Initialize rotary embeddings - self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + return rotated_q, rotated_k - # Maps D -> (H, Dh) - self.multilinear = partial( - nn.DenseGeneral, - axis=-1, - features=(cfg.H, self.Dh), - kernel_init=cfg.kernel_init, - use_bias=False, - dtype=cfg.dtype, - ) - self.multilinear_query = self.multilinear(name="query") - self.multilinear_key = self.multilinear(name="key") - self.multilinear_value = self.multilinear(name="value") - self.output_projection = nn.DenseGeneral( - features=cfg.D, - name="attn_out_proj", - # axis=(-2, -1), # - kernel_init=cfg.kernel_init, - use_bias=False, - dtype=cfg.dtype, - ) +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: DoConfig + + def setup(self): + cfg = self.cfg + assert cfg.D % cfg.H == 0, f'D {cfg.D} not divisible by H {cfg.H}' + self.Dh = cfg.D // cfg.H + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.H, self.Dh), + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) + + self.multilinear_query = self.multilinear(name='query') + self.multilinear_key = self.multilinear(name='key') + self.multilinear_value = self.multilinear(name='value') + self.output_projection = nn.DenseGeneral( + features=cfg.D, + name='attn_out_proj', + # axis=(-2, -1), # + kernel_init=cfg.kernel_init, + use_bias=False, + dtype=cfg.dtype, + ) - def __call__(self, x_BxLxD: jax.Array): - cfg = self.cfg + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg - # Project inputs to Q, K, V - q_BxLxHxDh = self.multilinear_query(x_BxLxD) - k_BxLxHxDh = self.multilinear_key(x_BxLxD) - v_BxLxHxDh = self.multilinear_value(x_BxLxD) + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) - # Apply rotary embeddings to Q and K - q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) - # Scale queries - q_BxLxHxDh /= self.Dh**0.5 + # Scale queries + q_BxLxHxDh /= self.Dh**0.5 - # Compute attention scores - att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) + # Compute attention scores + att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) - # Causal attention mask - L = x_BxLxD.shape[1] - mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) - # Apply mask and softmax - _NEG_INF = jnp.finfo(cfg.dtype).min - att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) - att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) - att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) - # Compute attention output - out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) + # Compute attention output + out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) - # Reshape and project output - out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) - # Output projection - out_BxLxD = self.output_projection(out_BxLxD) + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) - return out_BxLxD + return out_BxLxD class TBlock(nn.Module): - """Transformer Block.""" + """Transformer Block.""" - docfg: DoConfig + docfg: DoConfig - @nn.compact - def __call__(self, in_BxLxD: jax.Array): - cfg = self.docfg + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg - # x = x + attn( attn_norm(x) ) - x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - in_BxLxD - ) - x_BxLxD = CausalAttn(cfg)(x_BxLxD) - x_BxLxD += in_BxLxD + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD - # x = x + mlp( mlp_norm(x) ) - z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - x_BxLxD - ) - z_BxLxD = Mlp(cfg)(z_BxLxD) + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) - return x_BxLxD + z_BxLxD + return x_BxLxD + z_BxLxD class TransformerDo(nn.Module): - """Transformer decoder-only.""" - - docfg: DoConfig - - def setup(self): - cfg = self.docfg - self.embed = nn.Embed( - num_embeddings=cfg.V, - features=cfg.D, - embedding_init=cfg.embed_init, - ) - - self.blocks = [TBlock(cfg) for _ in range(cfg.N)] - self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) - - # Output projection - tied to input embeddings if configured - if cfg.tie_embeddings: - self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) - else: - self.output_proj = nn.Dense( - cfg.V, - kernel_init=cfg.embed_init, - dtype=cfg.dtype, - name="output_proj" - ) - - def __call__(self, y_BxL: jax.Array): - # For training on concatenated examples. - y_BxLxD = self.embed(y_BxL) - for block in self.blocks: - y_BxLxD = block(y_BxLxD) - y_BxLxD = self.out_ln(y_BxLxD) - logits_BxLxV = self.output_proj(y_BxLxD) - return logits_BxLxV - - def predict(self, y_BxL: jax.Array, k: int = 1): - """Generate k tokens autoregressively. - - Args: - y_BxL: Input token sequence of shape (batch_size, seq_len) - k: Number of tokens to predict - - Returns: - Tuple of (input_ids, predicted_ids) - """ - cfg = self.docfg - batch_size = y_BxL.shape[0] - seq_len = y_BxL.shape[1] - - # Store original input - original_input = y_BxL - - # Make sure we don't exceed the model's context length - if seq_len + k > cfg.L: - raise ValueError( - f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" - ) - - # Generate k tokens autoregressively - for _ in range(k): - # Get logits for the entire sequence - logits = self(y_BxL) - - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] - - # Get the most likely token - next_token = jnp.argmax(next_token_logits, axis=-1) - - # Append the predicted token to the sequence - y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) - - # Return original input and the k predicted tokens - return original_input, y_BxL[:, -k:] + """Transformer decoder-only.""" + docfg: DoConfig -# =========== Demo Code ========== + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.V, + features=cfg.D, + embedding_init=cfg.embed_init, + ) + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) -def main(): - """Create and run the DecoderOnly Transformer model.""" - # Initialize model configuration with smaller parameters for demo - B, L = (2, 128) # Batch size, sequence length - cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) - model = TransformerDo(cfg) - - # Print model info - print("\nModel Configuration:") - print(f" - Model dimension (D): {cfg.D}") - print(f" - Number of heads (H): {cfg.H}") - print(f" - Max sequence length (L): {cfg.L}") - print(f" - Number of layers (N): {cfg.N}") - print(f" - Vocabulary size (V): {cfg.V}") - print(f" - Feed forward dimension (F): {cfg.F}") - - # Create random input tokens (simulated token IDs) - rng_key = jax.random.PRNGKey(42) - input_rng, init_rng = jax.random.split(rng_key) - - # Generate random token IDs (integers between 0 and vocab_size-1) - x_BxL = jax.random.randint( - input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 - ) + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.V, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' + ) - # Initialize model parameters - print("\nInitializing model parameters...") - params = model.init(init_rng, x_BxL) + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV - # Print parameter count - param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) - print(f"Total parameters: {param_count:,}") + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. - # Make a prediction (forward pass) - print("\nRunning forward pass...") - logits = model.apply(params, x_BxL) + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict - # Print output shape and sample values - print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") - print(f"Output data type: {logits.dtype}") + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + seq_len = y_BxL.shape[1] - # Print sample logits (first 5 positions of the first sequence) - print("\nSample logits (first sequence, first 5 positions, first 5 values):") - for position in range(min(5, L)): - print(f" Position {position}: {logits[0, position, :5]}") + # Store original input + original_input = y_BxL - # Get predictions (token with highest logit at each position) - predictions = jnp.argmax(logits, axis=-1) - print("\nPredicted token IDs (first sequence, first 10 positions):") - print(predictions[0, :10]) + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.L: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + ) - # Test the predict function - print("\nTesting predict function...") - # Use a shorter - short_seq = x_BxL[:, :10] - print(f"Input sequence shape: {short_seq.shape}") + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) - # Predict 5 tokens - k = 5 - original, predicted = model.apply(params, short_seq, k, method=model.predict) + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] - # Get predictions (token with highest logit at each position) - predictions = jnp.argmax(logits, axis=-1) - print("\nPredicted token IDs (first sequence, first 10 positions):") - print(predictions[0, :10]) + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) - print("\nDone!") + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] -if __name__ == "__main__": - main() + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + model = TransformerDo(cfg) + + # Print model info + print('\nModel Configuration:') + print(f' - Model dimension (D): {cfg.D}') + print(f' - Number of heads (H): {cfg.H}') + print(f' - Max sequence length (L): {cfg.L}') + print(f' - Number of layers (N): {cfg.N}') + print(f' - Vocabulary size (V): {cfg.V}') + print(f' - Feed forward dimension (F): {cfg.F}') + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + ) + + # Initialize model parameters + print('\nInitializing model parameters...') + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f'Total parameters: {param_count:,}') + + # Make a prediction (forward pass) + print('\nRunning forward pass...') + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print( + f'\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)' + ) + print(f'Output data type: {logits.dtype}') + + # Print sample logits (first 5 positions of the first sequence) + print('\nSample logits (first sequence, first 5 positions, first 5 values):') + for position in range(min(5, L)): + print(f' Position {position}: {logits[0, position, :5]}') + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + # Test the predict function + print('\nTesting predict function...') + # Use a shorter + short_seq = x_BxL[:, :10] + print(f'Input sequence shape: {short_seq.shape}') + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + print('\nDone!') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index ad7eac8aa..5b736fad7 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -16,47 +16,52 @@ class LmWorkload(BaseLmWorkload): """LM JAX workload.""" - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None): + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): """Build an input queue using pre-cached FineWeb dataset.""" del cache, repeat_final_dataset ds = get_data_iter( - data_rng=data_rng, - split=split, - data_dir=data_dir, - batch_size=global_batch_size, - num_batches=num_batches) + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=global_batch_size, + num_batches=num_batches, + ) ds = map(jax_sharding_utils.shard_along_batch_dim, ds) return ds def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: # Initialize NanoDO transformer model cfg = DoConfig( - D=self._emb_dim, # embedding dim - H=self._n_heads, # num heads - L=self._seq_len, - N=self._n_layers, # num layers - V=self._vocab_size, - F=self._mlp_dim, # feedforward dim - dtype=jnp.float32 + D=self._emb_dim, # embedding dim + H=self._n_heads, # num heads + L=self._seq_len, + N=self._n_layers, # num layers + V=self._vocab_size, + F=self._mlp_dim, # feedforward dim + dtype=jnp.float32, ) self._model = TransformerDo(cfg) input_shape = (1, self._seq_len) # For token IDs params_rng, init_rng = jax.random.split(rng) - variables = jax.jit(self._model.init)({'params': params_rng}, - jnp.ones(input_shape, jnp.int32)) + variables = jax.jit(self._model.init)( + {'params': params_rng}, jnp.ones(input_shape, jnp.int32) + ) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -65,14 +70,15 @@ def init_model_fn( return params, model_state def model_fn( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] # Convert one-hot inputs to token IDs if needed @@ -81,14 +87,13 @@ def model_fn( logits = self._model.apply({'params': params}, inputs) return logits, None - def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0, - ) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. @@ -110,15 +115,15 @@ def compute_weighted_cross_entropy( # Extract log probability of the target class # Shape: [batch, length] target_log_probs = jnp.take_along_axis( - log_probs, - targets[..., None], - axis=-1 + log_probs, targets[..., None], axis=-1 ).squeeze(-1) # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. confidence = 1.0 - label_smoothing smoothing_term = label_smoothing / self._vocab_size - per_example_losses = -1.0 * (confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1)) + per_example_losses = -1.0 * ( + confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1) + ) if weights is not None: per_example_losses = jnp.where(weights, per_example_losses, 0.0) n_valid_examples = weights.sum() @@ -131,22 +136,27 @@ def compute_weighted_cross_entropy( 'per_example': per_example_losses, } - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - # Calculate cross-entropy loss - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - # CRITICAL: Detach tensors to free computation graph and activations - # Without this, all intermediate activations are kept in memory! - return { - 'loss': metrics['summed'], - 'denominator': metrics['n_valid_examples'], - } + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + # Calculate cross-entropy loss + metrics = self.compute_weighted_cross_entropy( + logits, batch['targets'], batch['weights'] + ) + # CRITICAL: Detach tensors to free computation graph and activations + # Without this, all intermediate activations are kept in memory! + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/lm/lm_pytorch/models.py deleted file mode 100644 index b88e457d8..000000000 --- a/algoperf/workloads/lm/lm_pytorch/models.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn as nn - - -class LinearLayer(nn.Module): - def __init__(self, vocab_size: int): - super().__init__() - self.bottleneck = nn.Linear(vocab_size, 512) - self.output = nn.Linear(512, vocab_size) - self.reset_parameters() - - def reset_parameters(self): - nn.init.normal_(self.bottleneck.weight, std=0.02) - nn.init.zeros_(self.bottleneck.bias) - nn.init.normal_(self.output.weight, std=0.02) - nn.init.zeros_(self.output.bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.output(self.bottleneck(x)) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 9dc8be522..f7e7f9e62 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -1,3 +1,9 @@ +""" +Originally based on the plainLM codebase: +https://github.com/Niccolo-Ajroldi/plainLM +under the MIT license https://github.com/Niccolo-Ajroldi/plainLM/blob/main/LICENSE. +""" + import math from dataclasses import dataclass from typing import Tuple @@ -9,299 +15,313 @@ @dataclass class ModelConfig: - vocab_size: int - seq_len: int - dim: int - expand: float - n_layers: int - n_heads: int - rmsnorm_eps: float = 1e-6 - tie_embeddings: bool = True + vocab_size: int + seq_len: int + dim: int + expand: float + n_layers: int + n_heads: int + rmsnorm_eps: float = 1e-6 + tie_embeddings: bool = True class MLP(nn.Module): - - def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): - super().__init__() - hidden_dim = multiple_of * ( - (hidden_dim + multiple_of - 1) // multiple_of) - self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) - self.fc2 = nn.Linear(hidden_dim, dim, bias=False) - self.glu = nn.GLU(dim=2) - - # Initialize with Xavier uniform - nn.init.xavier_uniform_(self.fc1.weight) - nn.init.xavier_uniform_(self.fc2.weight) - - def forward(self, x): - # x: (bsz, T, dim) - return self.fc2(self.glu(self.fc1(x))) - - -def precompute_freqs_cis(dim: int, - end: int, - theta: float = 10000.0, - condense_ratio: int = 1): - inv_freqs = 1.0 / (theta**(torch.arange( - 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) - t = torch.arange(end, dtype=torch.float32, - device=inv_freqs.device) / condense_ratio - freqs = torch.outer(t, inv_freqs).float() - return torch.stack([ - torch.cos(freqs)[None, :, None, :], - torch.sin(freqs)[None, :, None, :] - ], - dim=4) + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + + # Initialize with Xavier uniform + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1 +): + inv_freqs = 1.0 / ( + theta + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.device('cpu')) + / dim + ) + ) + t = ( + torch.arange(end, dtype=torch.float32, device=inv_freqs.device) + / condense_ratio + ) + freqs = torch.outer(t, inv_freqs).float() + return torch.stack( + [torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], + dim=4, + ) def apply_rotary_emb_complex_like( - q: torch.Tensor, k: torch.Tensor, - freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - # Rotate query and key vectors using RoPE - qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() - rotated_qk_r2 = torch.stack( - [ - qk_r2[..., 0] * freqs_cis[..., 0] - - qk_r2[..., 1] * freqs_cis[..., 1], - qk_r2[..., 1] * freqs_cis[..., 0] + - qk_r2[..., 0] * freqs_cis[..., 1], - ], - -1, - ).flatten(3) - rotated_qk = rotated_qk_r2 - return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) class Attention(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.dim % cfg.n_heads == 0 + self.dim = cfg.dim + self.n_heads = cfg.n_heads + self.head_dim = cfg.dim // cfg.n_heads - def __init__(self, cfg: ModelConfig): - super().__init__() - assert cfg.dim % cfg.n_heads == 0 - self.dim = cfg.dim - self.n_heads = cfg.n_heads - self.head_dim = cfg.dim // cfg.n_heads - - self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) - self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) - def forward(self, x, freqs_cis): - bsz, seqlen, d = x.shape # (bsz, seqlen, d) + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) - q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) - q = q.view(bsz, seqlen, self.n_heads, - self.head_dim) # (bsz, seqlen, nh, h_dim) - k = k.view(bsz, seqlen, self.n_heads, - self.head_dim) # (bsz, seqlen, nh, h_dim) - v = v.view(bsz, seqlen, self.n_heads, - self.head_dim) # (bsz, seqlen, nh, h_dim) + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + k = k.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + v = v.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) - q, k = apply_rotary_emb_complex_like( - q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis + ) # (bsz, seqlen, nh, h_dim) - q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) - k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) - v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) - out = F.scaled_dot_product_attention( - q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True + ) # (bsz, nh, seqlen, h_dim) - out = out.transpose(1, 2).contiguous().view(bsz, seqlen, - d) # (bsz, seqlen, d) + out = ( + out.transpose(1, 2).contiguous().view(bsz, seqlen, d) + ) # (bsz, seqlen, d) - return self.w_out(out) + return self.w_out(out) class Block(nn.Module): + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.layer_id = layer_id - def __init__(self, layer_id: int, cfg: ModelConfig): - super().__init__() - self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) - self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.layer_id = layer_id - - def forward(self, x, freqs_cis): - # x: (bsz, seqlen, dim) - x = x + self.attn(self.attn_norm(x), freqs_cis) - x = x + self.mlp(self.mlp_norm(x)) - return x + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x class Transformer(nn.Module): - - def __init__(self, cfg): - super().__init__() - self.n_layers = cfg.n_layers - self.cfg = cfg - head_dim = cfg.dim // cfg.n_heads - assert cfg.dim % cfg.n_heads == 0 - - self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) - self.layers = nn.ModuleList( - [Block(idx, cfg) for idx in range(cfg.n_layers)]) - self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) - - # Initialize freqs_cis on CPU first (more memory efficient) - self.register_buffer('freqs_cis', - precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], - persistent=False) - - # init all weights, scale residual branches - self.apply(self._init_weights) - self._scale_residual_branches() - - # Move model to device (which will also move freqs_cis) - if torch.cuda.is_available(): - self.cuda() - - if cfg.tie_embeddings: - self.tie_weights() - - def forward(self, x, targets=None): - # x: (bsz, seqlen) - x = self.embed_tokens(x) # (bsz, seqlen, dim) - L = x.shape[1] - - # Make sure we have enough precomputed frequencies - if L > self.freqs_cis.shape[1]: - # Need to recompute for longer sequence - head_dim = self.cfg.dim // self.cfg.n_heads - new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) - self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) - if torch.cuda.is_available(): - self.freqs_cis = self.freqs_cis.cuda() - - # Select the frequencies for current sequence length and ensure correct device - freqs_cis = self.freqs_cis[:, :L, :].to(x.device) - - for layer in self.layers: - x = layer(x, freqs_cis) # (bsz, seqlen, dim) - out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) - if targets is not None: - loss = F.cross_entropy( - out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100) - return out, loss - return out - - def predict(self, x, k=1): - """Generate k tokens autoregressively. - - Args: - x: Input token sequence of shape (batch_size, seq_len) - k: Number of tokens to predict - - Returns: - Tuple of (input_ids, predicted_ids) - """ - - # Store original input - original_input = x.clone() - generated_input = x.clone() - - # Generate k tokens autoregressively - for i in range(k): - - # Get logits for the entire sequence - logits = self(generated_input) - - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] - - # Zero out the last token ID to prevent repetition - # This is a common issue - the model gets stuck repeating the last token - last_token_id = generated_input[:, -1] - next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - - # Get the most likely token - next_token = torch.argmax(next_token_logits, dim=-1) - - # Append the predicted token to the sequence - next_token = next_token.unsqueeze(1) # Add sequence dimension - generated_input = torch.cat([generated_input, next_token], dim=1) - - # For debugging, print predictions for the first item in the batch - print("\nPyTorch detailed prediction (first item in batch):") - predicted_sequence = generated_input[0, -k:].tolist() - print(f" Predicted token IDs: {predicted_sequence}") - for i, token_id in enumerate(predicted_sequence): - print(f" Step {i+1}: Predicted token {token_id}") - - # Return all tokens, not just the last k - return original_input, generated_input[:, -k:] - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def _scale_residual_branches(self): - for n, p in self.named_parameters(): - if n.endswith("fc2.weight"): # mlp/glu output layer - torch.nn.init.normal_(p, - mean=0.0, - std=0.02 / math.sqrt(2 * self.n_layers)) - if n.endswith("w_out.weight"): # attn output layer - torch.nn.init.normal_(p, - mean=0.0, - std=0.02 / math.sqrt(2 * self.n_layers)) - - def tie_weights(self): - self.lm_head.weight = self.embed_tokens.weight - - def count_params(self, non_embedding=True): - n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: - n_params -= self.embed_tokens.weight.numel() - if (self.lm_head.weight is not self.embed_tokens.weight): # if no weight tying - n_params -= self.lm_head.weight.numel() - return n_params - - -def main(): - print("Initializing transformer model and running forward pass...") - - seq_length = 1024 - - # Define model configuration - config = ModelConfig( - vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece - seq_len=seq_length, # Maximum sequence length - dim=1024, # Embedding dimension - expand=4.0, # MLP expansion factor - n_layers=12, # Number of transformer layers - n_heads=8, # Number of attention heads - rmsnorm_eps=1e-6, # RMSNorm epsilon - tie_embeddings=True # Tie embedding and output weights + def __init__(self, cfg): + super().__init__() + self.n_layers = cfg.n_layers + self.cfg = cfg + head_dim = cfg.dim // cfg.n_heads + assert cfg.dim % cfg.n_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.n_layers)] + ) + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer( + 'freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0 : cfg.seq_len], + persistent=False, ) - # Instantiate the model - model = Transformer(config) - print(f"Model has {model.count_params():,} parameters.") - - # Create some random input data - batch_size = 2 - input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() - # Move data to the same device as the model + # Move model to device (which will also move freqs_cis) if torch.cuda.is_available(): - input_ids = input_ids.cuda() - - # Run a forward pass - print(f"Running forward pass with input shape: {input_ids.shape}") - logits = model(input_ids) - print(f"Output logits shape: {logits.shape}") - - # Run prediction - print("Running prediction...") - original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) - print(f"Original input shape for prediction: {original_input.shape}") - print(f"Predicted IDs shape: {predicted_ids.shape}") - print(f"Predicted IDs: {predicted_ids}") - -if __name__ == "__main__": - main() + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x, targets=None): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.dim // self.cfg.n_heads + new_freqs = precompute_freqs_cis( + head_dim, max(L, self.cfg.seq_len), 500000 + ) + self.register_buffer( + 'freqs_cis', new_freqs[0 : max(L, self.cfg.seq_len)], persistent=False + ) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: + loss = F.cross_entropy( + out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 + ) + return out, loss + return out + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith('fc2.weight'): # mlp/glu output layer + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) + ) + if n.endswith('w_out.weight'): # attn output layer + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if ( + self.lm_head.weight is not self.embed_tokens.weight + ): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print('Initializing transformer model and running forward pass...') + + seq_length = 1024 + + # Define model configuration + config = ModelConfig( + vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + dim=1024, # Embedding dimension + expand=4.0, # MLP expansion factor + n_layers=12, # Number of transformer layers + n_heads=8, # Number of attention heads + rmsnorm_eps=1e-6, # RMSNorm epsilon + tie_embeddings=True, # Tie embedding and output weights + ) + + # Instantiate the model + model = Transformer(config) + print(f'Model has {model.count_params():,} parameters.') + + # Create some random input data + batch_size = 2 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + + # Move data to the same device as the model + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + + # Run a forward pass + print(f'Running forward pass with input shape: {input_ids.shape}') + logits = model(input_ids) + print(f'Output logits shape: {logits.shape}') + + # Run prediction + print('Running prediction...') + original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) + print(f'Original input shape for prediction: {original_input.shape}') + print(f'Predicted IDs shape: {predicted_ids.shape}') + print(f'Predicted IDs: {predicted_ids}') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 9713e84b0..115fae4f6 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -24,28 +24,28 @@ class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: if hasattr(self, '_model'): - # Reinitialize weights but keep same config - self._model.apply(self._model._init_weights) - self._model._scale_residual_branches() - return self._model, None + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() + return self._model, None torch.manual_seed(rng[0]) cfg = ModelConfig( - vocab_size=self._vocab_size, - seq_len=self._seq_len, - dim=self._emb_dim, # Model dimension - expand=self._mlp_dim // self._emb_dim, # MLP expansion factor - # FIXME(rka97): fix expansion factor - n_layers=self._n_layers, # Number of transformer layers - n_heads=self._n_heads, # Number of attention heads - rmsnorm_eps=1e-6, - tie_embeddings=True + vocab_size=self._vocab_size, + seq_len=self._seq_len, + dim=self._emb_dim, # Model dimension + expand=self._mlp_dim // self._emb_dim, # MLP expansion factor + # FIXME(rka97): fix expansion factor + n_layers=self._n_layers, # Number of transformer layers + n_heads=self._n_heads, # Number of attention heads + rmsnorm_eps=1e-6, + tie_embeddings=True, ) self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) @@ -53,23 +53,23 @@ def init_model_fn( self._model.to(DEVICE) if N_GPUS > 1: - if USE_PYTORCH_DDP: - self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) - else: - self._model = torch.nn.DataParallel(self._model) + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) return self._model, None def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state, rng, update_batch_norm, dropout_rate model = params @@ -93,32 +93,39 @@ def model_fn( return logits, None def _build_input_queue( - self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: """Build an input queue for the given split.""" del cache, repeat_final_dataset local_batch_size = global_batch_size // N_GPUS loader = get_data_iter( - data_rng=data_rng, - split=split, - data_dir=data_dir, - batch_size=local_batch_size, - num_batches=num_batches, + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=local_batch_size, + num_batches=num_batches, ) if USE_PYTORCH_DDP: - loader = islice(loader, RANK, None, N_GPUS) + loader = islice(loader, RANK, None, N_GPUS) dtype = torch.int32 for batch in loader: batch = { - 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), - 'targets': torch.tensor(batch['targets'], device=DEVICE, dtype=torch.int64), - 'weights': torch.tensor(batch['weights'], device=DEVICE, dtype=torch.float32) if batch['weights'] is not None else None, + 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), + 'targets': torch.tensor( + batch['targets'], device=DEVICE, dtype=torch.int64 + ), + 'weights': torch.tensor( + batch['weights'], device=DEVICE, dtype=torch.float32 + ) + if batch['weights'] is not None + else None, } yield batch @@ -127,7 +134,13 @@ def is_output_params(self, param_name: str) -> bool: return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name # FIXME(rka97): Implement label smoothing - def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor, label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + def compute_weighted_cross_entropy( + self, + logits: spec.Tensor, + labels: spec.Tensor, + weights: spec.Tensor, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in PyTorch.""" vocab_size = logits.size(-1) @@ -138,44 +151,53 @@ def compute_weighted_cross_entropy(self, logits: spec.Tensor, labels: spec.Tenso else: # Dense labels loss = torch.nn.functional.cross_entropy( - logits.view(-1, vocab_size), - labels.view(-1), - reduction='none') + logits.view(-1, vocab_size), labels.view(-1), reduction='none' + ) loss = loss.view_as(labels) if weights is not None: loss = loss * weights - n_valid = weights.sum() if weights is not None else torch.tensor(labels.numel(), dtype=torch.float32, device=labels.device) + n_valid = ( + weights.sum() + if weights is not None + else torch.tensor( + labels.numel(), dtype=torch.float32, device=labels.device + ) + ) return { - 'summed': loss.sum(), - 'n_valid_examples': n_valid, - 'per_example': loss, + 'summed': loss.sum(), + 'n_valid_examples': n_valid, + 'per_example': loss, } - - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> spec.Tensor: - """Evaluate the model on a single batch.""" - logits, _ = self.model_fn( - params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False) - metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) - return { - 'loss': metrics['summed'].detach(), - 'denominator': metrics['n_valid_examples'].detach(), - } + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + metrics = self.compute_weighted_cross_entropy( + logits, batch['targets'], batch['weights'] + ) + return { + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), + } def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, Any] - ) -> Dict[str, float]: - """Normalize eval metrics.""" - del num_examples - if USE_PYTORCH_DDP: - for metric in total_metrics.values(): - dist.all_reduce(metric) - total_metrics = {k: v.item() for k, v in total_metrics.items()} - eval_denominator = total_metrics.pop('denominator') - return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) \ No newline at end of file + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index f15e4b8a7..43dd60ab5 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -45,11 +45,11 @@ def validation_target_value(self) -> float: return 25.5477 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: - return True # No test targets + return True # No test targets @property def test_target_value(self) -> float: - return None # No test targets + return None # No test targets @property def loss_type(self) -> spec.LossType: @@ -61,11 +61,11 @@ def num_train_examples(self) -> int: @property def num_eval_train_examples(self) -> int: - return 10_000 # Subset for evaluation. + return 10_000 # Subset for evaluation. @property def num_validation_examples(self) -> int: - return 100_000 # sequences + return 100_000 # sequences @property def num_test_examples(self) -> int: @@ -85,7 +85,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 14 # 14 hours TODO(kasimbeg): update + return 3600 * 14 # 14 hours TODO(kasimbeg): update @property def eval_period_time_sec(self) -> int: @@ -125,7 +125,6 @@ def _build_input_queue( ) -> Iterator[Dict[str, Any]]: """Build an input queue for the given split.""" - def _eval_model_on_split( self, split: str, @@ -147,11 +146,7 @@ def _eval_model_on_split( if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches=num_batches + rng, split, data_dir, global_batch_size, num_batches=num_batches ) eval_metrics = {} @@ -167,7 +162,6 @@ def _eval_model_on_split( eval_results['ppl'] = np.exp(eval_results['loss']).item() return eval_results - @abc.abstractmethod def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any] @@ -175,19 +169,20 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" def loss_fn( - self, - label_batch: spec.Tensor, - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: """Compute cross-entropy loss for language modeling in JAX.""" return self.compute_weighted_cross_entropy( logits_batch, label_batch, weights=mask_batch, - label_smoothing=label_smoothing + label_smoothing=label_smoothing, ) def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" - return param_name.contains('output') \ No newline at end of file + return param_name.contains('output') diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 114b1adb4..391f16f51 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -9,151 +9,153 @@ BASE_WORKLOADS_DIR = 'algoperf/workloads/' WORKLOADS = { - 'cifar': { - 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload' - }, - 'criteo1tb': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', - }, - 'criteo1tb_test': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', - }, - 'criteo1tb_layernorm': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' - }, - 'criteo1tb_embed_init': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload' - }, - 'criteo1tb_resnet': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' - }, - 'fastmri': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIWorkload', - }, - 'fastmri_model_size': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIModelSizeWorkload', - }, - 'fastmri_tanh': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRITanhWorkload', - }, - 'fastmri_layernorm': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRILayerNormWorkload', - }, - 'imagenet_resnet': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetWorkload', - }, - 'imagenet_resnet_silu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetSiLUWorkload', - }, - 'imagenet_resnet_gelu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetGELUWorkload', - }, - 'imagenet_resnet_large_bn_init': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', - }, - 'imagenet_vit': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitWorkload', - }, - 'imagenet_vit_glu': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitGluWorkload', - }, - 'imagenet_vit_post_ln': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitPostLNWorkload', - }, - 'imagenet_vit_map': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitMapWorkload', - }, - 'librispeech_conformer': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerWorkload', - }, - 'librispeech_conformer_attention_temperature': { - 'workload_path': - 'librispeech_conformer/librispeech', - 'workload_class_name': - 'LibriSpeechConformerAttentionTemperatureWorkload', - }, - 'librispeech_conformer_layernorm': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', - }, - 'librispeech_conformer_gelu': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerGeluWorkload', - }, - 'librispeech_deepspeech': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', - }, - 'librispeech_deepspeech_tanh': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', - }, - 'librispeech_deepspeech_no_resnet': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', - }, - 'librispeech_deepspeech_norm_and_spec_aug': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', - }, - 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, - 'mnist': { - 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' - }, - 'ogbg': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload' - }, - 'ogbg_gelu': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload' - }, - 'ogbg_silu': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' - }, - 'ogbg_model_size': { - 'workload_path': 'ogbg/ogbg', - 'workload_class_name': 'OgbgModelSizeWorkload' - }, - 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, - 'wmt_post_ln': { - 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN' - }, - 'wmt_attention_temp': { - 'workload_path': 'wmt/wmt', - 'workload_class_name': 'WmtWorkloadAttentionTemp' - }, - 'wmt_glu_tanh': { - 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadGLUTanH' - }, + 'cifar': { + 'workload_path': 'cifar/cifar', + 'workload_class_name': 'CifarWorkload', + }, + 'criteo1tb': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', + }, + 'criteo1tb_test': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', + }, + 'criteo1tb_layernorm': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload', + }, + 'criteo1tb_embed_init': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload', + }, + 'criteo1tb_resnet': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload', + }, + 'fastmri': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIWorkload', + }, + 'fastmri_model_size': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIModelSizeWorkload', + }, + 'fastmri_tanh': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRITanhWorkload', + }, + 'fastmri_layernorm': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRILayerNormWorkload', + }, + 'imagenet_resnet': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetWorkload', + }, + 'imagenet_resnet_silu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetSiLUWorkload', + }, + 'imagenet_resnet_gelu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetGELUWorkload', + }, + 'imagenet_resnet_large_bn_init': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', + }, + 'imagenet_vit': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitWorkload', + }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitMapWorkload', + }, + 'librispeech_conformer': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerWorkload', + }, + 'librispeech_conformer_attention_temperature': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload', + }, + 'librispeech_conformer_layernorm': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', + }, + 'librispeech_conformer_gelu': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerGeluWorkload', + }, + 'librispeech_deepspeech': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', + }, + 'librispeech_deepspeech_tanh': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', + }, + 'librispeech_deepspeech_no_resnet': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', + }, + 'librispeech_deepspeech_norm_and_spec_aug': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', + }, + 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, + 'mnist': { + 'workload_path': 'mnist/mnist', + 'workload_class_name': 'MnistWorkload', + }, + 'ogbg': {'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'}, + 'ogbg_gelu': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgGeluWorkload', + }, + 'ogbg_silu': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgSiluWorkload', + }, + 'ogbg_model_size': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgModelSizeWorkload', + }, + 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, + 'wmt_post_ln': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadPostLN', + }, + 'wmt_attention_temp': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadAttentionTemp', + }, + 'wmt_glu_tanh': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadGLUTanH', + }, } BASE_WORKLOADS = [ - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'lm', - 'ogbg', - 'wmt' + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'lm', + 'ogbg', + 'wmt', ] diff --git a/dataset/dataset_setup.py b/dataset/dataset_setup.py index 8fecaf419..de5e9d271 100644 --- a/dataset/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -72,8 +72,7 @@ from torchvision.datasets import CIFAR10 from algoperf.workloads.wmt import tokenizer -from algoperf.workloads.wmt.input_pipeline import \ - normalize_feature_names +from algoperf.workloads.wmt.input_pipeline import normalize_feature_names from dataset import librispeech_preprocess from dataset import librispeech_tokenizer @@ -111,38 +110,41 @@ 'files will be deleted.', ) flags.DEFINE_boolean( - 'all', - False, - 'Whether or not to download all datasets. If false, can download some ' - 'combination of datasets by setting the individual dataset flags below.') - -flags.DEFINE_boolean('criteo1tb', - False, - 'If --all=false, whether or not to download Criteo 1TB.') -flags.DEFINE_boolean('cifar', - False, - 'If --all=false, whether or not to download CIFAR-10.') -flags.DEFINE_boolean('fastmri', - False, - 'If --all=false, whether or not to download FastMRI.') -flags.DEFINE_boolean('finewebedu', - False, - 'If --all=false, whether or not to download FineWebEdu.') -flags.DEFINE_boolean('imagenet', - False, - 'If --all=false, whether or not to download Imagenet.') -flags.DEFINE_boolean('librispeech', - False, - 'If --all=false, whether or not to download LibriSpeech.') -flags.DEFINE_boolean('mnist', - False, - 'If --all=false, whether or not to download MNIST.') -flags.DEFINE_boolean('ogbg', - False, - 'If --all=false, whether or not to download OGBG.') -flags.DEFINE_boolean('wmt', - False, - 'If --all=false, whether or not to download WMT.') + 'all', + False, + 'Whether or not to download all datasets. If false, can download some ' + 'combination of datasets by setting the individual dataset flags below.', +) + +flags.DEFINE_boolean( + 'criteo1tb', False, 'If --all=false, whether or not to download Criteo 1TB.' +) +flags.DEFINE_boolean( + 'cifar', False, 'If --all=false, whether or not to download CIFAR-10.' +) +flags.DEFINE_boolean( + 'fastmri', False, 'If --all=false, whether or not to download FastMRI.' +) +flags.DEFINE_boolean( + 'finewebedu', False, 'If --all=false, whether or not to download FineWebEdu.' +) +flags.DEFINE_boolean( + 'imagenet', False, 'If --all=false, whether or not to download Imagenet.' +) +flags.DEFINE_boolean( + 'librispeech', + False, + 'If --all=false, whether or not to download LibriSpeech.', +) +flags.DEFINE_boolean( + 'mnist', False, 'If --all=false, whether or not to download MNIST.' +) +flags.DEFINE_boolean( + 'ogbg', False, 'If --all=false, whether or not to download OGBG.' +) +flags.DEFINE_boolean( + 'wmt', False, 'If --all=false, whether or not to download WMT.' +) flags.DEFINE_string( 'data_dir', @@ -199,7 +201,9 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') -flags.DEFINE_boolean('skip_tokenization', False, 'Skip Fineweb-edu tokenization.') +flags.DEFINE_boolean( + 'skip_tokenization', False, 'Skip Fineweb-edu tokenization.' +) FLAGS = flags.FLAGS @@ -773,30 +777,32 @@ def download_wmt(data_dir): ) -def download_finewebedu(data_dir, - tmp_dir=None, - skip_download=False, - skip_tokenization=False): +def download_finewebedu( + data_dir, tmp_dir=None, skip_download=False, skip_tokenization=False +): """Download FineWebEdu-10B.""" - if not skip_download: + if not skip_download: data_dir = os.path.join(data_dir, 'fineweb_edu_10B') tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' - cache_dir = os.path.join(tmp_dir, - 'lm') if tmp_dir is not None else os.path.expanduser( - '~/.cache/huggingface/datasets') + cache_dir = ( + os.path.join(tmp_dir, 'lm') + if tmp_dir is not None + else os.path.expanduser('~/.cache/huggingface/datasets') + ) _maybe_mkdir(data_dir) _maybe_mkdir(tmp_dir) _maybe_mkdir(cache_dir) - os.environ["TMPDIR"] = tmp_dir + os.environ['TMPDIR'] = tmp_dir ds = hf_datasets.load_dataset( - 'HuggingFaceFW/fineweb-edu', - name='sample-10BT', - split='train', - cache_dir=cache_dir) + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir, + ) ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) else: ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) @@ -804,10 +810,9 @@ def download_finewebedu(data_dir, if not skip_tokenization: # Tokenize lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') - logging.info(f"Vocab size of lm_tokenizer = {len(lm_tokenizer)}") + logging.info(f'Vocab size of lm_tokenizer = {len(lm_tokenizer)}') def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - def add_eos(seq): return seq + lm_tokenizer.eos_token if seq else seq @@ -815,33 +820,41 @@ def add_eos_batched(seqs): return [add_eos(seq) for seq in seqs] return lm_tokenizer( - add_eos_batched(examples["text"]), - return_special_tokens_mask=False, - return_attention_mask=False) + add_eos_batched(examples['text']), + return_special_tokens_mask=False, + return_attention_mask=False, + ) - lm_tokenizer.model_max_length = 1e30 # prevent truncation during tokenization - logging.info("Tokenizing...") + lm_tokenizer.model_max_length = ( + 1e30 # prevent truncation during tokenization + ) + logging.info('Tokenizing...') tokenized_dataset = ds.map( - tokenize, - remove_columns=[ - 'text', - 'id', - 'dump', - 'url', - 'file_path', - 'language', - 'language_score', - 'token_count', - 'score', - 'int_score' - ], - batched=True, - batch_size=1024, - num_proc=8) - - tokenized_dataset.save_to_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score', + ], + batched=True, + batch_size=1024, + num_proc=8, + ) + + tokenized_dataset.save_to_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) else: - tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized")) + tokenized_dataset = hf_datasets.load_from_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) # Convert to tensorflow_datasets.Dataset objects tokenized_dataset = tokenized_dataset.to_tf_dataset() @@ -854,10 +867,10 @@ def add_eos_batched(seqs): val_dataset = shuffled_dataset.skip(train_size) # Split in train and valid. - train_dataset.save(os.path.join(data_dir, "train")) - val_dataset.save(os.path.join(data_dir, "val")) + train_dataset.save(os.path.join(data_dir, 'train')) + val_dataset.save(os.path.join(data_dir, 'val')) - return + return def main(_): @@ -949,7 +962,9 @@ def main(_): if FLAGS.all or FLAGS.finewebedu: logging.info('Downloading FineWebEdu-10B...') - download_finewebedu(data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization) + download_finewebedu( + data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization + ) # pylint: enable=logging-format-interpolation diff --git a/submission_runner.py b/submission_runner.py index 1c50cd6d9..857d4479f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -253,12 +253,12 @@ def train_once( model_params, model_state = workload.init_model_fn(model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ - 'librispeech_conformer', - 'ogbg', - 'criteo1tb', - 'imagenet_vit', - 'librispeech_deepspeech', - 'lm' + 'librispeech_conformer', + 'ogbg', + 'criteo1tb', + 'imagenet_vit', + 'librispeech_deepspeech', + 'lm', ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -784,8 +784,10 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - limit_tf_threads = (base_workload != 'lm') - pytorch_init(USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads) + limit_tf_threads = base_workload != 'lm' + pytorch_init( + USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads + ) # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: @@ -797,11 +799,11 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] if base_workload in [ - 'librispeech_conformer', - 'librispeech_deepspeech', - 'imagenet_vit', - 'criteo1tb', - 'lm' + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb', + 'lm', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' From bb4a3809ee6978ce0cd0f7e5c15ba1aef2ca6a6e Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 04:33:40 +0000 Subject: [PATCH 70/98] Refactor loss function in LM workloads to unify label handling and improve clarity --- algoperf/workloads/lm/lm_jax/workload.py | 52 ++++++------- algoperf/workloads/lm/lm_pytorch/workload.py | 77 ++++++++++++-------- algoperf/workloads/lm/workload.py | 19 +++-- 3 files changed, 87 insertions(+), 61 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 5b736fad7..13738086a 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -87,35 +87,38 @@ def model_fn( logits = self._model.apply({'params': params}, inputs) return logits, None - def compute_weighted_cross_entropy( + def loss_fn( self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0, ) -> Dict[str, spec.Tensor]: # differentiable - """Compute weighted cross entropy and entropy for log probs and targets. + """Compute weighted cross entropy. + Args: - logits: [batch, length, num_classes] float array. - targets: categorical targets [batch, length] int array. - weights: array of shape [batch, length]. - label_smoothing: label smoothing constant, used to determine the on and off - values. + label_batch: categorical targets [batch, length] int array. + logits_batch: [batch, length, num_classes] float array. + mask_batch: weights array of shape [batch, length]. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + Returns: {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} + valid examples in batch, 'per_example': 2d array of per-example losses} """ - if logits.ndim != targets.ndim + 1: + if logits_batch.ndim != label_batch.ndim + 1: raise ValueError( - f'Incorrect shapes. Got shape {logits.shape} logits and ' - f'{targets.shape} targets.' + f'Incorrect shapes. Got shape {logits_batch.shape} logits and ' + f'{label_batch.shape} targets.' ) # Compute log probabilities - log_probs = jax.nn.log_softmax(logits, axis=-1) + log_probs = jax.nn.log_softmax(logits_batch, axis=-1) # Extract log probability of the target class # Shape: [batch, length] target_log_probs = jnp.take_along_axis( - log_probs, targets[..., None], axis=-1 + log_probs, label_batch[..., None], axis=-1 ).squeeze(-1) # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. @@ -124,11 +127,11 @@ def compute_weighted_cross_entropy( per_example_losses = -1.0 * ( confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1) ) - if weights is not None: - per_example_losses = jnp.where(weights, per_example_losses, 0.0) - n_valid_examples = weights.sum() + if mask_batch is not None: + per_example_losses = mask_batch * per_example_losses + n_valid_examples = mask_batch.sum() else: - n_valid_examples = targets.shape[0] * targets.shape[1] + n_valid_examples = label_batch.shape[0] * label_batch.shape[1] summed_loss = per_example_losses.sum() return { 'summed': summed_loss, @@ -147,12 +150,11 @@ def _eval_batch( logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False ) - # Calculate cross-entropy loss - metrics = self.compute_weighted_cross_entropy( - logits, batch['targets'], batch['weights'] + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], ) - # CRITICAL: Detach tensors to free computation graph and activations - # Without this, all intermediate activations are kept in memory! return { 'loss': metrics['summed'], 'denominator': metrics['n_valid_examples'], diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 115fae4f6..2f5c33ebf 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -133,42 +133,59 @@ def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name - # FIXME(rka97): Implement label smoothing - def compute_weighted_cross_entropy( + def loss_fn( self, - logits: spec.Tensor, - labels: spec.Tensor, - weights: spec.Tensor, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: spec.Tensor, label_smoothing: float = 0.0, ) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling in PyTorch.""" - vocab_size = logits.size(-1) - - if len(labels.shape) == len(logits.shape): - # One-hot labels - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - loss = -torch.sum(labels * log_probs, dim=-1) - else: - # Dense labels - loss = torch.nn.functional.cross_entropy( - logits.view(-1, vocab_size), labels.view(-1), reduction='none' - ) - loss = loss.view_as(labels) + """Compute weighted cross-entropy loss. + + Args: + label_batch: Target labels of shape [batch, length] (int). + logits_batch: Predicted logits of shape [batch, length, vocab_size] (float). + mask_batch: Optional weights of shape [batch, length] (float). Used to mask + out padding tokens or weight examples differently. If None, all examples + are weighted equally. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + + Returns: + Dictionary containing: + - 'summed': Scalar tensor with the sum of all weighted losses. + - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. + - 'per_example': Tensor of shape [batch, length] with individual losses per example. + """ + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) - if weights is not None: - loss = loss * weights + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch - n_valid = ( - weights.sum() - if weights is not None + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None else torch.tensor( - labels.numel(), dtype=torch.float32, device=labels.device + label_batch.numel(), dtype=torch.float32, device=label_batch.device ) ) + return { - 'summed': loss.sum(), - 'n_valid_examples': n_valid, - 'per_example': loss, + 'summed': per_example_losses.sum(), + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } def _eval_batch( @@ -182,8 +199,10 @@ def _eval_batch( logits, _ = self.model_fn( params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False ) - metrics = self.compute_weighted_cross_entropy( - logits, batch['targets'], batch['weights'] + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], ) return { 'loss': metrics['summed'].detach(), diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 43dd60ab5..56d9fabcc 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -125,6 +125,16 @@ def _build_input_queue( ) -> Iterator[Dict[str, Any]]: """Build an input queue for the given split.""" + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + eval_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[str, float]: + """Evaluate the model on a single batch.""" + def _eval_model_on_split( self, split: str, @@ -168,6 +178,7 @@ def _normalize_eval_metrics( ) -> Dict[str, float]: """Normalize eval metrics.""" + @abc.abstractmethod def loss_fn( self, label_batch: spec.Tensor, @@ -175,13 +186,7 @@ def loss_fn( mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0, ) -> Dict[str, spec.Tensor]: - """Compute cross-entropy loss for language modeling in JAX.""" - return self.compute_weighted_cross_entropy( - logits_batch, - label_batch, - weights=mask_batch, - label_smoothing=label_smoothing, - ) + """Compute cross-entropy loss for language modeling.""" def is_output_params(self, param_name: str) -> bool: """Return whether the given parameter is an output parameter.""" From a58fbd57ebd9fde597d96b3eba34f89929ffcab4 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 08:46:00 +0000 Subject: [PATCH 71/98] Fix init in both models to be the same, add lm model diff test --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 25 +- .../workloads/lm/lm_pytorch/plainlm_model.py | 20 +- tests/modeldiffs/lm/compare.py | 868 ++++++++++++++++++ 3 files changed, 893 insertions(+), 20 deletions(-) create mode 100644 tests/modeldiffs/lm/compare.py diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index a1644f569..1227e57b2 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -21,14 +21,17 @@ class DoConfig: N: int # number of transformer block layers V: int # vocab size F: int # FF inner dimension - kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() - embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 - ) + attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + use_residual_scaling: bool = True dtype: jnp.dtype = jnp.float32 rmsnorm_epsilon: float = 1e-6 multiple_of: int = 256 - tie_embeddings: bool = True # Whether to tie input and output embeddings + tie_embeddings: bool = True # Whether to tie input and output embed + + def __post_init__(self): + self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.N)) class Mlp(nn.Module): @@ -40,9 +43,8 @@ class Mlp(nn.Module): def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg # Use Xavier uniform initialization explicitly - xavier_init = nn.initializers.xavier_uniform() linear = partial( - nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype + nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype ) # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D @@ -55,7 +57,7 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) # Apply GLU activation x_BxLxF = nn.glu(x_BxLx2F, axis=-1) - x_BxLxD = linear(cfg.D)(x_BxLxF) + x_BxLxD = nn.Dense(cfg.D, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF) return x_BxLxD @@ -122,7 +124,7 @@ def setup(self): nn.DenseGeneral, axis=-1, features=(cfg.H, self.Dh), - kernel_init=cfg.kernel_init, + kernel_init=cfg.attention_init, use_bias=False, dtype=cfg.dtype, ) @@ -134,7 +136,7 @@ def setup(self): features=cfg.D, name='attn_out_proj', # axis=(-2, -1), # - kernel_init=cfg.kernel_init, + kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, use_bias=False, dtype=cfg.dtype, ) @@ -265,6 +267,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1): # Get the logits for the last token in each sequence next_token_logits = logits[:, -1, :] + last_token_id = y_BxL[:, -1] + # Prevent predicting the same token consecutively + next_token_logits = next_token_logits.at[jnp.arange(len(last_token_id)), last_token_id].set(float('-inf')) # Get the most likely token next_token = jnp.argmax(next_token_logits, axis=-1) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index f7e7f9e62..af4232b7e 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -23,6 +23,7 @@ class ModelConfig: n_heads: int rmsnorm_eps: float = 1e-6 tie_embeddings: bool = True + use_residual_scaling: bool = True class MLP(nn.Module): @@ -32,10 +33,8 @@ def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) self.fc2 = nn.Linear(hidden_dim, dim, bias=False) self.glu = nn.GLU(dim=2) - - # Initialize with Xavier uniform - nn.init.xavier_uniform_(self.fc1.weight) - nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.weight, std=0.02) + nn.init.normal_(self.fc2.weight, std=0.02) def forward(self, x): # x: (bsz, T, dim) @@ -89,6 +88,11 @@ def __init__(self, cfg: ModelConfig): self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + # Split into Q, K, V sections + wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) + for w in [wq, wk, wv]: + nn.init.normal_(w, std=0.02) + nn.init.normal_(self.w_out.weight, std=0.02) def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -254,15 +258,11 @@ def _init_weights(self, module): if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(module.weight, std=0.02) def _scale_residual_branches(self): for n, p in self.named_parameters(): - if n.endswith('fc2.weight'): # mlp/glu output layer - torch.nn.init.normal_( - p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) - ) - if n.endswith('w_out.weight'): # attn output layer + if n.endswith('fc2.weight') or n.endswith('w_out.weight'): # mlp/glu output layer torch.nn.init.normal_( p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) ) diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py new file mode 100644 index 000000000..5b95f934c --- /dev/null +++ b/tests/modeldiffs/lm/compare.py @@ -0,0 +1,868 @@ +""" +Test file to verify that JAX and PyTorch implementations produce identical outputs +when given the same weights and inputs. + +Tests are performed module-by-module: +1. RMSNorm +2. RoPE (Rotary Position Embeddings) +3. MLP +4. Attention +5. Transformer Block +6. Full Model +""" + +import os +import sys + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jax.numpy as jnp +import numpy as np +import torch +import torch.nn.functional as F +from absl import flags, logging +from absl.testing import absltest, parameterized + +# Import JAX implementation +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + CausalAttn, + DoConfig, + Mlp, + TBlock, + TransformerDo, + apply_rope, + init_rope, +) + +# Import PyTorch implementation +from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( + MLP, + Attention, + Block, + ModelConfig, + Transformer, + apply_rotary_emb_complex_like, + precompute_freqs_cis, +) + +FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +FLAGS(sys.argv) + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def assert_close(jax_output, torch_output, rtol=1e-5, atol=1e-6, name=''): + """Assert that JAX and PyTorch outputs are close.""" + jax_np = np.array(jax_output) + torch_np = torch_output.detach().cpu().numpy() + + mse = np.mean((jax_np - torch_np) ** 2) + max_diff = np.max(np.abs(jax_np - torch_np)) + + logging.info(f'\n{name} Comparison:') + logging.info(f' MSE: {mse:.8e}') + logging.info(f' Max Difference: {max_diff:.8e}') + + np.testing.assert_allclose( + jax_np, + torch_np, + rtol=rtol, + atol=atol, + err_msg=f'{name} outputs do not match', + ) + + +# ============================================================================ +# Test Functions (unchanged) +# ============================================================================ + + +def test_rmsnorm(): + """Test that RMSNorm produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RMSNorm') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + eps = 1e-6 + + # Create random input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + + # Initialize PyTorch RMSNorm + torch_norm = torch.nn.RMSNorm(dim, eps=eps) + torch_input = torch.tensor(np_input) + + # Initialize JAX RMSNorm (using Flax's RMSNorm from nanodo) + from flax import linen as nn + + flax_norm = nn.RMSNorm(epsilon=eps) + jax_input = jnp.array(np_input) + flax_params = flax_norm.init(jax.random.PRNGKey(0), jax_input) + + # Copy weights from PyTorch to JAX + with torch.no_grad(): + flax_params['params']['scale'] = jnp.array(torch_norm.weight.numpy()) + + # Forward pass + with torch.no_grad(): + torch_output = torch_norm(torch_input) + + jax_output = flax_norm.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='RMSNorm') + logging.info('✓ RMSNorm test passed') + + +def test_rope(): + """Test that RoPE produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RoPE (Rotary Position Embeddings)') + logging.info('=' * 70) + + batch_size, seq_len, n_heads, dim = 2, 16, 4, 128 + head_dim = dim // n_heads + + # Initialize RoPE + torch_freqs = precompute_freqs_cis(head_dim, seq_len, theta=500000) + jax_freqs = init_rope(dim, seq_len, n_heads) + + # Create random Q and K + np_q = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + np_k = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + + # PyTorch forward + torch_q = torch.tensor(np_q) + torch_k = torch.tensor(np_k) + with torch.no_grad(): + torch_q_rot, torch_k_rot = apply_rotary_emb_complex_like( + torch_q, torch_k, freqs_cis=torch_freqs + ) + + # JAX forward + jax_q = jnp.array(np_q) + jax_k = jnp.array(np_k) + jax_q_rot, jax_k_rot = apply_rope(jax_q, jax_k, jax_freqs) + + # Compare + assert_close(jax_q_rot, torch_q_rot, name='RoPE Q') + assert_close(jax_k_rot, torch_k_rot, name='RoPE K') + logging.info('✓ RoPE test passed') + + +def copy_mlp_params(pytorch_mlp, flax_params): + """Copy MLP parameters from PyTorch to JAX.""" + new_params = flax_params.copy() + + # Handle compiled models + if hasattr(pytorch_mlp, '_orig_mod'): + pytorch_mlp = pytorch_mlp._orig_mod + + # Copy fc1 and fc2 weights (transposed for JAX) + new_params['params']['Dense_0']['kernel'] = ( + pytorch_mlp.fc1.weight.detach().numpy().T + ) + new_params['params']['Dense_1']['kernel'] = ( + pytorch_mlp.fc2.weight.detach().numpy().T + ) + + return new_params + + +def test_mlp(): + """Test that MLP produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing MLP') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + hidden_dim = 1024 + + # Initialize PyTorch MLP + pytorch_mlp = MLP(dim=dim, hidden_dim=hidden_dim) + + # Initialize JAX MLP + cfg = DoConfig( + D=dim, + H=4, + L=128, + N=2, + V=1000, + F=hidden_dim, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_mlp = Mlp(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_mlp.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_mlp_params(pytorch_mlp, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_mlp(torch_input) + + jax_output = flax_mlp.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='MLP') + logging.info('✓ MLP test passed') + + +def copy_attention_params(pytorch_attn, flax_params): + """Copy attention parameters from PyTorch to JAX.""" + # Handle compiled models + if hasattr(pytorch_attn, '_orig_mod'): + pytorch_attn = pytorch_attn._orig_mod + + n_heads = pytorch_attn.n_heads + head_dim = pytorch_attn.head_dim + dim = pytorch_attn.dim + + # Split PyTorch's combined qkv weights + w_qkv = pytorch_attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + # Reshape for Flax's DenseGeneral format [D, H, Dh] + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + new_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': {'kernel': pytorch_attn.w_out.weight.detach().numpy().T}, + } + + return {'params': new_params} + + +def test_attention(): + """Test that Attention produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Attention') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + + # Initialize PyTorch Attention + config = ModelConfig( + vocab_size=1000, + seq_len=seq_len, + dim=dim, + expand=4.0, + n_layers=1, + n_heads=n_heads, + rmsnorm_eps=1e-6, + ) + pytorch_attn = Attention(config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Attention + cfg = DoConfig( + D=dim, + H=n_heads, + L=seq_len, + N=1, + V=1000, + F=1024, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_attn = CausalAttn(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_attn.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_attention_params(pytorch_attn, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_attn(torch_input, freqs_cis) + + jax_output = flax_attn.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Attention') + logging.info('✓ Attention test passed') + + +def copy_block_params(pytorch_block, flax_params): + """Copy block parameters from PyTorch to JAX.""" + # Copy attention parameters + attn_params = copy_attention_params(pytorch_block.attn, {'params': {}})[ + 'params' + ] + + # Copy MLP parameters + pytorch_mlp = pytorch_block.mlp + mlp_params = { + 'Dense_0': {'kernel': pytorch_mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_mlp.fc2.weight.detach().numpy().T}, + } + + # Copy RMSNorm parameters + norm_params = { + 'attn_norm': {'scale': pytorch_block.attn_norm.weight.detach().numpy()}, + 'mlp_norm': {'scale': pytorch_block.mlp_norm.weight.detach().numpy()}, + } + + return { + 'params': { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': norm_params['attn_norm'], + 'RMSNorm_1': norm_params['mlp_norm'], + } + } + + +def test_block(): + """Test that Transformer Block produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Transformer Block') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + expand = 4.0 + + # Initialize PyTorch Block + config = ModelConfig( + vocab_size=1000, + seq_len=seq_len, + dim=dim, + expand=expand, + n_layers=1, + n_heads=n_heads, + rmsnorm_eps=1e-6, + ) + pytorch_block = Block(layer_id=0, cfg=config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Block + cfg = DoConfig( + D=dim, + H=n_heads, + L=seq_len, + N=1, + V=1000, + F=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_block = TBlock(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_block.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_block_params(pytorch_block, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_block(torch_input, freqs_cis) + + jax_output = flax_block.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Block') + logging.info('✓ Block test passed') + + +def copy_full_model_params(pytorch_model, flax_params, config): + """Copy all parameters from PyTorch model to JAX model.""" + # Handle tied embeddings case + if hasattr(pytorch_model, '_orig_mod'): + pytorch_model = pytorch_model._orig_mod + + n_layers = config.n_layers + n_heads = config.n_heads + dim = config.dim + head_dim = dim // n_heads + + new_params = {'params': {}} + + # Copy embedding weights + new_params['params']['embed'] = { + 'embedding': pytorch_model.embed_tokens.weight.detach().numpy() + } + + # Copy each transformer block + for i in range(n_layers): + pytorch_block = pytorch_model.layers[i] + + # Attention params + w_qkv = pytorch_block.attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + attn_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': { + 'kernel': pytorch_block.attn.w_out.weight.detach().numpy().T + }, + } + + # MLP params + mlp_params = { + 'Dense_0': {'kernel': pytorch_block.mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_block.mlp.fc2.weight.detach().numpy().T}, + } + + # Norm params + attn_norm = {'scale': pytorch_block.attn_norm.weight.detach().numpy()} + mlp_norm = {'scale': pytorch_block.mlp_norm.weight.detach().numpy()} + + # Assemble block params + block_key = f'blocks_{i}' + new_params['params'][block_key] = { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': attn_norm, + 'RMSNorm_1': mlp_norm, + } + + # Copy output norm + new_params['params']['out_ln'] = { + 'scale': pytorch_model.out_norm.weight.detach().numpy() + } + + # Handle output projection (tied or untied) + if not config.tie_embeddings: + new_params['params']['output_proj'] = { + 'kernel': pytorch_model.lm_head.weight.detach().numpy().T + } + + return new_params + + +def test_full_model(): + """Test that full Transformer model produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Full Transformer Model') + logging.info('=' * 70) + + batch_size, seq_len = 2, 32 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + + # Initialize PyTorch model + pytorch_config = ModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + dim=dim, + expand=expand, + n_layers=n_layers, + n_heads=n_heads, + rmsnorm_eps=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = DoConfig( + D=dim, + H=n_heads, + L=seq_len, + N=n_layers, + V=vocab_size, + F=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Forward pass + with torch.no_grad(): + torch_logits = pytorch_model(torch_tokens) + + jax_logits = jax_model.apply(jax_params, jax_tokens) + + # Compare + assert_close( + jax_logits, torch_logits, rtol=1e-4, atol=1e-5, name='Full Model' + ) + logging.info('✓ Full Model test passed') + + +def test_prediction(): + """Test that autoregressive generation produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Autoregressive Prediction') + logging.info('=' * 70) + + batch_size, seq_len = 1, 10 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + k = 5 # Number of tokens to predict + + # Initialize PyTorch model + pytorch_config = ModelConfig( + vocab_size=vocab_size, + seq_len=seq_len + k, + dim=dim, + expand=expand, + n_layers=n_layers, + n_heads=n_heads, + rmsnorm_eps=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = DoConfig( + D=dim, + H=n_heads, + L=seq_len + k, + N=n_layers, + V=vocab_size, + F=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Predict k tokens + with torch.no_grad(): + _, torch_predictions = pytorch_model.predict(torch_tokens, k=k) + + _, jax_predictions = jax_model.apply( + jax_params, jax_tokens, k, method=jax_model.predict + ) + + # Compare predictions + torch_pred_np = torch_predictions.cpu().numpy() + jax_pred_np = np.array(jax_predictions) + + logging.info(f'\nPyTorch predictions: {torch_pred_np[0]}') + logging.info(f'JAX predictions: {jax_pred_np[0]}') + + # Check if predictions match exactly + if np.array_equal(torch_pred_np, jax_pred_np): + logging.info('✓ Predictions match exactly!') + else: + matching = np.sum(torch_pred_np == jax_pred_np) + total = torch_pred_np.size + logging.info( + f'⚠ Predictions differ: {matching}/{total} tokens match ({matching / total * 100:.1f}%)' + ) + logging.info( + ' (Note: Small numerical differences can lead to different argmax results)' + ) + + +def test_initialization_statistics(): + """Verify initialization follows expected distributions.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Statistics') + logging.info('=' * 70) + + # Initialize models + jax_cfg = DoConfig(D=512, H=8, L=1024, N=12, V=50000, F=2048) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) + ) + + pytorch_cfg = ModelConfig( + vocab_size=50000, seq_len=1024, dim=512, expand=4.0, n_layers=12, n_heads=8 + ) + pytorch_model = Transformer(pytorch_cfg) + + logging.info('Initialization Statistics Check:') + + # Check embedding + jax_embed = jax_params['params']['embed']['embedding'] + torch_embed = pytorch_model.embed_tokens.weight.detach().numpy() + + logging.info('\nToken Embedding (should be ~0.02 std):') + logging.info( + f' JAX: mean={jax_embed.mean():.6f}, std={jax_embed.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_embed.mean():.6f}, std={torch_embed.std():.6f}' + ) + + # Assert embedding std is close to 0.02 + assert abs(jax_embed.std() - 0.02) < 0.005, ( + f'JAX embedding std {jax_embed.std():.6f} not close to 0.02' + ) + assert abs(torch_embed.std() - 0.02) < 0.005, ( + f'PyTorch embedding std {torch_embed.std():.6f} not close to 0.02' + ) + assert abs(jax_embed.mean()) < 0.01, ( + f'JAX embedding mean {jax_embed.mean():.6f} not close to 0' + ) + assert abs(torch_embed.mean()) < 0.01, ( + f'PyTorch embedding mean {torch_embed.mean():.6f} not close to 0' + ) + + # Check first layer attention Q + jax_q = jax_params['params']['blocks_0']['CausalAttn_0']['query']['kernel'] + torch_q_weight = ( + pytorch_model.layers[0].attn.w_qkv.weight[:512].detach().numpy() + ) + + logging.info('\nAttention Q:') + logging.info(f' JAX: mean={jax_q.mean():.6f}, std={jax_q.std():.6f}') + logging.info( + f' PyTorch: mean={torch_q_weight.mean():.6f}, std={torch_q_weight.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_q.mean()) < 0.01, ( + f'JAX Q mean {jax_q.mean():.6f} not close to 0' + ) + assert abs(torch_q_weight.mean()) < 0.01, ( + f'PyTorch Q mean {torch_q_weight.mean():.6f} not close to 0' + ) + + # Check stds are similar + # Allow 20% difference due to random initialization + assert abs(jax_q.std() - torch_q_weight.std()) / torch_q_weight.std() < 0.2, ( + f'Q std differs too much: JAX {jax_q.std():.6f} vs PyTorch {torch_q_weight.std():.6f}' + ) + + # Check first layer attention output (should be scaled) + jax_attn_out = jax_params['params']['blocks_0']['CausalAttn_0'][ + 'attn_out_proj' + ]['kernel'] + torch_attn_out = pytorch_model.layers[0].attn.w_out.weight.detach().numpy() + + logging.info('\nAttention Output:') + logging.info( + f' JAX: mean={jax_attn_out.mean():.6f}, std={jax_attn_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_attn_out.mean():.6f}, std={torch_attn_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_attn_out.mean()) < 0.01, ( + f'JAX attn out mean {jax_attn_out.mean():.6f} not close to 0' + ) + assert abs(torch_attn_out.mean()) < 0.01, ( + f'PyTorch attn out mean {torch_attn_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_attn_out.std() - torch_attn_out.std()) / torch_attn_out.std() < 0.2 + ), ( + f'Attention output std differs too much: JAX {jax_attn_out.std():.6f} vs PyTorch {torch_attn_out.std():.6f}' + ) + + # Check MLP fc2 (should be scaled) + jax_mlp_out = jax_params['params']['blocks_0']['Mlp_0']['Dense_1']['kernel'] + torch_mlp_out = pytorch_model.layers[0].mlp.fc2.weight.detach().numpy() + + logging.info('\nMLP Output:') + logging.info( + f' JAX: mean={jax_mlp_out.mean():.6f}, std={jax_mlp_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_mlp_out.mean():.6f}, std={torch_mlp_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_mlp_out.mean()) < 0.01, ( + f'JAX MLP out mean {jax_mlp_out.mean():.6f} not close to 0' + ) + assert abs(torch_mlp_out.mean()) < 0.01, ( + f'PyTorch MLP out mean {torch_mlp_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_mlp_out.std() - torch_mlp_out.std()) / torch_mlp_out.std() < 0.2 + ), ( + f'MLP output std differs too much: JAX {jax_mlp_out.std():.6f} vs PyTorch {torch_mlp_out.std():.6f}' + ) + + logging.info('\n✓ Initialization statistics test passed') + + +def test_initialization_impact(): + """Test that initialization produces similar initial losses.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Impact') + logging.info('=' * 70) + + # Create identical inputs + batch_size, seq_len = 4, 128 + vocab_size = 50000 + + np.random.seed(42) + tokens = np.random.randint(0, vocab_size, size=(batch_size, seq_len)) + + # Initialize both models with same seed + jax_cfg = DoConfig(D=512, H=8, L=seq_len, N=12, V=vocab_size, F=2048) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.array(tokens, dtype=jnp.int32) + ) + + torch.manual_seed(42) + pytorch_cfg = ModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + dim=512, + expand=4.0, + n_layers=12, + n_heads=8, + ) + pytorch_model = Transformer(pytorch_cfg) + + # Forward pass + jax_logits = jax_model.apply(jax_params, jnp.array(tokens, dtype=jnp.int32)) + + with torch.no_grad(): + torch_logits = pytorch_model(torch.tensor(tokens, dtype=torch.long)) + + # Compute losses + targets = tokens[:, 1:] + jax_loss = -jax.nn.log_softmax(jax_logits[:, :-1]).mean() + torch_loss = F.cross_entropy( + torch_logits[:, :-1].reshape(-1, vocab_size), + torch.tensor(targets.reshape(-1), dtype=torch.long), + ) + + logging.info('\nInitial Loss Comparison:') + logging.info(f' JAX: {jax_loss:.4f}') + logging.info(f' PyTorch: {torch_loss.item():.4f}') + logging.info(f' Difference: {abs(jax_loss - torch_loss.item()):.6f}') + + # Check that losses are in reasonable range for random init + # With vocab_size=50000, random init should give loss around log(50000) ≈ 10.82 + expected_loss = np.log(vocab_size) + + assert 8.0 < jax_loss < 13.0, ( + f'JAX loss {jax_loss:.4f} outside expected range [8.0, 13.0]' + ) + assert 8.0 < torch_loss.item() < 13.0, ( + f'PyTorch loss {torch_loss.item():.4f} outside expected range [8.0, 13.0]' + ) + + # Both losses should be within 10% of log(vocab_size) + assert abs(jax_loss - expected_loss) / expected_loss < 0.1, ( + f'JAX loss {jax_loss:.4f} too far from expected {expected_loss:.4f}' + ) + assert abs(torch_loss.item() - expected_loss) / expected_loss < 0.1, ( + f'PyTorch loss {torch_loss.item():.4f} too far from expected {expected_loss:.4f}' + ) + + logging.info( + '\nNote: Losses are in expected range for random initialization.' + ) + logging.info(f' Expected ~log(vocab_size) = {expected_loss:.4f}') + logging.info('\n✓ Initialization impact test passed') + + +# ============================================================================ +# Test Class +# ============================================================================ + +named_parameters = [ + dict(testcase_name='rmsnorm', test_fn=test_rmsnorm), + dict(testcase_name='rope', test_fn=test_rope), + dict(testcase_name='mlp', test_fn=test_mlp), + dict(testcase_name='attention', test_fn=test_attention), + dict(testcase_name='block', test_fn=test_block), + dict(testcase_name='full_model', test_fn=test_full_model), + dict(testcase_name='prediction', test_fn=test_prediction), + dict( + testcase_name='initialization_statistics', + test_fn=test_initialization_statistics, + ), + dict( + testcase_name='initialization_impact', test_fn=test_initialization_impact + ), +] + + +class ModelMatchingTest(parameterized.TestCase): + """Tests for JAX vs PyTorch model matching.""" + + @parameterized.named_parameters(*named_parameters) + def test_model_matching(self, test_fn): + """Run individual model matching test.""" + test_fn() + + +if __name__ == '__main__': + absltest.main() From b59afa0120f98e7aecc04b3393addb2acbdafe23 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 21 Oct 2025 09:06:29 +0000 Subject: [PATCH 72/98] Refactor model configuration classes to make them consistent between JAX and PyTorch, also unify initialization to be the same in both --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 79 ++++---- algoperf/workloads/lm/lm_jax/workload.py | 16 +- .../workloads/lm/lm_pytorch/plainlm_model.py | 61 +++---- algoperf/workloads/lm/lm_pytorch/workload.py | 11 +- tests/modeldiffs/lm/compare.py | 169 ++++++++++-------- 5 files changed, 180 insertions(+), 156 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index 1227e57b2..2b47c1735 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -12,32 +12,33 @@ @dataclasses.dataclass -class DoConfig: +class ModelConfig: """Hyper-parameters for Transformer decoder-only.""" - D: int # model/embed dim = qkv dim - H: int # num attention heads - L: int # max context/sequence length - N: int # number of transformer block layers - V: int # vocab size - F: int # FF inner dimension + model_dim: int # model/embed dim = qkv dim + num_heads: int # num attention heads + seq_len: int # max context/sequence length + num_layers: int # number of transformer block layers + vocab_size: int # vocab size + expanded_model_dim: int # FF inner dimension + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 + use_residual_scaling: bool = True + tie_embeddings: bool = True # Whether to tie input and output embed + + dtype: jnp.dtype = jnp.float32 attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) - use_residual_scaling: bool = True - dtype: jnp.dtype = jnp.float32 - rmsnorm_epsilon: float = 1e-6 - multiple_of: int = 256 - tie_embeddings: bool = True # Whether to tie input and output embed def __post_init__(self): - self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.N)) + self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.num_layers)) class Mlp(nn.Module): """Multilayer perceptron with GLU activation.""" - cfg: DoConfig + cfg: ModelConfig @nn.compact def __call__(self, x_BxLxD: jax.Array): @@ -49,15 +50,15 @@ def __call__(self, x_BxLxD: jax.Array): # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D # parameters instead of 2 * hidden_dim * D parameters - hidden_dim = cfg.F * 2 / 3 + hidden_dim = cfg.expanded_model_dim * 2 / 3 hidden_dim = cfg.multiple_of * ( - (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of + (cfg.expanded_model_dim + cfg.multiple_of - 1) // cfg.multiple_of ) # Double the hidden dimension for GLU x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) # Apply GLU activation x_BxLxF = nn.glu(x_BxLx2F, axis=-1) - x_BxLxD = nn.Dense(cfg.D, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF) + x_BxLxD = nn.Dense(cfg.model_dim, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF) return x_BxLxD @@ -109,21 +110,21 @@ def rotate_tensor(x): class CausalAttn(nn.Module): """Causal attention layer with rotary embeddings.""" - cfg: DoConfig + cfg: ModelConfig def setup(self): cfg = self.cfg - assert cfg.D % cfg.H == 0, f'D {cfg.D} not divisible by H {cfg.H}' - self.Dh = cfg.D // cfg.H + assert cfg.model_dim % cfg.num_heads == 0, f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + self.Dh = cfg.model_dim // cfg.num_heads # Initialize rotary embeddings - self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) + self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads) # Maps D -> (H, Dh) self.multilinear = partial( nn.DenseGeneral, axis=-1, - features=(cfg.H, self.Dh), + features=(cfg.num_heads, self.Dh), kernel_init=cfg.attention_init, use_bias=False, dtype=cfg.dtype, @@ -133,7 +134,7 @@ def setup(self): self.multilinear_key = self.multilinear(name='key') self.multilinear_value = self.multilinear(name='value') self.output_projection = nn.DenseGeneral( - features=cfg.D, + features=cfg.model_dim, name='attn_out_proj', # axis=(-2, -1), # kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, @@ -183,7 +184,7 @@ def __call__(self, x_BxLxD: jax.Array): class TBlock(nn.Module): """Transformer Block.""" - docfg: DoConfig + docfg: ModelConfig @nn.compact def __call__(self, in_BxLxD: jax.Array): @@ -208,17 +209,17 @@ def __call__(self, in_BxLxD: jax.Array): class TransformerDo(nn.Module): """Transformer decoder-only.""" - docfg: DoConfig + docfg: ModelConfig def setup(self): cfg = self.docfg self.embed = nn.Embed( - num_embeddings=cfg.V, - features=cfg.D, + num_embeddings=cfg.vocab_size, + features=cfg.model_dim, embedding_init=cfg.embed_init, ) - self.blocks = [TBlock(cfg) for _ in range(cfg.N)] + self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) # Output projection - tied to input embeddings if configured @@ -226,7 +227,7 @@ def setup(self): self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) else: self.output_proj = nn.Dense( - cfg.V, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' + cfg.vocab_size, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' ) def __call__(self, y_BxL: jax.Array): @@ -255,9 +256,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1): original_input = y_BxL # Make sure we don't exceed the model's context length - if seq_len + k > cfg.L: + if seq_len + k > cfg.seq_len: raise ValueError( - f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.seq_len})" ) # Generate k tokens autoregressively @@ -288,17 +289,17 @@ def main(): """Create and run the DecoderOnly Transformer model.""" # Initialize model configuration with smaller parameters for demo B, L = (2, 128) # Batch size, sequence length - cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) + cfg = ModelConfig(model_dim=128, num_heads=4, seq_len=L, num_layers=2, vocab_size=256, expanded_model_dim=4 * 128) model = TransformerDo(cfg) # Print model info print('\nModel Configuration:') - print(f' - Model dimension (D): {cfg.D}') - print(f' - Number of heads (H): {cfg.H}') - print(f' - Max sequence length (L): {cfg.L}') - print(f' - Number of layers (N): {cfg.N}') - print(f' - Vocabulary size (V): {cfg.V}') - print(f' - Feed forward dimension (F): {cfg.F}') + print(f' - Model dimension (D): {cfg.model_dim}') + print(f' - Number of heads (H): {cfg.num_heads}') + print(f' - Max sequence length (L): {cfg.seq_len}') + print(f' - Number of layers (N): {cfg.num_layers}') + print(f' - Vocabulary size (V): {cfg.vocab_size}') + print(f' - Feed forward dimension (F): {cfg.expanded_model_dim}') # Create random input tokens (simulated token IDs) rng_key = jax.random.PRNGKey(42) @@ -306,7 +307,7 @@ def main(): # Generate random token IDs (integers between 0 and vocab_size-1) x_BxL = jax.random.randint( - input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 + input_rng, shape=(B, L), minval=0, maxval=cfg.vocab_size, dtype=jnp.int32 ) # Initialize model parameters diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index 13738086a..effb12089 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -8,7 +8,7 @@ from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter from algoperf.workloads.lm.lm_jax.nanodo_model import ( - DoConfig, + ModelConfig, TransformerDo, ) from algoperf.workloads.lm.workload import BaseLmWorkload @@ -46,13 +46,13 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None, ) -> spec.ModelInitState: # Initialize NanoDO transformer model - cfg = DoConfig( - D=self._emb_dim, # embedding dim - H=self._n_heads, # num heads - L=self._seq_len, - N=self._n_layers, # num layers - V=self._vocab_size, - F=self._mlp_dim, # feedforward dim + cfg = ModelConfig( + model_dim=self._emb_dim, # embedding dim + num_heads=self._n_heads, # num heads + seq_len=self._seq_len, + num_layers=self._n_layers, # num layers + vocab_size=self._vocab_size, + expanded_model_dim=self._mlp_dim, # feedforward dim dtype=jnp.float32, ) self._model = TransformerDo(cfg) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index af4232b7e..8186638e7 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -15,15 +15,16 @@ @dataclass class ModelConfig: - vocab_size: int + model_dim: int + num_heads: int seq_len: int - dim: int - expand: float - n_layers: int - n_heads: int - rmsnorm_eps: float = 1e-6 - tie_embeddings: bool = True + num_layers: int + vocab_size: int + expanded_model_dim: int + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 use_residual_scaling: bool = True + tie_embeddings: bool = True class MLP(nn.Module): @@ -81,13 +82,13 @@ def apply_rotary_emb_complex_like( class Attention(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() - assert cfg.dim % cfg.n_heads == 0 - self.dim = cfg.dim - self.n_heads = cfg.n_heads - self.head_dim = cfg.dim // cfg.n_heads + assert cfg.model_dim % cfg.num_heads == 0 + self.dim = cfg.model_dim + self.n_heads = cfg.num_heads + self.head_dim = cfg.model_dim // cfg.num_heads - self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) - self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) + self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) + self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: @@ -131,9 +132,9 @@ class Block(nn.Module): def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) - self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) + self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.mlp = MLP(dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of) + self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id def forward(self, x, freqs_cis): @@ -144,19 +145,19 @@ def forward(self, x, freqs_cis): class Transformer(nn.Module): - def __init__(self, cfg): + def __init__(self, cfg: ModelConfig): super().__init__() - self.n_layers = cfg.n_layers + self.n_layers = cfg.num_layers self.cfg = cfg - head_dim = cfg.dim // cfg.n_heads - assert cfg.dim % cfg.n_heads == 0 + head_dim = cfg.model_dim // cfg.num_heads + assert cfg.model_dim % cfg.num_heads == 0 - self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) self.layers = nn.ModuleList( - [Block(idx, cfg) for idx in range(cfg.n_layers)] + [Block(idx, cfg) for idx in range(cfg.num_layers)] ) - self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) - self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) # Initialize freqs_cis on CPU first (more memory efficient) self.register_buffer( @@ -184,7 +185,7 @@ def forward(self, x, targets=None): # Make sure we have enough precomputed frequencies if L > self.freqs_cis.shape[1]: # Need to recompute for longer sequence - head_dim = self.cfg.dim // self.cfg.n_heads + head_dim = self.cfg.model_dim // self.cfg.num_heads new_freqs = precompute_freqs_cis( head_dim, max(L, self.cfg.seq_len), 500000 ) @@ -290,11 +291,11 @@ def main(): config = ModelConfig( vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece seq_len=seq_length, # Maximum sequence length - dim=1024, # Embedding dimension - expand=4.0, # MLP expansion factor - n_layers=12, # Number of transformer layers - n_heads=8, # Number of attention heads - rmsnorm_eps=1e-6, # RMSNorm epsilon + model_dim=1024, # Embedding dimension + expanded_model_dim=4.0, # MLP expansion factor + num_layers=12, # Number of transformer layers + num_heads=8, # Number of attention heads + rmsnorm_epsilon=1e-6, # RMSNorm epsilon tie_embeddings=True, # Tie embedding and output weights ) diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 2f5c33ebf..3d185636b 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -39,12 +39,11 @@ def init_model_fn( cfg = ModelConfig( vocab_size=self._vocab_size, seq_len=self._seq_len, - dim=self._emb_dim, # Model dimension - expand=self._mlp_dim // self._emb_dim, # MLP expansion factor - # FIXME(rka97): fix expansion factor - n_layers=self._n_layers, # Number of transformer layers - n_heads=self._n_heads, # Number of attention heads - rmsnorm_eps=1e-6, + model_dim=self._emb_dim, # Model dimension + expanded_model_dim=self._mlp_dim, # MLP expansion factor + num_layers=self._n_layers, # Number of transformer layers + num_heads=self._n_heads, # Number of attention heads + rmsnorm_epsilon=1e-6, tie_embeddings=True, ) self._model = Transformer(cfg) diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py index 5b95f934c..f681597d8 100644 --- a/tests/modeldiffs/lm/compare.py +++ b/tests/modeldiffs/lm/compare.py @@ -28,24 +28,28 @@ # Import JAX implementation from algoperf.workloads.lm.lm_jax.nanodo_model import ( CausalAttn, - DoConfig, Mlp, TBlock, TransformerDo, apply_rope, init_rope, ) +from algoperf.workloads.lm.lm_jax.nanodo_model import ( + ModelConfig as JaxModelConfig, +) # Import PyTorch implementation from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( MLP, Attention, Block, - ModelConfig, Transformer, apply_rotary_emb_complex_like, precompute_freqs_cis, ) +from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( + ModelConfig as PyTorchModelConfig, +) FLAGS = flags.FLAGS # Needed to avoid UnparsedFlagAccessError @@ -192,13 +196,13 @@ def test_mlp(): pytorch_mlp = MLP(dim=dim, hidden_dim=hidden_dim) # Initialize JAX MLP - cfg = DoConfig( - D=dim, - H=4, - L=128, - N=2, - V=1000, - F=hidden_dim, + cfg = JaxModelConfig( + model_dim=dim, + num_heads=4, + seq_len=128, + num_layers=2, + vocab_size=1000, + expanded_model_dim=hidden_dim, dtype=jnp.float32, rmsnorm_epsilon=1e-6, ) @@ -266,26 +270,26 @@ def test_attention(): batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 # Initialize PyTorch Attention - config = ModelConfig( + config = PyTorchModelConfig( vocab_size=1000, seq_len=seq_len, - dim=dim, - expand=4.0, - n_layers=1, - n_heads=n_heads, - rmsnorm_eps=1e-6, + model_dim=dim, + expanded_model_dim=1024, + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, ) pytorch_attn = Attention(config) freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) # Initialize JAX Attention - cfg = DoConfig( - D=dim, - H=n_heads, - L=seq_len, - N=1, - V=1000, - F=1024, + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=1024, dtype=jnp.float32, rmsnorm_epsilon=1e-6, ) @@ -354,26 +358,26 @@ def test_block(): expand = 4.0 # Initialize PyTorch Block - config = ModelConfig( + config = PyTorchModelConfig( vocab_size=1000, seq_len=seq_len, - dim=dim, - expand=expand, - n_layers=1, - n_heads=n_heads, - rmsnorm_eps=1e-6, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, ) pytorch_block = Block(layer_id=0, cfg=config) freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) # Initialize JAX Block - cfg = DoConfig( - D=dim, - H=n_heads, - L=seq_len, - N=1, - V=1000, - F=int(dim * expand), + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=int(dim * expand), dtype=jnp.float32, rmsnorm_epsilon=1e-6, ) @@ -408,9 +412,9 @@ def copy_full_model_params(pytorch_model, flax_params, config): if hasattr(pytorch_model, '_orig_mod'): pytorch_model = pytorch_model._orig_mod - n_layers = config.n_layers - n_heads = config.n_heads - dim = config.dim + n_layers = config.num_layers + n_heads = config.num_heads + dim = config.model_dim head_dim = dim // n_heads new_params = {'params': {}} @@ -489,27 +493,27 @@ def test_full_model(): expand = 4.0 # Initialize PyTorch model - pytorch_config = ModelConfig( + pytorch_config = PyTorchModelConfig( vocab_size=vocab_size, seq_len=seq_len, - dim=dim, - expand=expand, - n_layers=n_layers, - n_heads=n_heads, - rmsnorm_eps=1e-6, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, tie_embeddings=True, ) pytorch_model = Transformer(pytorch_config) pytorch_model.eval() # Initialize JAX model - jax_config = DoConfig( - D=dim, - H=n_heads, - L=seq_len, - N=n_layers, - V=vocab_size, - F=int(dim * expand), + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), dtype=jnp.float32, rmsnorm_epsilon=1e-6, tie_embeddings=True, @@ -557,27 +561,27 @@ def test_prediction(): k = 5 # Number of tokens to predict # Initialize PyTorch model - pytorch_config = ModelConfig( + pytorch_config = PyTorchModelConfig( vocab_size=vocab_size, seq_len=seq_len + k, - dim=dim, - expand=expand, - n_layers=n_layers, - n_heads=n_heads, - rmsnorm_eps=1e-6, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, tie_embeddings=True, ) pytorch_model = Transformer(pytorch_config) pytorch_model.eval() # Initialize JAX model - jax_config = DoConfig( - D=dim, - H=n_heads, - L=seq_len + k, - N=n_layers, - V=vocab_size, - F=int(dim * expand), + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len + k, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), dtype=jnp.float32, rmsnorm_epsilon=1e-6, tie_embeddings=True, @@ -633,14 +637,26 @@ def test_initialization_statistics(): logging.info('=' * 70) # Initialize models - jax_cfg = DoConfig(D=512, H=8, L=1024, N=12, V=50000, F=2048) + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=1024, + num_layers=12, + vocab_size=50000, + expanded_model_dim=2048, + dtype=jnp.float32) jax_model = TransformerDo(jax_cfg) jax_params = jax_model.init( jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) ) - pytorch_cfg = ModelConfig( - vocab_size=50000, seq_len=1024, dim=512, expand=4.0, n_layers=12, n_heads=8 + pytorch_cfg = PyTorchModelConfig( + vocab_size=50000, + seq_len=1024, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, ) pytorch_model = Transformer(pytorch_cfg) @@ -771,20 +787,27 @@ def test_initialization_impact(): tokens = np.random.randint(0, vocab_size, size=(batch_size, seq_len)) # Initialize both models with same seed - jax_cfg = DoConfig(D=512, H=8, L=seq_len, N=12, V=vocab_size, F=2048) + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=seq_len, + num_layers=12, + vocab_size=vocab_size, + expanded_model_dim=2048, + ) jax_model = TransformerDo(jax_cfg) jax_params = jax_model.init( jax.random.PRNGKey(42), jnp.array(tokens, dtype=jnp.int32) ) torch.manual_seed(42) - pytorch_cfg = ModelConfig( + pytorch_cfg = PyTorchModelConfig( vocab_size=vocab_size, seq_len=seq_len, - dim=512, - expand=4.0, - n_layers=12, - n_heads=8, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, ) pytorch_model = Transformer(pytorch_cfg) From d35cddebdb2f62f49665313a79188510684c12df Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 23 Oct 2025 17:00:58 +0000 Subject: [PATCH 73/98] Add query-key normalization to CausalAttn and Attention classes, including learned scaling factor --- algoperf/workloads/lm/lm_jax/nanodo_model.py | 61 +++++++++++++++---- .../workloads/lm/lm_pytorch/plainlm_model.py | 26 ++++++-- tests/modeldiffs/lm/compare.py | 3 +- 3 files changed, 72 insertions(+), 18 deletions(-) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/nanodo_model.py index 2b47c1735..d08e9b7bf 100644 --- a/algoperf/workloads/lm/lm_jax/nanodo_model.py +++ b/algoperf/workloads/lm/lm_jax/nanodo_model.py @@ -25,14 +25,19 @@ class ModelConfig: rmsnorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True # Whether to tie input and output embed + qknorm_epsilon: float = 1e-6 dtype: jnp.dtype = jnp.float32 - attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + attention_init: nn.initializers.Initializer = nn.initializers.normal( + stddev=0.02 + ) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) def __post_init__(self): - self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.num_layers)) + self.residual_init = nn.initializers.normal( + stddev=0.02 / jnp.sqrt(2 * self.num_layers) + ) class Mlp(nn.Module): @@ -43,7 +48,6 @@ class Mlp(nn.Module): @nn.compact def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg - # Use Xavier uniform initialization explicitly linear = partial( nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype ) @@ -58,7 +62,14 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) # Apply GLU activation x_BxLxF = nn.glu(x_BxLx2F, axis=-1) - x_BxLxD = nn.Dense(cfg.model_dim, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF) + x_BxLxD = nn.Dense( + cfg.model_dim, + use_bias=False, + dtype=cfg.dtype, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + )(x_BxLxF) return x_BxLxD @@ -114,8 +125,11 @@ class CausalAttn(nn.Module): def setup(self): cfg = self.cfg - assert cfg.model_dim % cfg.num_heads == 0, f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + assert cfg.model_dim % cfg.num_heads == 0, ( + f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + ) self.Dh = cfg.model_dim // cfg.num_heads + self.eps = cfg.qknorm_epsilon # Initialize rotary embeddings self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads) @@ -129,15 +143,22 @@ def setup(self): use_bias=False, dtype=cfg.dtype, ) - self.multilinear_query = self.multilinear(name='query') self.multilinear_key = self.multilinear(name='key') self.multilinear_value = self.multilinear(name='value') + # See Henry et al. (2020) "Query Key Normalization for Transformers" + seq_len = cfg.seq_len + attn_scale0 = jnp.log2(seq_len**2 - seq_len) + self.attn_scale = self.param( + 'attn_scale', nn.initializers.constant(attn_scale0), () + ) self.output_projection = nn.DenseGeneral( features=cfg.model_dim, name='attn_out_proj', # axis=(-2, -1), # - kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, use_bias=False, dtype=cfg.dtype, ) @@ -153,8 +174,9 @@ def __call__(self, x_BxLxD: jax.Array): # Apply rotary embeddings to Q and K q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) - # Scale queries - q_BxLxHxDh /= self.Dh**0.5 + # Apply QK normalization + q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps + k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps # Compute attention scores att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) @@ -166,6 +188,9 @@ def __call__(self, x_BxLxD: jax.Array): # Apply mask and softmax _NEG_INF = jnp.finfo(cfg.dtype).min att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = ( + self.attn_scale * att_BxHxLxL + ) # Learned scaling factor for QK norm att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) @@ -227,7 +252,10 @@ def setup(self): self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) else: self.output_proj = nn.Dense( - cfg.vocab_size, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj' + cfg.vocab_size, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name='output_proj', ) def __call__(self, y_BxL: jax.Array): @@ -270,7 +298,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1): next_token_logits = logits[:, -1, :] last_token_id = y_BxL[:, -1] # Prevent predicting the same token consecutively - next_token_logits = next_token_logits.at[jnp.arange(len(last_token_id)), last_token_id].set(float('-inf')) + next_token_logits = next_token_logits.at[ + jnp.arange(len(last_token_id)), last_token_id + ].set(float('-inf')) # Get the most likely token next_token = jnp.argmax(next_token_logits, axis=-1) @@ -289,7 +319,14 @@ def main(): """Create and run the DecoderOnly Transformer model.""" # Initialize model configuration with smaller parameters for demo B, L = (2, 128) # Batch size, sequence length - cfg = ModelConfig(model_dim=128, num_heads=4, seq_len=L, num_layers=2, vocab_size=256, expanded_model_dim=4 * 128) + cfg = ModelConfig( + model_dim=128, + num_heads=4, + seq_len=L, + num_layers=2, + vocab_size=256, + expanded_model_dim=4 * 128, + ) model = TransformerDo(cfg) # Print model info diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py index 8186638e7..edee8318c 100644 --- a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py +++ b/algoperf/workloads/lm/lm_pytorch/plainlm_model.py @@ -23,6 +23,7 @@ class ModelConfig: expanded_model_dim: int multiple_of: int = 256 rmsnorm_epsilon: float = 1e-6 + qknorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True @@ -92,9 +93,14 @@ def __init__(self, cfg: ModelConfig): # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: - nn.init.normal_(w, std=0.02) + nn.init.normal_(w, std=0.02) nn.init.normal_(self.w_out.weight, std=0.02) + self.eps = cfg.qknorm_epsilon # e.g., 1e-6 + seq_len = cfg.seq_len + attn_scale0 = math.log2(seq_len**2 - seq_len) + self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) + def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -117,10 +123,14 @@ def forward(self, x, freqs_cis): k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + # Apply QK normalization + q = q / torch.norm(q, dim=-1, keepdim=True) + self.eps + k = k / torch.norm(k, dim=-1, keepdim=True) + self.eps + q *= self.attn_scale + out = F.scaled_dot_product_attention( - q, k, v, is_causal=True + q, k, v, is_causal=True, scale=1.0 ) # (bsz, nh, seqlen, h_dim) - out = ( out.transpose(1, 2).contiguous().view(bsz, seqlen, d) ) # (bsz, seqlen, d) @@ -133,7 +143,11 @@ def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) - self.mlp = MLP(dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of) + self.mlp = MLP( + dim=cfg.model_dim, + hidden_dim=cfg.expanded_model_dim, + multiple_of=cfg.multiple_of, + ) self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id @@ -263,7 +277,9 @@ def _init_weights(self, module): def _scale_residual_branches(self): for n, p in self.named_parameters(): - if n.endswith('fc2.weight') or n.endswith('w_out.weight'): # mlp/glu output layer + if n.endswith('fc2.weight') or n.endswith( + 'w_out.weight' + ): # mlp/glu output layer torch.nn.init.normal_( p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) ) diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py index f681597d8..e1d85eba7 100644 --- a/tests/modeldiffs/lm/compare.py +++ b/tests/modeldiffs/lm/compare.py @@ -644,7 +644,8 @@ def test_initialization_statistics(): num_layers=12, vocab_size=50000, expanded_model_dim=2048, - dtype=jnp.float32) + dtype=jnp.float32, + ) jax_model = TransformerDo(jax_cfg) jax_params = jax_model.init( jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) From ffb816329d1a9f5272956a8ad04ba2e307401ee2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 24 Oct 2025 19:46:11 +0000 Subject: [PATCH 74/98] update target --- algoperf/workloads/lm/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index e0af589e3..79f65040c 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -42,7 +42,7 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - return 25.5477 # Target perplexity + return 22.432 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return True # No test targets From 202e5cb79e237178d47fdb391fc29c4c4fbc3b8a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sun, 26 Oct 2025 23:58:00 +0000 Subject: [PATCH 75/98] add pytorch nadamw_target_setting --- .../pytorch_nadamw_target_setting.py | 403 ++++++++++++++++++ submission_runner.py | 2 +- 2 files changed, 404 insertions(+), 1 deletion(-) create mode 100644 algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..196c1f809 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py @@ -0,0 +1,403 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step'] + ) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) + + return loss + + +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + step_hint = step_hint * 0.75 + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + dropout_rate=hyperparameters.dropout_rate, + ) + + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip + ) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + elif workload_name == 'lm': + return 64 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/submission_runner.py b/submission_runner.py index 857d4479f..1bb763cf2 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -258,7 +258,6 @@ def train_once( 'criteo1tb', 'imagenet_vit', 'librispeech_deepspeech', - 'lm', ] eager_backend_workloads = [] aot_eager_backend_workloads = [] @@ -267,6 +266,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'lm' ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: From 98e491ad712ba2be1d17452f67be877d2abaf679 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 27 Oct 2025 00:42:12 +0000 Subject: [PATCH 76/98] docker updates for a100 --- docker/scripts/startup.sh | 6 +++--- scoring/utils/run_workloads.py | 1 + scoring/utils/workload_metadata_external_tuning.json | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 35ac30461..d92107e90 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -174,7 +174,7 @@ fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ - "wmt" "mnist") + "wmt" "mnist" "fineweb_edu_10B") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ @@ -185,7 +185,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" - "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "lm") VALID_RULESETS=("self" "external") # Set data and experiment paths @@ -221,7 +221,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" if [[ "${FRAMEWORK}" == "jax" ]]; then COMMAND_PREFIX="python" else - COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" + COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0 --standalone --nnodes=1 --nproc_per_node=4" fi # Set data directory and bucket (bucket is only relevant in internal mode) diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 273881c5a..c76ef6e32 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -270,6 +270,7 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency' f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c7d4ae195..5138e9acf 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -30,5 +30,9 @@ "librispeech_conformer": { "max_steps": 80000, "dataset": "librispeech" + }, + "lm" : { + "max_steps": 55000, + "dataset":"fineweb_edu_10B" } } From 02f835d0961227e56750360e73c7a8ed4213801b Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 6 Nov 2025 16:59:54 +0000 Subject: [PATCH 77/98] rename models.py --- .../workloads/lm/lm_jax/{nanodo_model.py => models.py} | 0 algoperf/workloads/lm/lm_jax/workload.py | 2 +- .../lm/lm_pytorch/{plainlm_model.py => models.py} | 0 algoperf/workloads/lm/lm_pytorch/workload.py | 2 +- tests/modeldiffs/lm/compare.py | 8 ++++---- 5 files changed, 6 insertions(+), 6 deletions(-) rename algoperf/workloads/lm/lm_jax/{nanodo_model.py => models.py} (100%) rename algoperf/workloads/lm/lm_pytorch/{plainlm_model.py => models.py} (100%) diff --git a/algoperf/workloads/lm/lm_jax/nanodo_model.py b/algoperf/workloads/lm/lm_jax/models.py similarity index 100% rename from algoperf/workloads/lm/lm_jax/nanodo_model.py rename to algoperf/workloads/lm/lm_jax/models.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/lm/lm_jax/workload.py index effb12089..3862b73dc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/lm/lm_jax/workload.py @@ -7,7 +7,7 @@ from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter -from algoperf.workloads.lm.lm_jax.nanodo_model import ( +from algoperf.workloads.lm.lm_jax.models import ( ModelConfig, TransformerDo, ) diff --git a/algoperf/workloads/lm/lm_pytorch/plainlm_model.py b/algoperf/workloads/lm/lm_pytorch/models.py similarity index 100% rename from algoperf/workloads/lm/lm_pytorch/plainlm_model.py rename to algoperf/workloads/lm/lm_pytorch/models.py diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/lm/lm_pytorch/workload.py index 3d185636b..a052a8452 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/lm/lm_pytorch/workload.py @@ -11,7 +11,7 @@ from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.lm.input_pipeline import get_data_iter -from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( +from algoperf.workloads.lm.lm_pytorch.models import ( ModelConfig, Transformer, ) diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py index e1d85eba7..e1ca8e06c 100644 --- a/tests/modeldiffs/lm/compare.py +++ b/tests/modeldiffs/lm/compare.py @@ -26,7 +26,7 @@ from absl.testing import absltest, parameterized # Import JAX implementation -from algoperf.workloads.lm.lm_jax.nanodo_model import ( +from algoperf.workloads.lm.lm_jax.models import ( CausalAttn, Mlp, TBlock, @@ -34,12 +34,12 @@ apply_rope, init_rope, ) -from algoperf.workloads.lm.lm_jax.nanodo_model import ( +from algoperf.workloads.lm.lm_jax.models import ( ModelConfig as JaxModelConfig, ) # Import PyTorch implementation -from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( +from algoperf.workloads.lm.lm_pytorch.models import ( MLP, Attention, Block, @@ -47,7 +47,7 @@ apply_rotary_emb_complex_like, precompute_freqs_cis, ) -from algoperf.workloads.lm.lm_pytorch.plainlm_model import ( +from algoperf.workloads.lm.lm_pytorch.models import ( ModelConfig as PyTorchModelConfig, ) From 0abf39d3e8888103dbecf952ab75d08be85657a9 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 6 Nov 2025 17:05:57 +0000 Subject: [PATCH 78/98] rename workload --- algoperf/workloads/{lm => finewebedu_lm}/__init__.py | 0 .../finewebedu_lm_jax}/__init__.py | 0 .../lm_jax => finewebedu_lm/finewebedu_lm_jax}/models.py | 0 .../finewebedu_lm_jax}/workload.py | 6 +++--- .../finewebedu_lm_pytorch}/__init__.py | 0 .../finewebedu_lm_pytorch}/models.py | 0 .../finewebedu_lm_pytorch}/workload.py | 6 +++--- .../workloads/{lm => finewebedu_lm}/input_pipeline.py | 0 algoperf/workloads/{lm => finewebedu_lm}/workload.py | 0 algoperf/workloads/ogbg/workload.py | 2 +- algoperf/workloads/workloads.py | 7 +++++-- .../archived_paper_baselines/adamw/pytorch/submission.py | 4 ++-- .../archived_paper_baselines/nesterov/jax/submission.py | 4 ++-- .../baselines/external_tuning/jax_nadamw_full_budget.py | 2 +- .../external_tuning/pytorch_nadamw_full_budget.py | 2 +- .../fineweb_edu_lm/jax_nadamw_target_setting.py | 2 +- .../fineweb_edu_lm/pytorch_nadamw_target_setting.py | 2 +- scoring/performance_profile.py | 2 +- scoring/utils/workload_metadata_external_tuning.json | 2 +- submission_runner.py | 6 +++--- tests/modeldiffs/lm/compare.py | 8 ++++---- 21 files changed, 29 insertions(+), 26 deletions(-) rename algoperf/workloads/{lm => finewebedu_lm}/__init__.py (100%) rename algoperf/workloads/{lm/lm_jax => finewebedu_lm/finewebedu_lm_jax}/__init__.py (100%) rename algoperf/workloads/{lm/lm_jax => finewebedu_lm/finewebedu_lm_jax}/models.py (100%) rename algoperf/workloads/{lm/lm_jax => finewebedu_lm/finewebedu_lm_jax}/workload.py (96%) rename algoperf/workloads/{lm/lm_pytorch => finewebedu_lm/finewebedu_lm_pytorch}/__init__.py (100%) rename algoperf/workloads/{lm/lm_pytorch => finewebedu_lm/finewebedu_lm_pytorch}/models.py (100%) rename algoperf/workloads/{lm/lm_pytorch => finewebedu_lm/finewebedu_lm_pytorch}/workload.py (97%) rename algoperf/workloads/{lm => finewebedu_lm}/input_pipeline.py (100%) rename algoperf/workloads/{lm => finewebedu_lm}/workload.py (100%) diff --git a/algoperf/workloads/lm/__init__.py b/algoperf/workloads/finewebedu_lm/__init__.py similarity index 100% rename from algoperf/workloads/lm/__init__.py rename to algoperf/workloads/finewebedu_lm/__init__.py diff --git a/algoperf/workloads/lm/lm_jax/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py similarity index 100% rename from algoperf/workloads/lm/lm_jax/__init__.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py diff --git a/algoperf/workloads/lm/lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py similarity index 100% rename from algoperf/workloads/lm/lm_jax/models.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py diff --git a/algoperf/workloads/lm/lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py similarity index 96% rename from algoperf/workloads/lm/lm_jax/workload.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py index 3862b73dc..ee4cffbbc 100644 --- a/algoperf/workloads/lm/lm_jax/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -6,12 +6,12 @@ import jax.numpy as jnp from algoperf import jax_sharding_utils, param_utils, spec -from algoperf.workloads.lm.input_pipeline import get_data_iter -from algoperf.workloads.lm.lm_jax.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( ModelConfig, TransformerDo, ) -from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload class LmWorkload(BaseLmWorkload): diff --git a/algoperf/workloads/lm/lm_pytorch/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py similarity index 100% rename from algoperf/workloads/lm/lm_pytorch/__init__.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py diff --git a/algoperf/workloads/lm/lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py similarity index 100% rename from algoperf/workloads/lm/lm_pytorch/models.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py diff --git a/algoperf/workloads/lm/lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py similarity index 97% rename from algoperf/workloads/lm/lm_pytorch/workload.py rename to algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py index a052a8452..a25ca334a 100644 --- a/algoperf/workloads/lm/lm_pytorch/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -10,12 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP from algoperf import param_utils, pytorch_utils, spec -from algoperf.workloads.lm.input_pipeline import get_data_iter -from algoperf.workloads.lm.lm_pytorch.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( ModelConfig, Transformer, ) -from algoperf.workloads.lm.workload import BaseLmWorkload +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/lm/input_pipeline.py b/algoperf/workloads/finewebedu_lm/input_pipeline.py similarity index 100% rename from algoperf/workloads/lm/input_pipeline.py rename to algoperf/workloads/finewebedu_lm/input_pipeline.py diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py similarity index 100% rename from algoperf/workloads/lm/workload.py rename to algoperf/workloads/finewebedu_lm/workload.py diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 002576268..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 391f16f51..e90300a36 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -113,7 +113,10 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, - 'lm': {'workload_path': 'lm/lm', 'workload_class_name': 'LmWorkload'}, + 'finewebedu_lm': { + 'workload_path': 'finewebedu_lm/finewebedu_lm', + 'workload_class_name': 'LmWorkload', + }, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload', @@ -153,7 +156,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', - 'lm', + 'finewebedu_lm', 'ogbg', 'wmt', ] diff --git a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py index 8fa4e27f6..7c50ff4ff 100644 --- a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py +++ b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py @@ -189,8 +189,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'lm': - return 4 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/archived_paper_baselines/nesterov/jax/submission.py b/algorithms/archived_paper_baselines/nesterov/jax/submission.py index cc8eba3c5..061acc3de 100644 --- a/algorithms/archived_paper_baselines/nesterov/jax/submission.py +++ b/algorithms/archived_paper_baselines/nesterov/jax/submission.py @@ -292,8 +292,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 - elif workload_name == 'lm': - return 8 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index ccfa25360..323022598 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -394,7 +394,7 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 - elif workload_name == 'lm': + elif workload_name == 'finewebedu_lm': return 64 elif workload_name == 'mnist': return 16 diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 9b544e380..2abf74c73 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -372,7 +372,7 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'lm': + elif workload_name == 'finewebedu_lm': return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py index 1fef611ac..b7adf6cd6 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -395,7 +395,7 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 - elif workload_name == 'lm': + elif workload_name == 'finewebedu_lm': return 64 elif workload_name == 'mnist': return 16 diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py index 196c1f809..b881747d8 100644 --- a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py @@ -373,7 +373,7 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'lm': + elif workload_name == 'finewebedu_lm': return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index b200c6865..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl' + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index f133f2462..0ba0d99ee 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -31,7 +31,7 @@ "max_steps": 80000, "dataset": "librispeech" }, - "lm" : { + "finewebedu_lm" : { "max_steps": 55000, "dataset":"fineweb_edu_10B" } diff --git a/submission_runner.py b/submission_runner.py index 1bb763cf2..01d9894d8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -266,7 +266,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', - 'lm' + 'finewebedu_lm', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -784,7 +784,7 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - limit_tf_threads = base_workload != 'lm' + limit_tf_threads = base_workload != 'finewebedu_lm' pytorch_init( USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads ) @@ -803,7 +803,7 @@ def main(_): 'librispeech_deepspeech', 'imagenet_vit', 'criteo1tb', - 'lm', + 'finewebedu_lm', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py index e1ca8e06c..709e3125f 100644 --- a/tests/modeldiffs/lm/compare.py +++ b/tests/modeldiffs/lm/compare.py @@ -26,7 +26,7 @@ from absl.testing import absltest, parameterized # Import JAX implementation -from algoperf.workloads.lm.lm_jax.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( CausalAttn, Mlp, TBlock, @@ -34,12 +34,12 @@ apply_rope, init_rope, ) -from algoperf.workloads.lm.lm_jax.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( ModelConfig as JaxModelConfig, ) # Import PyTorch implementation -from algoperf.workloads.lm.lm_pytorch.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( MLP, Attention, Block, @@ -47,7 +47,7 @@ apply_rotary_emb_complex_like, precompute_freqs_cis, ) -from algoperf.workloads.lm.lm_pytorch.models import ( +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( ModelConfig as PyTorchModelConfig, ) From 7f31b0276d0003689bd365f28cd51bd84322a8ad Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 8 Nov 2025 01:03:47 +0000 Subject: [PATCH 79/98] update budget for lm --- algoperf/workloads/lm/workload.py | 4 ++-- docker/build_docker_images.sh | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/algoperf/workloads/lm/workload.py b/algoperf/workloads/lm/workload.py index 14b02e085..5d6e3d742 100644 --- a/algoperf/workloads/lm/workload.py +++ b/algoperf/workloads/lm/workload.py @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 3600 * 14 # 14 hours TODO(kasimbeg): update + return 31_967 # 8.9 hours @property def eval_period_time_sec(self) -> int: - return 1200 # 20 minutes TODO(kasimbeg): update + return 2_571 # approximately 25 evals @property def step_hint(self) -> int: diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index aa94222ea..22590b9fd 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -45,10 +45,10 @@ do echo "On branch: ${GIT_BRANCH}" echo $DOCKER_BUILD_COMMAND eval $DOCKER_BUILD_COMMAND - echo $DOCKER_TAG_COMMAND - eval $DOCKER_TAG_COMMAND - echo $DOCKER_PUSH_COMMAND - eval $DOCKER_PUSH_COMMAND - echo "To pull container run: " - echo $DOCKER_PULL_COMMAND + # echo $DOCKER_TAG_COMMAND + # eval $DOCKER_TAG_COMMAND + # echo $DOCKER_PUSH_COMMAND + # eval $DOCKER_PUSH_COMMAND + # echo "To pull container run: " + # echo $DOCKER_PULL_COMMAND done From 8da0c79d97bda034523c10c55c608ca5ddf11cdf Mon Sep 17 00:00:00 2001 From: rka97 Date: Sun, 30 Nov 2025 01:54:55 +0000 Subject: [PATCH 80/98] add mixed precision training --- .../finewebedu_lm/finewebedu_lm_jax/models.py | 91 ++++++++------- .../finewebedu_lm_jax/workload.py | 49 +++++++- .../finewebedu_lm_pytorch/models.py | 110 ++++++++++++------ .../finewebedu_lm_pytorch/workload.py | 95 +++++++++------ algoperf/workloads/finewebedu_lm/workload.py | 23 +++- pyproject.toml | 3 +- 6 files changed, 249 insertions(+), 122 deletions(-) diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py index d08e9b7bf..3419fe6fb 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp +import jmp from flax import linen as nn @@ -26,18 +27,24 @@ class ModelConfig: use_residual_scaling: bool = True tie_embeddings: bool = True # Whether to tie input and output embed qknorm_epsilon: float = 1e-6 - - dtype: jnp.dtype = jnp.float32 attention_init: nn.initializers.Initializer = nn.initializers.normal( stddev=0.02 ) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + param_dtype: jnp.dtype = jnp.float32 + compute_dtype: jnp.dtype = jnp.bfloat16 + output_dtype: jnp.dtype = jnp.bfloat16 def __post_init__(self): self.residual_init = nn.initializers.normal( stddev=0.02 / jnp.sqrt(2 * self.num_layers) ) + self.mp_policy = jmp.Policy( + compute_dtype=self.compute_dtype, + param_dtype=self.param_dtype, + output_dtype=self.output_dtype, + ) class Mlp(nn.Module): @@ -49,7 +56,11 @@ class Mlp(nn.Module): def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg linear = partial( - nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype + nn.Dense, + kernel_init=cfg.linear_init, + use_bias=False, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D @@ -65,7 +76,8 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLxD = nn.Dense( cfg.model_dim, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, @@ -96,7 +108,7 @@ def apply_rope(q, k, freqs_cis): def rotate_tensor(x): # Split into real and imaginary parts - x_r2 = x.reshape(*x.shape[:-1], -1, 2) + x_r2 = x.reshape(*x.shape[:-1], -1, 2).astype(jnp.float32) L = x.shape[1] freqs = freqs_cis[:, :L, :, :, :] @@ -109,7 +121,7 @@ def rotate_tensor(x): axis=-1, ) - return rotated_x_r2.reshape(*x.shape) + return rotated_x_r2.reshape(*x.shape).astype(x.dtype) # Apply rotation to Q and K separately rotated_q = rotate_tensor(q) @@ -141,7 +153,8 @@ def setup(self): features=(cfg.num_heads, self.Dh), kernel_init=cfg.attention_init, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) self.multilinear_query = self.multilinear(name='query') self.multilinear_key = self.multilinear(name='key') @@ -150,7 +163,9 @@ def setup(self): seq_len = cfg.seq_len attn_scale0 = jnp.log2(seq_len**2 - seq_len) self.attn_scale = self.param( - 'attn_scale', nn.initializers.constant(attn_scale0), () + 'attn_scale', + nn.initializers.constant(attn_scale0, dtype=cfg.compute_dtype), + (), ) self.output_projection = nn.DenseGeneral( features=cfg.model_dim, @@ -160,7 +175,8 @@ def setup(self): if cfg.use_residual_scaling else cfg.linear_init, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) def __call__(self, x_BxLxD: jax.Array): @@ -177,32 +193,17 @@ def __call__(self, x_BxLxD: jax.Array): # Apply QK normalization q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps - - # Compute attention scores - att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) - - # Causal attention mask - L = x_BxLxD.shape[1] - mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) - - # Apply mask and softmax - _NEG_INF = jnp.finfo(cfg.dtype).min - att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) - att_BxHxLxL = ( - self.attn_scale * att_BxHxLxL - ) # Learned scaling factor for QK norm - att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) - att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) - - # Compute attention output - out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) - - # Reshape and project output + q_BxLxHxDh *= self.attn_scale + out_BxLxHxDh = jax.nn.dot_product_attention( + query=q_BxLxHxDh, + key=k_BxLxHxDh, + value=v_BxLxHxDh, + is_causal=True, + scale=1.0, + implementation='cudnn' if cfg.compute_dtype is not jnp.float32 else None, + ) out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) - - # Output projection out_BxLxD = self.output_projection(out_BxLxD) - return out_BxLxD @@ -216,16 +217,16 @@ def __call__(self, in_BxLxD: jax.Array): cfg = self.docfg # x = x + attn( attn_norm(x) ) - x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - in_BxLxD - ) + x_BxLxD = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + )(in_BxLxD) x_BxLxD = CausalAttn(cfg)(x_BxLxD) x_BxLxD += in_BxLxD # x = x + mlp( mlp_norm(x) ) - z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - x_BxLxD - ) + z_BxLxD = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + )(x_BxLxD) z_BxLxD = Mlp(cfg)(z_BxLxD) return x_BxLxD + z_BxLxD @@ -242,19 +243,24 @@ def setup(self): num_embeddings=cfg.vocab_size, features=cfg.model_dim, embedding_init=cfg.embed_init, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] - self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + self.out_ln = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + ) # Output projection - tied to input embeddings if configured if cfg.tie_embeddings: - self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + self.output_proj = lambda x: self.embed.attend(x) else: self.output_proj = nn.Dense( cfg.vocab_size, kernel_init=cfg.embed_init, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, name='output_proj', ) @@ -357,6 +363,7 @@ def main(): # Make a prediction (forward pass) print('\nRunning forward pass...') + params, x_BxL = cfg.mp_policy.cast_to_compute((params, x_BxL)) logits = model.apply(params, x_BxL) # Print output shape and sample values diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py index ee4cffbbc..14366d9ea 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -1,9 +1,11 @@ """LM workload implemented in Jax.""" +from functools import partial from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp +import jmp from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( @@ -13,10 +15,33 @@ from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload +replicated_sharding = jax_sharding_utils.get_replicate_sharding() +batch_sharding = jax_sharding_utils.get_batch_dim_sharding() + +# Dtype mapping from string to JAX dtype +DTYPE_MAP = { + 'float32': jnp.float32, + 'float16': jnp.float16, + 'bfloat16': jnp.bfloat16, +} + class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + # Convert dtype strings from base class to JAX dtypes + @property + def _compute_dtype(self) -> Any: + return DTYPE_MAP[self._compute_dtype_str] + + @property + def _param_dtype(self) -> Any: + return DTYPE_MAP[self._param_dtype_str] + + @property + def _output_dtype(self) -> Any: + return DTYPE_MAP[self._output_dtype_str] + def _build_input_queue( self, data_rng: jax.random.PRNGKey, @@ -53,8 +78,14 @@ def init_model_fn( num_layers=self._n_layers, # num layers vocab_size=self._vocab_size, expanded_model_dim=self._mlp_dim, # feedforward dim - dtype=jnp.float32, + rmsnorm_epsilon=self._rmsnorm_epsilon, + qknorm_epsilon=self._qknorm_epsilon, + tie_embeddings=self._tie_embeddings, + param_dtype=self._param_dtype, + compute_dtype=self._compute_dtype, + output_dtype=self._output_dtype, ) + self._mp_policy: jmp.Policy = cfg.mp_policy self._model = TransformerDo(cfg) input_shape = (1, self._seq_len) # For token IDs @@ -66,8 +97,7 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) params = jax_sharding_utils.replicate(params) - model_state = None - return params, model_state + return params, None def model_fn( self, @@ -81,10 +111,12 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] + params, inputs = self._mp_policy.cast_to_compute((params, inputs)) # Convert one-hot inputs to token IDs if needed if inputs.ndim == 3: # one-hot encoded inputs = jnp.argmax(inputs, axis=-1) logits = self._model.apply({'params': params}, inputs) + logits = self._mp_policy.cast_to_output(logits) return logits, None def loss_fn( @@ -139,6 +171,17 @@ def loss_fn( 'per_example': per_example_losses, } + @partial( + jax.jit, + static_argnums=(0,), + in_shardings=( + replicated_sharding, + batch_sharding, + replicated_sharding, + replicated_sharding, + ), + out_shardings=(replicated_sharding), + ) def _eval_batch( self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py index edee8318c..4c60198cc 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py @@ -26,14 +26,24 @@ class ModelConfig: qknorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True + compute_dtype: torch.dtype = torch.bfloat16 + param_dtype: torch.dtype = torch.float32 class MLP(nn.Module): - def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + dtype: torch.dtype = torch.float32, + ): super().__init__() - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) - self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + hidden_dim = int( + multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + ) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False, dtype=dtype) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False, dtype=dtype) self.glu = nn.GLU(dim=2) nn.init.normal_(self.fc1.weight, std=0.02) nn.init.normal_(self.fc2.weight, std=0.02) @@ -88,8 +98,12 @@ def __init__(self, cfg: ModelConfig): self.n_heads = cfg.num_heads self.head_dim = cfg.model_dim // cfg.num_heads - self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) - self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) + self.w_qkv = nn.Linear( + cfg.model_dim, 3 * cfg.model_dim, bias=False, dtype=cfg.param_dtype + ) + self.w_out = nn.Linear( + cfg.model_dim, cfg.model_dim, bias=False, dtype=cfg.param_dtype + ) # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: @@ -99,7 +113,9 @@ def __init__(self, cfg: ModelConfig): self.eps = cfg.qknorm_epsilon # e.g., 1e-6 seq_len = cfg.seq_len attn_scale0 = math.log2(seq_len**2 - seq_len) - self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) + self.attn_scale = nn.Parameter( + torch.tensor(attn_scale0, dtype=cfg.param_dtype) + ) def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -142,13 +158,18 @@ class Block(nn.Module): def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.attn_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype + ) self.mlp = MLP( dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of, + dtype=cfg.param_dtype, + ) + self.mlp_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype ) - self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id def forward(self, x, freqs_cis): @@ -166,12 +187,18 @@ def __init__(self, cfg: ModelConfig): head_dim = cfg.model_dim // cfg.num_heads assert cfg.model_dim % cfg.num_heads == 0 - self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) + self.embed_tokens = nn.Embedding( + cfg.vocab_size, cfg.model_dim, dtype=cfg.param_dtype + ) self.layers = nn.ModuleList( [Block(idx, cfg) for idx in range(cfg.num_layers)] ) - self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) - self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) + self.out_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype + ) + self.lm_head = nn.Linear( + cfg.model_dim, cfg.vocab_size, bias=False, dtype=cfg.param_dtype + ) # Initialize freqs_cis on CPU first (more memory efficient) self.register_buffer( @@ -215,6 +242,7 @@ def forward(self, x, targets=None): for layer in self.layers: x = layer(x, freqs_cis) # (bsz, seqlen, dim) out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: loss = F.cross_entropy( out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 @@ -232,40 +260,43 @@ def predict(self, x, k=1): Returns: Tuple of (input_ids, predicted_ids) """ + # Determine device type for autocast + device_type = 'cuda' if x.is_cuda else 'cpu' - # Store original input - original_input = x.clone() - generated_input = x.clone() + with torch.autocast(device_type=device_type, dtype=self.cfg.compute_dtype): + # Store original input + original_input = x.clone() + generated_input = x.clone() - # Generate k tokens autoregressively - for i in range(k): - # Get logits for the entire sequence - logits = self(generated_input) + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] - # Zero out the last token ID to prevent repetition - # This is a common issue - the model gets stuck repeating the last token - last_token_id = generated_input[:, -1] - next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - # Get the most likely token - next_token = torch.argmax(next_token_logits, dim=-1) + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) - # Append the predicted token to the sequence - next_token = next_token.unsqueeze(1) # Add sequence dimension - generated_input = torch.cat([generated_input, next_token], dim=1) + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) - # For debugging, print predictions for the first item in the batch - print('\nPyTorch detailed prediction (first item in batch):') - predicted_sequence = generated_input[0, -k:].tolist() - print(f' Predicted token IDs: {predicted_sequence}') - for i, token_id in enumerate(predicted_sequence): - print(f' Step {i + 1}: Predicted token {token_id}') + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') - # Return all tokens, not just the last k - return original_input, generated_input[:, -k:] + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -318,6 +349,8 @@ def main(): # Instantiate the model model = Transformer(config) print(f'Model has {model.count_params():,} parameters.') + for n, p in model.named_parameters(): + print(f'{n}.dtype == {p.dtype}') # Create some random input data batch_size = 2 @@ -330,6 +363,7 @@ def main(): # Run a forward pass print(f'Running forward pass with input shape: {input_ids.shape}') logits = model(input_ids) + print(f'Output logits dtype: {logits.dtype}') print(f'Output logits shape: {logits.shape}') # Run prediction diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py index a25ca334a..ed922f9c2 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -19,10 +19,25 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +# Dtype mapping from string to PyTorch dtype +DTYPE_MAP = { + 'float32': torch.float32, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, +} + class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" + @property + def _compute_dtype(self) -> torch.dtype: + return DTYPE_MAP[self._compute_dtype_str] + + @property + def _param_dtype(self) -> torch.dtype: + return DTYPE_MAP[self._param_dtype_str] + def init_model_fn( self, rng: spec.RandomState, @@ -40,11 +55,14 @@ def init_model_fn( vocab_size=self._vocab_size, seq_len=self._seq_len, model_dim=self._emb_dim, # Model dimension - expanded_model_dim=self._mlp_dim, # MLP expansion factor - num_layers=self._n_layers, # Number of transformer layers - num_heads=self._n_heads, # Number of attention heads - rmsnorm_epsilon=1e-6, - tie_embeddings=True, + expanded_model_dim=self._mlp_dim, # MLP expanded dim + num_layers=self._n_layers, + num_heads=self._n_heads, + rmsnorm_epsilon=self._rmsnorm_epsilon, + qknorm_epsilon=self._qknorm_epsilon, + tie_embeddings=self._tie_embeddings, + compute_dtype=self._compute_dtype, + param_dtype=self._param_dtype, ) self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) @@ -81,13 +99,18 @@ def model_fn( spec.ForwardPassMode.EVAL: torch.no_grad, spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + + # Determine device type for autocast + device_type = 'cuda' if DEVICE.type == 'cuda' else 'cpu' + with contexts[mode](): - # Convert one-hot inputs to token IDs if needed - inputs = augmented_and_preprocessed_input_batch['inputs'] - if inputs.dim() == 3: # one-hot encoded - inputs = inputs.argmax(dim=-1) + with torch.autocast(device_type=device_type, dtype=self._compute_dtype): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) - logits = model(inputs) + logits = model(inputs) return logits, None @@ -121,7 +144,7 @@ def _build_input_queue( batch['targets'], device=DEVICE, dtype=torch.int64 ), 'weights': torch.tensor( - batch['weights'], device=DEVICE, dtype=torch.float32 + batch['weights'], device=DEVICE, dtype=self._param_dtype ) if batch['weights'] is not None else None, @@ -157,29 +180,35 @@ def loss_fn( - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. - 'per_example': Tensor of shape [batch, length] with individual losses per example. """ - vocab_size = logits_batch.size(-1) - - # Compute cross-entropy loss with label smoothing - per_example_losses = torch.nn.functional.cross_entropy( - logits_batch.view(-1, vocab_size), - label_batch.view(-1), - reduction='none', - label_smoothing=label_smoothing, - ) - per_example_losses = per_example_losses.view_as(label_batch) - - # Apply weights if provided - if mask_batch is not None: - per_example_losses = per_example_losses * mask_batch - - # Calculate number of valid examples - n_valid_examples = ( - mask_batch.sum() - if mask_batch is not None - else torch.tensor( - label_batch.numel(), dtype=torch.float32, device=label_batch.device + # Determine device type for autocast + device_type = 'cuda' if logits_batch.is_cuda else 'cpu' + + with torch.autocast(device_type=device_type, dtype=self._compute_dtype): + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) + + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch + + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None + else torch.tensor( + label_batch.numel(), + dtype=self._param_dtype, + device=label_batch.device, + ) ) - ) return { 'summed': per_example_losses.sum(), diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index 5d6e3d742..b5a258f6f 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -27,6 +27,16 @@ class BaseLmWorkload(spec.Workload): _mlp_dim: int = 4096 warmup_factor: float = 0.1 + # Model configuration + _rmsnorm_epsilon: float = 1e-6 + _qknorm_epsilon: float = 1e-6 + _tie_embeddings: bool = True + + # Dtype configuration (as strings, to be converted by framework-specific subclasses) + _compute_dtype_str: str = 'bfloat16' + _param_dtype_str: str = 'float32' + _output_dtype_str: str = 'bfloat16' # Only used by JAX + def __init__(self) -> None: super().__init__() self._param_shapes = None @@ -85,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 31_967 # 8.9 hours + return 31_967 # 8.9 hours @property def eval_period_time_sec(self) -> int: - return 2_571 # approximately 25 evals + return 2_571 # approximately 25 evals @property def step_hint(self) -> int: @@ -164,9 +174,12 @@ def _eval_model_on_split( eval_batch = next(self._eval_iters[split]) metrics = self._eval_batch(params, eval_batch, model_state, rng) for metric_name, metric_value in metrics.items(): - if metric_name not in eval_metrics: - eval_metrics[metric_name] = 0.0 - eval_metrics[metric_name] += metric_value + eval_metrics.update( + {metric_name: eval_metrics.get(metric_name, 0.0) + metric_value} + ) + print( + f"Completed eval batch {_ + 1}/{num_batches} for split '{split}' at global step {global_step}." + ) eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) eval_results['ppl'] = np.exp(eval_results['loss']).item() diff --git a/pyproject.toml b/pyproject.toml index 006e7e5cd..e3d86df3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] -lm = ["transformers==4.26", "datasets==3.6.0"] +lm = ["transformers==4.26.0", "datasets==3.6.0"] # Frameworks jax_core_deps = [ @@ -99,6 +99,7 @@ jax_core_deps = [ "chex==0.1.86", "ml_dtypes==0.5.1", "protobuf==4.25.5", + "jmp>=0.0.4" ] jax_cpu = [ "jax==0.7.0", From 1645d1fa13d6c08cd6293f2d60f183a317a00011 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 11 Dec 2025 03:04:06 +0000 Subject: [PATCH 81/98] update step time calculation --- scoring/score_submissions.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 4b7bed2b5..51a9c54f5 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -119,9 +119,14 @@ def get_summary_df(workload, workload_df, include_test_split=False): axis=1, ) - summary_df['step_time (s)'] = ( - workload_df['accumulated_submission_time'] / workload_df['global_step'] - ).iloc[-1][-1] + # compute the step times + def delta(series): + return series.shift(1, fill_value=0) - series + accumulated_time_intervals = delta(workload_df['accumulated_submission_time']) + step_intervals = delta(workload_df['global_step']) + + summary_df['step_time (s)'] = np.median((accumulated_time_intervals / step_intervals).iloc[0]) + summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) From 12289fa229b2789891190efc6abe1933e16aab2f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 11 Dec 2025 03:06:29 +0000 Subject: [PATCH 82/98] update target --- algoperf/workloads/finewebedu_lm/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index 5d6e3d742..e6e2e9ba5 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -42,7 +42,7 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - return 22.432 # Target perplexity + return 22.2995 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return True # No test targets From fc6d6f715c7e602970fdc40f960c1e98510f0974 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 11 Dec 2025 03:08:17 +0000 Subject: [PATCH 83/98] add lm workload to list of valid workloads --- docker/scripts/startup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index d92107e90..1cd676d2a 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -185,7 +185,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" - "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "lm") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "finewebedu_lm") VALID_RULESETS=("self" "external") # Set data and experiment paths From 3e0e07c6886b0c3f3d49af37c06b3ab5af4caa62 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 16 Dec 2025 06:54:21 +0000 Subject: [PATCH 84/98] set matmuls, conv and rnn to tf32 for torch.cuda --- submission_runner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/submission_runner.py b/submission_runner.py index 01d9894d8..20c621a96 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -36,6 +36,13 @@ jax.config.update('jax_default_prng_impl', 'threefry2x32') jax.config.update('jax_threefry_partitionable', True) +# PyTorch set TF32 +torch.backends.fp32_precision = "ieee" +torch.backends.cuda.matmul.fp32_precision = "tf32" +torch.backends.cudnn.fp32_precision = "ieee" +torch.backends.cudnn.conv.fp32_precision = "tf32" +torch.backends.cudnn.rnn.fp32_precision = "tf32" + # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') From 113c48198164ab6b4d01a31136f272361f5a5ea5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 16 Dec 2025 17:33:18 +0000 Subject: [PATCH 85/98] add include_submission option --- algoperf/workloads/workloads.py | 4 ++++ scoring/score_submissions.py | 32 +++++++++++++++++++++----------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index e90300a36..1bb0e4e21 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -117,6 +117,10 @@ 'workload_path': 'finewebedu_lm/finewebedu_lm', 'workload_class_name': 'LmWorkload', }, + 'lm': { + 'workload_path': 'finewebedu_lm/finewebedu_lm', + 'workload_class_name': 'LmWorkload', + }, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload', diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 51a9c54f5..55d824dd4 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -67,6 +67,11 @@ '', 'Optional comma seperated list of names of submissions to exclude from scoring.', ) +flags.DEFINE_string( + 'include_submissions', + '', + 'Optional comma seperated list of names of submissions to include from scoring.' +) FLAGS = flags.FLAGS @@ -210,18 +215,23 @@ def main(_): ) as f: results = pickle.load(f) else: - for submission in os.listdir(FLAGS.submission_directory): + all_submission_dirs = list(os.listdir(FLAGS.submission_directory)) + if not FLAGS.include_submissions: + include_submissions = all_submission_dirs + else: + include_submissions = FLAGS.include_submissions.split(',') + + for submission in all_submission_dirs: print(submission) - if submission in FLAGS.exclude_submissions.split(','): - continue - experiment_path = os.path.join(FLAGS.submission_directory, submission) - df = scoring_utils.get_experiment_df(experiment_path) - results[submission] = df - summary_df = get_submission_summary(df) - with open( - os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' - ) as fout: - summary_df.to_csv(fout) + if submission not in FLAGS.exclude_submissions.split(',') and (submission in include_submissions): + experiment_path = os.path.join(FLAGS.submission_directory, submission) + df = scoring_utils.get_experiment_df(experiment_path) + results[submission] = df + summary_df = get_submission_summary(df) + with open( + os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' + ) as fout: + summary_df.to_csv(fout) # Optionally save results to filename if FLAGS.save_results_to_filename: From 4e82d1c79b6054c0aaae6189599d7db6eb2d2530 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 18 Dec 2025 18:45:49 +0000 Subject: [PATCH 86/98] remove new api settings for tf32 --- algoperf/pytorch_utils.py | 7 +++++++ submission_runner.py | 7 ------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index e24b0f141..c5aae1050 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -20,6 +20,13 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: + # PyTorch set TF32 + # torch.backends.fp32_precision = "ieee" + # torch.backends.cuda.matmul.fp32_precision = "tf32" + # torch.backends.cudnn.fp32_precision = "ieee" + # torch.backends.cudnn.conv.fp32_precision = "tf32" + # torch.backends.cudnn.rnn.fp32_precision = "tf32" + use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu') diff --git a/submission_runner.py b/submission_runner.py index 20c621a96..01d9894d8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -36,13 +36,6 @@ jax.config.update('jax_default_prng_impl', 'threefry2x32') jax.config.update('jax_threefry_partitionable', True) -# PyTorch set TF32 -torch.backends.fp32_precision = "ieee" -torch.backends.cuda.matmul.fp32_precision = "tf32" -torch.backends.cudnn.fp32_precision = "ieee" -torch.backends.cudnn.conv.fp32_precision = "tf32" -torch.backends.cudnn.rnn.fp32_precision = "tf32" - # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') From 85705d774cccfc72220dd50a1723389cc8301a5a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 18 Dec 2025 18:55:59 +0000 Subject: [PATCH 87/98] document how to use nw api for torch.backend precision --- algoperf/pytorch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index c5aae1050..cb9780817 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -26,6 +26,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: # torch.backends.cudnn.fp32_precision = "ieee" # torch.backends.cudnn.conv.fp32_precision = "tf32" # torch.backends.cudnn.rnn.fp32_precision = "tf32" + use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 From d24264dfdb4f599fb43ba95b88eef125f908c21d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 15 Jan 2026 19:02:17 +0000 Subject: [PATCH 88/98] eval with same number of workers --- algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index b31998822..1a569a65e 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -254,7 +254,7 @@ def _build_dataset( batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=5 * N_GPUS if is_train else self.eval_num_workers, + num_workers=5 * N_GPUS, pin_memory=True, drop_last=is_train, persistent_workers=is_train, From 56337f8e209ae4b172d13d6ca1fff71a130911a8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 15 Jan 2026 23:22:24 +0000 Subject: [PATCH 89/98] temporarily disable cache for imagneet --- .../workloads/imagenet_resnet/imagenet_pytorch/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 1a569a65e..2ea830fd6 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -220,10 +220,10 @@ def _build_dataset( ) folder = 'train' if 'train' in split else 'val' - dataset = CachedImageFolder( + dataset = ImageFolder( os.path.join(data_dir, folder), transform=transform_config, - cache_file='.imagenet_cache_index.json', + # cache_file='.imagenet_cache_index.json', ) if split == 'eval_train': From d4dc1b9dd86138159d64fb48e508ddaf9740d043 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 20 Jan 2026 19:35:36 +0000 Subject: [PATCH 90/98] remove print statement --- algoperf/workloads/finewebedu_lm/workload.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index b5a258f6f..d95da48ec 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -177,9 +177,6 @@ def _eval_model_on_split( eval_metrics.update( {metric_name: eval_metrics.get(metric_name, 0.0) + metric_value} ) - print( - f"Completed eval batch {_ + 1}/{num_batches} for split '{split}' at global step {global_step}." - ) eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) eval_results['ppl'] = np.exp(eval_results['loss']).item() From c7ad36d3433f978bb06ad30da212134d769c9b86 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 21 Jan 2026 08:54:31 -0800 Subject: [PATCH 91/98] Revert "add mixed precision training for lm workload" --- .../finewebedu_lm/finewebedu_lm_jax/models.py | 91 +++++++-------- .../finewebedu_lm_jax/workload.py | 49 +------- .../finewebedu_lm_pytorch/models.py | 110 ++++++------------ .../finewebedu_lm_pytorch/workload.py | 95 ++++++--------- algoperf/workloads/finewebedu_lm/workload.py | 20 +--- pyproject.toml | 3 +- 6 files changed, 122 insertions(+), 246 deletions(-) diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py index 3419fe6fb..d08e9b7bf 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py @@ -8,7 +8,6 @@ import jax import jax.numpy as jnp -import jmp from flax import linen as nn @@ -27,24 +26,18 @@ class ModelConfig: use_residual_scaling: bool = True tie_embeddings: bool = True # Whether to tie input and output embed qknorm_epsilon: float = 1e-6 + + dtype: jnp.dtype = jnp.float32 attention_init: nn.initializers.Initializer = nn.initializers.normal( stddev=0.02 ) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) - param_dtype: jnp.dtype = jnp.float32 - compute_dtype: jnp.dtype = jnp.bfloat16 - output_dtype: jnp.dtype = jnp.bfloat16 def __post_init__(self): self.residual_init = nn.initializers.normal( stddev=0.02 / jnp.sqrt(2 * self.num_layers) ) - self.mp_policy = jmp.Policy( - compute_dtype=self.compute_dtype, - param_dtype=self.param_dtype, - output_dtype=self.output_dtype, - ) class Mlp(nn.Module): @@ -56,11 +49,7 @@ class Mlp(nn.Module): def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg linear = partial( - nn.Dense, - kernel_init=cfg.linear_init, - use_bias=False, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype ) # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D @@ -76,8 +65,7 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLxD = nn.Dense( cfg.model_dim, use_bias=False, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, @@ -108,7 +96,7 @@ def apply_rope(q, k, freqs_cis): def rotate_tensor(x): # Split into real and imaginary parts - x_r2 = x.reshape(*x.shape[:-1], -1, 2).astype(jnp.float32) + x_r2 = x.reshape(*x.shape[:-1], -1, 2) L = x.shape[1] freqs = freqs_cis[:, :L, :, :, :] @@ -121,7 +109,7 @@ def rotate_tensor(x): axis=-1, ) - return rotated_x_r2.reshape(*x.shape).astype(x.dtype) + return rotated_x_r2.reshape(*x.shape) # Apply rotation to Q and K separately rotated_q = rotate_tensor(q) @@ -153,8 +141,7 @@ def setup(self): features=(cfg.num_heads, self.Dh), kernel_init=cfg.attention_init, use_bias=False, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + dtype=cfg.dtype, ) self.multilinear_query = self.multilinear(name='query') self.multilinear_key = self.multilinear(name='key') @@ -163,9 +150,7 @@ def setup(self): seq_len = cfg.seq_len attn_scale0 = jnp.log2(seq_len**2 - seq_len) self.attn_scale = self.param( - 'attn_scale', - nn.initializers.constant(attn_scale0, dtype=cfg.compute_dtype), - (), + 'attn_scale', nn.initializers.constant(attn_scale0), () ) self.output_projection = nn.DenseGeneral( features=cfg.model_dim, @@ -175,8 +160,7 @@ def setup(self): if cfg.use_residual_scaling else cfg.linear_init, use_bias=False, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + dtype=cfg.dtype, ) def __call__(self, x_BxLxD: jax.Array): @@ -193,17 +177,32 @@ def __call__(self, x_BxLxD: jax.Array): # Apply QK normalization q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps - q_BxLxHxDh *= self.attn_scale - out_BxLxHxDh = jax.nn.dot_product_attention( - query=q_BxLxHxDh, - key=k_BxLxHxDh, - value=v_BxLxHxDh, - is_causal=True, - scale=1.0, - implementation='cudnn' if cfg.compute_dtype is not jnp.float32 else None, - ) + + # Compute attention scores + att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = ( + self.attn_scale * att_BxHxLxL + ) # Learned scaling factor for QK norm + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection out_BxLxD = self.output_projection(out_BxLxD) + return out_BxLxD @@ -217,16 +216,16 @@ def __call__(self, in_BxLxD: jax.Array): cfg = self.docfg # x = x + attn( attn_norm(x) ) - x_BxLxD = nn.RMSNorm( - param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon - )(in_BxLxD) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) x_BxLxD = CausalAttn(cfg)(x_BxLxD) x_BxLxD += in_BxLxD # x = x + mlp( mlp_norm(x) ) - z_BxLxD = nn.RMSNorm( - param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon - )(x_BxLxD) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) z_BxLxD = Mlp(cfg)(z_BxLxD) return x_BxLxD + z_BxLxD @@ -243,24 +242,19 @@ def setup(self): num_embeddings=cfg.vocab_size, features=cfg.model_dim, embedding_init=cfg.embed_init, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, ) self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] - self.out_ln = nn.RMSNorm( - param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon - ) + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) # Output projection - tied to input embeddings if configured if cfg.tie_embeddings: - self.output_proj = lambda x: self.embed.attend(x) + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) else: self.output_proj = nn.Dense( cfg.vocab_size, kernel_init=cfg.embed_init, - dtype=cfg.compute_dtype, - param_dtype=cfg.param_dtype, + dtype=cfg.dtype, name='output_proj', ) @@ -363,7 +357,6 @@ def main(): # Make a prediction (forward pass) print('\nRunning forward pass...') - params, x_BxL = cfg.mp_policy.cast_to_compute((params, x_BxL)) logits = model.apply(params, x_BxL) # Print output shape and sample values diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py index 14366d9ea..ee4cffbbc 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -1,11 +1,9 @@ """LM workload implemented in Jax.""" -from functools import partial from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp -import jmp from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( @@ -15,33 +13,10 @@ from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload -replicated_sharding = jax_sharding_utils.get_replicate_sharding() -batch_sharding = jax_sharding_utils.get_batch_dim_sharding() - -# Dtype mapping from string to JAX dtype -DTYPE_MAP = { - 'float32': jnp.float32, - 'float16': jnp.float16, - 'bfloat16': jnp.bfloat16, -} - class LmWorkload(BaseLmWorkload): """LM JAX workload.""" - # Convert dtype strings from base class to JAX dtypes - @property - def _compute_dtype(self) -> Any: - return DTYPE_MAP[self._compute_dtype_str] - - @property - def _param_dtype(self) -> Any: - return DTYPE_MAP[self._param_dtype_str] - - @property - def _output_dtype(self) -> Any: - return DTYPE_MAP[self._output_dtype_str] - def _build_input_queue( self, data_rng: jax.random.PRNGKey, @@ -78,14 +53,8 @@ def init_model_fn( num_layers=self._n_layers, # num layers vocab_size=self._vocab_size, expanded_model_dim=self._mlp_dim, # feedforward dim - rmsnorm_epsilon=self._rmsnorm_epsilon, - qknorm_epsilon=self._qknorm_epsilon, - tie_embeddings=self._tie_embeddings, - param_dtype=self._param_dtype, - compute_dtype=self._compute_dtype, - output_dtype=self._output_dtype, + dtype=jnp.float32, ) - self._mp_policy: jmp.Policy = cfg.mp_policy self._model = TransformerDo(cfg) input_shape = (1, self._seq_len) # For token IDs @@ -97,7 +66,8 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) params = jax_sharding_utils.replicate(params) - return params, None + model_state = None + return params, model_state def model_fn( self, @@ -111,12 +81,10 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] - params, inputs = self._mp_policy.cast_to_compute((params, inputs)) # Convert one-hot inputs to token IDs if needed if inputs.ndim == 3: # one-hot encoded inputs = jnp.argmax(inputs, axis=-1) logits = self._model.apply({'params': params}, inputs) - logits = self._mp_policy.cast_to_output(logits) return logits, None def loss_fn( @@ -171,17 +139,6 @@ def loss_fn( 'per_example': per_example_losses, } - @partial( - jax.jit, - static_argnums=(0,), - in_shardings=( - replicated_sharding, - batch_sharding, - replicated_sharding, - replicated_sharding, - ), - out_shardings=(replicated_sharding), - ) def _eval_batch( self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py index 4c60198cc..edee8318c 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py @@ -26,24 +26,14 @@ class ModelConfig: qknorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True - compute_dtype: torch.dtype = torch.bfloat16 - param_dtype: torch.dtype = torch.float32 class MLP(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int = 256, - dtype: torch.dtype = torch.float32, - ): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): super().__init__() - hidden_dim = int( - multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - ) - self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False, dtype=dtype) - self.fc2 = nn.Linear(hidden_dim, dim, bias=False, dtype=dtype) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) self.glu = nn.GLU(dim=2) nn.init.normal_(self.fc1.weight, std=0.02) nn.init.normal_(self.fc2.weight, std=0.02) @@ -98,12 +88,8 @@ def __init__(self, cfg: ModelConfig): self.n_heads = cfg.num_heads self.head_dim = cfg.model_dim // cfg.num_heads - self.w_qkv = nn.Linear( - cfg.model_dim, 3 * cfg.model_dim, bias=False, dtype=cfg.param_dtype - ) - self.w_out = nn.Linear( - cfg.model_dim, cfg.model_dim, bias=False, dtype=cfg.param_dtype - ) + self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) + self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: @@ -113,9 +99,7 @@ def __init__(self, cfg: ModelConfig): self.eps = cfg.qknorm_epsilon # e.g., 1e-6 seq_len = cfg.seq_len attn_scale0 = math.log2(seq_len**2 - seq_len) - self.attn_scale = nn.Parameter( - torch.tensor(attn_scale0, dtype=cfg.param_dtype) - ) + self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -158,18 +142,13 @@ class Block(nn.Module): def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm( - cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype - ) + self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.mlp = MLP( dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of, - dtype=cfg.param_dtype, - ) - self.mlp_norm = nn.RMSNorm( - cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype ) + self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id def forward(self, x, freqs_cis): @@ -187,18 +166,12 @@ def __init__(self, cfg: ModelConfig): head_dim = cfg.model_dim // cfg.num_heads assert cfg.model_dim % cfg.num_heads == 0 - self.embed_tokens = nn.Embedding( - cfg.vocab_size, cfg.model_dim, dtype=cfg.param_dtype - ) + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) self.layers = nn.ModuleList( [Block(idx, cfg) for idx in range(cfg.num_layers)] ) - self.out_norm = nn.RMSNorm( - cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype - ) - self.lm_head = nn.Linear( - cfg.model_dim, cfg.vocab_size, bias=False, dtype=cfg.param_dtype - ) + self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) # Initialize freqs_cis on CPU first (more memory efficient) self.register_buffer( @@ -242,7 +215,6 @@ def forward(self, x, targets=None): for layer in self.layers: x = layer(x, freqs_cis) # (bsz, seqlen, dim) out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) - if targets is not None: loss = F.cross_entropy( out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 @@ -260,43 +232,40 @@ def predict(self, x, k=1): Returns: Tuple of (input_ids, predicted_ids) """ - # Determine device type for autocast - device_type = 'cuda' if x.is_cuda else 'cpu' - with torch.autocast(device_type=device_type, dtype=self.cfg.compute_dtype): - # Store original input - original_input = x.clone() - generated_input = x.clone() + # Store original input + original_input = x.clone() + generated_input = x.clone() - # Generate k tokens autoregressively - for i in range(k): - # Get logits for the entire sequence - logits = self(generated_input) + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] - # Zero out the last token ID to prevent repetition - # This is a common issue - the model gets stuck repeating the last token - last_token_id = generated_input[:, -1] - next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - # Get the most likely token - next_token = torch.argmax(next_token_logits, dim=-1) + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) - # Append the predicted token to the sequence - next_token = next_token.unsqueeze(1) # Add sequence dimension - generated_input = torch.cat([generated_input, next_token], dim=1) + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) - # For debugging, print predictions for the first item in the batch - print('\nPyTorch detailed prediction (first item in batch):') - predicted_sequence = generated_input[0, -k:].tolist() - print(f' Predicted token IDs: {predicted_sequence}') - for i, token_id in enumerate(predicted_sequence): - print(f' Step {i + 1}: Predicted token {token_id}') + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') - # Return all tokens, not just the last k - return original_input, generated_input[:, -k:] + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -349,8 +318,6 @@ def main(): # Instantiate the model model = Transformer(config) print(f'Model has {model.count_params():,} parameters.') - for n, p in model.named_parameters(): - print(f'{n}.dtype == {p.dtype}') # Create some random input data batch_size = 2 @@ -363,7 +330,6 @@ def main(): # Run a forward pass print(f'Running forward pass with input shape: {input_ids.shape}') logits = model(input_ids) - print(f'Output logits dtype: {logits.dtype}') print(f'Output logits shape: {logits.shape}') # Run prediction diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py index ed922f9c2..a25ca334a 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -19,25 +19,10 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() -# Dtype mapping from string to PyTorch dtype -DTYPE_MAP = { - 'float32': torch.float32, - 'float16': torch.float16, - 'bfloat16': torch.bfloat16, -} - class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" - @property - def _compute_dtype(self) -> torch.dtype: - return DTYPE_MAP[self._compute_dtype_str] - - @property - def _param_dtype(self) -> torch.dtype: - return DTYPE_MAP[self._param_dtype_str] - def init_model_fn( self, rng: spec.RandomState, @@ -55,14 +40,11 @@ def init_model_fn( vocab_size=self._vocab_size, seq_len=self._seq_len, model_dim=self._emb_dim, # Model dimension - expanded_model_dim=self._mlp_dim, # MLP expanded dim - num_layers=self._n_layers, - num_heads=self._n_heads, - rmsnorm_epsilon=self._rmsnorm_epsilon, - qknorm_epsilon=self._qknorm_epsilon, - tie_embeddings=self._tie_embeddings, - compute_dtype=self._compute_dtype, - param_dtype=self._param_dtype, + expanded_model_dim=self._mlp_dim, # MLP expansion factor + num_layers=self._n_layers, # Number of transformer layers + num_heads=self._n_heads, # Number of attention heads + rmsnorm_epsilon=1e-6, + tie_embeddings=True, ) self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) @@ -99,18 +81,13 @@ def model_fn( spec.ForwardPassMode.EVAL: torch.no_grad, spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } - - # Determine device type for autocast - device_type = 'cuda' if DEVICE.type == 'cuda' else 'cpu' - with contexts[mode](): - with torch.autocast(device_type=device_type, dtype=self._compute_dtype): - # Convert one-hot inputs to token IDs if needed - inputs = augmented_and_preprocessed_input_batch['inputs'] - if inputs.dim() == 3: # one-hot encoded - inputs = inputs.argmax(dim=-1) + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) - logits = model(inputs) + logits = model(inputs) return logits, None @@ -144,7 +121,7 @@ def _build_input_queue( batch['targets'], device=DEVICE, dtype=torch.int64 ), 'weights': torch.tensor( - batch['weights'], device=DEVICE, dtype=self._param_dtype + batch['weights'], device=DEVICE, dtype=torch.float32 ) if batch['weights'] is not None else None, @@ -180,35 +157,29 @@ def loss_fn( - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. - 'per_example': Tensor of shape [batch, length] with individual losses per example. """ - # Determine device type for autocast - device_type = 'cuda' if logits_batch.is_cuda else 'cpu' - - with torch.autocast(device_type=device_type, dtype=self._compute_dtype): - vocab_size = logits_batch.size(-1) - - # Compute cross-entropy loss with label smoothing - per_example_losses = torch.nn.functional.cross_entropy( - logits_batch.view(-1, vocab_size), - label_batch.view(-1), - reduction='none', - label_smoothing=label_smoothing, - ) - per_example_losses = per_example_losses.view_as(label_batch) - - # Apply weights if provided - if mask_batch is not None: - per_example_losses = per_example_losses * mask_batch - - # Calculate number of valid examples - n_valid_examples = ( - mask_batch.sum() - if mask_batch is not None - else torch.tensor( - label_batch.numel(), - dtype=self._param_dtype, - device=label_batch.device, - ) + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) + + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch + + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None + else torch.tensor( + label_batch.numel(), dtype=torch.float32, device=label_batch.device ) + ) return { 'summed': per_example_losses.sum(), diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index 3abb9c138..e6e2e9ba5 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -27,16 +27,6 @@ class BaseLmWorkload(spec.Workload): _mlp_dim: int = 4096 warmup_factor: float = 0.1 - # Model configuration - _rmsnorm_epsilon: float = 1e-6 - _qknorm_epsilon: float = 1e-6 - _tie_embeddings: bool = True - - # Dtype configuration (as strings, to be converted by framework-specific subclasses) - _compute_dtype_str: str = 'bfloat16' - _param_dtype_str: str = 'float32' - _output_dtype_str: str = 'bfloat16' # Only used by JAX - def __init__(self) -> None: super().__init__() self._param_shapes = None @@ -95,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 31_967 # 8.9 hours + return 31_967 # 8.9 hours @property def eval_period_time_sec(self) -> int: - return 2_571 # approximately 25 evals + return 2_571 # approximately 25 evals @property def step_hint(self) -> int: @@ -174,9 +164,9 @@ def _eval_model_on_split( eval_batch = next(self._eval_iters[split]) metrics = self._eval_batch(params, eval_batch, model_state, rng) for metric_name, metric_value in metrics.items(): - eval_metrics.update( - {metric_name: eval_metrics.get(metric_name, 0.0) + metric_value} - ) + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) eval_results['ppl'] = np.exp(eval_results['loss']).item() diff --git a/pyproject.toml b/pyproject.toml index e3d86df3d..006e7e5cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] -lm = ["transformers==4.26.0", "datasets==3.6.0"] +lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ @@ -99,7 +99,6 @@ jax_core_deps = [ "chex==0.1.86", "ml_dtypes==0.5.1", "protobuf==4.25.5", - "jmp>=0.0.4" ] jax_cpu = [ "jax==0.7.0", From 5f733e16625840fba5e00838ce68d0da748aa958 Mon Sep 17 00:00:00 2001 From: rka97 Date: Tue, 27 Jan 2026 19:48:38 +0000 Subject: [PATCH 92/98] fix formatting errors --- algoperf/pytorch_utils.py | 1 - algoperf/workloads/finewebedu_lm/workload.py | 4 ++-- scoring/score_submissions.py | 12 ++++++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index b00bb6ea3..08dda7de2 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -27,7 +27,6 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: # torch.backends.cudnn.fp32_precision = "ieee" # torch.backends.cudnn.conv.fp32_precision = "tf32" # torch.backends.cudnn.rnn.fp32_precision = "tf32" - use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index e6e2e9ba5..59f70380f 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -85,11 +85,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 31_967 # 8.9 hours + return 31_967 # 8.9 hours @property def eval_period_time_sec(self) -> int: - return 2_571 # approximately 25 evals + return 2_571 # approximately 25 evals @property def step_hint(self) -> int: diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 55d824dd4..5cb7d25e7 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -70,7 +70,7 @@ flags.DEFINE_string( 'include_submissions', '', - 'Optional comma seperated list of names of submissions to include from scoring.' + 'Optional comma seperated list of names of submissions to include from scoring.', ) FLAGS = flags.FLAGS @@ -127,11 +127,13 @@ def get_summary_df(workload, workload_df, include_test_split=False): # compute the step times def delta(series): return series.shift(1, fill_value=0) - series + accumulated_time_intervals = delta(workload_df['accumulated_submission_time']) step_intervals = delta(workload_df['global_step']) - summary_df['step_time (s)'] = np.median((accumulated_time_intervals / step_intervals).iloc[0]) - + summary_df['step_time (s)'] = np.median( + (accumulated_time_intervals / step_intervals).iloc[0] + ) summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) @@ -223,7 +225,9 @@ def main(_): for submission in all_submission_dirs: print(submission) - if submission not in FLAGS.exclude_submissions.split(',') and (submission in include_submissions): + if submission not in FLAGS.exclude_submissions.split(',') and ( + submission in include_submissions + ): experiment_path = os.path.join(FLAGS.submission_directory, submission) df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df From 5ac8fa624b5e44e05105677c42fd39be4f3b87a2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 03:27:29 +0000 Subject: [PATCH 93/98] logging warning in utils script --- scoring/score_submissions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 55d824dd4..7b87d5b2f 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -76,6 +76,7 @@ def get_summary_df(workload, workload_df, include_test_split=False): + print(f" WORKLOAD: {workload}") validation_metric, validation_target = ( scoring_utils.get_workload_metrics_and_targets(workload, split='validation') ) @@ -127,12 +128,12 @@ def get_summary_df(workload, workload_df, include_test_split=False): # compute the step times def delta(series): return series.shift(1, fill_value=0) - series - accumulated_time_intervals = delta(workload_df['accumulated_submission_time']) - step_intervals = delta(workload_df['global_step']) + accumulated_time_intervals = delta(workload_df['accumulated_submission_time']) # exclude first step + step_intervals = delta(workload_df['global_step']) # exclude time up to first step + if len(accumulated_time_intervals) < 2: + print(f"WARNING: The number of evals may be too low to calculate reliable step time for {workload}") summary_df['step_time (s)'] = np.median((accumulated_time_intervals / step_intervals).iloc[0]) - - summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) # test metrics From 9ec8a9102d79195a0cafae147ddeb411e191d594 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 29 Jan 2026 18:08:06 +0000 Subject: [PATCH 94/98] format --- scoring/score_submissions.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 0aad76606..efe276a33 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -76,7 +76,7 @@ def get_summary_df(workload, workload_df, include_test_split=False): - print(f" WORKLOAD: {workload}") + print(f' WORKLOAD: {workload}') validation_metric, validation_target = ( scoring_utils.get_workload_metrics_and_targets(workload, split='validation') ) @@ -128,13 +128,17 @@ def get_summary_df(workload, workload_df, include_test_split=False): # compute the step times def delta(series): return series.shift(1, fill_value=0) - series + accumulated_time_intervals = delta(workload_df['accumulated_submission_time']) step_intervals = delta(workload_df['global_step']) if len(accumulated_time_intervals) < 2: - print(f"WARNING: The number of evals may be too low to calculate reliable step time for {workload}") - - summary_df['step_time (s)'] = np.median((accumulated_time_intervals / step_intervals).iloc[0]) + print( + f'WARNING: The number of evals may be too low to calculate reliable step time for {workload}' + ) + summary_df['step_time (s)'] = np.median( + (accumulated_time_intervals / step_intervals).iloc[0] + ) summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) From 4dc8e868acd316d8428f65e904bd86d28199f51b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 22:09:00 +0000 Subject: [PATCH 95/98] use cache --- algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 4a6ffe27a..cd476e37f 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -220,7 +220,7 @@ def _build_dataset( ) folder = 'train' if 'train' in split else 'val' - dataset = ImageFolder( + dataset = CachedImageFolder( os.path.join(data_dir, folder), transform=transform_config, cache_file='.imagenet_{}_cache_index.json'.format(split), From b2725c25140cf39e9faf1eea78147f09c809a7a2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 22:43:29 +0000 Subject: [PATCH 96/98] remove debugging statements --- algoperf/pytorch_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index 08dda7de2..706a4fffd 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -21,12 +21,6 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: torch.set_float32_matmul_precision('high') - # PyTorch set TF32 - # torch.backends.fp32_precision = "ieee" - # torch.backends.cuda.matmul.fp32_precision = "tf32" - # torch.backends.cudnn.fp32_precision = "ieee" - # torch.backends.cudnn.conv.fp32_precision = "tf32" - # torch.backends.cudnn.rnn.fp32_precision = "tf32" use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 From 3414aea06fa0f329e4d0119c561b9454fd032db7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 22:56:29 +0000 Subject: [PATCH 97/98] update documentation --- dataset/README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/dataset/README.md b/dataset/README.md index 1bfd9bf73..221637e64 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -24,7 +24,7 @@ This document provides instructions on downloading and preparing all datasets ut *TL;DR to download and prepare a dataset, run `dataset_setup.py`:* ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ -- -- @@ -88,7 +88,7 @@ By default, a user will be prompted before any files are deleted. If you do not From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --ogbg ``` @@ -124,7 +124,7 @@ In total, it should contain 13 files (via `find -type f | wc -l`) for a total of From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --wmt ``` @@ -194,7 +194,7 @@ you should get an email containing the URLS for "knee_singlecoil_train", "knee_singlecoil_val" and "knee_singlecoil_test". ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --fastmri \ --fastmri_knee_singlecoil_train_url '' \ @@ -229,13 +229,13 @@ In total, it should contain 1280 files (via `find -type f | wc -l`) for a total Register on and follow directions to obtain the URLS for the ILSVRC2012 train and validation images. -The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/datasets/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. +The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/dataset/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. The ImageNet data pipeline differs between the PyTorch and JAX workloads. Therefore, you will have to specify the framework (either `pytorch` or `jax`) through the framework flag. ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --imagenet \ --temp_dir $DATA_DIR/tmp \ @@ -349,7 +349,7 @@ In total, it should contain 20 files (via `find -type f | wc -l`) for a total of ### Criteo1TB ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --criteo1tb @@ -378,7 +378,7 @@ In total, it should contain 885 files (via `find -type f | wc -l`) for a total o To download, train a tokenizer and preprocess the librispeech dataset: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --librispeech @@ -458,7 +458,7 @@ python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_v From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --fineweb_edu From 5d4cee91d640e8fe68c29764d804329d31dbbf99 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 22:57:22 +0000 Subject: [PATCH 98/98] revert changes to docker build command --- docker/build_docker_images.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 22590b9fd..aa94222ea 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -45,10 +45,10 @@ do echo "On branch: ${GIT_BRANCH}" echo $DOCKER_BUILD_COMMAND eval $DOCKER_BUILD_COMMAND - # echo $DOCKER_TAG_COMMAND - # eval $DOCKER_TAG_COMMAND - # echo $DOCKER_PUSH_COMMAND - # eval $DOCKER_PUSH_COMMAND - # echo "To pull container run: " - # echo $DOCKER_PULL_COMMAND + echo $DOCKER_TAG_COMMAND + eval $DOCKER_TAG_COMMAND + echo $DOCKER_PUSH_COMMAND + eval $DOCKER_PUSH_COMMAND + echo "To pull container run: " + echo $DOCKER_PULL_COMMAND done