diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 751734e53..3844c2ccc 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -157,6 +157,7 @@ def __init__( add_bos_token = self.cfg.original_architecture not in [ "OlmoForCausalLM", "OlmoeForCausalLM", + "Olmo2ForCausalLM", ] self.set_tokenizer( AutoTokenizer.from_pretrained( @@ -695,6 +696,7 @@ def set_tokenizer( if self.cfg.original_architecture not in [ "OlmoForCausalLM", "OlmoeForCausalLM", + "Olmo2ForCausalLM", ]: tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) else: @@ -1781,13 +1783,17 @@ def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): W_out. This is done by subtracting the mean of the weights from the weights themselves. This is done in-place. See fold_layer_norm for more details. """ - state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( - -1, keepdim=True - ) - if self.cfg.positional_embedding_type != "rotary": - state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ - "pos_embed.W_pos" - ].mean(-1, keepdim=True) + if self.cfg.original_architecture == "Olmo2ForCausalLM": + print("Not centering embedding weights for Olmo2ForCausalLM") + pass # should not because input of attn of 1st layer is not normed + else: + state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( + -1, keepdim=True + ) + if self.cfg.positional_embedding_type != "rotary": + state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ + "pos_embed.W_pos" + ].mean(-1, keepdim=True) for l in range(self.cfg.n_layers): state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ f"blocks.{l}.attn.W_O" diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index c6838b4e7..6bf05a97f 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -141,9 +141,10 @@ def __init__( # will be overwritten by the child T5Attention class self.has_relative_attention_bias = False - if self.cfg.original_architecture == "OlmoeForCausalLM": - self.q_norm = RMSNorm(cfg, cfg.d_model) - self.k_norm = RMSNorm(cfg, cfg.d_head * cfg.n_key_value_heads) + if self.cfg.original_architecture == "OlmoeForCausalLM" or self.cfg.original_architecture == "Olmo2ForCausalLM": + self.q_norm = RMSNorm(self.cfg, self.cfg.d_model) + k_norm_dim = self.cfg.d_model if self.cfg.original_architecture == "Olmo2ForCausalLM" else self.cfg.d_head * self.cfg.n_key_value_heads + self.k_norm = RMSNorm(self.cfg, k_norm_dim) @property def OV(self) -> FactoredMatrix: @@ -201,7 +202,7 @@ def forward( q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) # OLMoE uses QK-norm. - if self.cfg.original_architecture == "OlmoeForCausalLM": + if self.cfg.original_architecture == "OlmoeForCausalLM" or self.cfg.original_architecture == "Olmo2ForCausalLM": q = einops.rearrange( self.q_norm( einops.rearrange( diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index f04fe8d46..bef896b54 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -17,6 +17,7 @@ AutoModelForCausalLM, BertForPreTraining, T5ForConditionalGeneration, + PretrainedConfig ) import transformer_lens.utils as utils @@ -37,6 +38,7 @@ convert_neox_weights, convert_olmo_weights, convert_olmoe_weights, + convert_olmo2_weights, convert_opt_weights, convert_phi3_weights, convert_phi_weights, @@ -264,6 +266,7 @@ "allenai/OLMoE-1B-7B-0924", "allenai/OLMoE-1B-7B-0924-SFT", "allenai/OLMoE-1B-7B-0924-Instruct", + "allenai/OLMo-2-1124-7B" ] """Official model names for models on HuggingFace.""" @@ -764,7 +767,7 @@ def convert_hf_model_config(model_name: str, **kwargs): architecture = "GemmaForCausalLM" else: huggingface_token = os.environ.get("HF_TOKEN", None) - hf_config = AutoConfig.from_pretrained( + hf_config:PretrainedConfig = AutoConfig.from_pretrained( official_model_name, token=huggingface_token, **kwargs, @@ -1488,6 +1491,24 @@ def convert_hf_model_config(model_name: str, **kwargs): "positional_embedding_type": "rotary", "gated_mlp": True, } + elif official_model_name == "allenai/OLMo-2-1124-7B": + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 11008, + "n_layers": 32, + "n_ctx": 4096, + "eps": 1e-06, + "d_vocab": 100352, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "RMSPre", + "rotary_base": 500000.0, + "attn_types": ["global"] * 32, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } elif architecture == "OlmoeForCausalLM": cfg_dict = { "d_model": hf_config.hidden_size, @@ -1932,6 +1953,8 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "OlmoForCausalLM": state_dict = convert_olmo_weights(hf_model, cfg) + elif cfg.original_architecture == "Olmo2ForCausalLM": + state_dict = convert_olmo2_weights(hf_model, cfg) elif cfg.original_architecture == "OlmoeForCausalLM": state_dict = convert_olmoe_weights(hf_model, cfg) else: diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index bb2146832..88d5a76cc 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -20,3 +20,4 @@ from .neel_solu_old import convert_neel_solu_old_weights from .olmo import convert_olmo_weights from .olmoe import convert_olmoe_weights +from .olmo2 import convert_olmo2_weights \ No newline at end of file diff --git a/transformer_lens/pretrained/weight_conversions/olmo2.py b/transformer_lens/pretrained/weight_conversions/olmo2.py new file mode 100644 index 000000000..e531bf0f6 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/olmo2.py @@ -0,0 +1,60 @@ +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM, Olmo2DecoderLayer + +def convert_olmo2_weights(olmo2:Olmo2ForCausalLM, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.d_mlp is not None + + state_dict["embed.W_E"] = olmo2.model.embed_tokens.weight + + for l in range(cfg.n_layers): + olmo2_layer:Olmo2DecoderLayer = olmo2.model.layers[l] + + W_Q = olmo2_layer.self_attn.q_proj.weight + W_K = olmo2_layer.self_attn.k_proj.weight + W_V = olmo2_layer.self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + state_dict[f"blocks.{l}.attn.q_norm.w"] = olmo2_layer.self_attn.q_norm.weight + state_dict[f"blocks.{l}.attn.k_norm.w"] = olmo2_layer.self_attn.k_norm.weight + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + + W_O = olmo2_layer.self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln1.w"] = olmo2_layer.post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = olmo2_layer.mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = olmo2_layer.mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = olmo2_layer.mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = olmo2_layer.post_feedforward_layernorm.weight + + + state_dict["ln_final.w"] = olmo2.model.norm.weight + + state_dict["unembed.W_U"] = olmo2.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict