diff --git a/tests/unit/components/test_abstract_attention.py b/tests/unit/components/test_abstract_attention.py index 7820c1690..912188169 100644 --- a/tests/unit/components/test_abstract_attention.py +++ b/tests/unit/components/test_abstract_attention.py @@ -1,6 +1,7 @@ import torch -from transformer_lens.components import AbstractAttention +from transformer_lens.components import AbstractAttention, RotaryEmbedding +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig def test_create_alibi_slope(): @@ -38,3 +39,31 @@ def test_create_alibi_bias(): assert torch.equal( torch.tril(matrix, diagonal=-1), torch.tril(ref_lower_triangle, diagonal=-1) ) + + +def test_rotary_attribute_access(): + cfg = HookedTransformerConfig( + n_layers=12, + d_model=512, + n_ctx=1024, + d_head=64, + n_heads=8, + load_in_4bit=False, + dtype=torch.float32, + act_fn="relu", + rotary_dim=64, + rotary_base=10000, + rotary_adjacent_pairs=True, + ) + + rotary_module = RotaryEmbedding(cfg) + + class DummyAttention(AbstractAttention): + def __init__(self): + super().__init__(cfg) + self.rotary_module = rotary_module + + attention = DummyAttention() + + assert torch.equal(attention.rotary_sin, rotary_module.rotary_sin), "rotary_sin does not match!" + assert torch.equal(attention.rotary_cos, rotary_module.rotary_cos), "rotary_cos does not match!" diff --git a/transformer_lens/components/__init__.py b/transformer_lens/components/__init__.py index 44c98f6e7..768884b61 100644 --- a/transformer_lens/components/__init__.py +++ b/transformer_lens/components/__init__.py @@ -24,6 +24,7 @@ from .grouped_query_attention import GroupedQueryAttention from .mlps.gated_mlp import GatedMLP from .mlps.mlp import MLP +from .rotary_embeddings import RotaryEmbedding # Interdependent modules from .bert_block import BertBlock diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 0aee43814..5374c60f4 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -1,4 +1,3 @@ -import math from abc import ABC from typing import Dict, Optional, Tuple, Union @@ -12,11 +11,11 @@ from transformer_lens.components.rms_norm import RMSNorm from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.factories.rotary_embedding_factory import RotaryEmbeddingFactory from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear -from transformer_lens.utils import get_offset_position_ids if is_bitsandbytes_available(): import bitsandbytes as bnb @@ -137,14 +136,7 @@ def __init__( self.hook_rot_q = HookPoint() if self.cfg.rotary_dim is None: # keep mypy happy raise ValueError("Rotary dim must be provided for rotary positional embeddings") - sin, cos = self.calculate_sin_cos_rotary( - self.cfg.rotary_dim, - self.cfg.n_ctx, - base=self.cfg.rotary_base, - dtype=self.cfg.dtype, - ) - self.register_buffer("rotary_sin", sin) - self.register_buffer("rotary_cos", cos) + self.rotary_module = RotaryEmbeddingFactory.create_rotary(self.cfg) elif self.cfg.positional_embedding_type == "alibi": # ALiBi bias wil be constructed on the first forward pass. # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. @@ -154,6 +146,14 @@ def __init__( # will be overwritten by the child T5Attention class self.has_relative_attention_bias = False + @property + def rotary_sin(self): + return self.rotary_module.rotary_sin + + @property + def rotary_cos(self): + return self.rotary_module.rotary_cos + @property def OV(self) -> FactoredMatrix: """ @@ -218,10 +218,8 @@ def forward( kv_cache_pos_offset = 0 if self.cfg.positional_embedding_type == "rotary": - q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) - k = self.hook_rot_k( - self.apply_rotary(k, 0, attention_mask) - ) # keys are cached so no offset + q = self.hook_rot_q(self.rotary_module(q, kv_cache_pos_offset, attention_mask)) + k = self.hook_rot_k(self.rotary_module(k, 0, attention_mask)) if self.cfg.dtype not in [torch.float32, torch.float64]: # If using 16 bits, increase the precision to avoid numerical instabilities diff --git a/transformer_lens/components/rotary_embeddings.py b/transformer_lens/components/rotary_embeddings.py new file mode 100644 index 000000000..93fef12e7 --- /dev/null +++ b/transformer_lens/components/rotary_embeddings.py @@ -0,0 +1,133 @@ +import math +from typing import Optional, Tuple, cast + +import einops +import torch +import torch.nn as nn +from jaxtyping import Float, Int + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utils import get_offset_position_ids + + +class RotaryEmbedding(nn.Module): + def __init__(self, cfg: HookedTransformerConfig): + super().__init__() + self.cfg: HookedTransformerConfig = cfg + rotary_dim = cast(int, self.cfg.rotary_dim) + sin, cos = self.calculate_sin_cos_rotary( + rotary_dim=rotary_dim, n_ctx=cfg.n_ctx, base=cfg.rotary_base, dtype=cfg.dtype + ) + self.register_buffer("rotary_sin", sin) + self.register_buffer("rotary_cos", cos) + + def calculate_sin_cos_rotary( + self, + rotary_dim: int, + n_ctx: int, + base: int = 10000, + dtype: torch.dtype = torch.float32, + ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: + """ + Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details + + Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. + To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. + """ + high_precision = torch.float32 if dtype != torch.float64 else torch.float64 + pos = torch.arange(n_ctx, dtype=high_precision) + dim = torch.arange(rotary_dim // 2, dtype=high_precision) + freq = base ** (dim / (rotary_dim / 2)) + freq = einops.repeat(freq, "d -> (d 2)") + # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency + angles = pos[:, None] / freq[None, :] + return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) + + def forward( + self, + x: Float[torch.Tensor, "batch pos head_index d_head"], + past_kv_pos_offset=0, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos head_index d_head"]: + # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) + x_pos = x.size(1) + x_rot = x[..., : self.cfg.rotary_dim] + x_pass = x[..., self.cfg.rotary_dim :] + x_flip = self.rotate_every_two(x_rot) + + if attention_mask is None: + rotary_cos = self.rotary_cos[ + None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : + ] + rotary_sin = self.rotary_sin[ + None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : + ] + x_rotated = x_rot * rotary_cos + x_flip * rotary_sin + else: + offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) + offset_position_ids = offset_position_ids.to(self.rotary_cos.device) + mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] + mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] + x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin + + return torch.cat([x_rotated, x_pass], dim=-1) + + def rotate_every_two( + self, x: Float[torch.Tensor, "... rotary_dim"] + ) -> Float[torch.Tensor, "... rotary_dim"]: + """ + Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] + + The final axis of x must have even length. + + GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. + """ + rot_x = x.clone() + if self.cfg.rotary_adjacent_pairs: + rot_x[..., ::2] = -x[..., 1::2] + rot_x[..., 1::2] = x[..., ::2] + else: + n = x.size(-1) // 2 + rot_x[..., :n] = -x[..., n:] + rot_x[..., n:] = x[..., :n] + return rot_x + + +class DynamicNTKScalingRotary(RotaryEmbedding): + def calculate_sin_cos_rotary( + self, + rotary_dim: int, + n_ctx: int, + base: int = 10000, + dtype: torch.dtype = torch.float32, + ): + # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 + # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 + high_precision = torch.float32 if dtype != torch.float64 else torch.float64 + pos = torch.arange(n_ctx, dtype=high_precision) + + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) + ) + factor = self.cfg.NTK_by_parts_factor + low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor + high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor + old_context_len = n_ctx + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + freq = 1 / inv_freq_llama + freq = einops.repeat(freq, "d -> (d 2)") + angles = pos[:, None] / freq[None, :] + return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) diff --git a/transformer_lens/factories/rotary_embedding_factory.py b/transformer_lens/factories/rotary_embedding_factory.py new file mode 100644 index 000000000..b53d91db5 --- /dev/null +++ b/transformer_lens/factories/rotary_embedding_factory.py @@ -0,0 +1,14 @@ +from transformer_lens.components.rotary_embeddings import ( + DynamicNTKScalingRotary, + RotaryEmbedding, +) +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +class RotaryEmbeddingFactory: + @staticmethod + def create_rotary(cfg: HookedTransformerConfig) -> RotaryEmbedding: + if cfg.use_NTK_by_parts_rope: + return DynamicNTKScalingRotary(cfg) + else: + return RotaryEmbedding(cfg)