diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 71425d1bf..46b42c2ba 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -651,7 +651,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "TransformerLens currently supports 231 models out of the box.\n" + "TransformerLens currently supports 232 models out of the box.\n" ] } ], @@ -1086,6 +1086,7 @@ " \"meta-llama/Meta-Llama-3-70B-Instruct\",\n", " \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", " \"mistralai/Mixtral-8x7B-v0.1\",\n", + " \"openai/gpt-oss-20b\",\n", " \"Qwen/Qwen2.5-32B\",\n", " \"Qwen/Qwen2.5-32B-Instruct\",\n", " \"Qwen/Qwen2.5-72B\",\n", diff --git a/demos/GPT_OSS_Demo.ipynb b/demos/GPT_OSS_Demo.ipynb new file mode 100644 index 000000000..74de6c04c --- /dev/null +++ b/demos/GPT_OSS_Demo.ipynb @@ -0,0 +1,512 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GPT-OSS-20B Demo\n", + "\n", + "This notebook loads OpenAI's [GPT-OSS-20B](https://huggingface.co/openai/gpt-oss-20b) into TransformerLens for mechanistic interpretability.\n", + "\n", + "GPT-OSS-20B is a Mixture of Experts (MoE) model with:\n", + "- 24 layers, d_model=2880, 32 experts, 4 experts per token\n", + "- MXFP4 quantized weights on HuggingFace (dequantized to BF16 during loading)\n", + "- Custom GLU activation and post-top-k softmax routing\n", + "\n", + "**Memory:** The full model is ~40GB in BF16. This notebook loads directly from safetensors, bypassing the HuggingFace model pipeline to keep peak memory manageable. Use `N_LAYERS` to load fewer layers if needed.\n", + "\n", + "**Requirements:** `transformers`, `safetensors`, `einops`, `psutil`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gc\n", + "import json\n", + "from pathlib import Path\n", + "\n", + "import einops\n", + "import torch\n", + "from safetensors import safe_open\n", + "from transformers import AutoTokenizer\n", + "from transformers.integrations.mxfp4 import convert_moe_packed_tensors\n", + "\n", + "from transformer_lens import HookedTransformer\n", + "from transformer_lens.HookedTransformerConfig import HookedTransformerConfig" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration\n", + "\n", + "Set `N_LAYERS` to control how many layers to load. Each layer is ~1.6GB. Use fewer layers to save memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "N_LAYERS = 24 # Full model. Set to 3-6 for quick testing." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Loading Utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_model_path():\n", + " \"\"\"Get the cached model path, downloading if necessary.\"\"\"\n", + " cache_path = Path.home() / \".cache/huggingface/hub/models--openai--gpt-oss-20b\"\n", + " snapshots = cache_path / \"snapshots\"\n", + "\n", + " if snapshots.exists():\n", + " snapshot_dirs = list(snapshots.iterdir())\n", + " if snapshot_dirs:\n", + " return snapshot_dirs[0]\n", + "\n", + " print(\"Model not found in cache. Downloading...\")\n", + " from huggingface_hub import snapshot_download\n", + "\n", + " return Path(snapshot_download(\"openai/gpt-oss-20b\"))\n", + "\n", + "\n", + "def create_config(n_layers=24):\n", + " \"\"\"Create TransformerLens config for GPT-OSS-20B.\"\"\"\n", + " return HookedTransformerConfig(\n", + " n_layers=n_layers,\n", + " d_model=2880,\n", + " d_head=64,\n", + " n_heads=64,\n", + " d_mlp=2880,\n", + " n_ctx=4096,\n", + " d_vocab=201088,\n", + " act_fn=\"silu\",\n", + " normalization_type=\"RMS\",\n", + " positional_embedding_type=\"rotary\",\n", + " rotary_base=150000,\n", + " eps=1e-5,\n", + " n_key_value_heads=8,\n", + " gated_mlp=True,\n", + " use_local_attn=False,\n", + " rotary_dim=64,\n", + " num_experts=32,\n", + " experts_per_token=4,\n", + " dtype=torch.bfloat16,\n", + " device=\"cpu\",\n", + " original_architecture=\"GptOssForCausalLM\",\n", + " model_name=\"openai/gpt-oss-20b\",\n", + " )\n", + "\n", + "\n", + "_open_files = {}\n", + "\n", + "\n", + "def _get_tensor(hf_name, wmap, model_path):\n", + " \"\"\"Load a single tensor from the correct safetensors shard.\"\"\"\n", + " st_file = wmap[hf_name]\n", + " filepath = str(model_path / st_file)\n", + " if filepath not in _open_files:\n", + " _open_files[filepath] = safe_open(filepath, framework=\"pt\", device=\"cpu\")\n", + " return _open_files[filepath].get_tensor(hf_name)\n", + "\n", + "\n", + "def load_layer_weights(l, cfg, index, model_path):\n", + " \"\"\"Load and convert weights for one transformer layer from safetensors.\"\"\"\n", + " state_dict = {}\n", + " wmap = index[\"weight_map\"]\n", + " prefix = f\"model.layers.{l}\"\n", + "\n", + " def gt(name):\n", + " return _get_tensor(name, wmap, model_path)\n", + "\n", + " # LayerNorms\n", + " state_dict[f\"blocks.{l}.ln1.w\"] = gt(f\"{prefix}.input_layernorm.weight\")\n", + " state_dict[f\"blocks.{l}.ln2.w\"] = gt(f\"{prefix}.post_attention_layernorm.weight\")\n", + "\n", + " # Attention weights\n", + " q_w = gt(f\"{prefix}.self_attn.q_proj.weight\")\n", + " k_w = gt(f\"{prefix}.self_attn.k_proj.weight\")\n", + " v_w = gt(f\"{prefix}.self_attn.v_proj.weight\")\n", + " o_w = gt(f\"{prefix}.self_attn.o_proj.weight\")\n", + "\n", + " state_dict[f\"blocks.{l}.attn.W_Q\"] = einops.rearrange(q_w, \"(n h) m -> n m h\", n=cfg.n_heads)\n", + " state_dict[f\"blocks.{l}.attn._W_K\"] = einops.rearrange(\n", + " k_w, \"(n h) m -> n m h\", n=cfg.n_key_value_heads\n", + " )\n", + " state_dict[f\"blocks.{l}.attn._W_V\"] = einops.rearrange(\n", + " v_w, \"(n h) m -> n m h\", n=cfg.n_key_value_heads\n", + " )\n", + " state_dict[f\"blocks.{l}.attn.W_O\"] = einops.rearrange(\n", + " o_w, \"m (n h) -> n h m\", n=cfg.n_heads\n", + " )\n", + " del q_w, k_w, v_w, o_w\n", + "\n", + " # Attention biases\n", + " q_bias_key = f\"{prefix}.self_attn.q_proj.bias\"\n", + " if q_bias_key in wmap:\n", + " state_dict[f\"blocks.{l}.attn.b_Q\"] = einops.rearrange(\n", + " gt(q_bias_key), \"(n h) -> n h\", n=cfg.n_heads\n", + " )\n", + " state_dict[f\"blocks.{l}.attn._b_K\"] = einops.rearrange(\n", + " gt(f\"{prefix}.self_attn.k_proj.bias\"), \"(n h) -> n h\", n=cfg.n_key_value_heads\n", + " )\n", + " state_dict[f\"blocks.{l}.attn._b_V\"] = einops.rearrange(\n", + " gt(f\"{prefix}.self_attn.v_proj.bias\"), \"(n h) -> n h\", n=cfg.n_key_value_heads\n", + " )\n", + " else:\n", + " state_dict[f\"blocks.{l}.attn.b_Q\"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)\n", + " state_dict[f\"blocks.{l}.attn._b_K\"] = torch.zeros(\n", + " cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype\n", + " )\n", + " state_dict[f\"blocks.{l}.attn._b_V\"] = torch.zeros(\n", + " cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype\n", + " )\n", + "\n", + " o_bias_key = f\"{prefix}.self_attn.o_proj.bias\"\n", + " if o_bias_key in wmap:\n", + " state_dict[f\"blocks.{l}.attn.b_O\"] = gt(o_bias_key)\n", + " else:\n", + " state_dict[f\"blocks.{l}.attn.b_O\"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)\n", + "\n", + " # Router\n", + " state_dict[f\"blocks.{l}.mlp.W_gate.weight\"] = gt(f\"{prefix}.mlp.router.weight\")\n", + " state_dict[f\"blocks.{l}.mlp.W_gate.bias\"] = gt(f\"{prefix}.mlp.router.bias\")\n", + "\n", + " # Expert weights - dequantize MXFP4 to BF16\n", + " gate_up_blocks = gt(f\"{prefix}.mlp.experts.gate_up_proj_blocks\")\n", + " gate_up_scales = gt(f\"{prefix}.mlp.experts.gate_up_proj_scales\")\n", + " gate_up_bias = gt(f\"{prefix}.mlp.experts.gate_up_proj_bias\")\n", + "\n", + " print(f\" Dequantizing layer {l} gate_up_proj...\", end=\"\", flush=True)\n", + " gate_up_proj = convert_moe_packed_tensors(gate_up_blocks, gate_up_scales)\n", + " del gate_up_blocks, gate_up_scales\n", + " print(\" done\")\n", + "\n", + " down_blocks = gt(f\"{prefix}.mlp.experts.down_proj_blocks\")\n", + " down_scales = gt(f\"{prefix}.mlp.experts.down_proj_scales\")\n", + " down_bias = gt(f\"{prefix}.mlp.experts.down_proj_bias\")\n", + "\n", + " print(f\" Dequantizing layer {l} down_proj...\", end=\"\", flush=True)\n", + " down_proj = convert_moe_packed_tensors(down_blocks, down_scales)\n", + " del down_blocks, down_scales\n", + " print(\" done\")\n", + "\n", + " # Split merged expert tensors into per-expert weights\n", + " # Even columns -> gate, Odd columns -> up\n", + " for e in range(cfg.num_experts):\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_gate.weight\"] = gate_up_proj[\n", + " e, :, ::2\n", + " ].T.contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_gate.bias\"] = gate_up_bias[\n", + " e, ::2\n", + " ].contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_in.weight\"] = gate_up_proj[\n", + " e, :, 1::2\n", + " ].T.contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_in.bias\"] = gate_up_bias[\n", + " e, 1::2\n", + " ].contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_out.weight\"] = down_proj[e].T.contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_out.bias\"] = down_bias[e].contiguous()\n", + "\n", + " del gate_up_proj, gate_up_bias, down_proj, down_bias\n", + " return state_dict" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = get_model_path()\n", + "print(f\"Model path: {model_path}\")\n", + "\n", + "with open(model_path / \"model.safetensors.index.json\") as f:\n", + " index = json.load(f)\n", + "\n", + "cfg = create_config(n_layers=N_LAYERS)\n", + "tokenizer = AutoTokenizer.from_pretrained(str(model_path))\n", + "model = HookedTransformer(cfg, tokenizer, move_to_device=False)\n", + "\n", + "# Load embeddings\n", + "wmap = index[\"weight_map\"]\n", + "model.load_state_dict(\n", + " {\"embed.W_E\": _get_tensor(\"model.embed_tokens.weight\", wmap, model_path)},\n", + " strict=False,\n", + ")\n", + "gc.collect()\n", + "\n", + "# Load layers one at a time\n", + "for l in range(N_LAYERS):\n", + " print(f\"Loading layer {l}/{N_LAYERS-1}...\")\n", + " layer_dict = load_layer_weights(l, cfg, index, model_path)\n", + " for key in list(layer_dict.keys()):\n", + " model.load_state_dict({key: layer_dict[key]}, strict=False)\n", + " del layer_dict[key]\n", + " del layer_dict\n", + " gc.collect()\n", + "\n", + "# Load final LayerNorm and unembed\n", + "model.load_state_dict(\n", + " {\"ln_final.w\": _get_tensor(\"model.norm.weight\", wmap, model_path)},\n", + " strict=False,\n", + ")\n", + "model.load_state_dict(\n", + " {\"unembed.W_U\": _get_tensor(\"lm_head.weight\", wmap, model_path).T},\n", + " strict=False,\n", + ")\n", + "model.load_state_dict(\n", + " {\"unembed.b_U\": torch.zeros(cfg.d_vocab, dtype=cfg.dtype)},\n", + " strict=False,\n", + ")\n", + "gc.collect()\n", + "\n", + "print(f\"\\nModel loaded! Layers: {cfg.n_layers}, Experts: {cfg.num_experts}, d_model: {cfg.d_model}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\n", + " \"The capital of France is\",\n", + " \"2 + 2 =\",\n", + " \"The opposite of hot is\",\n", + "]\n", + "\n", + "for prompt in prompts:\n", + " tokens = model.to_tokens(prompt)\n", + " with torch.no_grad():\n", + " logits = model(tokens)\n", + " probs = torch.softmax(logits[0, -1].float(), dim=-1)\n", + " top5 = probs.topk(5)\n", + " print(f\"Prompt: '{prompt}'\")\n", + " for i in range(5):\n", + " token_str = model.to_string(top5.indices[i])\n", + " print(f\" {token_str!r}: {top5.values[i]:.4f}\")\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Caching and Residual Stream" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokens = model.to_tokens(\"The cat sat on the mat\")\n", + "with torch.no_grad():\n", + " logits, cache = model.run_with_cache(tokens)\n", + "\n", + "print(\"Cached activation keys (first 10):\")\n", + "for key in list(cache.keys())[:10]:\n", + " print(f\" {key}: {cache[key].shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Expert Routing Analysis\n", + "\n", + "GPT-OSS routes each token to 4 of 32 experts. We can inspect which experts are selected and their weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "str_tokens = model.to_str_tokens(\"The cat sat on the mat\")\n", + "\n", + "for layer in range(min(N_LAYERS, 3)):\n", + " expert_weights = cache[f\"blocks.{layer}.mlp.hook_expert_weights\"]\n", + " expert_indices = cache[f\"blocks.{layer}.mlp.hook_expert_indices\"]\n", + "\n", + " print(f\"\\nLayer {layer} expert routing:\")\n", + " for t in range(len(str_tokens)):\n", + " indices = expert_indices[t].tolist()\n", + " weights = [expert_weights[t, idx].item() for idx in indices]\n", + " pairs = \", \".join(f\"E{idx}({w:.2f})\" for idx, w in zip(indices, weights))\n", + " print(f\" {str_tokens[t]:>10s} -> {pairs}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Logit Lens\n", + "\n", + "Apply the unembedding matrix to intermediate residual streams to see how the model's prediction evolves across layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"The capital of France is\"\n", + "tokens = model.to_tokens(prompt)\n", + "with torch.no_grad():\n", + " logits, cache = model.run_with_cache(tokens)\n", + "\n", + "print(f\"Prompt: '{prompt}'\")\n", + "print(f\"{'Layer':>6s} {'Top prediction':>20s} {'Prob':>6s}\")\n", + "print(\"-\" * 38)\n", + "\n", + "for layer in range(N_LAYERS):\n", + " resid = cache[f\"blocks.{layer}.hook_resid_post\"]\n", + " normed = model.ln_final(resid)\n", + " layer_logits = normed @ model.unembed.W_U + model.unembed.b_U\n", + " probs = torch.softmax(layer_logits[0, -1].float(), dim=-1)\n", + " top_idx = probs.argmax().item()\n", + " top_word = model.to_string(torch.tensor(top_idx))\n", + " print(f\"{layer:>6d} {top_word:>20s} {probs[top_idx]:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attention Patterns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"The cat sat on the mat\"\n", + "tokens = model.to_tokens(prompt)\n", + "str_tokens = model.to_str_tokens(prompt)\n", + "\n", + "with torch.no_grad():\n", + " logits, cache = model.run_with_cache(tokens)\n", + "\n", + "# Show attention pattern for layer 0, head 0\n", + "pattern = cache[\"blocks.0.attn.hook_pattern\"][0, 0] # [seq, seq]\n", + "print(\"Layer 0, Head 0 attention pattern:\")\n", + "print(f\"{'':>12s}\", \" \".join(f\"{t:>6s}\" for t in str_tokens))\n", + "for i, src_token in enumerate(str_tokens):\n", + " row = \" \".join(f\"{pattern[i, j]:.3f}\" for j in range(len(str_tokens)))\n", + " print(f\"{src_token:>12s} {row}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Activation Patching\n", + "\n", + "Patch the residual stream from a clean run into a corrupted run to measure causal effects." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clean_prompt = \"The capital of France is\"\n", + "corrupt_prompt = \"The capital of Germany is\"\n", + "\n", + "clean_tokens = model.to_tokens(clean_prompt)\n", + "corrupt_tokens = model.to_tokens(corrupt_prompt)\n", + "\n", + "# Get clean activations\n", + "captured_clean = {}\n", + "\n", + "def save_clean(tensor, hook):\n", + " captured_clean[hook.name] = tensor.detach().clone()\n", + "\n", + "with torch.no_grad():\n", + " clean_logits = model.run_with_hooks(\n", + " clean_tokens,\n", + " fwd_hooks=[(f\"blocks.{l}.hook_resid_post\", save_clean) for l in range(N_LAYERS)],\n", + " )\n", + "\n", + "with torch.no_grad():\n", + " corrupt_logits = model(corrupt_tokens)\n", + "\n", + "clean_pred = model.to_string(clean_logits[0, -1].argmax())\n", + "corrupt_pred = model.to_string(corrupt_logits[0, -1].argmax())\n", + "print(f\"Clean prediction: '{clean_pred}'\")\n", + "print(f\"Corrupt prediction: '{corrupt_pred}'\")\n", + "\n", + "# Patch each layer's residual stream and measure effect\n", + "print(f\"\\n{'Layer':>6s} {'Patched prediction':>20s} {'Logit diff':>10s}\")\n", + "print(\"-\" * 42)\n", + "\n", + "for layer in range(N_LAYERS):\n", + " hook_name = f\"blocks.{layer}.hook_resid_post\"\n", + "\n", + " def patch_hook(tensor, hook, clean_act=captured_clean[hook_name]):\n", + " return clean_act\n", + "\n", + " with torch.no_grad():\n", + " patched_logits = model.run_with_hooks(\n", + " corrupt_tokens,\n", + " fwd_hooks=[(hook_name, patch_hook)],\n", + " )\n", + "\n", + " patched_pred = model.to_string(patched_logits[0, -1].argmax())\n", + " diff = (patched_logits[0, -1] - corrupt_logits[0, -1]).abs().sum().item()\n", + " print(f\"{layer:>6d} {patched_pred:>20s} {diff:>10.1f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/transformer_lens/components/mlps/gpt_oss_moe.py b/transformer_lens/components/mlps/gpt_oss_moe.py new file mode 100644 index 000000000..bfc575280 --- /dev/null +++ b/transformer_lens/components/mlps/gpt_oss_moe.py @@ -0,0 +1,126 @@ +"""GPT-OSS Mixture of Experts implementation for TransformerLens. + +GPT-OSS uses a unique MoE architecture: +- Merged expert weights (gate_up_proj with interleaved gate/up columns) +- Custom GLU activation: gate * sigmoid(gate * 1.702) * (up + 1), with clamping +- Router with bias, softmax applied AFTER top-k selection +- Expert projections have biases +""" + +from typing import Dict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from jaxtyping import Float + +from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + +GPT_OSS_ALPHA = 1.702 +GPT_OSS_LIMIT = 7.0 + + +class GptOssExpert(nn.Module): + """Single GPT-OSS expert with custom GLU activation. + + The activation differs from standard SiLU: + gate = clamp(x @ W_gate + b_gate, max=7.0) + up = clamp(x @ W_in + b_in, min=-7.0, max=7.0) + glu = gate * sigmoid(gate * 1.702) + out = (up + 1) * glu + result = out @ W_out + b_out + """ + + def __init__(self, cfg: HookedTransformerConfig): + super().__init__() + self.cfg = cfg + assert cfg.d_mlp is not None + + self.W_gate = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True, dtype=cfg.dtype) + self.W_in = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True, dtype=cfg.dtype) + self.W_out = nn.Linear(cfg.d_mlp, cfg.d_model, bias=True, dtype=cfg.dtype) + + self.hook_gate = HookPoint() + self.hook_pre = HookPoint() + self.hook_post = HookPoint() + + def forward(self, x: Float[torch.Tensor, "pos d_model"]) -> Float[torch.Tensor, "pos d_model"]: + gate = self.hook_gate(self.W_gate(x)) + up = self.hook_pre(self.W_in(x)) + + # GPT-OSS custom activation + gate = gate.clamp(max=GPT_OSS_LIMIT) + up = up.clamp(min=-GPT_OSS_LIMIT, max=GPT_OSS_LIMIT) + glu = gate * torch.sigmoid(gate * GPT_OSS_ALPHA) + post = self.hook_post((up + 1) * glu) + + return self.W_out(post) + + +class GptOssMoE(CanBeUsedAsMLP): + """GPT-OSS Mixture of Experts layer. + + Differences from standard TransformerLens MoE (Mixtral): + - Router has bias + - Softmax applied AFTER top-k selection (not before) + - Experts use custom GLU activation (not SiLU) + - Expert projections have biases + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__(cfg) + + assert self.cfg.num_experts is not None + assert self.cfg.experts_per_token is not None + + self.num_experts: int = self.cfg.num_experts + self.experts_per_token: int = self.cfg.experts_per_token + + self.experts = nn.ModuleList([GptOssExpert(self.cfg) for _ in range(self.num_experts)]) + # GPT-OSS router has bias (unlike Mixtral) + self.W_gate = nn.Linear( + self.cfg.d_model, self.cfg.num_experts, bias=True, dtype=self.cfg.dtype + ) + + self.hook_expert_weights = HookPoint() + self.hook_expert_indices = HookPoint() + + def forward( + self, x: Float[torch.Tensor, "batch pos d_model"] + ) -> Float[torch.Tensor, "batch pos d_model"]: + batch, pos, d_model = x.shape + x = x.view(-1, d_model) + + # GPT-OSS routing: softmax AFTER top-k (differs from Mixtral) + gate_logits = self.W_gate(x) + top_values, expert_indices = torch.topk(gate_logits, self.experts_per_token, dim=-1) + # Softmax over just the selected experts + top_weights = F.softmax(top_values, dim=-1, dtype=torch.float) + + # Build full routing weights tensor for hooks (num_tokens, num_experts) + routing_weights = torch.zeros_like(gate_logits, dtype=torch.float) + routing_weights.scatter_(1, expert_indices, top_weights) + + routing_weights = self.hook_expert_weights(routing_weights) + expert_indices = self.hook_expert_indices(expert_indices) + routing_weights = routing_weights.to(x.dtype) + + results = torch.zeros((batch * pos, d_model), dtype=x.dtype, device=x.device) + expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.numel() == 0: + continue + + current_state = x[top_x] + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, expert_idx, None] + ) + results.index_add_(0, top_x, current_hidden_states.to(x.dtype)) + + return results.reshape(batch, pos, d_model) diff --git a/transformer_lens/factories/mlp_factory.py b/transformer_lens/factories/mlp_factory.py index fe4dbbab7..273c9c85c 100644 --- a/transformer_lens/factories/mlp_factory.py +++ b/transformer_lens/factories/mlp_factory.py @@ -6,6 +6,7 @@ from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP from transformer_lens.components.mlps.gated_mlp import GatedMLP from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit +from transformer_lens.components.mlps.gpt_oss_moe import GptOssMoE from transformer_lens.components.mlps.mlp import MLP from transformer_lens.components.mlps.moe import MoE from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -15,6 +16,8 @@ class MLPFactory: @staticmethod def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP: if cfg.num_experts: + if cfg.original_architecture == "GptOssForCausalLM": + return GptOssMoE(cfg) return MoE(cfg) elif cfg.gated_mlp: return GatedMLP(cfg) if not cfg.load_in_4bit else GatedMLP4Bit(cfg) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 1a8ef9ddc..c4870e40d 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -29,6 +29,7 @@ convert_coder_weights, convert_gemma_weights, convert_gpt2_weights, + convert_gpt_oss_weights, convert_gptj_weights, convert_llama_weights, convert_mingpt_weights, @@ -192,6 +193,7 @@ "mistralai/Mistral-Nemo-Base-2407", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1", + "openai/gpt-oss-20b", "bigscience/bloom-560m", "bigscience/bloom-1b1", "bigscience/bloom-1b7", @@ -663,6 +665,7 @@ "mixtral-instruct", "mixtral-8x7b-instruct", ], + "openai/gpt-oss-20b": ["gpt-oss-20b", "gpt-oss"], "bigscience/bloom-560m": ["bloom-560m"], "bigscience/bloom-1b1": ["bloom-1b1"], "bigscience/bloom-1b7": ["bloom-1b7"], @@ -1283,6 +1286,28 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "num_experts": hf_config.num_local_experts, "experts_per_token": hf_config.num_experts_per_tok, } + elif architecture == "GptOssForCausalLM": + cfg_dict = { + "dtype": torch.bfloat16, + "d_model": hf_config.hidden_size, + "d_head": hf_config.head_dim, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_base": hf_config.rope_theta, + "eps": hf_config.rms_norm_eps, + "n_key_value_heads": hf_config.num_key_value_heads, + "gated_mlp": True, + "use_local_attn": False, + "rotary_dim": hf_config.head_dim, + "num_experts": hf_config.num_local_experts, + "experts_per_token": hf_config.num_experts_per_tok, + } elif architecture == "BloomForCausalLM": cfg_dict = { "d_model": hf_config.hidden_size, @@ -2379,6 +2404,8 @@ def get_pretrained_state_dict( state_dict = convert_mistral_weights(hf_model, cfg) elif cfg.original_architecture == "MixtralForCausalLM": state_dict = convert_mixtral_weights(hf_model, cfg) + elif cfg.original_architecture == "GptOssForCausalLM": + state_dict = convert_gpt_oss_weights(hf_model, cfg) elif cfg.original_architecture == "BloomForCausalLM": state_dict = convert_bloom_weights(hf_model, cfg) elif cfg.original_architecture == "GPT2LMHeadCustomModel": diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..b58cce705 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .openai import convert_gpt_oss_weights diff --git a/transformer_lens/pretrained/weight_conversions/openai.py b/transformer_lens/pretrained/weight_conversions/openai.py new file mode 100644 index 000000000..17b9b40dd --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/openai.py @@ -0,0 +1,109 @@ +"""Weight conversion for OpenAI GPT-OSS models. + +GPT-OSS has a unique MoE architecture: +- GptOssExperts stores all expert weights in merged tensors (not individual modules) +- gate_up_proj: (num_experts, hidden_size, 2*expert_dim) with interleaved gate/up columns +- down_proj: (num_experts, expert_dim, hidden_size) +- Router (GptOssTopKRouter) uses weight + bias +""" + +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_gpt_oss_weights(gpt_oss, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.n_key_value_heads is not None + assert cfg.d_mlp is not None + assert cfg.num_experts is not None + + state_dict["embed.W_E"] = gpt_oss.model.embed_tokens.weight + + for l in range(cfg.n_layers): + layer = gpt_oss.model.layers[l] + + # LayerNorms + state_dict[f"blocks.{l}.ln1.w"] = layer.input_layernorm.weight + state_dict[f"blocks.{l}.ln2.w"] = layer.post_attention_layernorm.weight + + # Attention + W_Q = einops.rearrange(layer.self_attn.q_proj.weight, "(n h) m -> n m h", n=cfg.n_heads) + W_K = einops.rearrange( + layer.self_attn.k_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads + ) + W_V = einops.rearrange( + layer.self_attn.v_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads + ) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn._W_K"] = W_K + state_dict[f"blocks.{l}.attn._W_V"] = W_V + + if layer.self_attn.q_proj.bias is not None: + state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( + layer.self_attn.q_proj.bias, "(n h) -> n h", n=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn._b_K"] = einops.rearrange( + layer.self_attn.k_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads + ) + state_dict[f"blocks.{l}.attn._b_V"] = einops.rearrange( + layer.self_attn.v_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads + ) + else: + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + + W_O = einops.rearrange(layer.self_attn.o_proj.weight, "m (n h) -> n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + if hasattr(layer.self_attn.o_proj, "bias") and layer.self_attn.o_proj.bias is not None: + state_dict[f"blocks.{l}.attn.b_O"] = layer.self_attn.o_proj.bias + else: + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + # MoE - Router (GPT-OSS uses 'router' with bias) + state_dict[f"blocks.{l}.mlp.W_gate.weight"] = layer.mlp.router.weight + state_dict[f"blocks.{l}.mlp.W_gate.bias"] = layer.mlp.router.bias + + # MoE - Experts + # GPT-OSS stores all experts in merged tensors: + # gate_up_proj: (num_experts, hidden_size, 2*expert_dim) - interleaved gate/up + # down_proj: (num_experts, expert_dim, hidden_size) + experts = layer.mlp.experts + gate_up_proj = experts.gate_up_proj # (num_experts, hidden_size, 2*expert_dim) + gate_up_bias = experts.gate_up_proj_bias # (num_experts, 2*expert_dim) + down_proj = experts.down_proj # (num_experts, expert_dim, hidden_size) + down_bias = experts.down_proj_bias # (num_experts, hidden_size) + + for e in range(cfg.num_experts): + # Split interleaved gate_up_proj into separate gate and up (in) projections + # Even columns → gate path, Odd columns → up/in path + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[ + e, :, ::2 + ].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[ + e, ::2 + ].contiguous() + + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[ + e, :, 1::2 + ].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.bias"] = gate_up_bias[e, 1::2].contiguous() + + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down_proj[e].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.bias"] = down_bias[e].contiguous() + + state_dict["ln_final.w"] = gpt_oss.model.norm.weight + state_dict["unembed.W_U"] = gpt_oss.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict