diff --git a/.github/workflows/simplexity.yaml b/.github/workflows/simplexity.yaml index 386d3e8b..044b99b7 100644 --- a/.github/workflows/simplexity.yaml +++ b/.github/workflows/simplexity.yaml @@ -121,4 +121,4 @@ jobs: slug: Astera-org/simplexity verbose: true files: ./coverage.xml - fail_ci_if_error: false + fail_ci_if_error: false \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3f47c89c..4d4f56f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,6 @@ ignored-modules = ["simplexity"] [tool.pyright] typeCheckingMode = "standard" -reportUnnecessaryEllipsis = false [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/simplexity/generative_processes/structures/fully_conditional.py b/simplexity/generative_processes/structures/fully_conditional.py index ece4f92e..3f8fbb8c 100644 --- a/simplexity/generative_processes/structures/fully_conditional.py +++ b/simplexity/generative_processes/structures/fully_conditional.py @@ -6,6 +6,8 @@ from __future__ import annotations +from typing import Literal + import equinox as eqx import jax import jax.numpy as jnp @@ -29,6 +31,8 @@ class FullyConditional(eqx.Module): perms_py: Axis permutations to align conditional distributions vocab_sizes_py: Python int tuple of vocab sizes for shape operations joint_vocab_size: Total vocabulary size (product of all V_i) + fallback_strategy: Strategy used when unnormalized mass is zero + fallback_epsilon: Additive smoothing for epsilon fallback """ control_maps: tuple[jax.Array, ...] @@ -37,11 +41,15 @@ class FullyConditional(eqx.Module): perms_py: tuple[tuple[int, ...], ...] vocab_sizes_py: tuple[int, ...] joint_vocab_size: int + fallback_strategy: Literal["uniform", "epsilon_smooth"] + fallback_epsilon: float def __init__( self, control_maps: tuple[jax.Array, ...], vocab_sizes: jax.Array, + fallback_strategy: Literal["uniform", "epsilon_smooth"] = "uniform", + fallback_epsilon: float = 1e-12, ): """Initialize fully conditional structure. @@ -50,9 +58,15 @@ def __init__( have shape [prod(V_j for j!=i)] mapping other-factor tokens to variant index for factor i. vocab_sizes: Array of shape [F] with vocab sizes per factor + fallback_strategy: How to recover when total unnormalized mass is zero. + - "uniform": use uniform distribution over joint vocabulary + - "epsilon_smooth": add epsilon and renormalize + fallback_epsilon: Additive smoothing value for "epsilon_smooth" """ self.control_maps = tuple(jnp.asarray(cm, dtype=jnp.int32) for cm in control_maps) self.vocab_sizes_py = tuple(int(v) for v in vocab_sizes) + self.fallback_strategy = fallback_strategy + self.fallback_epsilon = float(fallback_epsilon) num_factors = len(self.vocab_sizes_py) if num_factors == 0: @@ -61,6 +75,12 @@ def __init__( raise ValueError(f"Expected {num_factors} control maps (one per factor), got {len(self.control_maps)}") if any(v <= 0 for v in self.vocab_sizes_py): raise ValueError(f"All vocab sizes must be positive, got {self.vocab_sizes_py}") + if self.fallback_strategy not in ("uniform", "epsilon_smooth"): + raise ValueError( + f"fallback_strategy must be one of {{'uniform', 'epsilon_smooth'}}, got '{self.fallback_strategy}'" + ) + if self.fallback_strategy == "epsilon_smooth" and self.fallback_epsilon <= 0.0: + raise ValueError(f"fallback_epsilon must be positive for epsilon smoothing, got {self.fallback_epsilon}") # Compute joint vocab size jv = 1 @@ -158,7 +178,13 @@ def get_dist_i(k: jax.Array, i: int = i) -> jax.Array: for log_p in parts[1:]: log_joint = log_joint + log_p log_z = jax.nn.logsumexp(log_joint) - fallback = jnp.ones_like(log_joint) / self.joint_vocab_size + + if self.fallback_strategy == "uniform": + fallback = jnp.ones_like(log_joint) / self.joint_vocab_size + else: + unnormalized = jnp.exp(log_joint) + smoothed = unnormalized + self.fallback_epsilon + fallback = smoothed / jnp.sum(smoothed) norm_j = jax.lax.cond( jnp.isfinite(log_z),