diff --git a/clu/parameter_overview.py b/clu/parameter_overview.py index 13add0b..17fa473 100644 --- a/clu/parameter_overview.py +++ b/clu/parameter_overview.py @@ -52,12 +52,21 @@ class _ParamRowWithStatsAndSharding(_ParamRowWithStats): sharding: tuple[int | None, ...] | str +def _upcast(x): + """Upcast low-precision floats to float32 for numerically stable stats.""" + if hasattr(x, "dtype") and jnp.issubdtype(x.dtype, jnp.floating) and x.dtype != jnp.float32: + return x.astype(jnp.float32) + return x + + @jax.jit def _mean_std_jit(x): + x = jax.tree_util.tree_map(_upcast, x) return jax.tree_util.tree_map(jnp.mean, x), jax.tree_util.tree_map(jnp.std, x) def _mean_std(x): + x = jax.tree_util.tree_map(_upcast, x) mean = jax.tree_util.tree_map(lambda x: x.mean(), x) std = jax.tree_util.tree_map(lambda x: x.std(), x) return mean, std