diff --git a/transformer_lens/components/mlps/can_be_used_as_mlp.py b/transformer_lens/components/mlps/can_be_used_as_mlp.py index b0945276b..6ff660e0c 100644 --- a/transformer_lens/components/mlps/can_be_used_as_mlp.py +++ b/transformer_lens/components/mlps/can_be_used_as_mlp.py @@ -17,6 +17,7 @@ from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utilities.activation_functions import ActivationFunction +from transformer_lens.utils import XIELU class CanBeUsedAsMLP(nn.Module): @@ -65,7 +66,10 @@ def select_activation_function(self) -> None: ValueError: If the configure activation function is not supported. """ - self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) + if self.cfg.act_fn == "xielu": + self.act_fn = XIELU() + else: + self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) if self.cfg.is_layer_norm_activation(): self.hook_mid = HookPoint() diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..0a54c3304 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -24,6 +24,7 @@ import transformer_lens.utils as utils from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions import ( + convert_apertus_weights, convert_bert_weights, convert_bloom_weights, convert_coder_weights, @@ -245,6 +246,8 @@ "microsoft/phi-2", "microsoft/Phi-3-mini-4k-instruct", "microsoft/phi-4", + "swiss-ai/Apertus-8B-2509", + "swiss-ai/Apertus-8B-Instruct-2509", "google/gemma-2b", "google/gemma-7b", "google/gemma-2b-it", @@ -701,6 +704,8 @@ "microsoft/phi-2": ["phi-2"], "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], "microsoft/phi-4": ["phi-4"], + "swiss-ai/Apertus-8B-2509": ["apertus-8b", "apertus"], + "swiss-ai/Apertus-8B-Instruct-2509": ["apertus-8b-instruct", "apertus-instruct"], "google/gemma-2b": ["gemma-2b"], "google/gemma-7b": ["gemma-7b"], "google/gemma-2b-it": ["gemma-2b-it"], @@ -742,6 +747,7 @@ "microsoft/phi-2", "microsoft/Phi-3-mini-4k-instruct", "microsoft/phi-4", + "swiss-ai/Apertus-", ) @@ -1436,6 +1442,43 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "parallel_attn_mlp": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } + elif architecture == "ApertusForCausalLM": + n_heads = hf_config.num_attention_heads + d_head = hf_config.hidden_size // n_heads + num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads) + n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": d_head, + "n_heads": n_heads, + "n_key_value_heads": n_kv_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_dim": d_head, + "rotary_base": getattr(hf_config, "rope_theta", None), + "gated_mlp": False, + "final_rms": True, + "use_qk_norm": getattr(hf_config, "qk_norm", False), + } + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling: + rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower() + else: + rope_type = "" + if rope_type == "llama3": + cfg_dict["use_NTK_by_parts_rope"] = True + cfg_dict["NTK_original_ctx_len"] = rope_scaling.get( + "original_max_position_embeddings", hf_config.max_position_embeddings + ) + cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0) + cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0) + cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0) elif official_model_name.startswith("google/gemma-2b"): # Architecture for Gemma 2b and Gemma 2b Instruct models @@ -1986,6 +2029,8 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture == "ApertusForCausalLM": + state_dict = convert_apertus_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..b0defcf4c 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .apertus import convert_apertus_weights diff --git a/transformer_lens/pretrained/weight_conversions/apertus.py b/transformer_lens/pretrained/weight_conversions/apertus.py new file mode 100644 index 000000000..739a84d83 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/apertus.py @@ -0,0 +1,123 @@ +""" +Apertus is Llama like model architecture from Swiss AI. +convert weights to standardized format for HookedTransformer +""" + +from typing import cast + +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_apertus_weights(apertus, cfg: HookedTransformerConfig): + state_dict = {} + + state_dict["embed.W_E"] = apertus.model.embed_tokens.weight + + using_gqa = cfg.n_key_value_heads is not None + gqa_uscore = "_" if using_gqa else "" + + n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) + + + assert cfg.d_mlp is not None # keep mypy happy + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = apertus.model.layers[l].attention_layernorm.weight + state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + + W_Q = apertus.model.layers[l].self_attn.q_proj.weight + W_K = apertus.model.layers[l].self_attn.k_proj.weight + W_V = apertus.model.layers[l].self_attn.v_proj.weight + + # in case of quantization, + # parameters should stay as bitsandbytes.nn.modules.Params4bit + if not cfg.load_in_4bit: + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) + + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( + n_kv_heads, + cfg.d_head, + dtype=cfg.dtype, + device=cfg.device, + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( + n_kv_heads, + cfg.d_head, + dtype=cfg.dtype, + device=cfg.device, + ) + + W_O = apertus.model.layers[l].self_attn.o_proj.weight + + if not cfg.load_in_4bit: + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + + state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + state_dict[f"blocks.{l}.ln2.w"] = apertus.model.layers[l].feedforward_layernorm.weight + state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + + # in case of quantization, + # parameters should stay as bitsandbytes.nn.modules.Params4bit + if not cfg.load_in_4bit: + state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight.T + else: + state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight + state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight + + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( + cfg.d_mlp, dtype=cfg.dtype, device=cfg.device + ) + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) + + # Extract trainable activation parameters + mlp = apertus.model.layers[l].mlp + try: + if hasattr(mlp, 'act_fn'): + alpha_p = mlp.act_fn.alpha_p + alpha_n = mlp.act_fn.alpha_n + beta = mlp.act_fn.beta + elif hasattr(mlp, 'act'): + alpha_p = mlp.act.alpha_p + alpha_n = mlp.act.alpha_n + beta = mlp.act.beta + else: + alpha_p = mlp.alpha_p + alpha_n = mlp.alpha_n + beta = mlp.beta + state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = alpha_p + state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = alpha_n + state_dict[f"blocks.{l}.mlp.act_fn.beta"] = beta + except AttributeError: + # If parameters not found, use defaults + print(f"Activation parameters not found in layer {l}, using defaults") + state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor(0.8, dtype=cfg.dtype, device=cfg.device) + state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor(0.8, dtype=cfg.dtype, device=cfg.device) + state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor(0.5, dtype=cfg.dtype, device=cfg.device) + + state_dict["ln_final.w"] = apertus.model.norm.weight + state_dict["ln_final.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device) + + state_dict["unembed.W_U"] = apertus.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) + + return state_dict diff --git a/transformer_lens/utilities/activation_functions.py b/transformer_lens/utilities/activation_functions.py index 6cc701360..9cfa6eeb9 100644 --- a/transformer_lens/utilities/activation_functions.py +++ b/transformer_lens/utilities/activation_functions.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F -from transformer_lens.utils import gelu_fast, gelu_new, solu +from transformer_lens.utils import gelu_fast, gelu_new, solu, xielu # Convenient type for the format of each activation function ActivationFunction = Callable[..., torch.Tensor] @@ -23,4 +23,5 @@ "relu": F.relu, "gelu": F.gelu, "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), + "xielu": xielu, } diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index c0992848a..67c5c6652 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -30,6 +30,7 @@ from transformer_lens.FactoredMatrix import FactoredMatrix + CACHE_DIR = constants.HUGGINGFACE_HUB_CACHE USE_DEFAULT_VALUE = None @@ -203,6 +204,63 @@ def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, " return input * F.softmax(input, dim=-1) +class XIELU(nn.Module): + """ + Trainable xIELU activation function as described by + https://arxiv.org/abs/2411.13010 + + Defined as: + f(x) = { + α_p * x² + β * x, if x > 0 + α_n * (exp(min(x, ε)) - 1) - α_n * x + β * x, if x ≤ 0 + } + where α_p, α_n, β are trainable parameters. + """ + def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta_init: float = 0.5, eps: float = -1e-6): + super().__init__() + self.alpha_p = nn.Parameter(torch.tensor(alpha_p_init, dtype=torch.float32)) + self.alpha_n = nn.Parameter(torch.tensor(alpha_n_init, dtype=torch.float32)) + self.beta = nn.Parameter(torch.tensor(beta_init, dtype=torch.float32)) + self.eps = eps + + def forward(self, input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: + return torch.where( + input > 0, + self.alpha_p * input ** 2 + self.beta * input, + self.alpha_n * torch.expm1(torch.clamp_max(input, self.eps)) - self.alpha_n * input + self.beta * input + ) + + +def xielu( + input: Float[torch.Tensor, "batch pos d_mlp"] +) -> Float[torch.Tensor, "batch pos d_mlp"]: + """ + xIELU activation function as described by + https://arxiv.org/abs/2411.13010 + + and original code in: + https://github.com/rubber-duck-debug/xielu + + Defined as + + f(x) = { + α_p * x² + β * x, if x > 0 + α_n * (exp(min(x, ε)) - 1) - α_n * x + β * x, if x ≤ 0 + } + + in this function the values are FIXED. However, the script can_be_used_as_mlp.py correctly used the XIELU class with trainable parameters, so the parameters can be trained if desired. + """ + alpha_p: float = 0.8 + alpha_n: float = 0.8 + beta: float = 0.5 + eps: float = -1e-6 + + # The core calculation logic: + return torch.where(input > 0, + alpha_p * input * input + beta * input, + alpha_n * torch.expm1(torch.clamp_max(input, eps)) - alpha_n * input + beta * input) + + ACTIVATION_FN_DICT = { "solu": solu, "solu_ln": solu, @@ -212,6 +270,7 @@ def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, " "relu": F.relu, "gelu": F.gelu, "gelu_pytorch_tanh": gelu_pytorch_tanh, + "xielu": xielu, }