Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion transformer_lens/components/mlps/can_be_used_as_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformer_lens.hook_points import HookPoint
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.utilities.activation_functions import ActivationFunction
from transformer_lens.utils import XIELU


class CanBeUsedAsMLP(nn.Module):
Expand Down Expand Up @@ -65,7 +66,10 @@ def select_activation_function(self) -> None:
ValueError: If the configure activation function is not supported.
"""

self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)
if self.cfg.act_fn == "xielu":
self.act_fn = XIELU()
else:
self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)

if self.cfg.is_layer_norm_activation():
self.hook_mid = HookPoint()
Expand Down
45 changes: 45 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import transformer_lens.utils as utils
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.pretrained.weight_conversions import (
convert_apertus_weights,
convert_bert_weights,
convert_bloom_weights,
convert_coder_weights,
Expand Down Expand Up @@ -245,6 +246,8 @@
"microsoft/phi-2",
"microsoft/Phi-3-mini-4k-instruct",
"microsoft/phi-4",
"swiss-ai/Apertus-8B-2509",
"swiss-ai/Apertus-8B-Instruct-2509",
"google/gemma-2b",
"google/gemma-7b",
"google/gemma-2b-it",
Expand Down Expand Up @@ -701,6 +704,8 @@
"microsoft/phi-2": ["phi-2"],
"microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
"microsoft/phi-4": ["phi-4"],
"swiss-ai/Apertus-8B-2509": ["apertus-8b", "apertus"],
"swiss-ai/Apertus-8B-Instruct-2509": ["apertus-8b-instruct", "apertus-instruct"],
"google/gemma-2b": ["gemma-2b"],
"google/gemma-7b": ["gemma-7b"],
"google/gemma-2b-it": ["gemma-2b-it"],
Expand Down Expand Up @@ -742,6 +747,7 @@
"microsoft/phi-2",
"microsoft/Phi-3-mini-4k-instruct",
"microsoft/phi-4",
"swiss-ai/Apertus-",
)


Expand Down Expand Up @@ -1436,6 +1442,43 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
"parallel_attn_mlp": False,
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
}
elif architecture == "ApertusForCausalLM":
n_heads = hf_config.num_attention_heads
d_head = hf_config.hidden_size // n_heads
num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads)
n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": d_head,
"n_heads": n_heads,
"n_key_value_heads": n_kv_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"eps": hf_config.rms_norm_eps,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_dim": d_head,
"rotary_base": getattr(hf_config, "rope_theta", None),
"gated_mlp": False,
"final_rms": True,
"use_qk_norm": getattr(hf_config, "qk_norm", False),
}
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling:
rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower()
else:
rope_type = ""
if rope_type == "llama3":
cfg_dict["use_NTK_by_parts_rope"] = True
cfg_dict["NTK_original_ctx_len"] = rope_scaling.get(
"original_max_position_embeddings", hf_config.max_position_embeddings
)
cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0)
cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0)
cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0)

elif official_model_name.startswith("google/gemma-2b"):
# Architecture for Gemma 2b and Gemma 2b Instruct models
Expand Down Expand Up @@ -1986,6 +2029,8 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "ApertusForCausalLM":
state_dict = convert_apertus_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .apertus import convert_apertus_weights
123 changes: 123 additions & 0 deletions transformer_lens/pretrained/weight_conversions/apertus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Apertus is Llama like model architecture from Swiss AI.
convert weights to standardized format for HookedTransformer
"""

from typing import cast

import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def convert_apertus_weights(apertus, cfg: HookedTransformerConfig):
state_dict = {}

state_dict["embed.W_E"] = apertus.model.embed_tokens.weight

using_gqa = cfg.n_key_value_heads is not None
gqa_uscore = "_" if using_gqa else ""

n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)


assert cfg.d_mlp is not None # keep mypy happy

for l in range(cfg.n_layers):
state_dict[f"blocks.{l}.ln1.w"] = apertus.model.layers[l].attention_layernorm.weight
state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device)

W_Q = apertus.model.layers[l].self_attn.q_proj.weight
W_K = apertus.model.layers[l].self_attn.k_proj.weight
W_V = apertus.model.layers[l].self_attn.v_proj.weight

# in case of quantization,
# parameters should stay as bitsandbytes.nn.modules.Params4bit
if not cfg.load_in_4bit:
W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads)
W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads)

state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K
state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
n_kv_heads,
cfg.d_head,
dtype=cfg.dtype,
device=cfg.device,
)
state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
n_kv_heads,
cfg.d_head,
dtype=cfg.dtype,
device=cfg.device,
)

W_O = apertus.model.layers[l].self_attn.o_proj.weight

if not cfg.load_in_4bit:
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)

state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device)

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=cfg.device
)

