From fd56e84236393fa5ba8242e92fbe950ba7b597ce Mon Sep 17 00:00:00 2001 From: Shutong Li Date: Thu, 8 Jan 2026 09:41:15 -0800 Subject: [PATCH] Fix a type incompatibility issues. PiperOrigin-RevId: 853779316 --- ffn/jax/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffn/jax/train.py b/ffn/jax/train.py index d0df882..13f8126 100644 --- a/ffn/jax/train.py +++ b/ffn/jax/train.py @@ -340,7 +340,7 @@ def _get_tf_writer(writers) -> metric_writers.SummaryWriter | None: def _get_ocp_args( train_iter: DataIterator, restore: bool = True -) -> DataIterator: +) -> DataIterator | ocp.args.CheckpointArgs: if isinstance(train_iter, tf.data.Iterator): return DatasetArgs(train_iter)