From 8cf93b4b3a3302688b22bb1ca401f85e7f70fb63 Mon Sep 17 00:00:00 2001 From: Ahmed Khaled Date: Mon, 9 Feb 2026 20:49:48 -0800 Subject: [PATCH] Add training scripts for the diffusion language model workload PiperOrigin-RevId: 867908182 --- init2winit/trainer_lib/base_trainer.py | 14 +- .../trainer_lib/test_mdlm_integration.py | 263 ++++++++++++++++++ init2winit/trainer_lib/trainer.py | 4 + init2winit/trainer_lib/trainer_utils.py | 19 +- 4 files changed, 292 insertions(+), 8 deletions(-) create mode 100644 init2winit/trainer_lib/test_mdlm_integration.py diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 248c54c1..456dfe7d 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -424,7 +424,7 @@ def _check_early_stopping(self, report): self._early_stopping_target_value) return early_stopping_condition - def _eval(self, start_step, start_time, save=True): + def _eval(self, start_step, start_time, eval_rng, save=True): """Evaluate. Has the side-effects of: @@ -437,12 +437,14 @@ def _eval(self, start_step, start_time, save=True): Args: start_step: the training start step. start_time: the training start time. + eval_rng: rng seed used in eval (chiefly for the MDLM workload). save: flag to save a checkpoint to disk. defaults to True. Returns: A Dict[str, Any] eval report, originally created in trainer_utils.eval_metrics. """ + time_since_last_eval = time.time() - self._time_at_prev_eval_end if self._eval_use_ema: @@ -452,6 +454,8 @@ def _eval(self, start_step, start_time, save=True): else: eval_params = self._params + eval_rng = jax.random.fold_in(eval_rng, self._global_step) + report, eval_time = trainer_utils.eval_metrics( eval_params, self._batch_stats, @@ -461,6 +465,7 @@ def _eval(self, start_step, start_time, save=True): self._eval_train_num_batches, self._evaluate_batch_jitted, self.finalize_batch_fn, + eval_rng=eval_rng, ) self._run_eval_callbacks(report) if save: @@ -618,8 +623,7 @@ def train(self): # across hosts. rng, init_rng = jax.random.split(self._rng) rng = jax.random.fold_in(rng, jax.process_index()) - rng, data_rng = jax.random.split(rng) - rng, callback_rng = jax.random.split(rng) + rng, data_rng, callback_rng, eval_rng = jax.random.split(rng, 4) if jax.process_index() == 0: logging.info('Let the training begin!') @@ -705,7 +709,7 @@ def train(self): self._global_step, self._eval_frequency, self._eval_steps ): try: - report = self._eval(start_step, start_time) + report = self._eval(start_step, start_time, eval_rng) except utils.TrainingDivergedError as e: self.wait_until_orbax_checkpointer_finished() raise utils.TrainingDivergedError( @@ -720,7 +724,7 @@ def train(self): # If we moved where in the loop body evals happen then we would not need # this test. if self._prev_eval_step != self._num_train_steps: - report = self._eval(start_step, start_time) + report = self._eval(start_step, start_time, eval_rng) yield report # To make sure the last checkpoint was correctly saved. self.wait_until_orbax_checkpointer_finished() diff --git a/init2winit/trainer_lib/test_mdlm_integration.py b/init2winit/trainer_lib/test_mdlm_integration.py new file mode 100644 index 00000000..a6ec52b2 --- /dev/null +++ b/init2winit/trainer_lib/test_mdlm_integration.py @@ -0,0 +1,263 @@ +# coding=utf-8 +# Copyright 2026 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for MDLM training with a patterned fake dataset. + +Verifies that the full training loop (model init -> training -> eval) works +end-to-end and that loss decreases on a simple repeating pattern. + +""" + +import os +import shutil +import tempfile + +from absl import logging +from absl.testing import absltest +from init2winit import utils +from init2winit.dataset_lib import data_utils +from init2winit.init_lib import initializers +from init2winit.model_lib import models +from init2winit.trainer_lib import trainer +import jax +import jax.numpy as jnp +from ml_collections.config_dict import config_dict +import numpy as np +import pandas +import tensorflow.compat.v1 as tf + +Dataset = data_utils.Dataset + +# Small vocab and sequence length so the test runs quickly on CPU. +_VOCAB_SIZE = 16 +_SEQ_LEN = 32 +_BATCH_SIZE = 16 +_EVAL_NUM_BATCHES = 10 + + +def _make_patterned_batch(batch_size, vocab_size, seq_len): + """Creates a batch where each row is a cyclic shift of [0, 1, ..., V-1]. + + Row i = [(i % V), (i+1 % V), ..., (i+seq_len-1 % V)]. + This gives the model a simple and learnable pattern. + + Args: + batch_size: Number of sequences in the batch. + vocab_size: Size of the vocabulary. + seq_len: Length of each sequence. + + Returns: + A dict with 'inputs', 'targets', and 'weights'. + """ + rows = [] + for i in range(batch_size): + row = [(i + j) % vocab_size for j in range(seq_len)] + rows.append(row) + tokens = jnp.array(rows, dtype=jnp.int32) + return { + 'inputs': tokens, + 'targets': tokens, # MDLM: inputs == targets. + 'weights': jnp.ones(tokens.shape), + } + + +def _get_patterned_mdlm_dataset(batch_size, eval_num_batches): + """Returns a fake MDLM dataset with a cyclic-shift pattern.""" + + def train_iterator_fn(): + while True: + batch = _make_patterned_batch(batch_size, _VOCAB_SIZE, _SEQ_LEN) + yield batch + + def eval_train_epoch(num_batches=None): + if num_batches is None: + num_batches = eval_num_batches + for _ in range(num_batches): + batch = _make_patterned_batch(batch_size, _VOCAB_SIZE, _SEQ_LEN) + yield batch + + meta_data = { + 'apply_one_hot_in_loss': False, + 'shift_inputs': False, + 'causal': False, + } + return ( + Dataset( + train_iterator_fn, + eval_train_epoch, + eval_train_epoch, + eval_train_epoch, + ), + meta_data, + ) + + +class MDLMIntegrationTest(absltest.TestCase): + """Integration test: train MDLM and verify loss decreases.""" + + def setUp(self): + super().setUp() + self.test_dir = tempfile.mkdtemp() + self.trainer = None + + def tearDown(self): + if self.trainer is not None: + self.trainer.wait_until_orbax_checkpointer_finished() + shutil.rmtree(self.test_dir) + super().tearDown() + + def test_loss_decreases_on_pattern(self): + """MDLM should learn a trivial cyclic pattern and decrease loss.""" + rng = jax.random.PRNGKey(0) + + model_str = 'mdlm_rope_nanodo' + model_cls = models.get_model(model_str) + loss_name = 'passthrough' + metrics_name = 'mdlm_metrics' + + hps = config_dict.ConfigDict({ + 'batch_size': _BATCH_SIZE, + 'emb_dim': 32, + 'num_heads': 2, + 'num_layers': 2, + 'mlp_dim': 64, + 'vocab_size': _VOCAB_SIZE, + 'input_shape': (_SEQ_LEN,), + 'output_shape': (_SEQ_LEN, _VOCAB_SIZE), + 'computation_dtype': 'float32', + 'model_dtype': 'float32', + 'normalization': 'rmsnorm', + 'mlp_activation': 'glu', + 'qk_norm': True, + 'tie_embeddings': True, + 'noise_schedule': 'log_linear', + 'optimizer': 'adam', + 'opt_hparams': { + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'weight_decay': 0.0, + }, + 'lr_hparams': { + 'base_lr': 0.003, + 'schedule': 'constant', + }, + 'l2_decay_factor': 0.0, + 'l2_decay_rank_threshold': 2, + 'grad_clip': None, + 'label_smoothing': 0.0, + 'use_shallue_label_smoothing': False, + 'rng_seed': 0, + 'train_size': _BATCH_SIZE * 100, + 'num_device_prefetches': 0, + 'epsilon': 1e-9, + }) + + dataset, dataset_meta_data = _get_patterned_mdlm_dataset( + _BATCH_SIZE, _EVAL_NUM_BATCHES + ) + model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) + initializer = initializers.get_initializer('noop') + + num_train_steps = 1200 + eval_frequency = 200 + + metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) + self.trainer = trainer.Trainer( + train_dir=self.test_dir, + model=model, + dataset_builder=lambda *unused_args, **unused_kwargs: dataset, + initializer=initializer, + num_train_steps=num_train_steps, + hps=hps, + rng=rng, + eval_batch_size=_BATCH_SIZE, + eval_use_ema=False, + eval_num_batches=_EVAL_NUM_BATCHES, + test_num_batches=0, + eval_train_num_batches=_EVAL_NUM_BATCHES, + eval_frequency=eval_frequency, + checkpoint_steps=[], + metrics_logger=metrics_logger, + init_logger=init_logger, + ) + _ = list(self.trainer.train()) + + # ---- Check loss trajectory ---- + with tf.io.gfile.GFile( + os.path.join(self.test_dir, 'measurements.csv') + ) as f: + df = pandas.read_csv(f) + train_cost = df['train_cost'].values + self.assertGreater( + train_cost[0], + train_cost[-1], + msg=( + 'Expected loss to decrease. ' + f'Initial: {train_cost[0]:.4f}, Final: {train_cost[-1]:.4f}' + ), + ) + self.assertLess( + train_cost[-1], + 0.5, + msg=( + 'Expected final loss well below random baseline. ' + f'Final: {train_cost[-1]:.4f}' + ), + ) + + valid_ce = df['valid/ce_loss'].values + valid_ppl = df['valid/perplexity'].values + self.assertTrue( + all(np.isfinite(valid_ce)), + msg=f'valid/ce_loss contains non-finite: {valid_ce}', + ) + self.assertTrue( + all(np.isfinite(valid_ppl)), + msg=f'valid/perplexity contains non-finite: {valid_ppl}', + ) + self.assertLess( + valid_ce[-1], + valid_ce[0], + msg=( + 'Expected valid/ce_loss to decrease. ' + f'Initial: {valid_ce[0]:.4f}, Final: {valid_ce[-1]:.4f}' + ), + ) + self.assertGreater( + valid_ppl[0], + valid_ppl[-1], + msg=( + 'Expected valid/perplexity to decrease. ' + f'Initial: {valid_ppl[0]:.4f}, Final: {valid_ppl[-1]:.4f}' + ), + ) + + # ---- Verify evaluate_batch ---- + params = self.trainer.get_params() + batch = _make_patterned_batch(_BATCH_SIZE, _VOCAB_SIZE, _SEQ_LEN) + batch['eval_rng'] = jax.random.PRNGKey(42) + eval_metrics = model.evaluate_batch(params, batch_stats=None, batch=batch) + eval_results = eval_metrics.compute() + self.assertTrue(np.isfinite(eval_results['ce_loss'])) + self.assertTrue(np.isfinite(eval_results['perplexity'])) + logging.info( + 'Direct evaluate_batch: ce_loss=%.4f, perplexity=%.4f', + eval_results['ce_loss'], + eval_results['perplexity'], + ) + +if __name__ == '__main__': + absltest.main() diff --git a/init2winit/trainer_lib/trainer.py b/init2winit/trainer_lib/trainer.py index 260bbf74..c35c39b1 100644 --- a/init2winit/trainer_lib/trainer.py +++ b/init2winit/trainer_lib/trainer.py @@ -138,3 +138,7 @@ def finalize_batch_fn(self, batch): """Finalize the batch by making a global array out of the shards.""" return trainer_utils.make_finalize_batch_fn(self._mesh)(batch) + + def get_params(self): + """Returns the model parameters.""" + return self._params diff --git a/init2winit/trainer_lib/trainer_utils.py b/init2winit/trainer_lib/trainer_utils.py index 705ae2aa..2a966e17 100644 --- a/init2winit/trainer_lib/trainer_utils.py +++ b/init2winit/trainer_lib/trainer_utils.py @@ -147,6 +147,7 @@ def evaluate( batch_iter, evaluate_batch_jitted, finalize_batch_fn, + eval_rng=None, ): """Compute aggregated metrics on the given data iterator. @@ -173,6 +174,8 @@ def evaluate( finalize_batch_fn: Function to finalize the batch before passing to evaluate_batch_jitted. For sharding or reshaping if necessary. Can be a no-op otherwise. + eval_rng: Optional JAX PRNG key. When provided, a unique sub-key is injected + into each batch as batch['eval_rng']. Returns: A dictionary of aggregated metrics. The keys will match the keys returned by @@ -180,8 +183,10 @@ def evaluate( """ metrics = None - for batch in batch_iter: + for batch_idx, batch in enumerate(batch_iter): batch = finalize_batch_fn(batch) + if eval_rng is not None: + batch['eval_rng'] = jax.random.fold_in(eval_rng, batch_idx) # Returns a clu.metrics.Collection object. We assume that # `evaluate_batch_jitted` calls CLU's `single_from_model_outputs`. computed_metrics = evaluate_batch_jitted( @@ -226,6 +231,7 @@ def eval_metrics( eval_train_num_batches, evaluate_batch_jitted, finalize_batch_fn=None, + eval_rng=None, ): """Evaluates the given network on the train, validation, and test sets. @@ -252,6 +258,8 @@ def eval_metrics( evaluate_batch_jitted: Computes the metrics on a sharded batch. finalize_batch_fn: Function to finalize the batch before passing to evaluate_batch_jitted. For sharding or reshaping. + eval_rng: Optional JAX PRNG key for stochastic evaluation. Used for the + masked diffusion language model. Returns: A dictionary of all computed metrics. @@ -261,15 +269,20 @@ def eval_metrics( test_iter = dataset.test_epoch(test_num_batches) metrics = {} - for split_iter, split_name in zip([train_iter, valid_iter, test_iter], - ['train', 'valid', 'test']): + for split_idx, (split_iter, split_name) in enumerate( + zip([train_iter, valid_iter, test_iter], ['train', 'valid', 'test']) + ): logging.info('Evaluating split: %s', split_name) + split_rng = None + if eval_rng is not None: + split_rng = jax.random.fold_in(eval_rng, split_idx) split_metrics = evaluate( params, batch_stats, split_iter, evaluate_batch_jitted, finalize_batch_fn, + eval_rng=split_rng, ) # Metrics are None if the dataset doesn't have that split if split_metrics is not None: