Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions init2winit/trainer_lib/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def _check_early_stopping(self, report):
self._early_stopping_target_value)
return early_stopping_condition

def _eval(self, start_step, start_time, save=True):
def _eval(self, start_step, start_time, eval_rng, save=True):
"""Evaluate.

Has the side-effects of:
Expand All @@ -437,12 +437,14 @@ def _eval(self, start_step, start_time, save=True):
Args:
start_step: the training start step.
start_time: the training start time.
eval_rng: rng seed used in eval (chiefly for the MDLM workload).
save: flag to save a checkpoint to disk. defaults to True.

Returns:
A Dict[str, Any] eval report, originally created in
trainer_utils.eval_metrics.
"""

time_since_last_eval = time.time() - self._time_at_prev_eval_end

if self._eval_use_ema:
Expand All @@ -452,6 +454,8 @@ def _eval(self, start_step, start_time, save=True):
else:
eval_params = self._params

eval_rng = jax.random.fold_in(eval_rng, self._global_step)

report, eval_time = trainer_utils.eval_metrics(
eval_params,
self._batch_stats,
Expand All @@ -461,6 +465,7 @@ def _eval(self, start_step, start_time, save=True):
self._eval_train_num_batches,
self._evaluate_batch_jitted,
self.finalize_batch_fn,
eval_rng=eval_rng,
)
self._run_eval_callbacks(report)
if save:
Expand Down Expand Up @@ -618,8 +623,7 @@ def train(self):
# across hosts.
rng, init_rng = jax.random.split(self._rng)
rng = jax.random.fold_in(rng, jax.process_index())
rng, data_rng = jax.random.split(rng)
rng, callback_rng = jax.random.split(rng)
rng, data_rng, callback_rng, eval_rng = jax.random.split(rng, 4)

