From 16a292aec3749f51bf5be772a9e040f718ff4071 Mon Sep 17 00:00:00 2001 From: Vikas Ummadisetty Date: Mon, 23 Mar 2026 21:52:19 -0700 Subject: [PATCH] calc param stat in f32 --- clu/parameter_overview.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/clu/parameter_overview.py b/clu/parameter_overview.py index 13add0b..17fa473 100644 --- a/clu/parameter_overview.py +++ b/clu/parameter_overview.py @@ -52,12 +52,21 @@ class _ParamRowWithStatsAndSharding(_ParamRowWithStats): sharding: tuple[int | None, ...] | str +def _upcast(x): + """Upcast low-precision floats to float32 for numerically stable stats.""" + if hasattr(x, "dtype") and jnp.issubdtype(x.dtype, jnp.floating) and x.dtype != jnp.float32: + return x.astype(jnp.float32) + return x + + @jax.jit def _mean_std_jit(x): + x = jax.tree_util.tree_map(_upcast, x) return jax.tree_util.tree_map(jnp.mean, x), jax.tree_util.tree_map(jnp.std, x) def _mean_std(x): + x = jax.tree_util.tree_map(_upcast, x) mean = jax.tree_util.tree_map(lambda x: x.mean(), x) std = jax.tree_util.tree_map(lambda x: x.std(), x) return mean, std