diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py index cfb4d39af..9af5a96db 100644 --- a/tests/acceptance/test_activation_cache.py +++ b/tests/acceptance/test_activation_cache.py @@ -245,7 +245,7 @@ def test_accumulated_resid_with_apply_ln(): # Get accumulated resid and apply ln seperately (cribbed notebook code) accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) ref_scaled_residual_stack = cache.apply_ln_to_stack( - accumulated_residual, layer=-1, pos_slice=-1 + accumulated_residual, layer=-1, pos_slice=-1, recompute_ln=True ) # Get scaled_residual_stack using apply_ln parameter @@ -271,6 +271,22 @@ def test_accumulated_resid_with_apply_ln(): assert labels == expected_labels +@torch.no_grad +def test_apply_ln_recompute_ln_differs_from_cached(): + model = load_model("solu-2l") + tokens, _ = get_ioi_tokens_and_answer_tokens(model) + _, cache = model.run_with_cache(tokens) + + accumulated = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) + with_recompute = cache.apply_ln_to_stack(accumulated, layer=-1, pos_slice=-1, recompute_ln=True) + with_cached = cache.apply_ln_to_stack(accumulated, layer=-1, pos_slice=-1, recompute_ln=False) + + assert with_recompute.shape == with_cached.shape + assert not torch.isclose( + with_recompute, with_cached, atol=1e-7 + ).all(), "recompute_ln=True and recompute_ln=False should differ for accumulated residual stack" + + @torch.no_grad def test_decompose_resid_with_apply_ln(): # Load solu-2l diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index f76549ebe..b62282a52 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -443,8 +443,13 @@ def accumulated_resid( components_list = [pos_slice.apply(c, dim=-2) for c in components_list] components = torch.stack(components_list, dim=0) if apply_ln: + recompute_ln = layer == self.model.cfg.n_layers components = self.apply_ln_to_stack( - components, layer, pos_slice=pos_slice, mlp_input=mlp_input + components, + layer, + pos_slice=pos_slice, + mlp_input=mlp_input, + recompute_ln=recompute_ln, ) if return_labels: return components, labels @@ -952,6 +957,7 @@ def apply_ln_to_stack( pos_slice: Union[Slice, SliceInput] = None, batch_slice: Union[Slice, SliceInput] = None, has_batch_dim: bool = True, + recompute_ln: bool = False, ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: """Apply Layer Norm to a Stack. @@ -964,6 +970,10 @@ def apply_ln_to_stack( element and position, which is why we need to use the cached scale factors rather than just applying a new LayerNorm. + When recompute_ln=True and the target layer is the final layer (unembed), each + component is normalized using stats recomputed from that component; use this for logit lens + analysis. When recompute_ln=False, a single cached scale is used for all components. + If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. Args: @@ -986,6 +996,9 @@ def apply_ln_to_stack( The slice to take on the batch dimension. Defaults to None, do nothing. has_batch_dim: Whether residual_stack has a batch dimension. + recompute_ln: + If True and target layer is the unembed (final layer), apply the final layer norm + to each component with statistics recomputed from that component. Defaults to False. """ if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: @@ -1004,6 +1017,22 @@ def apply_ln_to_stack( # Apply batch slice to the stack residual_stack = batch_slice.apply(residual_stack, dim=1) + # Logit lens: apply final layer norm to each component with recomputed statistics + if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"): + ln_final = self.model.ln_final + had_pos_dim = residual_stack.ndim == 4 + results = [] + for i in range(residual_stack.shape[0]): + x = residual_stack[i] + # ln_final expects (batch, pos, d_model); ensure pos dim present + if x.ndim == 2: + x = x.unsqueeze(1) + out = ln_final(x) + if not had_pos_dim: + out = out.squeeze(1) + results.append(out) + return torch.stack(results, dim=0) + # Center the stack onlny if the model uses LayerNorm if self.model.cfg.normalization_type in ["LN", "LNPre"]: residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)