Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion tests/unit/components/test_abstract_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from transformer_lens.components import AbstractAttention
from transformer_lens.components import AbstractAttention, RotaryEmbedding
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def test_create_alibi_slope():
Expand Down Expand Up @@ -38,3 +39,31 @@ def test_create_alibi_bias():
assert torch.equal(
torch.tril(matrix, diagonal=-1), torch.tril(ref_lower_triangle, diagonal=-1)
)


def test_rotary_attribute_access():
cfg = HookedTransformerConfig(
n_layers=12,
d_model=512,
n_ctx=1024,
d_head=64,
n_heads=8,
load_in_4bit=False,
dtype=torch.float32,
act_fn="relu",
rotary_dim=64,
rotary_base=10000,
rotary_adjacent_pairs=True,
)

rotary_module = RotaryEmbedding(cfg)

class DummyAttention(AbstractAttention):
def __init__(self):
super().__init__(cfg)
self.rotary_module = rotary_module

attention = DummyAttention()

assert torch.equal(attention.rotary_sin, rotary_module.rotary_sin), "rotary_sin does not match!"
assert torch.equal(attention.rotary_cos, rotary_module.rotary_cos), "rotary_cos does not match!"
1 change: 1 addition & 0 deletions transformer_lens/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .grouped_query_attention import GroupedQueryAttention
from .mlps.gated_mlp import GatedMLP
from .mlps.mlp import MLP
from .rotary_embeddings import RotaryEmbedding

# Interdependent modules
from .bert_block import BertBlock
Expand Down
26 changes: 12 additions & 14 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from abc import ABC
from typing import Dict, Optional, Tuple, Union

Expand All @@ -12,11 +11,11 @@

from transformer_lens.components.rms_norm import RMSNorm
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.factories.rotary_embedding_factory import RotaryEmbeddingFactory
from transformer_lens.hook_points import HookPoint
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear
from transformer_lens.utils import get_offset_position_ids

if is_bitsandbytes_available():
import bitsandbytes as bnb
Expand Down Expand Up @@ -137,14 +136,7 @@ def __init__(
self.hook_rot_q = HookPoint()
if self.cfg.rotary_dim is None: # keep mypy happy
raise ValueError("Rotary dim must be provided for rotary positional embeddings")
sin, cos = self.calculate_sin_cos_rotary(
self.cfg.rotary_dim,
self.cfg.n_ctx,
base=self.cfg.rotary_base,
dtype=self.cfg.dtype,
)
self.register_buffer("rotary_sin", sin)
self.register_buffer("rotary_cos", cos)
self.rotary_module = RotaryEmbeddingFactory.create_rotary(self.cfg)
elif self.cfg.positional_embedding_type == "alibi":
# ALiBi bias wil be constructed on the first forward pass.
# Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
Expand All @@ -154,6 +146,14 @@ def __init__(
# will be overwritten by the child T5Attention class
self.has_relative_attention_bias = False

@property
def rotary_sin(self):
return self.rotary_module.rotary_sin

@property
def rotary_cos(self):
return self.rotary_module.rotary_cos

@property
def OV(self) -> FactoredMatrix:
"""
Expand Down Expand Up @@ -218,10 +218,8 @@ def forward(
kv_cache_pos_offset = 0

if self.cfg.positional_embedding_type == "rotary":
q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask))
k = self.hook_rot_k(
self.apply_rotary(k, 0, attention_mask)
) # keys are cached so no offset
q = self.hook_rot_q(self.rotary_module(q, kv_cache_pos_offset, attention_mask))
k = self.hook_rot_k(self.rotary_module(k, 0, attention_mask))

if self.cfg.dtype not in [torch.float32, torch.float64]:
# If using 16 bits, increase the precision to avoid numerical instabilities
Expand Down
133 changes: 133 additions & 0 deletions transformer_lens/components/rotary_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import math
from typing import Optional, Tuple, cast

import einops
import torch
import torch.nn as nn
from jaxtyping import Float, Int

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.utils import get_offset_position_ids


class RotaryEmbedding(nn.Module):
def __init__(self, cfg: HookedTransformerConfig):
super().__init__()
self.cfg: HookedTransformerConfig = cfg
rotary_dim = cast(int, self.cfg.rotary_dim)
sin, cos = self.calculate_sin_cos_rotary(
rotary_dim=rotary_dim, n_ctx=cfg.n_ctx, base=cfg.rotary_base, dtype=cfg.dtype
)
self.register_buffer("rotary_sin", sin)
self.register_buffer("rotary_cos", cos)

def calculate_sin_cos_rotary(
self,
rotary_dim: int,
n_ctx: int,
base: int = 10000,
dtype: torch.dtype = torch.float32,
) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
"""
Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details

Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
"""
high_precision = torch.float32 if dtype != torch.float64 else torch.float64
pos = torch.arange(n_ctx, dtype=high_precision)
dim = torch.arange(rotary_dim // 2, dtype=high_precision)
freq = base ** (dim / (rotary_dim / 2))
freq = einops.repeat(freq, "d -> (d 2)")
# Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
angles = pos[:, None] / freq[None, :]
return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)

def forward(
self,
x: Float[torch.Tensor, "batch pos head_index d_head"],
past_kv_pos_offset=0,
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
# Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
x_pos = x.size(1)
x_rot = x[..., : self.cfg.rotary_dim]
x_pass = x[..., self.cfg.rotary_dim :]
x_flip = self.rotate_every_two(x_rot)

if attention_mask is None:
rotary_cos = self.rotary_cos[
None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
]
rotary_sin = self.rotary_sin[
None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
]
x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
else:
offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask)
offset_position_ids = offset_position_ids.to(self.rotary_cos.device)
mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :]
x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin

return torch.cat([x_rotated, x_pass], dim=-1)

def rotate_every_two(
self, x: Float[torch.Tensor, "... rotary_dim"]
) -> Float[torch.Tensor, "... rotary_dim"]:
"""
Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]

The final axis of x must have even length.

GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
"""
rot_x = x.clone()
if self.cfg.rotary_adjacent_pairs:
rot_x[..., ::2] = -x[..., 1::2]
rot_x[..., 1::2] = x[..., ::2]
else:
n = x.size(-1) // 2
rot_x[..., :n] = -x[..., n:]
rot_x[..., n:] = x[..., :n]
return rot_x


class DynamicNTKScalingRotary(RotaryEmbedding):
def calculate_sin_cos_rotary(
self,
rotary_dim: int,
n_ctx: int,
base: int = 10000,
dtype: torch.dtype = torch.float32,
):
# Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071
# Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310
high_precision = torch.float32 if dtype != torch.float64 else torch.float64
pos = torch.arange(n_ctx, dtype=high_precision)

inv_freq = 1.0 / (
base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim)
)
factor = self.cfg.NTK_by_parts_factor
low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
old_context_len = n_ctx

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor

wavelen = 2 * math.pi / inv_freq
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
freq = 1 / inv_freq_llama
freq = einops.repeat(freq, "d -> (d 2)")
angles = pos[:, None] / freq[None, :]
return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
14 changes: 14 additions & 0 deletions transformer_lens/factories/rotary_embedding_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from transformer_lens.components.rotary_embeddings import (
DynamicNTKScalingRotary,
RotaryEmbedding,
)
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


class RotaryEmbeddingFactory:
@staticmethod
def create_rotary(cfg: HookedTransformerConfig) -> RotaryEmbedding:
if cfg.use_NTK_by_parts_rope:
return DynamicNTKScalingRotary(cfg)
else:
return RotaryEmbedding(cfg)
Loading