Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
3091ac5
Factored processes -> Dev (#131)
casperlchristensen Dec 9, 2025
2dd3e8c
Fix resolve base config bug (#134)
ealt Dec 10, 2025
21efdbb
Update github workflows for dev branch (#133)
ealt Dec 11, 2025
9514ce5
Activation visualizations -> dev (#132)
casperlchristensen Dec 16, 2025
e6c2a6c
Add subspace orthogonality analysis for factored processes (#136)
loren-ac Dec 16, 2025
44f53ee
Expose ability to compute subspace orthogonality in LinearRegressionA…
ealt Dec 17, 2025
6681393
Add CONTRIBUTING.md with PR requirements for dev and main (#138)
loren-ac Dec 17, 2025
7a62e72
Fix/dropdown slider interaction (#143)
casperlchristensen Dec 17, 2025
919adee
Automatically save log files at the end of managed runs (#142)
ealt Dec 17, 2025
0b1b6a2
Add simplexity-multirun CLI for parallel experiment execution (#144)
adamimos Dec 19, 2025
67ff8b0
save more path-specific visualizations (#145)
casperlchristensen Dec 19, 2025
9549554
Casper/generic resolution (#161)
casperlchristensen Jan 7, 2026
861bfbb
Improve metric naming for length and readability (#153)
loren-ac Jan 7, 2026
fb28491
reduce number of metrics returned from variance analysis (#162)
casperlchristensen Jan 7, 2026
4050f50
return targets (#163)
casperlchristensen Jan 7, 2026
1c06eaa
Extend format_layer_spec to handle all TransformerLens layer patterns…
loren-ac Jan 13, 2026
83a9528
option for no deduplication (#166)
casperlchristensen Jan 13, 2026
d4f6021
Add IndependentFactoredGenerativeProcess for frozen factor support (#…
loren-ac Jan 15, 2026
c982f44
noises process option (#165)
casperlchristensen Jan 16, 2026
c0ed354
get rid of visualization
casperlchristensen Feb 24, 2026
e4b540e
get rid of visualization (#171)
casperlchristensen Feb 24, 2026
57f0646
Merge branch 'dev' of https://github.com/Astera-org/simplexity into dev
casperlchristensen Feb 24, 2026
71e5d78
PR feedback
casperlchristensen Feb 24, 2026
cf93513
more test coverage
casperlchristensen Feb 24, 2026
a4a8b20
test docstrings
casperlchristensen Feb 24, 2026
e78a015
ensure pytorch tests are included
casperlchristensen Feb 24, 2026
1678ab9
coverage again
casperlchristensen Feb 24, 2026
559afbd
undo
casperlchristensen Feb 24, 2026
eccb0cd
full coverage information for main
casperlchristensen Feb 24, 2026
2b74891
simplify svd validation
casperlchristensen Feb 25, 2026
e035b45
Merge branch 'main' into dev
casperlchristensen Feb 25, 2026
0a1c8fe
use only keyword
casperlchristensen Feb 25, 2026
38a3bc5
Merge branch 'dev' of https://github.com/Astera-org/simplexity into dev
casperlchristensen Feb 25, 2026
a61861c
simplify
casperlchristensen Feb 25, 2026
ae5eca6
consolidation
casperlchristensen Feb 25, 2026
1b26b3c
remove
casperlchristensen Feb 25, 2026
acfbffd
docstring simplification
casperlchristensen Feb 25, 2026
48683dd
math simplification
casperlchristensen Feb 25, 2026
ea21872
harden fully-conditional PoC normalization and indexing
ealt Feb 26, 2026
126f536
Merge remote-tracking branch 'origin/dev' into dev
ealt Feb 26, 2026
4f80885
Remove unused code
ealt Feb 27, 2026
c887e30
Merge remote-tracking branch 'origin/main' into eric/dev
ealt Mar 4, 2026
bf5c28e
Remove unnecessary reportUnnecessaryEllipsis setting from pyproject.toml
ealt Mar 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/simplexity.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ jobs:
slug: Astera-org/simplexity
verbose: true
files: ./coverage.xml
fail_ci_if_error: false
fail_ci_if_error: false
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ ignored-modules = ["simplexity"]

[tool.pyright]
typeCheckingMode = "standard"
reportUnnecessaryEllipsis = false

[tool.pytest.ini_options]
testpaths = ["tests"]
Expand Down
28 changes: 27 additions & 1 deletion simplexity/generative_processes/structures/fully_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

from typing import Literal

import equinox as eqx
import jax
import jax.numpy as jnp
Expand All @@ -29,6 +31,8 @@ class FullyConditional(eqx.Module):
perms_py: Axis permutations to align conditional distributions
vocab_sizes_py: Python int tuple of vocab sizes for shape operations
joint_vocab_size: Total vocabulary size (product of all V_i)
fallback_strategy: Strategy used when unnormalized mass is zero
fallback_epsilon: Additive smoothing for epsilon fallback
"""

control_maps: tuple[jax.Array, ...]
Expand All @@ -37,11 +41,15 @@ class FullyConditional(eqx.Module):
perms_py: tuple[tuple[int, ...], ...]
vocab_sizes_py: tuple[int, ...]
joint_vocab_size: int
fallback_strategy: Literal["uniform", "epsilon_smooth"]
fallback_epsilon: float

def __init__(
self,
control_maps: tuple[jax.Array, ...],
vocab_sizes: jax.Array,
fallback_strategy: Literal["uniform", "epsilon_smooth"] = "uniform",
fallback_epsilon: float = 1e-12,
):
"""Initialize fully conditional structure.

Expand All @@ -50,9 +58,15 @@ def __init__(
have shape [prod(V_j for j!=i)] mapping other-factor tokens
to variant index for factor i.
vocab_sizes: Array of shape [F] with vocab sizes per factor
fallback_strategy: How to recover when total unnormalized mass is zero.
- "uniform": use uniform distribution over joint vocabulary
- "epsilon_smooth": add epsilon and renormalize
fallback_epsilon: Additive smoothing value for "epsilon_smooth"
"""
self.control_maps = tuple(jnp.asarray(cm, dtype=jnp.int32) for cm in control_maps)
self.vocab_sizes_py = tuple(int(v) for v in vocab_sizes)
self.fallback_strategy = fallback_strategy
self.fallback_epsilon = float(fallback_epsilon)
num_factors = len(self.vocab_sizes_py)

if num_factors == 0:
Expand All @@ -61,6 +75,12 @@ def __init__(
raise ValueError(f"Expected {num_factors} control maps (one per factor), got {len(self.control_maps)}")
if any(v <= 0 for v in self.vocab_sizes_py):
raise ValueError(f"All vocab sizes must be positive, got {self.vocab_sizes_py}")
if self.fallback_strategy not in ("uniform", "epsilon_smooth"):
raise ValueError(
f"fallback_strategy must be one of {{'uniform', 'epsilon_smooth'}}, got '{self.fallback_strategy}'"
)
if self.fallback_strategy == "epsilon_smooth" and self.fallback_epsilon <= 0.0:
raise ValueError(f"fallback_epsilon must be positive for epsilon smoothing, got {self.fallback_epsilon}")

# Compute joint vocab size
jv = 1
Expand Down Expand Up @@ -158,7 +178,13 @@ def get_dist_i(k: jax.Array, i: int = i) -> jax.Array:
for log_p in parts[1:]:
log_joint = log_joint + log_p
log_z = jax.nn.logsumexp(log_joint)
fallback = jnp.ones_like(log_joint) / self.joint_vocab_size

if self.fallback_strategy == "uniform":
fallback = jnp.ones_like(log_joint) / self.joint_vocab_size
else:
unnormalized = jnp.exp(log_joint)
smoothed = unnormalized + self.fallback_epsilon
fallback = smoothed / jnp.sum(smoothed)

norm_j = jax.lax.cond(
jnp.isfinite(log_z),
Expand Down
Loading