state_dict[f"blocks.{l}.ln2.w"] = apertus.model.layers[l].feedforward_layernorm.weight
state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device)

# in case of quantization,
# parameters should stay as bitsandbytes.nn.modules.Params4bit
if not cfg.load_in_4bit:
state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight.T
state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight.T
else:
state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight
state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight

state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
cfg.d_mlp, dtype=cfg.dtype, device=cfg.device
)
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=cfg.device
)

# Extract trainable activation parameters
mlp = apertus.model.layers[l].mlp
try:
if hasattr(mlp, 'act_fn'):
alpha_p = mlp.act_fn.alpha_p
alpha_n = mlp.act_fn.alpha_n
beta = mlp.act_fn.beta
elif hasattr(mlp, 'act'):
alpha_p = mlp.act.alpha_p
alpha_n = mlp.act.alpha_n
beta = mlp.act.beta
else:
alpha_p = mlp.alpha_p
alpha_n = mlp.alpha_n
beta = mlp.beta
state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = alpha_p
state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = alpha_n
state_dict[f"blocks.{l}.mlp.act_fn.beta"] = beta
except AttributeError:
# If parameters not found, use defaults
print(f"Activation parameters not found in layer {l}, using defaults")
state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor(0.8, dtype=cfg.dtype, device=cfg.device)
state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor(0.8, dtype=cfg.dtype, device=cfg.device)
state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor(0.5, dtype=cfg.dtype, device=cfg.device)

state_dict["ln_final.w"] = apertus.model.norm.weight
state_dict["ln_final.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype, device=cfg.device)

state_dict["unembed.W_U"] = apertus.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device)

return state_dict
3 changes: 2 additions & 1 deletion transformer_lens/utilities/activation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn.functional as F

from transformer_lens.utils import gelu_fast, gelu_new, solu
from transformer_lens.utils import gelu_fast, gelu_new, solu, xielu

# Convenient type for the format of each activation function
ActivationFunction = Callable[..., torch.Tensor]
Expand All @@ -23,4 +23,5 @@
"relu": F.relu,
"gelu": F.gelu,
"gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"),
"xielu": xielu,
}
59 changes: 59 additions & 0 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from transformer_lens.FactoredMatrix import FactoredMatrix


CACHE_DIR = constants.HUGGINGFACE_HUB_CACHE
USE_DEFAULT_VALUE = None

Expand Down Expand Up @@ -203,6 +204,63 @@ def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "
return input * F.softmax(input, dim=-1)


class XIELU(nn.Module):
"""
Trainable xIELU activation function as described by
https://arxiv.org/abs/2411.13010

Defined as:
f(x) = {
α_p * x² + β * x, if x > 0
α_n * (exp(min(x, ε)) - 1) - α_n * x + β * x, if x ≤ 0
}
where α_p, α_n, β are trainable parameters.
"""
def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta_init: float = 0.5, eps: float = -1e-6):
super().__init__()
self.alpha_p = nn.Parameter(torch.tensor(alpha_p_init, dtype=torch.float32))
self.alpha_n = nn.Parameter(torch.tensor(alpha_n_init, dtype=torch.float32))
self.beta = nn.Parameter(torch.tensor(beta_init, dtype=torch.float32))
self.eps = eps

def forward(self, input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]:
return torch.where(
input > 0,
self.alpha_p * input ** 2 + self.beta * input,
self.alpha_n * torch.expm1(torch.clamp_max(input, self.eps)) - self.alpha_n * input + self.beta * input
)


def xielu(
input: Float[torch.Tensor, "batch pos d_mlp"]
) -> Float[torch.Tensor, "batch pos d_mlp"]:
"""
xIELU activation function as described by
https://arxiv.org/abs/2411.13010

and original code in:
https://github.com/rubber-duck-debug/xielu

Defined as

f(x) = {
α_p * x² + β * x, if x > 0
α_n * (exp(min(x, ε)) - 1) - α_n * x + β * x, if x ≤ 0
}

in this function the values are FIXED. However, the script can_be_used_as_mlp.py correctly used the XIELU class with trainable parameters, so the parameters can be trained if desired.
"""
alpha_p: float = 0.8
alpha_n: float = 0.8
beta: float = 0.5
eps: float = -1e-6

# The core calculation logic:
return torch.where(input > 0,
alpha_p * input * input + beta * input,
alpha_n * torch.expm1(torch.clamp_max(input, eps)) - alpha_n * input + beta * input)


ACTIVATION_FN_DICT = {
"solu": solu,
"solu_ln": solu,
Expand All @@ -212,6 +270,7 @@ def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "
"relu": F.relu,
"gelu": F.gelu,
"gelu_pytorch_tanh": gelu_pytorch_tanh,
"xielu": xielu,
}


Expand Down