Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions delphi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class RunConfig(Serializable):
directory containing their weights. Models must be loadable with sparsify
or gemmascope."""

random: bool = False
"""Whether to initialize the sparse models with random weights."""

hookpoints: list[str] = list_field()
"""list of model hookpoints to attach sparse models to."""

Expand Down
37 changes: 33 additions & 4 deletions delphi/sparse_coders/load_sparsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def load_sparsify_sparse_coders(
name: str,
hookpoints: list[str],
device: str | torch.device,
random: bool = False,
compile: bool = False,
) -> dict[str, PotentiallyWrappedSparseCoder]:
"""
Expand All @@ -88,18 +89,45 @@ def load_sparsify_sparse_coders(
name_path = Path(name)
if name_path.exists():
for hookpoint in hookpoints:
sparse_model_dict[hookpoint] = SparseCoder.load_from_disk(
name_path / hookpoint, device=device
sparse_model = SparseCoder.load_from_disk(
name_path / hookpoint, device="cpu"
)
# if random, initialize a new sparse model with random weights
if random:
config = sparse_model.cfg
d_in = sparse_model.d_in
dtype = sparse_model.dtype
sparse_model = SparseCoder(
d_in,
config,
device=device,
dtype=dtype,
decoder=False,
)

sparse_model_dict[hookpoint] = sparse_model
if compile:
sparse_model_dict[hookpoint] = torch.compile(
sparse_model_dict[hookpoint]
)
else:
# Load on CPU first to not run out of memory
sparse_models = SparseCoder.load_many(name, device="cpu")

for hookpoint in hookpoints:
sparse_model_dict[hookpoint] = sparse_models[hookpoint].to(device)
sparse_model = sparse_models[hookpoint]
if random:
config = sparse_model.cfg
d_in = sparse_model.d_in
dtype = sparse_model.dtype
sparse_model = SparseCoder(
d_in,
config,
device=device,
dtype=dtype,
decoder=False,
)
sparse_model_dict[hookpoint] = sparse_model.to(device)
if compile:
sparse_model_dict[hookpoint] = torch.compile(
sparse_model_dict[hookpoint]
Expand All @@ -113,6 +141,7 @@ def load_sparsify_hooks(
model: PreTrainedModel,
name: str,
hookpoints: list[str],
random: bool = False,
device: str | torch.device | None = None,
compile: bool = False,
) -> tuple[dict[str, Callable], bool]:
Expand All @@ -136,6 +165,7 @@ def load_sparsify_hooks(
name,
hookpoints,
device,
random,
compile,
)
hookpoint_to_sparse_encode = {}
Expand All @@ -145,7 +175,6 @@ def load_sparsify_hooks(
path_segments = resolve_path(model, hookpoint.split("."))
if path_segments is None:
raise ValueError(f"Could not find valid path for hookpoint: {hookpoint}")

hookpoint_to_sparse_encode[".".join(path_segments)] = partial(
sae_dense_latents, sae=sparse_model
)
Expand Down
4 changes: 3 additions & 1 deletion delphi/sparse_coders/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def load_hooks_sparse_coders(
model,
run_cfg.sparse_model,
run_cfg.hookpoints,
random=run_cfg.random,
compile=compile,
)
else:
Expand Down Expand Up @@ -96,7 +97,8 @@ def load_sparse_coders(
run_cfg.sparse_model,
run_cfg.hookpoints,
device,
compile,
random=run_cfg.random,
compile=compile,
)
else:
# model path will always be of the form google/gemma-scope-<size>-pt-<type>/
Expand Down
57 changes: 55 additions & 2 deletions tests/test_autoencoders/test_sparse_coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from delphi.config import RunConfig

# Import the function to be tested
from delphi.sparse_coders import load_hooks_sparse_coders
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders


# A simple dummy run configuration for testing.
class DummyRunConfig:
def __init__(self, sparse_model, hookpoints):
def __init__(self, sparse_model, hookpoints, random=False):
self.sparse_model = sparse_model
self.hookpoints = hookpoints
self.random = random
# Additional required fields can be added here if needed.
self.model = "dummy_model"
self.hf_token = ""
Expand Down Expand Up @@ -62,6 +63,7 @@ def run_cfg_sparsify():
return DummyRunConfig(
sparse_model="EleutherAI/sae-pythia-70m-32k",
hookpoints=["layers.4.mlp", "layers.0"],
random=False,
)


Expand All @@ -75,6 +77,7 @@ def run_cfg_gemma():
"layer_12/width_131k/average_l0_67",
"layer_12/width_16k/average_l0_22",
],
random=False,
)


Expand Down Expand Up @@ -127,3 +130,53 @@ def test_retrieve_autoencoders_from_gemma(dummy_model, run_cfg_gemma):
f"Autoencoder '{key}' from the Gemma branch failed when called:"
f"\n{repr(e)}"
)


def test_load_sparse_coders_forwards_random_and_compile(monkeypatch):
"""Ensure random and compile flags are forwarded for the sparsify path."""
captured: dict[str, object] = {}

def fake_loader(name, hookpoints, device, random=False, compile=False):
captured["name"] = name
captured["hookpoints"] = hookpoints
captured["device"] = device
captured["random"] = random
captured["compile"] = compile
return {"layers.0": object()}

monkeypatch.setattr(
"delphi.sparse_coders.sparse_model.load_sparsify_sparse_coders",
fake_loader,
)

cfg = DummyRunConfig(
sparse_model="EleutherAI/sae-pythia-70m-32k",
hookpoints=["layers.0"],
random=True,
)

result = load_sparse_coders(cfg, device="cpu", compile=True)

assert isinstance(result, dict)
assert captured == {
"name": "EleutherAI/sae-pythia-70m-32k",
"hookpoints": ["layers.0"],
"device": "cpu",
"random": True,
"compile": True,
}


def test_load_sparse_coders_requires_random_field():
"""The run config must explicitly provide a random field."""

class MissingRandomRunConfig:
sparse_model = "EleutherAI/sae-pythia-70m-32k"
hookpoints = ["layers.0"]

@property
def __class__(self) -> type: # type: ignore
return RunConfig

with pytest.raises(AttributeError):
load_sparse_coders(MissingRandomRunConfig(), device="cpu")
Loading