Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion tests/acceptance/test_activation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_accumulated_resid_with_apply_ln():
# Get accumulated resid and apply ln seperately (cribbed notebook code)
accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
ref_scaled_residual_stack = cache.apply_ln_to_stack(
accumulated_residual, layer=-1, pos_slice=-1
accumulated_residual, layer=-1, pos_slice=-1, recompute_ln=True
)

# Get scaled_residual_stack using apply_ln parameter
Expand All @@ -271,6 +271,22 @@ def test_accumulated_resid_with_apply_ln():
assert labels == expected_labels


@torch.no_grad
def test_apply_ln_recompute_ln_differs_from_cached():
model = load_model("solu-2l")
tokens, _ = get_ioi_tokens_and_answer_tokens(model)
_, cache = model.run_with_cache(tokens)

accumulated = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
with_recompute = cache.apply_ln_to_stack(accumulated, layer=-1, pos_slice=-1, recompute_ln=True)
with_cached = cache.apply_ln_to_stack(accumulated, layer=-1, pos_slice=-1, recompute_ln=False)

assert with_recompute.shape == with_cached.shape
assert not torch.isclose(
with_recompute, with_cached, atol=1e-7
).all(), "recompute_ln=True and recompute_ln=False should differ for accumulated residual stack"


@torch.no_grad
def test_decompose_resid_with_apply_ln():
# Load solu-2l
Expand Down
31 changes: 30 additions & 1 deletion transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,13 @@ def accumulated_resid(
components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
components = torch.stack(components_list, dim=0)
if apply_ln:
recompute_ln = layer == self.model.cfg.n_layers
components = self.apply_ln_to_stack(
components, layer, pos_slice=pos_slice, mlp_input=mlp_input
components,
layer,
pos_slice=pos_slice,
mlp_input=mlp_input,
recompute_ln=recompute_ln,
)
if return_labels:
return components, labels
Expand Down Expand Up @@ -952,6 +957,7 @@ def apply_ln_to_stack(
pos_slice: Union[Slice, SliceInput] = None,
batch_slice: Union[Slice, SliceInput] = None,
has_batch_dim: bool = True,
recompute_ln: bool = False,
) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
"""Apply Layer Norm to a Stack.

Expand All @@ -964,6 +970,10 @@ def apply_ln_to_stack(
element and position, which is why we need to use the cached scale factors rather than just
applying a new LayerNorm.

When recompute_ln=True and the target layer is the final layer (unembed), each
component is normalized using stats recomputed from that component; use this for logit lens
analysis. When recompute_ln=False, a single cached scale is used for all components.

If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged.

Args:
Expand All @@ -986,6 +996,9 @@ def apply_ln_to_stack(
The slice to take on the batch dimension. Defaults to None, do nothing.
has_batch_dim:
Whether residual_stack has a batch dimension.
recompute_ln:
If True and target layer is the unembed (final layer), apply the final layer norm
to each component with statistics recomputed from that component. Defaults to False.

"""
if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]:
Expand All @@ -1004,6 +1017,22 @@ def apply_ln_to_stack(
# Apply batch slice to the stack
residual_stack = batch_slice.apply(residual_stack, dim=1)

# Logit lens: apply final layer norm to each component with recomputed statistics
if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"):
ln_final = self.model.ln_final
had_pos_dim = residual_stack.ndim == 4
results = []
for i in range(residual_stack.shape[0]):
x = residual_stack[i]
# ln_final expects (batch, pos, d_model); ensure pos dim present
if x.ndim == 2:
x = x.unsqueeze(1)
out = ln_final(x)
if not had_pos_dim:
out = out.squeeze(1)
results.append(out)
return torch.stack(results, dim=0)

# Center the stack onlny if the model uses LayerNorm
if self.model.cfg.normalization_type in ["LN", "LNPre"]:
residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)
Expand Down
Loading