diff --git a/.gitignore b/.gitignore index 7d35f0ccc..916a29ff4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv -algoperf/_version.py +algoperf/_version.py \ No newline at end of file diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index 2c8441d9c..af05111cd 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -5,14 +5,16 @@ """ import os -from typing import Sequence, Tuple +from typing import Optional, Sequence, Tuple import numpy as np +import orbax.checkpoint as ocp import torch from absl import logging from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint +from orbax.checkpoint.type_handlers import NumpyHandler from tensorflow.io import gfile # pytype: disable=import-error from algoperf import spec @@ -30,6 +32,51 @@ ] +class BoolHandler(NumpyHandler): + """ + An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler. + It works by treating the scalar as a 0-dimensional array. + """ + + def typestr(self) -> str: + """Unique string identifier for this handler.""" + return 'np.bool_' + + async def serialize( + self, + values: Sequence[np.bool_], + infos: Sequence, + args: Optional[Sequence[ocp.SaveArgs]] = None, + ): + """ + Serializes a sequence of np.bool_ scalars by first converting them + to 0-dim numpy arrays and then calling the parent NumpyHandler. + """ + # Convert each scalar np.bool_ to a 0-dimensional np.ndarray + array_values = [np.asarray(v, dtype=np.bool_) for v in values] + # Use the parent class's robust serialization logic + return await super().serialize(array_values, infos, args) + + async def deserialize( + self, + infos: Sequence, + args: Optional[Sequence[ocp.RestoreArgs]] = None, + ) -> Sequence[np.bool_]: + """ + Deserializes into a sequence of np.bool_ scalars by calling the + parent handler and then converting the resulting 0-dim arrays. + """ + # Parent deserialize will return a sequence of 0-dimensional np.ndarray + results = await super().deserialize(infos, args) + + # Convert each 0-d array back to an np.bool_ scalar using .item() + scalar_results = [np.bool_(r.item()) for r in results] + return scalar_results + + +ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True) + + def maybe_restore_checkpoint( framework: str, optimizer_state: spec.OptimizerState, diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 908ef0f27..26a351bb4 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -44,6 +44,8 @@ def pytorch_param_types( param_types[name] = spec.ParameterType.ATTENTION_BIAS elif 'in_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_QKV + elif 'qkv' in name: + param_types[name] = spec.ParameterType.ATTENTION_QKV elif 'kv_proj' in name: param_types[name] = spec.ParameterType.ATTENTION_KV elif 'k_proj' in name or 'key' in name: diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index 937001b87..706a4fffd 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -21,6 +21,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: torch.set_float32_matmul_precision('high') + use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu') @@ -28,7 +29,9 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: return use_pytorch_ddp, rank, device, n_gpus -def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: +def pytorch_init( + use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads=True +) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. @@ -40,7 +43,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: if use_pytorch_ddp: # Avoid tf input pipeline creating too many threads. - if rank != 0: + if rank != 0 and limit_tf_threads: tf.config.threading.set_intra_op_parallelism_threads(1) tf.config.threading.set_inter_op_parallelism_threads(1) diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index 1dc773e80..07efa2bdf 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/algoperf/workloads/finewebedu_lm/__init__.py b/algoperf/workloads/finewebedu_lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py new file mode 100644 index 000000000..d08e9b7bf --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py @@ -0,0 +1,397 @@ +""" +Originally based on code from the NanoDO repository under the Apache 2.0 license: +https://github.com/google-deepmind/nanodo +""" + +import dataclasses +from functools import partial + +import jax +import jax.numpy as jnp +from flax import linen as nn + + +@dataclasses.dataclass +class ModelConfig: + """Hyper-parameters for Transformer decoder-only.""" + + model_dim: int # model/embed dim = qkv dim + num_heads: int # num attention heads + seq_len: int # max context/sequence length + num_layers: int # number of transformer block layers + vocab_size: int # vocab size + expanded_model_dim: int # FF inner dimension + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 + use_residual_scaling: bool = True + tie_embeddings: bool = True # Whether to tie input and output embed + qknorm_epsilon: float = 1e-6 + + dtype: jnp.dtype = jnp.float32 + attention_init: nn.initializers.Initializer = nn.initializers.normal( + stddev=0.02 + ) + linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + + def __post_init__(self): + self.residual_init = nn.initializers.normal( + stddev=0.02 / jnp.sqrt(2 * self.num_layers) + ) + + +class Mlp(nn.Module): + """Multilayer perceptron with GLU activation.""" + + cfg: ModelConfig + + @nn.compact + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + linear = partial( + nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype + ) + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters + hidden_dim = cfg.expanded_model_dim * 2 / 3 + hidden_dim = cfg.multiple_of * ( + (cfg.expanded_model_dim + cfg.multiple_of - 1) // cfg.multiple_of + ) + # Double the hidden dimension for GLU + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) + # Apply GLU activation + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) + x_BxLxD = nn.Dense( + cfg.model_dim, + use_bias=False, + dtype=cfg.dtype, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + )(x_BxLxF) + return x_BxLxD + + +@partial(jax.jit, static_argnums=(0, 1, 2)) +def init_rope(dim=256, seq_len=128, n_heads=4): + """Initialize rotary embeddings.""" + + def precompute_freqs_cis_jax(dim, end, theta=10000.0): + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) + t = jnp.arange(end) / 1.0 + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) + return jnp.stack( + [jnp.cos(freqs)[None, :, None, :], jnp.sin(freqs)[None, :, None, :]], + axis=3, + ) + + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) + return freqs_cis.transpose(0, 1, 2, 4, 3) + + +@jax.jit +def apply_rope(q, k, freqs_cis): + """Apply rotary embeddings to Q and K.""" + + def rotate_tensor(x): + # Split into real and imaginary parts + x_r2 = x.reshape(*x.shape[:-1], -1, 2) + L = x.shape[1] + freqs = freqs_cis[:, :L, :, :, :] + + # Apply rotation + rotated_x_r2 = jnp.stack( + [ + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1], + ], + axis=-1, + ) + + return rotated_x_r2.reshape(*x.shape) + + # Apply rotation to Q and K separately + rotated_q = rotate_tensor(q) + rotated_k = rotate_tensor(k) + + return rotated_q, rotated_k + + +class CausalAttn(nn.Module): + """Causal attention layer with rotary embeddings.""" + + cfg: ModelConfig + + def setup(self): + cfg = self.cfg + assert cfg.model_dim % cfg.num_heads == 0, ( + f'D {cfg.model_dim} not divisible by H {cfg.num_heads}' + ) + self.Dh = cfg.model_dim // cfg.num_heads + self.eps = cfg.qknorm_epsilon + + # Initialize rotary embeddings + self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads) + + # Maps D -> (H, Dh) + self.multilinear = partial( + nn.DenseGeneral, + axis=-1, + features=(cfg.num_heads, self.Dh), + kernel_init=cfg.attention_init, + use_bias=False, + dtype=cfg.dtype, + ) + self.multilinear_query = self.multilinear(name='query') + self.multilinear_key = self.multilinear(name='key') + self.multilinear_value = self.multilinear(name='value') + # See Henry et al. (2020) "Query Key Normalization for Transformers" + seq_len = cfg.seq_len + attn_scale0 = jnp.log2(seq_len**2 - seq_len) + self.attn_scale = self.param( + 'attn_scale', nn.initializers.constant(attn_scale0), () + ) + self.output_projection = nn.DenseGeneral( + features=cfg.model_dim, + name='attn_out_proj', + # axis=(-2, -1), # + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + use_bias=False, + dtype=cfg.dtype, + ) + + def __call__(self, x_BxLxD: jax.Array): + cfg = self.cfg + + # Project inputs to Q, K, V + q_BxLxHxDh = self.multilinear_query(x_BxLxD) + k_BxLxHxDh = self.multilinear_key(x_BxLxD) + v_BxLxHxDh = self.multilinear_value(x_BxLxD) + + # Apply rotary embeddings to Q and K + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) + + # Apply QK normalization + q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps + k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps + + # Compute attention scores + att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) + + # Causal attention mask + L = x_BxLxD.shape[1] + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) + + # Apply mask and softmax + _NEG_INF = jnp.finfo(cfg.dtype).min + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) + att_BxHxLxL = ( + self.attn_scale * att_BxHxLxL + ) # Learned scaling factor for QK norm + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) + + # Compute attention output + out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) + + # Reshape and project output + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) + + # Output projection + out_BxLxD = self.output_projection(out_BxLxD) + + return out_BxLxD + + +class TBlock(nn.Module): + """Transformer Block.""" + + docfg: ModelConfig + + @nn.compact + def __call__(self, in_BxLxD: jax.Array): + cfg = self.docfg + + # x = x + attn( attn_norm(x) ) + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + in_BxLxD + ) + x_BxLxD = CausalAttn(cfg)(x_BxLxD) + x_BxLxD += in_BxLxD + + # x = x + mlp( mlp_norm(x) ) + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( + x_BxLxD + ) + z_BxLxD = Mlp(cfg)(z_BxLxD) + + return x_BxLxD + z_BxLxD + + +class TransformerDo(nn.Module): + """Transformer decoder-only.""" + + docfg: ModelConfig + + def setup(self): + cfg = self.docfg + self.embed = nn.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.model_dim, + embedding_init=cfg.embed_init, + ) + + self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + + # Output projection - tied to input embeddings if configured + if cfg.tie_embeddings: + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + else: + self.output_proj = nn.Dense( + cfg.vocab_size, + kernel_init=cfg.embed_init, + dtype=cfg.dtype, + name='output_proj', + ) + + def __call__(self, y_BxL: jax.Array): + # For training on concatenated examples. + y_BxLxD = self.embed(y_BxL) + for block in self.blocks: + y_BxLxD = block(y_BxLxD) + y_BxLxD = self.out_ln(y_BxLxD) + logits_BxLxV = self.output_proj(y_BxLxD) + return logits_BxLxV + + def predict(self, y_BxL: jax.Array, k: int = 1): + """Generate k tokens autoregressively. + + Args: + y_BxL: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + cfg = self.docfg + seq_len = y_BxL.shape[1] + + # Store original input + original_input = y_BxL + + # Make sure we don't exceed the model's context length + if seq_len + k > cfg.seq_len: + raise ValueError( + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.seq_len})" + ) + + # Generate k tokens autoregressively + for _ in range(k): + # Get logits for the entire sequence + logits = self(y_BxL) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + last_token_id = y_BxL[:, -1] + # Prevent predicting the same token consecutively + next_token_logits = next_token_logits.at[ + jnp.arange(len(last_token_id)), last_token_id + ].set(float('-inf')) + + # Get the most likely token + next_token = jnp.argmax(next_token_logits, axis=-1) + + # Append the predicted token to the sequence + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) + + # Return original input and the k predicted tokens + return original_input, y_BxL[:, -k:] + + +# =========== Demo Code ========== + + +def main(): + """Create and run the DecoderOnly Transformer model.""" + # Initialize model configuration with smaller parameters for demo + B, L = (2, 128) # Batch size, sequence length + cfg = ModelConfig( + model_dim=128, + num_heads=4, + seq_len=L, + num_layers=2, + vocab_size=256, + expanded_model_dim=4 * 128, + ) + model = TransformerDo(cfg) + + # Print model info + print('\nModel Configuration:') + print(f' - Model dimension (D): {cfg.model_dim}') + print(f' - Number of heads (H): {cfg.num_heads}') + print(f' - Max sequence length (L): {cfg.seq_len}') + print(f' - Number of layers (N): {cfg.num_layers}') + print(f' - Vocabulary size (V): {cfg.vocab_size}') + print(f' - Feed forward dimension (F): {cfg.expanded_model_dim}') + + # Create random input tokens (simulated token IDs) + rng_key = jax.random.PRNGKey(42) + input_rng, init_rng = jax.random.split(rng_key) + + # Generate random token IDs (integers between 0 and vocab_size-1) + x_BxL = jax.random.randint( + input_rng, shape=(B, L), minval=0, maxval=cfg.vocab_size, dtype=jnp.int32 + ) + + # Initialize model parameters + print('\nInitializing model parameters...') + params = model.init(init_rng, x_BxL) + + # Print parameter count + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f'Total parameters: {param_count:,}') + + # Make a prediction (forward pass) + print('\nRunning forward pass...') + logits = model.apply(params, x_BxL) + + # Print output shape and sample values + print( + f'\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)' + ) + print(f'Output data type: {logits.dtype}') + + # Print sample logits (first 5 positions of the first sequence) + print('\nSample logits (first sequence, first 5 positions, first 5 values):') + for position in range(min(5, L)): + print(f' Position {position}: {logits[0, position, :5]}') + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + # Test the predict function + print('\nTesting predict function...') + # Use a shorter + short_seq = x_BxL[:, :10] + print(f'Input sequence shape: {short_seq.shape}') + + # Predict 5 tokens + k = 5 + original, predicted = model.apply(params, short_seq, k, method=model.predict) + + # Get predictions (token with highest logit at each position) + predictions = jnp.argmax(logits, axis=-1) + print('\nPredicted token IDs (first sequence, first 10 positions):') + print(predictions[0, :10]) + + print('\nDone!') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py new file mode 100644 index 000000000..ee4cffbbc --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -0,0 +1,169 @@ +"""LM workload implemented in Jax.""" + +from typing import Any, Dict, Optional, Tuple + +import jax +import jax.numpy as jnp + +from algoperf import jax_sharding_utils, param_utils, spec +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + ModelConfig, + TransformerDo, +) +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload + + +class LmWorkload(BaseLmWorkload): + """LM JAX workload.""" + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): + """Build an input queue using pre-cached FineWeb dataset.""" + del cache, repeat_final_dataset + ds = get_data_iter( + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=global_batch_size, + num_batches=num_batches, + ) + ds = map(jax_sharding_utils.shard_along_batch_dim, ds) + return ds + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: + # Initialize NanoDO transformer model + cfg = ModelConfig( + model_dim=self._emb_dim, # embedding dim + num_heads=self._n_heads, # num heads + seq_len=self._seq_len, + num_layers=self._n_layers, # num layers + vocab_size=self._vocab_size, + expanded_model_dim=self._mlp_dim, # feedforward dim + dtype=jnp.float32, + ) + self._model = TransformerDo(cfg) + input_shape = (1, self._seq_len) # For token IDs + + params_rng, init_rng = jax.random.split(rng) + variables = jax.jit(self._model.init)( + {'params': params_rng}, jnp.ones(input_shape, jnp.int32) + ) + params = variables['params'] + self._param_shapes = param_utils.jax_param_shapes(params) + self._param_types = param_utils.jax_param_types(self._param_shapes) + params = jax_sharding_utils.replicate(params) + model_state = None + return params, model_state + + def model_fn( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode, rng, update_batch_norm, model_state, dropout_rate + inputs = batch['inputs'] + # Convert one-hot inputs to token IDs if needed + if inputs.ndim == 3: # one-hot encoded + inputs = jnp.argmax(inputs, axis=-1) + logits = self._model.apply({'params': params}, inputs) + return logits, None + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable + """Compute weighted cross entropy. + + Args: + label_batch: categorical targets [batch, length] int array. + logits_batch: [batch, length, num_classes] float array. + mask_batch: weights array of shape [batch, length]. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + + Returns: + {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 2d array of per-example losses} + """ + if logits_batch.ndim != label_batch.ndim + 1: + raise ValueError( + f'Incorrect shapes. Got shape {logits_batch.shape} logits and ' + f'{label_batch.shape} targets.' + ) + # Compute log probabilities + log_probs = jax.nn.log_softmax(logits_batch, axis=-1) + # Extract log probability of the target class + # Shape: [batch, length] + target_log_probs = jnp.take_along_axis( + log_probs, label_batch[..., None], axis=-1 + ).squeeze(-1) + # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p) + # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss. + confidence = 1.0 - label_smoothing + smoothing_term = label_smoothing / self._vocab_size + per_example_losses = -1.0 * ( + confidence * target_log_probs + smoothing_term * log_probs.sum(axis=-1) + ) + if mask_batch is not None: + per_example_losses = mask_batch * per_example_losses + n_valid_examples = mask_batch.sum() + else: + n_valid_examples = label_batch.shape[0] * label_batch.shape[1] + summed_loss = per_example_losses.sum() + return { + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, + } + + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], + ) + return { + 'loss': metrics['summed'], + 'denominator': metrics['n_valid_examples'], + } + + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py new file mode 100644 index 000000000..edee8318c --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py @@ -0,0 +1,344 @@ +""" +Originally based on the plainLM codebase: +https://github.com/Niccolo-Ajroldi/plainLM +under the MIT license https://github.com/Niccolo-Ajroldi/plainLM/blob/main/LICENSE. +""" + +import math +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +@dataclass +class ModelConfig: + model_dim: int + num_heads: int + seq_len: int + num_layers: int + vocab_size: int + expanded_model_dim: int + multiple_of: int = 256 + rmsnorm_epsilon: float = 1e-6 + qknorm_epsilon: float = 1e-6 + use_residual_scaling: bool = True + tie_embeddings: bool = True + + +class MLP(nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + super().__init__() + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + self.glu = nn.GLU(dim=2) + nn.init.normal_(self.fc1.weight, std=0.02) + nn.init.normal_(self.fc2.weight, std=0.02) + + def forward(self, x): + # x: (bsz, T, dim) + return self.fc2(self.glu(self.fc1(x))) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1 +): + inv_freqs = 1.0 / ( + theta + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.device('cpu')) + / dim + ) + ) + t = ( + torch.arange(end, dtype=torch.float32, device=inv_freqs.device) + / condense_ratio + ) + freqs = torch.outer(t, inv_freqs).float() + return torch.stack( + [torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], + dim=4, + ) + + +def apply_rotary_emb_complex_like( + q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + # Rotate query and key vectors using RoPE + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() + rotated_qk_r2 = torch.stack( + [ + qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1], + qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1], + ], + -1, + ).flatten(3) + rotated_qk = rotated_qk_r2 + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) + + +class Attention(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + assert cfg.model_dim % cfg.num_heads == 0 + self.dim = cfg.model_dim + self.n_heads = cfg.num_heads + self.head_dim = cfg.model_dim // cfg.num_heads + + self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) + self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) + # Split into Q, K, V sections + wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) + for w in [wq, wk, wv]: + nn.init.normal_(w, std=0.02) + nn.init.normal_(self.w_out.weight, std=0.02) + + self.eps = cfg.qknorm_epsilon # e.g., 1e-6 + seq_len = cfg.seq_len + attn_scale0 = math.log2(seq_len**2 - seq_len) + self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) + + def forward(self, x, freqs_cis): + bsz, seqlen, d = x.shape # (bsz, seqlen, d) + + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) + q = q.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + k = k.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + v = v.view( + bsz, seqlen, self.n_heads, self.head_dim + ) # (bsz, seqlen, nh, h_dim) + + q, k = apply_rotary_emb_complex_like( + q, k, freqs_cis=freqs_cis + ) # (bsz, seqlen, nh, h_dim) + + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) + + # Apply QK normalization + q = q / torch.norm(q, dim=-1, keepdim=True) + self.eps + k = k / torch.norm(k, dim=-1, keepdim=True) + self.eps + q *= self.attn_scale + + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True, scale=1.0 + ) # (bsz, nh, seqlen, h_dim) + out = ( + out.transpose(1, 2).contiguous().view(bsz, seqlen, d) + ) # (bsz, seqlen, d) + + return self.w_out(out) + + +class Block(nn.Module): + def __init__(self, layer_id: int, cfg: ModelConfig): + super().__init__() + self.attn = Attention(cfg) + self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.mlp = MLP( + dim=cfg.model_dim, + hidden_dim=cfg.expanded_model_dim, + multiple_of=cfg.multiple_of, + ) + self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.layer_id = layer_id + + def forward(self, x, freqs_cis): + # x: (bsz, seqlen, dim) + x = x + self.attn(self.attn_norm(x), freqs_cis) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.n_layers = cfg.num_layers + self.cfg = cfg + head_dim = cfg.model_dim // cfg.num_heads + assert cfg.model_dim % cfg.num_heads == 0 + + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) + self.layers = nn.ModuleList( + [Block(idx, cfg) for idx in range(cfg.num_layers)] + ) + self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) + + # Initialize freqs_cis on CPU first (more memory efficient) + self.register_buffer( + 'freqs_cis', + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0 : cfg.seq_len], + persistent=False, + ) + + # init all weights, scale residual branches + self.apply(self._init_weights) + self._scale_residual_branches() + + # Move model to device (which will also move freqs_cis) + if torch.cuda.is_available(): + self.cuda() + + if cfg.tie_embeddings: + self.tie_weights() + + def forward(self, x, targets=None): + # x: (bsz, seqlen) + x = self.embed_tokens(x) # (bsz, seqlen, dim) + L = x.shape[1] + + # Make sure we have enough precomputed frequencies + if L > self.freqs_cis.shape[1]: + # Need to recompute for longer sequence + head_dim = self.cfg.model_dim // self.cfg.num_heads + new_freqs = precompute_freqs_cis( + head_dim, max(L, self.cfg.seq_len), 500000 + ) + self.register_buffer( + 'freqs_cis', new_freqs[0 : max(L, self.cfg.seq_len)], persistent=False + ) + if torch.cuda.is_available(): + self.freqs_cis = self.freqs_cis.cuda() + + # Select the frequencies for current sequence length and ensure correct device + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) + + for layer in self.layers: + x = layer(x, freqs_cis) # (bsz, seqlen, dim) + out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: + loss = F.cross_entropy( + out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 + ) + return out, loss + return out + + def predict(self, x, k=1): + """Generate k tokens autoregressively. + + Args: + x: Input token sequence of shape (batch_size, seq_len) + k: Number of tokens to predict + + Returns: + Tuple of (input_ids, predicted_ids) + """ + + # Store original input + original_input = x.clone() + generated_input = x.clone() + + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) + + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] + + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) + + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) + + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') + + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, std=0.02) + + def _scale_residual_branches(self): + for n, p in self.named_parameters(): + if n.endswith('fc2.weight') or n.endswith( + 'w_out.weight' + ): # mlp/glu output layer + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers) + ) + + def tie_weights(self): + self.lm_head.weight = self.embed_tokens.weight + + def count_params(self, non_embedding=True): + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.embed_tokens.weight.numel() + if ( + self.lm_head.weight is not self.embed_tokens.weight + ): # if no weight tying + n_params -= self.lm_head.weight.numel() + return n_params + + +def main(): + print('Initializing transformer model and running forward pass...') + + seq_length = 1024 + + # Define model configuration + config = ModelConfig( + vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece + seq_len=seq_length, # Maximum sequence length + model_dim=1024, # Embedding dimension + expanded_model_dim=4.0, # MLP expansion factor + num_layers=12, # Number of transformer layers + num_heads=8, # Number of attention heads + rmsnorm_epsilon=1e-6, # RMSNorm epsilon + tie_embeddings=True, # Tie embedding and output weights + ) + + # Instantiate the model + model = Transformer(config) + print(f'Model has {model.count_params():,} parameters.') + + # Create some random input data + batch_size = 2 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)) + + # Move data to the same device as the model + if torch.cuda.is_available(): + input_ids = input_ids.cuda() + + # Run a forward pass + print(f'Running forward pass with input shape: {input_ids.shape}') + logits = model(input_ids) + print(f'Output logits shape: {logits.shape}') + + # Run prediction + print('Running prediction...') + original_input, predicted_ids = model.predict(input_ids[:, :10], k=5) + print(f'Original input shape for prediction: {original_input.shape}') + print(f'Predicted IDs shape: {predicted_ids.shape}') + print(f'Predicted IDs: {predicted_ids}') + + +if __name__ == '__main__': + main() diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py new file mode 100644 index 000000000..a25ca334a --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -0,0 +1,221 @@ +"""LM workload implemented in PyTorch.""" + +import contextlib +from itertools import islice +from typing import Any, Dict, Iterator, Optional, Tuple + +import jax +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + ModelConfig, + Transformer, +) +from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter +from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload + +USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() + + +class LmWorkload(BaseLmWorkload): + """LM PyTorch workload.""" + + def init_model_fn( + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: + if hasattr(self, '_model'): + # Reinitialize weights but keep same config + self._model.apply(self._model._init_weights) + self._model._scale_residual_branches() + return self._model, None + + torch.manual_seed(rng[0]) + cfg = ModelConfig( + vocab_size=self._vocab_size, + seq_len=self._seq_len, + model_dim=self._emb_dim, # Model dimension + expanded_model_dim=self._mlp_dim, # MLP expansion factor + num_layers=self._n_layers, # Number of transformer layers + num_heads=self._n_heads, # Number of attention heads + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + self._model = Transformer(cfg) + self._param_shapes = param_utils.pytorch_param_shapes(self._model) + self._param_types = param_utils.pytorch_param_types(self._param_shapes) + self._model.to(DEVICE) + + if N_GPUS > 1: + if USE_PYTORCH_DDP: + self._model = DDP(self._model, device_ids=[RANK], output_device=RANK) + else: + self._model = torch.nn.DataParallel(self._model) + + return self._model, None + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del model_state, rng, update_batch_norm, dropout_rate + model = params + + # Set model to eval or train mode based on the mode parameter + if mode == spec.ForwardPassMode.EVAL: + model.eval() + elif mode == spec.ForwardPassMode.TRAIN: + model.train() + contexts = { + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + } + with contexts[mode](): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) + + logits = model(inputs) + + return logits, None + + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: + """Build an input queue for the given split.""" + del cache, repeat_final_dataset + local_batch_size = global_batch_size // N_GPUS + loader = get_data_iter( + data_rng=data_rng, + split=split, + data_dir=data_dir, + batch_size=local_batch_size, + num_batches=num_batches, + ) + if USE_PYTORCH_DDP: + loader = islice(loader, RANK, None, N_GPUS) + dtype = torch.int32 + for batch in loader: + batch = { + 'inputs': torch.tensor(batch['inputs'], device=DEVICE, dtype=dtype), + 'targets': torch.tensor( + batch['targets'], device=DEVICE, dtype=torch.int64 + ), + 'weights': torch.tensor( + batch['weights'], device=DEVICE, dtype=torch.float32 + ) + if batch['weights'] is not None + else None, + } + yield batch + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return 'lm_head.weight' in param_name or 'lm_head.bias' in param_name + + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: spec.Tensor, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: + """Compute weighted cross-entropy loss. + + Args: + label_batch: Target labels of shape [batch, length] (int). + logits_batch: Predicted logits of shape [batch, length, vocab_size] (float). + mask_batch: Optional weights of shape [batch, length] (float). Used to mask + out padding tokens or weight examples differently. If None, all examples + are weighted equally. + label_smoothing: Label smoothing factor in [0, 1]. When > 0, the target + distribution becomes (1 - label_smoothing) for the correct class and + label_smoothing / vocab_size for all other classes. Default is 0.0 (no smoothing). + + Returns: + Dictionary containing: + - 'summed': Scalar tensor with the sum of all weighted losses. + - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. + - 'per_example': Tensor of shape [batch, length] with individual losses per example. + """ + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) + + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch + + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None + else torch.tensor( + label_batch.numel(), dtype=torch.float32, device=label_batch.device + ) + ) + + return { + 'summed': per_example_losses.sum(), + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, + } + + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> spec.Tensor: + """Evaluate the model on a single batch.""" + logits, _ = self.model_fn( + params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False + ) + metrics = self.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch['weights'], + ) + return { + 'loss': metrics['summed'].detach(), + 'denominator': metrics['n_valid_examples'].detach(), + } + + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + del num_examples + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + total_metrics = {k: v.item() for k, v in total_metrics.items()} + eval_denominator = total_metrics.pop('denominator') + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) diff --git a/algoperf/workloads/finewebedu_lm/input_pipeline.py b/algoperf/workloads/finewebedu_lm/input_pipeline.py new file mode 100644 index 000000000..3007371fc --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/input_pipeline.py @@ -0,0 +1,153 @@ +"""Input pipeline for a LM dataset.""" + +import functools +import os +from typing import Optional + +import jax +import tensorflow as tf + +from algoperf import data_utils + +AUTOTUNE = tf.data.experimental.AUTOTUNE +PAD_ID = tf.constant(-1, dtype=tf.int64) + +TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} + +SEQUENCE_LENGTH = 1024 +MAX_CORPUS_CHARS = 1_000_000_000 +SHUFFLE_BUFFER_SIZE = 1000 +VOCAB_SIZE = 50_257 + + +def batch_with_padding( + dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, +): + """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. + + Args: + dataset: tf.data.Dataset + batch_size: batch size of resulting batched dataset + padded_shapes: shapes of the padded batches + padding_id: value for padding, for elements in new batch + + Returns: + """ + batched_dataset = dataset.batch(batch_size, drop_remainder=False) + + # tf.data.Dataset.padded.batch pads elements in the batch so we call it + # again with batch_size=1 to pad each element in original batch. + padded_batched_dataset = batched_dataset.padded_batch( + 1, padded_shapes=padded_shapes, padding_values=padding_id + ) + + # Remove extra dimension resulting from the batch_size=1. + padded_batched_dataset = padded_batched_dataset.unbatch() + + return padded_batched_dataset + + +def get_data_iter( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int, + num_batches: Optional[int] = None, +): + ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) + + it = map( + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size + ), + ds, + ) + + return iter(it) + + +def get_lm_dataset( + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + batch_size: int, + num_batches: Optional[int] = None, +): + """Load preprocessed TF dataset.""" + if split not in TFDS_SPLIT_NAME: + raise NotImplementedError + + shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1) + + data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) + tokens_ds = tf.data.Dataset.load(data_dir) + + # tokens + tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices) + + # sequences + sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True) + + # get inputs and outputs + sequences_ds = sequences_ds.map( + lambda x: { + 'inputs': x['input_ids'][:SEQUENCE_LENGTH], + 'targets': x['input_ids'][1:], + }, + num_parallel_calls=AUTOTUNE, + ) + if split == 'train': + ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) + ds = ds.batch(batch_size, drop_remainder=False) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': None, + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'eval_train': + ds = batch_with_padding( + sequences_ds, + batch_size, + padded_shapes={ + 'inputs': (batch_size, None), + 'targets': (batch_size, None), + }, + ) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + elif split == 'validation': + ds = batch_with_padding( + sequences_ds, + batch_size, + padded_shapes={ + 'inputs': (batch_size, None), + 'targets': (batch_size, None), + }, + ) + ds = ds.take(num_batches) if num_batches is not None else ds + ds = ds.repeat() + ds = ds.map( + lambda x: { + 'inputs': x['inputs'], + 'targets': x['targets'], + 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), + } + ) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + return ds diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py new file mode 100644 index 000000000..59f70380f --- /dev/null +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -0,0 +1,193 @@ +"""LM workload parent class.""" + +import abc +import math +import os +from typing import Any, Dict, Iterator, Optional + +import jax +import numpy as np +from absl import flags + +from algoperf import spec + +FLAGS = flags.FLAGS + +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ + + +class BaseLmWorkload(spec.Workload): + """LM workload.""" + + _vocab_size: int = 50257 + _seq_len: int = 1024 + _emb_dim: int = 1024 + _n_heads: int = 8 + _n_layers: int = 12 + _mlp_dim: int = 4096 + warmup_factor: float = 0.1 + + def __init__(self) -> None: + super().__init__() + self._param_shapes = None + self._param_types = None + + @property + def target_metric_name(self) -> str: + """The name of the target metric (useful for scoring/processing code).""" + return 'ppl' + + def has_reached_validation_target(self, eval_result: float) -> bool: + return eval_result['validation/ppl'] <= self.validation_target_value + + @property + def validation_target_value(self) -> float: + return 22.2995 # Target perplexity + + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return True # No test targets + + @property + def test_target_value(self) -> float: + return None # No test targets + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.SOFTMAX_CROSS_ENTROPY + + @property + def num_train_examples(self) -> int: + return 8_749_870 # sequences of 1024 tokens each + + @property + def num_eval_train_examples(self) -> int: + return 10_000 # Subset for evaluation. + + @property + def num_validation_examples(self) -> int: + return 100_000 # sequences + + @property + def num_test_examples(self) -> int: + return 0 + + @property + def eval_batch_size(self) -> int: + return 256 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def max_allowed_runtime_sec(self) -> int: + return 31_967 # 8.9 hours + + @property + def eval_period_time_sec(self) -> int: + return 2_571 # approximately 25 evals + + @property + def step_hint(self) -> int: + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 72_000 + + @property + def pre_ln(self) -> bool: + return True + + @property + def attention_temp(self) -> float: + return 1.0 + + @property + def activation(self) -> str: + return 'silu' + + @property + def glu(self) -> bool: + return True + + @abc.abstractmethod + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, Any]]: + """Build an input queue for the given split.""" + + @abc.abstractmethod + def _eval_batch( + self, + params: spec.ParameterContainer, + eval_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[str, float]: + """Evaluate the model on a single batch.""" + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: + """Run a full evaluation of the model.""" + num_batches = int(math.ceil(num_examples / global_batch_size)) + + # Handle edge case where num_batches is 0 (e.g., test split with 0 examples) + if num_batches == 0: + return {'loss': 0.0, 'ppl': 1.0} + + if split not in self._eval_iters: + # These iterators will repeat indefinitely. + self._eval_iters[split] = self._build_input_queue( + rng, split, data_dir, global_batch_size, num_batches=num_batches + ) + + eval_metrics = {} + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + metrics = self._eval_batch(params, eval_batch, model_state, rng) + for metric_name, metric_value in metrics.items(): + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value + + eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) + eval_results['ppl'] = np.exp(eval_results['loss']).item() + return eval_results + + @abc.abstractmethod + def _normalize_eval_metrics( + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: + """Normalize eval metrics.""" + + @abc.abstractmethod + def loss_fn( + self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: + """Compute cross-entropy loss for language modeling.""" + + def is_output_params(self, param_name: str) -> bool: + """Return whether the given parameter is an output parameter.""" + return param_name.contains('output') diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 289136bfb..cd476e37f 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -253,7 +253,7 @@ def _build_dataset( batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=5 * N_GPUS if is_train else self.eval_num_workers, + num_workers=5 * N_GPUS, pin_memory=True, drop_last=is_train, persistent_workers=is_train, diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4dd4717e9..1bb0e4e21 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -113,6 +113,14 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', }, + 'finewebedu_lm': { + 'workload_path': 'finewebedu_lm/finewebedu_lm', + 'workload_class_name': 'LmWorkload', + }, + 'lm': { + 'workload_path': 'finewebedu_lm/finewebedu_lm', + 'workload_class_name': 'LmWorkload', + }, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload', @@ -152,6 +160,7 @@ 'imagenet_vit', 'librispeech_conformer', 'librispeech_deepspeech', + 'finewebedu_lm', 'ogbg', 'wmt', ] diff --git a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py index 761ce5cb1..7c50ff4ff 100644 --- a/algorithms/archived_paper_baselines/adamw/pytorch/submission.py +++ b/algorithms/archived_paper_baselines/adamw/pytorch/submission.py @@ -189,6 +189,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/archived_paper_baselines/nesterov/jax/submission.py b/algorithms/archived_paper_baselines/nesterov/jax/submission.py index e199fb2b9..061acc3de 100644 --- a/algorithms/archived_paper_baselines/nesterov/jax/submission.py +++ b/algorithms/archived_paper_baselines/nesterov/jax/submission.py @@ -292,6 +292,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 128 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index cf431de24..6d2808593 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -388,6 +388,8 @@ def get_batch_size(workload_name): return 512 elif workload_name == 'wmt': return 128 + elif workload_name == 'finewebedu_lm': + return 64 elif workload_name == 'mnist': return 16 else: diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 494ada4c8..92027887f 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -349,6 +349,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py new file mode 100644 index 000000000..b7adf6cd6 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/jax_nadamw_target_setting.py @@ -0,0 +1,427 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +# isort: on +import chex +import jax +import jax.numpy as jnp +import optax + +from algoperf import jax_sharding_utils, spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree.map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree.map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + step_hint = 0.75 * step_hint + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) + optimizer_state = opt_init_fn(params_zeros_like) + + return optimizer_state, opt_update_fn + + +def train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, +): + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + dropout_rate=dropout_rate, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container + ) + # Compute mean loss and grad + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + dropout_rate = hyperparameters.dropout_rate + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = ( + jax_sharding_utils.get_batch_dim_sharding() + ) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated, # label_smoothing + replicated, # dropout_rate + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings, + ) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + jitted_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, + ) + ) + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step + ) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'finewebedu_lm': + return 64 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..b881747d8 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/pytorch_nadamw_target_setting.py @@ -0,0 +1,403 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step'] + ) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) + + return loss + + +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + step_hint = step_hint * 0.75 + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + dropout_rate=hyperparameters.dropout_rate, + ) + + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip + ) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + elif workload_name == 'finewebedu_lm': + return 64 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json new file mode 100644 index 000000000..ce0f75623 --- /dev/null +++ b/algorithms/target_setting_algorithms/fineweb_edu_lm/tuning_search_space.json @@ -0,0 +1,11 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.0, + "learning_rate": 0.00038418421332238876, + "one_minus_beta1": 0.01564758865, + "beta2": 0.992362328914093, + "weight_decay": 0.25551270901641954, + "warmup_factor": 0.05 + } +] \ No newline at end of file diff --git a/datasets/README.md b/dataset/README.md similarity index 96% rename from datasets/README.md rename to dataset/README.md index 1aeb83239..221637e64 100644 --- a/datasets/README.md +++ b/dataset/README.md @@ -24,7 +24,7 @@ This document provides instructions on downloading and preparing all datasets ut *TL;DR to download and prepare a dataset, run `dataset_setup.py`:* ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ -- -- @@ -88,7 +88,7 @@ By default, a user will be prompted before any files are deleted. If you do not From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --ogbg ``` @@ -124,7 +124,7 @@ In total, it should contain 13 files (via `find -type f | wc -l`) for a total of From `algorithmic-efficiency` run: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --wmt ``` @@ -194,7 +194,7 @@ you should get an email containing the URLS for "knee_singlecoil_train", "knee_singlecoil_val" and "knee_singlecoil_test". ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --fastmri \ --fastmri_knee_singlecoil_train_url '' \ @@ -229,13 +229,13 @@ In total, it should contain 1280 files (via `find -type f | wc -l`) for a total Register on and follow directions to obtain the URLS for the ILSVRC2012 train and validation images. -The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/datasets/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. +The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/dataset/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. The ImageNet data pipeline differs between the PyTorch and JAX workloads. Therefore, you will have to specify the framework (either `pytorch` or `jax`) through the framework flag. ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --imagenet \ --temp_dir $DATA_DIR/tmp \ @@ -349,7 +349,7 @@ In total, it should contain 20 files (via `find -type f | wc -l`) for a total of ### Criteo1TB ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --criteo1tb @@ -378,7 +378,7 @@ In total, it should contain 885 files (via `find -type f | wc -l`) for a total o To download, train a tokenizer and preprocess the librispeech dataset: ```bash -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir $DATA_DIR \ --temp_dir $DATA_DIR/tmp \ --librispeech @@ -453,3 +453,13 @@ The preprocessing script will generate `.npy` files for audio data, `features.cs ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` + +### Fineweb-EDU 10B +From `algorithmic-efficiency` run: + +```bash +python3 dataset/dataset_setup.py \ +--data_dir $DATA_DIR \ +--temp_dir $DATA_DIR/tmp \ +--fineweb_edu +``` \ No newline at end of file diff --git a/datasets/dataset_setup.py b/dataset/dataset_setup.py similarity index 89% rename from datasets/dataset_setup.py rename to dataset/dataset_setup.py index e110930cd..de5e9d271 100644 --- a/datasets/dataset_setup.py +++ b/dataset/dataset_setup.py @@ -56,7 +56,7 @@ Example command: -python3 datasets/dataset_setup.py \ +python3 dataset/dataset_setup.py \ --data_dir=~/data \ --temp_dir=/tmp/mlcommons_data --imagenet \ @@ -73,8 +73,11 @@ from algoperf.workloads.wmt import tokenizer from algoperf.workloads.wmt.input_pipeline import normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer +from dataset import librispeech_preprocess +from dataset import librispeech_tokenizer + +import datasets as hf_datasets +from transformers import AutoTokenizer import functools import os @@ -82,6 +85,7 @@ import subprocess import tarfile +from typing import Dict, List, Any from absl import app from absl import flags from absl import logging @@ -121,6 +125,9 @@ flags.DEFINE_boolean( 'fastmri', False, 'If --all=false, whether or not to download FastMRI.' ) +flags.DEFINE_boolean( + 'finewebedu', False, 'If --all=false, whether or not to download FineWebEdu.' +) flags.DEFINE_boolean( 'imagenet', False, 'If --all=false, whether or not to download Imagenet.' ) @@ -194,6 +201,9 @@ flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') flags.DEFINE_boolean('skip_download', False, 'Skips data download.') +flags.DEFINE_boolean( + 'skip_tokenization', False, 'Skip Fineweb-edu tokenization.' +) FLAGS = flags.FLAGS @@ -767,6 +777,102 @@ def download_wmt(data_dir): ) +def download_finewebedu( + data_dir, tmp_dir=None, skip_download=False, skip_tokenization=False +): + """Download FineWebEdu-10B.""" + + if not skip_download: + data_dir = os.path.join(data_dir, 'fineweb_edu_10B') + tmp_dir = tmp_dir if tmp_dir is not None else '/tmp' + cache_dir = ( + os.path.join(tmp_dir, 'lm') + if tmp_dir is not None + else os.path.expanduser('~/.cache/huggingface/datasets') + ) + + _maybe_mkdir(data_dir) + _maybe_mkdir(tmp_dir) + _maybe_mkdir(cache_dir) + + os.environ['TMPDIR'] = tmp_dir + + ds = hf_datasets.load_dataset( + 'HuggingFaceFW/fineweb-edu', + name='sample-10BT', + split='train', + cache_dir=cache_dir, + ) + ds.save_to_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + else: + ds = hf_datasets.load_from_disk(os.path.join(tmp_dir, 'fwedu_10B_raw')) + + if not skip_tokenization: + # Tokenize + lm_tokenizer = AutoTokenizer.from_pretrained('gpt2') + logging.info(f'Vocab size of lm_tokenizer = {len(lm_tokenizer)}') + + def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + def add_eos(seq): + return seq + lm_tokenizer.eos_token if seq else seq + + def add_eos_batched(seqs): + return [add_eos(seq) for seq in seqs] + + return lm_tokenizer( + add_eos_batched(examples['text']), + return_special_tokens_mask=False, + return_attention_mask=False, + ) + + lm_tokenizer.model_max_length = ( + 1e30 # prevent truncation during tokenization + ) + logging.info('Tokenizing...') + tokenized_dataset = ds.map( + tokenize, + remove_columns=[ + 'text', + 'id', + 'dump', + 'url', + 'file_path', + 'language', + 'language_score', + 'token_count', + 'score', + 'int_score', + ], + batched=True, + batch_size=1024, + num_proc=8, + ) + + tokenized_dataset.save_to_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) + else: + tokenized_dataset = hf_datasets.load_from_disk( + os.path.join(data_dir, 'fwedu_10B_tokenized') + ) + + # Convert to tensorflow_datasets.Dataset objects + tokenized_dataset = tokenized_dataset.to_tf_dataset() + + # Shuffle dataset + dataset_size = tokenized_dataset.cardinality().numpy() + shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0) + train_size = int(0.9 * dataset_size) + train_dataset = shuffled_dataset.take(train_size) + val_dataset = shuffled_dataset.skip(train_size) + + # Split in train and valid. + train_dataset.save(os.path.join(data_dir, 'train')) + val_dataset.save(os.path.join(data_dir, 'val')) + + return + + def main(_): data_dir = FLAGS.data_dir tmp_dir = FLAGS.temp_dir @@ -854,6 +960,12 @@ def main(_): logging.info('Downloading WMT...') download_wmt(data_dir) + if FLAGS.all or FLAGS.finewebedu: + logging.info('Downloading FineWebEdu-10B...') + download_finewebedu( + data_dir, tmp_dir, FLAGS.skip_download, FLAGS.skip_tokenization + ) + # pylint: enable=logging-format-interpolation # pylint: enable=consider-using-with diff --git a/datasets/librispeech_preprocess.py b/dataset/librispeech_preprocess.py similarity index 99% rename from datasets/librispeech_preprocess.py rename to dataset/librispeech_preprocess.py index 1c216db46..878f10f2a 100644 --- a/datasets/librispeech_preprocess.py +++ b/dataset/librispeech_preprocess.py @@ -14,7 +14,7 @@ from absl import logging from pydub import AudioSegment -from datasets import librispeech_tokenizer +from dataset import librispeech_tokenizer gfile = tf.io.gfile copy = tf.io.gfile.copy diff --git a/datasets/librispeech_tokenizer.py b/dataset/librispeech_tokenizer.py similarity index 100% rename from datasets/librispeech_tokenizer.py rename to dataset/librispeech_tokenizer.py diff --git a/pyproject.toml b/pyproject.toml index 534f5d678..ae2f2c8fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,9 @@ version_file = "algoperf/_version.py" ############################################################################### [project.optional-dependencies] # All workloads -full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"] +full = [ + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,lm]", +] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package @@ -88,6 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] +lm = ["transformers==4.26", "datasets==3.6.0"] # Frameworks jax_core_deps = [ diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 4b7bed2b5..efe276a33 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -67,10 +67,16 @@ '', 'Optional comma seperated list of names of submissions to exclude from scoring.', ) +flags.DEFINE_string( + 'include_submissions', + '', + 'Optional comma seperated list of names of submissions to include from scoring.', +) FLAGS = flags.FLAGS def get_summary_df(workload, workload_df, include_test_split=False): + print(f' WORKLOAD: {workload}') validation_metric, validation_target = ( scoring_utils.get_workload_metrics_and_targets(workload, split='validation') ) @@ -119,9 +125,20 @@ def get_summary_df(workload, workload_df, include_test_split=False): axis=1, ) - summary_df['step_time (s)'] = ( - workload_df['accumulated_submission_time'] / workload_df['global_step'] - ).iloc[-1][-1] + # compute the step times + def delta(series): + return series.shift(1, fill_value=0) - series + + accumulated_time_intervals = delta(workload_df['accumulated_submission_time']) + step_intervals = delta(workload_df['global_step']) + if len(accumulated_time_intervals) < 2: + print( + f'WARNING: The number of evals may be too low to calculate reliable step time for {workload}' + ) + + summary_df['step_time (s)'] = np.median( + (accumulated_time_intervals / step_intervals).iloc[0] + ) summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) @@ -205,18 +222,25 @@ def main(_): ) as f: results = pickle.load(f) else: - for submission in os.listdir(FLAGS.submission_directory): + all_submission_dirs = list(os.listdir(FLAGS.submission_directory)) + if not FLAGS.include_submissions: + include_submissions = all_submission_dirs + else: + include_submissions = FLAGS.include_submissions.split(',') + + for submission in all_submission_dirs: print(submission) - if submission in FLAGS.exclude_submissions.split(','): - continue - experiment_path = os.path.join(FLAGS.submission_directory, submission) - df = scoring_utils.get_experiment_df(experiment_path) - results[submission] = df - summary_df = get_submission_summary(df) - with open( - os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' - ) as fout: - summary_df.to_csv(fout) + if submission not in FLAGS.exclude_submissions.split(',') and ( + submission in include_submissions + ): + experiment_path = os.path.join(FLAGS.submission_directory, submission) + df = scoring_utils.get_experiment_df(experiment_path) + results[submission] = df + summary_df = get_submission_summary(df) + with open( + os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' + ) as fout: + summary_df.to_csv(fout) # Optionally save results to filename if FLAGS.save_results_to_filename: diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index 3d9f78ca1..0ba0d99ee 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -30,5 +30,9 @@ "librispeech_conformer": { "max_steps": 80000, "dataset": "librispeech" + }, + "finewebedu_lm" : { + "max_steps": 55000, + "dataset":"fineweb_edu_10B" } } diff --git a/submission_runner.py b/submission_runner.py index d15bda74b..b557c4f40 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -265,6 +265,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'finewebedu_lm', 'imagenet_vit', ] base_workload = workloads.get_base_workload_name(workload_name) @@ -782,7 +783,10 @@ def main(_): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' if FLAGS.framework == 'pytorch': - pytorch_init(USE_PYTORCH_DDP, RANK, profiler) + limit_tf_threads = base_workload != 'finewebedu_lm' + pytorch_init( + USE_PYTORCH_DDP, RANK, profiler, limit_tf_threads=limit_tf_threads + ) # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: @@ -798,6 +802,7 @@ def main(_): 'librispeech_deepspeech', 'imagenet_vit', 'criteo1tb', + 'finewebedu_lm', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' diff --git a/tests/modeldiffs/lm/compare.py b/tests/modeldiffs/lm/compare.py new file mode 100644 index 000000000..709e3125f --- /dev/null +++ b/tests/modeldiffs/lm/compare.py @@ -0,0 +1,892 @@ +""" +Test file to verify that JAX and PyTorch implementations produce identical outputs +when given the same weights and inputs. + +Tests are performed module-by-module: +1. RMSNorm +2. RoPE (Rotary Position Embeddings) +3. MLP +4. Attention +5. Transformer Block +6. Full Model +""" + +import os +import sys + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jax.numpy as jnp +import numpy as np +import torch +import torch.nn.functional as F +from absl import flags, logging +from absl.testing import absltest, parameterized + +# Import JAX implementation +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + CausalAttn, + Mlp, + TBlock, + TransformerDo, + apply_rope, + init_rope, +) +from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( + ModelConfig as JaxModelConfig, +) + +# Import PyTorch implementation +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + MLP, + Attention, + Block, + Transformer, + apply_rotary_emb_complex_like, + precompute_freqs_cis, +) +from algoperf.workloads.finewebedu_lm.finewebedu_lm_pytorch.models import ( + ModelConfig as PyTorchModelConfig, +) + +FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +FLAGS(sys.argv) + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def assert_close(jax_output, torch_output, rtol=1e-5, atol=1e-6, name=''): + """Assert that JAX and PyTorch outputs are close.""" + jax_np = np.array(jax_output) + torch_np = torch_output.detach().cpu().numpy() + + mse = np.mean((jax_np - torch_np) ** 2) + max_diff = np.max(np.abs(jax_np - torch_np)) + + logging.info(f'\n{name} Comparison:') + logging.info(f' MSE: {mse:.8e}') + logging.info(f' Max Difference: {max_diff:.8e}') + + np.testing.assert_allclose( + jax_np, + torch_np, + rtol=rtol, + atol=atol, + err_msg=f'{name} outputs do not match', + ) + + +# ============================================================================ +# Test Functions (unchanged) +# ============================================================================ + + +def test_rmsnorm(): + """Test that RMSNorm produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RMSNorm') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + eps = 1e-6 + + # Create random input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + + # Initialize PyTorch RMSNorm + torch_norm = torch.nn.RMSNorm(dim, eps=eps) + torch_input = torch.tensor(np_input) + + # Initialize JAX RMSNorm (using Flax's RMSNorm from nanodo) + from flax import linen as nn + + flax_norm = nn.RMSNorm(epsilon=eps) + jax_input = jnp.array(np_input) + flax_params = flax_norm.init(jax.random.PRNGKey(0), jax_input) + + # Copy weights from PyTorch to JAX + with torch.no_grad(): + flax_params['params']['scale'] = jnp.array(torch_norm.weight.numpy()) + + # Forward pass + with torch.no_grad(): + torch_output = torch_norm(torch_input) + + jax_output = flax_norm.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='RMSNorm') + logging.info('✓ RMSNorm test passed') + + +def test_rope(): + """Test that RoPE produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing RoPE (Rotary Position Embeddings)') + logging.info('=' * 70) + + batch_size, seq_len, n_heads, dim = 2, 16, 4, 128 + head_dim = dim // n_heads + + # Initialize RoPE + torch_freqs = precompute_freqs_cis(head_dim, seq_len, theta=500000) + jax_freqs = init_rope(dim, seq_len, n_heads) + + # Create random Q and K + np_q = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + np_k = np.random.randn(batch_size, seq_len, n_heads, head_dim).astype( + np.float32 + ) + + # PyTorch forward + torch_q = torch.tensor(np_q) + torch_k = torch.tensor(np_k) + with torch.no_grad(): + torch_q_rot, torch_k_rot = apply_rotary_emb_complex_like( + torch_q, torch_k, freqs_cis=torch_freqs + ) + + # JAX forward + jax_q = jnp.array(np_q) + jax_k = jnp.array(np_k) + jax_q_rot, jax_k_rot = apply_rope(jax_q, jax_k, jax_freqs) + + # Compare + assert_close(jax_q_rot, torch_q_rot, name='RoPE Q') + assert_close(jax_k_rot, torch_k_rot, name='RoPE K') + logging.info('✓ RoPE test passed') + + +def copy_mlp_params(pytorch_mlp, flax_params): + """Copy MLP parameters from PyTorch to JAX.""" + new_params = flax_params.copy() + + # Handle compiled models + if hasattr(pytorch_mlp, '_orig_mod'): + pytorch_mlp = pytorch_mlp._orig_mod + + # Copy fc1 and fc2 weights (transposed for JAX) + new_params['params']['Dense_0']['kernel'] = ( + pytorch_mlp.fc1.weight.detach().numpy().T + ) + new_params['params']['Dense_1']['kernel'] = ( + pytorch_mlp.fc2.weight.detach().numpy().T + ) + + return new_params + + +def test_mlp(): + """Test that MLP produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing MLP') + logging.info('=' * 70) + + batch_size, seq_len, dim = 2, 10, 256 + hidden_dim = 1024 + + # Initialize PyTorch MLP + pytorch_mlp = MLP(dim=dim, hidden_dim=hidden_dim) + + # Initialize JAX MLP + cfg = JaxModelConfig( + model_dim=dim, + num_heads=4, + seq_len=128, + num_layers=2, + vocab_size=1000, + expanded_model_dim=hidden_dim, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_mlp = Mlp(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_mlp.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_mlp_params(pytorch_mlp, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_mlp(torch_input) + + jax_output = flax_mlp.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, name='MLP') + logging.info('✓ MLP test passed') + + +def copy_attention_params(pytorch_attn, flax_params): + """Copy attention parameters from PyTorch to JAX.""" + # Handle compiled models + if hasattr(pytorch_attn, '_orig_mod'): + pytorch_attn = pytorch_attn._orig_mod + + n_heads = pytorch_attn.n_heads + head_dim = pytorch_attn.head_dim + dim = pytorch_attn.dim + + # Split PyTorch's combined qkv weights + w_qkv = pytorch_attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + # Reshape for Flax's DenseGeneral format [D, H, Dh] + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + new_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': {'kernel': pytorch_attn.w_out.weight.detach().numpy().T}, + } + + return {'params': new_params} + + +def test_attention(): + """Test that Attention produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Attention') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + + # Initialize PyTorch Attention + config = PyTorchModelConfig( + vocab_size=1000, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=1024, + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + ) + pytorch_attn = Attention(config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Attention + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=1024, + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_attn = CausalAttn(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_attn.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_attention_params(pytorch_attn, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_attn(torch_input, freqs_cis) + + jax_output = flax_attn.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Attention') + logging.info('✓ Attention test passed') + + +def copy_block_params(pytorch_block, flax_params): + """Copy block parameters from PyTorch to JAX.""" + # Copy attention parameters + attn_params = copy_attention_params(pytorch_block.attn, {'params': {}})[ + 'params' + ] + + # Copy MLP parameters + pytorch_mlp = pytorch_block.mlp + mlp_params = { + 'Dense_0': {'kernel': pytorch_mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_mlp.fc2.weight.detach().numpy().T}, + } + + # Copy RMSNorm parameters + norm_params = { + 'attn_norm': {'scale': pytorch_block.attn_norm.weight.detach().numpy()}, + 'mlp_norm': {'scale': pytorch_block.mlp_norm.weight.detach().numpy()}, + } + + return { + 'params': { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': norm_params['attn_norm'], + 'RMSNorm_1': norm_params['mlp_norm'], + } + } + + +def test_block(): + """Test that Transformer Block produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Transformer Block') + logging.info('=' * 70) + + batch_size, seq_len, dim, n_heads = 2, 16, 256, 4 + expand = 4.0 + + # Initialize PyTorch Block + config = PyTorchModelConfig( + vocab_size=1000, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=1, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + ) + pytorch_block = Block(layer_id=0, cfg=config) + freqs_cis = precompute_freqs_cis(dim // n_heads, seq_len, theta=500000) + + # Initialize JAX Block + cfg = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=1, + vocab_size=1000, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + ) + flax_block = TBlock(cfg) + + # Initialize JAX params + dummy_input = jnp.ones((batch_size, seq_len, dim)) + flax_params = flax_block.init(jax.random.PRNGKey(0), dummy_input) + + # Copy weights + flax_params = copy_block_params(pytorch_block, flax_params) + + # Create input + np_input = np.random.randn(batch_size, seq_len, dim).astype(np.float32) + torch_input = torch.tensor(np_input) + jax_input = jnp.array(np_input) + + # Forward pass + with torch.no_grad(): + torch_output = pytorch_block(torch_input, freqs_cis) + + jax_output = flax_block.apply(flax_params, jax_input) + + # Compare + assert_close(jax_output, torch_output, rtol=1e-4, atol=1e-5, name='Block') + logging.info('✓ Block test passed') + + +def copy_full_model_params(pytorch_model, flax_params, config): + """Copy all parameters from PyTorch model to JAX model.""" + # Handle tied embeddings case + if hasattr(pytorch_model, '_orig_mod'): + pytorch_model = pytorch_model._orig_mod + + n_layers = config.num_layers + n_heads = config.num_heads + dim = config.model_dim + head_dim = dim // n_heads + + new_params = {'params': {}} + + # Copy embedding weights + new_params['params']['embed'] = { + 'embedding': pytorch_model.embed_tokens.weight.detach().numpy() + } + + # Copy each transformer block + for i in range(n_layers): + pytorch_block = pytorch_model.layers[i] + + # Attention params + w_qkv = pytorch_block.attn.w_qkv.weight + q_weight, k_weight, v_weight = [ + u.detach().numpy() for u in w_qkv.split(dim, dim=0) + ] + + def reshape_for_flax(w, n_heads, head_dim): + return w.reshape(n_heads, head_dim, -1).transpose(2, 0, 1) + + attn_params = { + 'query': {'kernel': reshape_for_flax(q_weight, n_heads, head_dim)}, + 'key': {'kernel': reshape_for_flax(k_weight, n_heads, head_dim)}, + 'value': {'kernel': reshape_for_flax(v_weight, n_heads, head_dim)}, + 'attn_out_proj': { + 'kernel': pytorch_block.attn.w_out.weight.detach().numpy().T + }, + } + + # MLP params + mlp_params = { + 'Dense_0': {'kernel': pytorch_block.mlp.fc1.weight.detach().numpy().T}, + 'Dense_1': {'kernel': pytorch_block.mlp.fc2.weight.detach().numpy().T}, + } + + # Norm params + attn_norm = {'scale': pytorch_block.attn_norm.weight.detach().numpy()} + mlp_norm = {'scale': pytorch_block.mlp_norm.weight.detach().numpy()} + + # Assemble block params + block_key = f'blocks_{i}' + new_params['params'][block_key] = { + 'CausalAttn_0': attn_params, + 'Mlp_0': mlp_params, + 'RMSNorm_0': attn_norm, + 'RMSNorm_1': mlp_norm, + } + + # Copy output norm + new_params['params']['out_ln'] = { + 'scale': pytorch_model.out_norm.weight.detach().numpy() + } + + # Handle output projection (tied or untied) + if not config.tie_embeddings: + new_params['params']['output_proj'] = { + 'kernel': pytorch_model.lm_head.weight.detach().numpy().T + } + + return new_params + + +def test_full_model(): + """Test that full Transformer model produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Full Transformer Model') + logging.info('=' * 70) + + batch_size, seq_len = 2, 32 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + + # Initialize PyTorch model + pytorch_config = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Forward pass + with torch.no_grad(): + torch_logits = pytorch_model(torch_tokens) + + jax_logits = jax_model.apply(jax_params, jax_tokens) + + # Compare + assert_close( + jax_logits, torch_logits, rtol=1e-4, atol=1e-5, name='Full Model' + ) + logging.info('✓ Full Model test passed') + + +def test_prediction(): + """Test that autoregressive generation produces identical outputs.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Autoregressive Prediction') + logging.info('=' * 70) + + batch_size, seq_len = 1, 10 + vocab_size = 256 + dim = 128 + n_heads = 4 + n_layers = 2 + expand = 4.0 + k = 5 # Number of tokens to predict + + # Initialize PyTorch model + pytorch_config = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len + k, + model_dim=dim, + expanded_model_dim=int(dim * expand), + num_layers=n_layers, + num_heads=n_heads, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + pytorch_model = Transformer(pytorch_config) + pytorch_model.eval() + + # Initialize JAX model + jax_config = JaxModelConfig( + model_dim=dim, + num_heads=n_heads, + seq_len=seq_len + k, + num_layers=n_layers, + vocab_size=vocab_size, + expanded_model_dim=int(dim * expand), + dtype=jnp.float32, + rmsnorm_epsilon=1e-6, + tie_embeddings=True, + ) + jax_model = TransformerDo(jax_config) + + # Create input tokens + np_tokens = np.random.randint( + 0, vocab_size, size=(batch_size, seq_len), dtype=np.int32 + ) + torch_tokens = torch.tensor(np_tokens, dtype=torch.long) + jax_tokens = jnp.array(np_tokens, dtype=jnp.int32) + + # Initialize JAX params + jax_params = jax_model.init(jax.random.PRNGKey(0), jax_tokens) + + # Copy weights from PyTorch to JAX + jax_params = copy_full_model_params(pytorch_model, jax_params, pytorch_config) + + # Predict k tokens + with torch.no_grad(): + _, torch_predictions = pytorch_model.predict(torch_tokens, k=k) + + _, jax_predictions = jax_model.apply( + jax_params, jax_tokens, k, method=jax_model.predict + ) + + # Compare predictions + torch_pred_np = torch_predictions.cpu().numpy() + jax_pred_np = np.array(jax_predictions) + + logging.info(f'\nPyTorch predictions: {torch_pred_np[0]}') + logging.info(f'JAX predictions: {jax_pred_np[0]}') + + # Check if predictions match exactly + if np.array_equal(torch_pred_np, jax_pred_np): + logging.info('✓ Predictions match exactly!') + else: + matching = np.sum(torch_pred_np == jax_pred_np) + total = torch_pred_np.size + logging.info( + f'⚠ Predictions differ: {matching}/{total} tokens match ({matching / total * 100:.1f}%)' + ) + logging.info( + ' (Note: Small numerical differences can lead to different argmax results)' + ) + + +def test_initialization_statistics(): + """Verify initialization follows expected distributions.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Statistics') + logging.info('=' * 70) + + # Initialize models + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=1024, + num_layers=12, + vocab_size=50000, + expanded_model_dim=2048, + dtype=jnp.float32, + ) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.ones((1, 10), dtype=jnp.int32) + ) + + pytorch_cfg = PyTorchModelConfig( + vocab_size=50000, + seq_len=1024, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, + ) + pytorch_model = Transformer(pytorch_cfg) + + logging.info('Initialization Statistics Check:') + + # Check embedding + jax_embed = jax_params['params']['embed']['embedding'] + torch_embed = pytorch_model.embed_tokens.weight.detach().numpy() + + logging.info('\nToken Embedding (should be ~0.02 std):') + logging.info( + f' JAX: mean={jax_embed.mean():.6f}, std={jax_embed.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_embed.mean():.6f}, std={torch_embed.std():.6f}' + ) + + # Assert embedding std is close to 0.02 + assert abs(jax_embed.std() - 0.02) < 0.005, ( + f'JAX embedding std {jax_embed.std():.6f} not close to 0.02' + ) + assert abs(torch_embed.std() - 0.02) < 0.005, ( + f'PyTorch embedding std {torch_embed.std():.6f} not close to 0.02' + ) + assert abs(jax_embed.mean()) < 0.01, ( + f'JAX embedding mean {jax_embed.mean():.6f} not close to 0' + ) + assert abs(torch_embed.mean()) < 0.01, ( + f'PyTorch embedding mean {torch_embed.mean():.6f} not close to 0' + ) + + # Check first layer attention Q + jax_q = jax_params['params']['blocks_0']['CausalAttn_0']['query']['kernel'] + torch_q_weight = ( + pytorch_model.layers[0].attn.w_qkv.weight[:512].detach().numpy() + ) + + logging.info('\nAttention Q:') + logging.info(f' JAX: mean={jax_q.mean():.6f}, std={jax_q.std():.6f}') + logging.info( + f' PyTorch: mean={torch_q_weight.mean():.6f}, std={torch_q_weight.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_q.mean()) < 0.01, ( + f'JAX Q mean {jax_q.mean():.6f} not close to 0' + ) + assert abs(torch_q_weight.mean()) < 0.01, ( + f'PyTorch Q mean {torch_q_weight.mean():.6f} not close to 0' + ) + + # Check stds are similar + # Allow 20% difference due to random initialization + assert abs(jax_q.std() - torch_q_weight.std()) / torch_q_weight.std() < 0.2, ( + f'Q std differs too much: JAX {jax_q.std():.6f} vs PyTorch {torch_q_weight.std():.6f}' + ) + + # Check first layer attention output (should be scaled) + jax_attn_out = jax_params['params']['blocks_0']['CausalAttn_0'][ + 'attn_out_proj' + ]['kernel'] + torch_attn_out = pytorch_model.layers[0].attn.w_out.weight.detach().numpy() + + logging.info('\nAttention Output:') + logging.info( + f' JAX: mean={jax_attn_out.mean():.6f}, std={jax_attn_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_attn_out.mean():.6f}, std={torch_attn_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_attn_out.mean()) < 0.01, ( + f'JAX attn out mean {jax_attn_out.mean():.6f} not close to 0' + ) + assert abs(torch_attn_out.mean()) < 0.01, ( + f'PyTorch attn out mean {torch_attn_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_attn_out.std() - torch_attn_out.std()) / torch_attn_out.std() < 0.2 + ), ( + f'Attention output std differs too much: JAX {jax_attn_out.std():.6f} vs PyTorch {torch_attn_out.std():.6f}' + ) + + # Check MLP fc2 (should be scaled) + jax_mlp_out = jax_params['params']['blocks_0']['Mlp_0']['Dense_1']['kernel'] + torch_mlp_out = pytorch_model.layers[0].mlp.fc2.weight.detach().numpy() + + logging.info('\nMLP Output:') + logging.info( + f' JAX: mean={jax_mlp_out.mean():.6f}, std={jax_mlp_out.std():.6f}' + ) + logging.info( + f' PyTorch: mean={torch_mlp_out.mean():.6f}, std={torch_mlp_out.std():.6f}' + ) + + # Check means are close to 0 + assert abs(jax_mlp_out.mean()) < 0.01, ( + f'JAX MLP out mean {jax_mlp_out.mean():.6f} not close to 0' + ) + assert abs(torch_mlp_out.mean()) < 0.01, ( + f'PyTorch MLP out mean {torch_mlp_out.mean():.6f} not close to 0' + ) + + # Check stds are similar + assert ( + abs(jax_mlp_out.std() - torch_mlp_out.std()) / torch_mlp_out.std() < 0.2 + ), ( + f'MLP output std differs too much: JAX {jax_mlp_out.std():.6f} vs PyTorch {torch_mlp_out.std():.6f}' + ) + + logging.info('\n✓ Initialization statistics test passed') + + +def test_initialization_impact(): + """Test that initialization produces similar initial losses.""" + logging.info('\n' + '=' * 70) + logging.info('Testing Initialization Impact') + logging.info('=' * 70) + + # Create identical inputs + batch_size, seq_len = 4, 128 + vocab_size = 50000 + + np.random.seed(42) + tokens = np.random.randint(0, vocab_size, size=(batch_size, seq_len)) + + # Initialize both models with same seed + jax_cfg = JaxModelConfig( + model_dim=512, + num_heads=8, + seq_len=seq_len, + num_layers=12, + vocab_size=vocab_size, + expanded_model_dim=2048, + ) + jax_model = TransformerDo(jax_cfg) + jax_params = jax_model.init( + jax.random.PRNGKey(42), jnp.array(tokens, dtype=jnp.int32) + ) + + torch.manual_seed(42) + pytorch_cfg = PyTorchModelConfig( + vocab_size=vocab_size, + seq_len=seq_len, + model_dim=512, + expanded_model_dim=2048, + num_layers=12, + num_heads=8, + ) + pytorch_model = Transformer(pytorch_cfg) + + # Forward pass + jax_logits = jax_model.apply(jax_params, jnp.array(tokens, dtype=jnp.int32)) + + with torch.no_grad(): + torch_logits = pytorch_model(torch.tensor(tokens, dtype=torch.long)) + + # Compute losses + targets = tokens[:, 1:] + jax_loss = -jax.nn.log_softmax(jax_logits[:, :-1]).mean() + torch_loss = F.cross_entropy( + torch_logits[:, :-1].reshape(-1, vocab_size), + torch.tensor(targets.reshape(-1), dtype=torch.long), + ) + + logging.info('\nInitial Loss Comparison:') + logging.info(f' JAX: {jax_loss:.4f}') + logging.info(f' PyTorch: {torch_loss.item():.4f}') + logging.info(f' Difference: {abs(jax_loss - torch_loss.item()):.6f}') + + # Check that losses are in reasonable range for random init + # With vocab_size=50000, random init should give loss around log(50000) ≈ 10.82 + expected_loss = np.log(vocab_size) + + assert 8.0 < jax_loss < 13.0, ( + f'JAX loss {jax_loss:.4f} outside expected range [8.0, 13.0]' + ) + assert 8.0 < torch_loss.item() < 13.0, ( + f'PyTorch loss {torch_loss.item():.4f} outside expected range [8.0, 13.0]' + ) + + # Both losses should be within 10% of log(vocab_size) + assert abs(jax_loss - expected_loss) / expected_loss < 0.1, ( + f'JAX loss {jax_loss:.4f} too far from expected {expected_loss:.4f}' + ) + assert abs(torch_loss.item() - expected_loss) / expected_loss < 0.1, ( + f'PyTorch loss {torch_loss.item():.4f} too far from expected {expected_loss:.4f}' + ) + + logging.info( + '\nNote: Losses are in expected range for random initialization.' + ) + logging.info(f' Expected ~log(vocab_size) = {expected_loss:.4f}') + logging.info('\n✓ Initialization impact test passed') + + +# ============================================================================ +# Test Class +# ============================================================================ + +named_parameters = [ + dict(testcase_name='rmsnorm', test_fn=test_rmsnorm), + dict(testcase_name='rope', test_fn=test_rope), + dict(testcase_name='mlp', test_fn=test_mlp), + dict(testcase_name='attention', test_fn=test_attention), + dict(testcase_name='block', test_fn=test_block), + dict(testcase_name='full_model', test_fn=test_full_model), + dict(testcase_name='prediction', test_fn=test_prediction), + dict( + testcase_name='initialization_statistics', + test_fn=test_initialization_statistics, + ), + dict( + testcase_name='initialization_impact', test_fn=test_initialization_impact + ), +] + + +class ModelMatchingTest(parameterized.TestCase): + """Tests for JAX vs PyTorch model matching.""" + + @parameterized.named_parameters(*named_parameters) + def test_model_matching(self, test_fn): + """Run individual model matching test.""" + test_fn() + + +if __name__ == '__main__': + absltest.main()