diff --git a/delphi/config.py b/delphi/config.py index de806157..bb6083f9 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -120,6 +120,9 @@ class RunConfig(Serializable): directory containing their weights. Models must be loadable with sparsify or gemmascope.""" + random: bool = False + """Whether to initialize the sparse models with random weights.""" + hookpoints: list[str] = list_field() """list of model hookpoints to attach sparse models to.""" diff --git a/delphi/sparse_coders/load_sparsify.py b/delphi/sparse_coders/load_sparsify.py index 19cc698f..cc99f281 100644 --- a/delphi/sparse_coders/load_sparsify.py +++ b/delphi/sparse_coders/load_sparsify.py @@ -65,6 +65,7 @@ def load_sparsify_sparse_coders( name: str, hookpoints: list[str], device: str | torch.device, + random: bool = False, compile: bool = False, ) -> dict[str, PotentiallyWrappedSparseCoder]: """ @@ -88,9 +89,23 @@ def load_sparsify_sparse_coders( name_path = Path(name) if name_path.exists(): for hookpoint in hookpoints: - sparse_model_dict[hookpoint] = SparseCoder.load_from_disk( - name_path / hookpoint, device=device + sparse_model = SparseCoder.load_from_disk( + name_path / hookpoint, device="cpu" ) + # if random, initialize a new sparse model with random weights + if random: + config = sparse_model.cfg + d_in = sparse_model.d_in + dtype = sparse_model.dtype + sparse_model = SparseCoder( + d_in, + config, + device=device, + dtype=dtype, + decoder=False, + ) + + sparse_model_dict[hookpoint] = sparse_model if compile: sparse_model_dict[hookpoint] = torch.compile( sparse_model_dict[hookpoint] @@ -98,8 +113,21 @@ def load_sparsify_sparse_coders( else: # Load on CPU first to not run out of memory sparse_models = SparseCoder.load_many(name, device="cpu") + for hookpoint in hookpoints: - sparse_model_dict[hookpoint] = sparse_models[hookpoint].to(device) + sparse_model = sparse_models[hookpoint] + if random: + config = sparse_model.cfg + d_in = sparse_model.d_in + dtype = sparse_model.dtype + sparse_model = SparseCoder( + d_in, + config, + device=device, + dtype=dtype, + decoder=False, + ) + sparse_model_dict[hookpoint] = sparse_model.to(device) if compile: sparse_model_dict[hookpoint] = torch.compile( sparse_model_dict[hookpoint] @@ -113,6 +141,7 @@ def load_sparsify_hooks( model: PreTrainedModel, name: str, hookpoints: list[str], + random: bool = False, device: str | torch.device | None = None, compile: bool = False, ) -> tuple[dict[str, Callable], bool]: @@ -136,6 +165,7 @@ def load_sparsify_hooks( name, hookpoints, device, + random, compile, ) hookpoint_to_sparse_encode = {} @@ -145,7 +175,6 @@ def load_sparsify_hooks( path_segments = resolve_path(model, hookpoint.split(".")) if path_segments is None: raise ValueError(f"Could not find valid path for hookpoint: {hookpoint}") - hookpoint_to_sparse_encode[".".join(path_segments)] = partial( sae_dense_latents, sae=sparse_model ) diff --git a/delphi/sparse_coders/sparse_model.py b/delphi/sparse_coders/sparse_model.py index f8df3700..356cac86 100644 --- a/delphi/sparse_coders/sparse_model.py +++ b/delphi/sparse_coders/sparse_model.py @@ -36,6 +36,7 @@ def load_hooks_sparse_coders( model, run_cfg.sparse_model, run_cfg.hookpoints, + random=run_cfg.random, compile=compile, ) else: @@ -96,7 +97,8 @@ def load_sparse_coders( run_cfg.sparse_model, run_cfg.hookpoints, device, - compile, + random=run_cfg.random, + compile=compile, ) else: # model path will always be of the form google/gemma-scope--pt-/ diff --git a/tests/test_autoencoders/test_sparse_coders.py b/tests/test_autoencoders/test_sparse_coders.py index 4eca8b49..9aa162e0 100644 --- a/tests/test_autoencoders/test_sparse_coders.py +++ b/tests/test_autoencoders/test_sparse_coders.py @@ -6,14 +6,15 @@ from delphi.config import RunConfig # Import the function to be tested -from delphi.sparse_coders import load_hooks_sparse_coders +from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders # A simple dummy run configuration for testing. class DummyRunConfig: - def __init__(self, sparse_model, hookpoints): + def __init__(self, sparse_model, hookpoints, random=False): self.sparse_model = sparse_model self.hookpoints = hookpoints + self.random = random # Additional required fields can be added here if needed. self.model = "dummy_model" self.hf_token = "" @@ -62,6 +63,7 @@ def run_cfg_sparsify(): return DummyRunConfig( sparse_model="EleutherAI/sae-pythia-70m-32k", hookpoints=["layers.4.mlp", "layers.0"], + random=False, ) @@ -75,6 +77,7 @@ def run_cfg_gemma(): "layer_12/width_131k/average_l0_67", "layer_12/width_16k/average_l0_22", ], + random=False, ) @@ -127,3 +130,53 @@ def test_retrieve_autoencoders_from_gemma(dummy_model, run_cfg_gemma): f"Autoencoder '{key}' from the Gemma branch failed when called:" f"\n{repr(e)}" ) + + +def test_load_sparse_coders_forwards_random_and_compile(monkeypatch): + """Ensure random and compile flags are forwarded for the sparsify path.""" + captured: dict[str, object] = {} + + def fake_loader(name, hookpoints, device, random=False, compile=False): + captured["name"] = name + captured["hookpoints"] = hookpoints + captured["device"] = device + captured["random"] = random + captured["compile"] = compile + return {"layers.0": object()} + + monkeypatch.setattr( + "delphi.sparse_coders.sparse_model.load_sparsify_sparse_coders", + fake_loader, + ) + + cfg = DummyRunConfig( + sparse_model="EleutherAI/sae-pythia-70m-32k", + hookpoints=["layers.0"], + random=True, + ) + + result = load_sparse_coders(cfg, device="cpu", compile=True) + + assert isinstance(result, dict) + assert captured == { + "name": "EleutherAI/sae-pythia-70m-32k", + "hookpoints": ["layers.0"], + "device": "cpu", + "random": True, + "compile": True, + } + + +def test_load_sparse_coders_requires_random_field(): + """The run config must explicitly provide a random field.""" + + class MissingRandomRunConfig: + sparse_model = "EleutherAI/sae-pythia-70m-32k" + hookpoints = ["layers.0"] + + @property + def __class__(self) -> type: # type: ignore + return RunConfig + + with pytest.raises(AttributeError): + load_sparse_coders(MissingRandomRunConfig(), device="cpu")