if jax.process_index() == 0:
logging.info('Let the training begin!')
Expand Down Expand Up @@ -705,7 +709,7 @@ def train(self):
self._global_step, self._eval_frequency, self._eval_steps
):
try:
report = self._eval(start_step, start_time)
report = self._eval(start_step, start_time, eval_rng)
except utils.TrainingDivergedError as e:
self.wait_until_orbax_checkpointer_finished()
raise utils.TrainingDivergedError(
Expand All @@ -720,7 +724,7 @@ def train(self):
# If we moved where in the loop body evals happen then we would not need
# this test.
if self._prev_eval_step != self._num_train_steps:
report = self._eval(start_step, start_time)
report = self._eval(start_step, start_time, eval_rng)
yield report
# To make sure the last checkpoint was correctly saved.
self.wait_until_orbax_checkpointer_finished()
Expand Down
263 changes: 263 additions & 0 deletions init2winit/trainer_lib/test_mdlm_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# coding=utf-8
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Integration test for MDLM training with a patterned fake dataset.

Verifies that the full training loop (model init -> training -> eval) works
end-to-end and that loss decreases on a simple repeating pattern.

"""

import os
import shutil
import tempfile

from absl import logging
from absl.testing import absltest
from init2winit import utils
from init2winit.dataset_lib import data_utils
from init2winit.init_lib import initializers
from init2winit.model_lib import models
from init2winit.trainer_lib import trainer
import jax
import jax.numpy as jnp
from ml_collections.config_dict import config_dict
import numpy as np
import pandas
import tensorflow.compat.v1 as tf

Dataset = data_utils.Dataset

# Small vocab and sequence length so the test runs quickly on CPU.
_VOCAB_SIZE = 16
_SEQ_LEN = 32
_BATCH_SIZE = 16
_EVAL_NUM_BATCHES = 10


def _make_patterned_batch(batch_size, vocab_size, seq_len):
"""Creates a batch where each row is a cyclic shift of [0, 1, ..., V-1].

Row i = [(i % V), (i+1 % V), ..., (i+seq_len-1 % V)].
This gives the model a simple and learnable pattern.

Args:
batch_size: Number of sequences in the batch.
vocab_size: Size of the vocabulary.
seq_len: Length of each sequence.

Returns:
A dict with 'inputs', 'targets', and 'weights'.
"""
rows = []
for i in range(batch_size):
row = [(i + j) % vocab_size for j in range(seq_len)]
rows.append(row)
tokens = jnp.array(rows, dtype=jnp.int32)
return {
'inputs': tokens,
'targets': tokens, # MDLM: inputs == targets.
'weights': jnp.ones(tokens.shape),
}


def _get_patterned_mdlm_dataset(batch_size, eval_num_batches):
"""Returns a fake MDLM dataset with a cyclic-shift pattern."""

def train_iterator_fn():
while True:
batch = _make_patterned_batch(batch_size, _VOCAB_SIZE, _SEQ_LEN)
yield batch

def eval_train_epoch(num_batches=None):
if num_batches is None:
num_batches = eval_num_batches
for _ in range(num_batches):
batch = _make_patterned_batch(batch_size, _VOCAB_SIZE, _SEQ_LEN)
yield batch

meta_data = {
'apply_one_hot_in_loss': False,
'shift_inputs': False,
'causal': False,
}
return (
Dataset(
train_iterator_fn,
eval_train_epoch,
eval_train_epoch,
eval_train_epoch,
),
meta_data,
)


class MDLMIntegrationTest(absltest.TestCase):
"""Integration test: train MDLM and verify loss decreases."""

def setUp(self):
super().setUp()
self.test_dir = tempfile.mkdtemp()
self.trainer = None

def tearDown(self):
if self.trainer is not None:
self.trainer.wait_until_orbax_checkpointer_finished()
shutil.rmtree(self.test_dir)
super().tearDown()

def test_loss_decreases_on_pattern(self):
"""MDLM should learn a trivial cyclic pattern and decrease loss."""
rng = jax.random.PRNGKey(0)

model_str = 'mdlm_rope_nanodo'
model_cls = models.get_model(model_str)
loss_name = 'passthrough'
metrics_name = 'mdlm_metrics'

hps = config_dict.ConfigDict({
'batch_size': _BATCH_SIZE,
'emb_dim': 32,
'num_heads': 2,
'num_layers': 2,
'mlp_dim': 64,
'vocab_size': _VOCAB_SIZE,
'input_shape': (_SEQ_LEN,),
'output_shape': (_SEQ_LEN, _VOCAB_SIZE),
'computation_dtype': 'float32',
'model_dtype': 'float32',
'normalization': 'rmsnorm',
'mlp_activation': 'glu',
'qk_norm': True,
'tie_embeddings': True,
'noise_schedule': 'log_linear',
'optimizer': 'adam',
'opt_hparams': {
'beta1': 0.9,
'beta2': 0.999,
'epsilon': 1e-8,
'weight_decay': 0.0,
},
'lr_hparams': {
'base_lr': 0.003,
'schedule': 'constant',
},
'l2_decay_factor': 0.0,
'l2_decay_rank_threshold': 2,
'grad_clip': None,
'label_smoothing': 0.0,
'use_shallue_label_smoothing': False,
'rng_seed': 0,
'train_size': _BATCH_SIZE * 100,
'num_device_prefetches': 0,
'epsilon': 1e-9,
})

dataset, dataset_meta_data = _get_patterned_mdlm_dataset(
_BATCH_SIZE, _EVAL_NUM_BATCHES
)
model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)
initializer = initializers.get_initializer('noop')

num_train_steps = 1200
eval_frequency = 200

metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
self.trainer = trainer.Trainer(
train_dir=self.test_dir,
model=model,
dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
initializer=initializer,
num_train_steps=num_train_steps,
hps=hps,
rng=rng,
eval_batch_size=_BATCH_SIZE,
eval_use_ema=False,
eval_num_batches=_EVAL_NUM_BATCHES,
test_num_batches=0,
eval_train_num_batches=_EVAL_NUM_BATCHES,
eval_frequency=eval_frequency,
checkpoint_steps=[],
metrics_logger=metrics_logger,
init_logger=init_logger,
)
_ = list(self.trainer.train())

# ---- Check loss trajectory ----
with tf.io.gfile.GFile(
os.path.join(self.test_dir, 'measurements.csv')
) as f:
df = pandas.read_csv(f)
train_cost = df['train_cost'].values
self.assertGreater(
train_cost[0],
train_cost[-1],
msg=(
'Expected loss to decrease. '
f'Initial: {train_cost[0]:.4f}, Final: {train_cost[-1]:.4f}'
),
)
self.assertLess(
train_cost[-1],
0.5,
msg=(
'Expected final loss well below random baseline. '
f'Final: {train_cost[-1]:.4f}'
),
)

valid_ce = df['valid/ce_loss'].values
valid_ppl = df['valid/perplexity'].values
self.assertTrue(
all(np.isfinite(valid_ce)),
msg=f'valid/ce_loss contains non-finite: {valid_ce}',
)
self.assertTrue(
all(np.isfinite(valid_ppl)),
msg=f'valid/perplexity contains non-finite: {valid_ppl}',
)
self.assertLess(
valid_ce[-1],
valid_ce[0],
msg=(
'Expected valid/ce_loss to decrease. '
f'Initial: {valid_ce[0]:.4f}, Final: {valid_ce[-1]:.4f}'
),
)
self.assertGreater(
valid_ppl[0],
valid_ppl[-1],
msg=(
'Expected valid/perplexity to decrease. '
f'Initial: {valid_ppl[0]:.4f}, Final: {valid_ppl[-1]:.4f}'
),
)

# ---- Verify evaluate_batch ----
params = self.trainer.get_params()
batch = _make_patterned_batch(_BATCH_SIZE, _VOCAB_SIZE, _SEQ_LEN)
batch['eval_rng'] = jax.random.PRNGKey(42)
eval_metrics = model.evaluate_batch(params, batch_stats=None, batch=batch)
eval_results = eval_metrics.compute()
self.assertTrue(np.isfinite(eval_results['ce_loss']))
self.assertTrue(np.isfinite(eval_results['perplexity']))
logging.info(
'Direct evaluate_batch: ce_loss=%.4f, perplexity=%.4f',
eval_results['ce_loss'],
eval_results['perplexity'],
)

if __name__ == '__main__':
absltest.main()
4 changes: 4 additions & 0 deletions init2winit/trainer_lib/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,7 @@ def finalize_batch_fn(self, batch):
"""Finalize the batch by making a global array out of the shards."""

return trainer_utils.make_finalize_batch_fn(self._mesh)(batch)

def get_params(self):
"""Returns the model parameters."""
return self._params
Loading
Loading