From 6b305a0df3cbcaecd68bffa19bbf015dc2f18f2c Mon Sep 17 00:00:00 2001 From: Evelyn Yen Date: Fri, 20 Feb 2026 11:25:47 -0500 Subject: [PATCH 1/4] add recompute_ln=True --- tests/acceptance/test_activation_cache.py | 4 +-- transformer_lens/ActivationCache.py | 38 ++++++++++++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py index cfb4d39af..bfad22086 100644 --- a/tests/acceptance/test_activation_cache.py +++ b/tests/acceptance/test_activation_cache.py @@ -242,10 +242,10 @@ def test_accumulated_resid_with_apply_ln(): # Run the model and cache all activations _, cache = model.run_with_cache(tokens) - # Get accumulated resid and apply ln seperately (cribbed notebook code) + # Get accumulated resid and apply ln with recompute_ln (standard logit lens) accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) ref_scaled_residual_stack = cache.apply_ln_to_stack( - accumulated_residual, layer=-1, pos_slice=-1 + accumulated_residual, layer=-1, pos_slice=-1, recompute_ln=True ) # Get scaled_residual_stack using apply_ln parameter diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index f76549ebe..4aba4ddbc 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -443,8 +443,14 @@ def accumulated_resid( components_list = [pos_slice.apply(c, dim=-2) for c in components_list] components = torch.stack(components_list, dim=0) if apply_ln: + # Use recomputed layer norm per component for standard logit lens + recompute_ln = layer == self.model.cfg.n_layers components = self.apply_ln_to_stack( - components, layer, pos_slice=pos_slice, mlp_input=mlp_input + components, + layer, + pos_slice=pos_slice, + mlp_input=mlp_input, + recompute_ln=recompute_ln, ) if return_labels: return components, labels @@ -952,6 +958,7 @@ def apply_ln_to_stack( pos_slice: Union[Slice, SliceInput] = None, batch_slice: Union[Slice, SliceInput] = None, has_batch_dim: bool = True, + recompute_ln: bool = False, ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: """Apply Layer Norm to a Stack. @@ -964,6 +971,11 @@ def apply_ln_to_stack( element and position, which is why we need to use the cached scale factors rather than just applying a new LayerNorm. + When ``recompute_ln=True`` and the target layer is the final layer (unembed), each + component is normalized using statistics recomputed from that component (standard logit + lens). This gives correct intermediate "belief" distributions; use this for logit lens + analysis. When ``recompute_ln=False``, a single cached scale is used for all components. + If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. Args: @@ -986,6 +998,10 @@ def apply_ln_to_stack( The slice to take on the batch dimension. Defaults to None, do nothing. has_batch_dim: Whether residual_stack has a batch dimension. + recompute_ln: + If True and target layer is the unembed (final layer), apply the final layer norm + to each component with statistics recomputed from that component (standard logit + lens). Defaults to False (use single cached scale for all components). """ if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: @@ -1004,6 +1020,26 @@ def apply_ln_to_stack( # Apply batch slice to the stack residual_stack = batch_slice.apply(residual_stack, dim=1) + # Logit lens: apply final layer norm to each component with recomputed statistics + if ( + recompute_ln + and layer == self.model.cfg.n_layers + and hasattr(self.model, "ln_final") + ): + ln_final = self.model.ln_final + had_pos_dim = residual_stack.ndim == 4 + results = [] + for i in range(residual_stack.shape[0]): + x = residual_stack[i] + # ln_final expects (batch, pos, d_model); ensure pos dim present + if x.ndim == 2: + x = x.unsqueeze(1) + out = ln_final(x) + if not had_pos_dim: + out = out.squeeze(1) + results.append(out) + return torch.stack(results, dim=0) + # Center the stack onlny if the model uses LayerNorm if self.model.cfg.normalization_type in ["LN", "LNPre"]: residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) From a81b245e26bc9d0c2166e8ee288e2ec9f869c1b5 Mon Sep 17 00:00:00 2001 From: Evelyn Yen Date: Fri, 20 Feb 2026 11:35:12 -0500 Subject: [PATCH 2/4] add test --- tests/acceptance/test_activation_cache.py | 22 +++++++++++++++++++++- transformer_lens/ActivationCache.py | 11 ++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py index bfad22086..d8ba7630b 100644 --- a/tests/acceptance/test_activation_cache.py +++ b/tests/acceptance/test_activation_cache.py @@ -242,7 +242,7 @@ def test_accumulated_resid_with_apply_ln(): # Run the model and cache all activations _, cache = model.run_with_cache(tokens) - # Get accumulated resid and apply ln with recompute_ln (standard logit lens) + # Get accumulated resid and apply ln seperately (cribbed notebook code) accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) ref_scaled_residual_stack = cache.apply_ln_to_stack( accumulated_residual, layer=-1, pos_slice=-1, recompute_ln=True @@ -271,6 +271,26 @@ def test_accumulated_resid_with_apply_ln(): assert labels == expected_labels +@torch.no_grad +def test_apply_ln_recompute_ln_differs_from_cached(): + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + accumulated = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) + with_recompute = cache.apply_ln_to_stack( + accumulated, layer=-1, pos_slice=-1, recompute_ln=True + ) + with_cached = cache.apply_ln_to_stack( + accumulated, layer=-1, pos_slice=-1, recompute_ln=False + ) + + assert with_recompute.shape == with_cached.shape + assert not torch.isclose(with_recompute, with_cached, atol=1e-7).all(), ( + "recompute_ln=True and recompute_ln=False should differ for accumulated residual stack" + ) + + @torch.no_grad def test_decompose_resid_with_apply_ln(): # Load solu-2l diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 4aba4ddbc..e76c8b331 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -443,7 +443,6 @@ def accumulated_resid( components_list = [pos_slice.apply(c, dim=-2) for c in components_list] components = torch.stack(components_list, dim=0) if apply_ln: - # Use recomputed layer norm per component for standard logit lens recompute_ln = layer == self.model.cfg.n_layers components = self.apply_ln_to_stack( components, @@ -971,10 +970,9 @@ def apply_ln_to_stack( element and position, which is why we need to use the cached scale factors rather than just applying a new LayerNorm. - When ``recompute_ln=True`` and the target layer is the final layer (unembed), each - component is normalized using statistics recomputed from that component (standard logit - lens). This gives correct intermediate "belief" distributions; use this for logit lens - analysis. When ``recompute_ln=False``, a single cached scale is used for all components. + When recompute_ln=True and the target layer is the final layer (unembed), each + component is normalized using stats recomputed from that component; use this for logit lens + analysis. When recompute_ln=False, a single cached scale is used for all components. If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. @@ -1000,8 +998,7 @@ def apply_ln_to_stack( Whether residual_stack has a batch dimension. recompute_ln: If True and target layer is the unembed (final layer), apply the final layer norm - to each component with statistics recomputed from that component (standard logit - lens). Defaults to False (use single cached scale for all components). + to each component with statistics recomputed from that component. Defaults to False. """ if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: From 0c41a543b5fcbff6a660b46286da5cd83afc0e42 Mon Sep 17 00:00:00 2001 From: Evelyn Yen Date: Fri, 20 Feb 2026 13:15:01 -0500 Subject: [PATCH 3/4] fix formatting --- tests/acceptance/test_activation_cache.py | 14 +++++--------- transformer_lens/ActivationCache.py | 12 +++--------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py index d8ba7630b..9af5a96db 100644 --- a/tests/acceptance/test_activation_cache.py +++ b/tests/acceptance/test_activation_cache.py @@ -278,17 +278,13 @@ def test_apply_ln_recompute_ln_differs_from_cached(): _, cache = model.run_with_cache(tokens) accumulated = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) - with_recompute = cache.apply_ln_to_stack( - accumulated, layer=-1, pos_slice=-1, recompute_ln=True - ) - with_cached = cache.apply_ln_to_stack( - accumulated, layer=-1, pos_slice=-1, recompute_ln=False - ) + with_recompute = cache.apply_ln_to_stack(accumulated, layer=-1, pos_slice=-1, recompute_ln=True) + with_cached = cache.apply_ln_to_stack(accumulated, layer=-1, pos_slice=-1, recompute_ln=False) assert with_recompute.shape == with_cached.shape - assert not torch.isclose(with_recompute, with_cached, atol=1e-7).all(), ( - "recompute_ln=True and recompute_ln=False should differ for accumulated residual stack" - ) + assert not torch.isclose( + with_recompute, with_cached, atol=1e-7 + ).all(), "recompute_ln=True and recompute_ln=False should differ for accumulated residual stack" @torch.no_grad diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index e76c8b331..b880ccd72 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -553,11 +553,9 @@ def logit_attrs( incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check) if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape: - raise ValueError( - f"tokens and incorrect_tokens must have the same shape! \ + raise ValueError(f"tokens and incorrect_tokens must have the same shape! \ (tokens.shape={tokens_for_shape_check.shape}, \ - incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})" - ) + incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})") # If incorrect_tokens was provided, take the logit difference logit_directions = logit_directions - self.model.tokens_to_residual_directions( @@ -1018,11 +1016,7 @@ def apply_ln_to_stack( residual_stack = batch_slice.apply(residual_stack, dim=1) # Logit lens: apply final layer norm to each component with recomputed statistics - if ( - recompute_ln - and layer == self.model.cfg.n_layers - and hasattr(self.model, "ln_final") - ): + if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"): ln_final = self.model.ln_final had_pos_dim = residual_stack.ndim == 4 results = [] From 2e7ce0500575829b99c3335514d7b059495fb095 Mon Sep 17 00:00:00 2001 From: Evelyn Yen Date: Fri, 20 Feb 2026 14:45:50 -0500 Subject: [PATCH 4/4] change formatting --- transformer_lens/ActivationCache.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index b880ccd72..b62282a52 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -553,9 +553,11 @@ def logit_attrs( incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check) if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape: - raise ValueError(f"tokens and incorrect_tokens must have the same shape! \ + raise ValueError( + f"tokens and incorrect_tokens must have the same shape! \ (tokens.shape={tokens_for_shape_check.shape}, \ - incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})") + incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})" + ) # If incorrect_tokens was provided, take the logit difference logit_directions = logit_directions - self.model.tokens_to_residual_directions(