From 572d7ee19a226d300118bbef2d3c85c71ec3ff66 Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Fri, 6 Mar 2026 16:54:04 -0800 Subject: [PATCH 1/6] Add GPT-OSS-20B model support with MoE architecture Adds support for loading and running OpenAI's GPT-OSS-20B model, which uses a unique MoE architecture with 32 experts, custom GLU activation, and MXFP4 quantized weights on HuggingFace. New files: - gpt_oss_moe.py: GptOssExpert (custom activation) and GptOssMoE (routing) - openai.py: Weight converter for HF model -> TransformerLens state dict - run_gpt_oss.py: Direct safetensors loader that dequantizes MXFP4 on CPU, enabling the model to run on machines with <40GB RAM via swap Modified files: - loading_from_pretrained.py: Register model, config extraction, weight dispatch - mlp_factory.py: Route GptOssForCausalLM to GptOssMoE - weight_conversions/__init__.py: Export convert_gpt_oss_weights Co-Authored-By: Claude Opus 4.6 --- run_gpt_oss.py | 296 ++++++++++++++++++ .../components/mlps/gpt_oss_moe.py | 122 ++++++++ transformer_lens/factories/mlp_factory.py | 3 + transformer_lens/loading_from_pretrained.py | 27 ++ .../pretrained/weight_conversions/__init__.py | 1 + .../pretrained/weight_conversions/openai.py | 93 ++++++ 6 files changed, 542 insertions(+) create mode 100644 run_gpt_oss.py create mode 100644 transformer_lens/components/mlps/gpt_oss_moe.py create mode 100644 transformer_lens/pretrained/weight_conversions/openai.py diff --git a/run_gpt_oss.py b/run_gpt_oss.py new file mode 100644 index 000000000..1590e5111 --- /dev/null +++ b/run_gpt_oss.py @@ -0,0 +1,296 @@ +"""Load GPT-OSS-20B directly from safetensors into TransformerLens. + +Bypasses the HuggingFace model loading pipeline to avoid doubling memory usage. +The model is ~40GB in BF16 — loading via HF would require ~80GB peak (HF model + state dict). + +Instead, we: +1. Create the TransformerLens model structure (~40GB, filled with empty tensors) +2. Load weights from safetensors one layer at a time +3. Dequantize MXFP4 expert weights on the fly using HF's convert_moe_packed_tensors +4. Copy directly into TL model parameters, freeing temp data immediately + +Peak memory: ~42GB (model + one layer's temp data). Works on a 38.7GB Mac via swap. +""" + +import gc +import json +from pathlib import Path + +import einops +import torch +from safetensors import safe_open +from transformers import AutoTokenizer +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + +from transformer_lens import HookedTransformer +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def get_model_path(): + """Get the cached model path, downloading if necessary.""" + cache_path = Path.home() / ".cache/huggingface/hub/models--openai--gpt-oss-20b" + snapshots = cache_path / "snapshots" + + if snapshots.exists(): + # Use the first (usually only) snapshot + snapshot_dirs = list(snapshots.iterdir()) + if snapshot_dirs: + return snapshot_dirs[0] + + # Not cached — download + print("Model not found in cache. Downloading...") + from huggingface_hub import snapshot_download + return Path(snapshot_download("openai/gpt-oss-20b")) + + +def create_config(n_layers=24): + """Create TransformerLens config for GPT-OSS-20B.""" + return HookedTransformerConfig( + n_layers=n_layers, + d_model=2880, + d_head=64, + n_heads=64, + d_mlp=2880, + n_ctx=4096, # Reduced from 131072 to save memory + d_vocab=201088, + act_fn="silu", + normalization_type="RMS", + positional_embedding_type="rotary", + rotary_base=150000, + eps=1e-5, + n_key_value_heads=8, + gated_mlp=True, + use_local_attn=False, + rotary_dim=64, + num_experts=32, + experts_per_token=4, + dtype=torch.bfloat16, + device="cpu", + original_architecture="GptOssForCausalLM", + model_name="openai/gpt-oss-20b", + ) + + +def _get_tensor(hf_name, wmap, model_path, _open_files={}): + """Load a single tensor from the correct safetensors shard.""" + st_file = wmap[hf_name] + filepath = str(model_path / st_file) + if filepath not in _open_files: + _open_files[filepath] = safe_open(filepath, framework="pt", device="cpu") + return _open_files[filepath].get_tensor(hf_name) + + +def load_layer_weights(l, cfg, index, model_path): + """Load and convert weights for one transformer layer from safetensors.""" + state_dict = {} + wmap = index["weight_map"] + prefix = f"model.layers.{l}" + + def gt(name): + return _get_tensor(name, wmap, model_path) + + # LayerNorms + state_dict[f"blocks.{l}.ln1.w"] = gt(f"{prefix}.input_layernorm.weight") + state_dict[f"blocks.{l}.ln2.w"] = gt(f"{prefix}.post_attention_layernorm.weight") + + # Attention weights + q_w = gt(f"{prefix}.self_attn.q_proj.weight") + k_w = gt(f"{prefix}.self_attn.k_proj.weight") + v_w = gt(f"{prefix}.self_attn.v_proj.weight") + o_w = gt(f"{prefix}.self_attn.o_proj.weight") + + state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(q_w, "(n h) m -> n m h", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn._W_K"] = einops.rearrange(k_w, "(n h) m -> n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn._W_V"] = einops.rearrange(v_w, "(n h) m -> n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(o_w, "m (n h) -> n h m", n=cfg.n_heads) + del q_w, k_w, v_w, o_w + + # Attention biases + q_bias_key = f"{prefix}.self_attn.q_proj.bias" + if q_bias_key in wmap: + state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( + gt(q_bias_key), "(n h) -> n h", n=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn._b_K"] = einops.rearrange( + gt(f"{prefix}.self_attn.k_proj.bias"), "(n h) -> n h", n=cfg.n_key_value_heads + ) + state_dict[f"blocks.{l}.attn._b_V"] = einops.rearrange( + gt(f"{prefix}.self_attn.v_proj.bias"), "(n h) -> n h", n=cfg.n_key_value_heads + ) + else: + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype) + + o_bias_key = f"{prefix}.self_attn.o_proj.bias" + if o_bias_key in wmap: + state_dict[f"blocks.{l}.attn.b_O"] = gt(o_bias_key) + else: + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + # Router + state_dict[f"blocks.{l}.mlp.W_gate.weight"] = gt(f"{prefix}.mlp.router.weight") + state_dict[f"blocks.{l}.mlp.W_gate.bias"] = gt(f"{prefix}.mlp.router.bias") + + # Expert weights — dequantize MXFP4 to BF16 + gate_up_blocks = gt(f"{prefix}.mlp.experts.gate_up_proj_blocks") + gate_up_scales = gt(f"{prefix}.mlp.experts.gate_up_proj_scales") + gate_up_bias = gt(f"{prefix}.mlp.experts.gate_up_proj_bias") + + # Dequantize gate_up_proj: [32, 5760, 90, 16] + [32, 5760, 90] -> [32, 2880, 5760] + print(f" Dequantizing layer {l} gate_up_proj...", end="", flush=True) + gate_up_proj = convert_moe_packed_tensors(gate_up_blocks, gate_up_scales) + del gate_up_blocks, gate_up_scales + print(" done") + + down_blocks = gt(f"{prefix}.mlp.experts.down_proj_blocks") + down_scales = gt(f"{prefix}.mlp.experts.down_proj_scales") + down_bias = gt(f"{prefix}.mlp.experts.down_proj_bias") + + # Dequantize down_proj: [32, 2880, 90, 16] + [32, 2880, 90] -> [32, 2880, 2880] + print(f" Dequantizing layer {l} down_proj...", end="", flush=True) + down_proj = convert_moe_packed_tensors(down_blocks, down_scales) + del down_blocks, down_scales + print(" done") + + # Split merged expert tensors into per-expert weights + # gate_up_proj shape: [num_experts, hidden_size, 2*expert_dim] + # Even columns -> gate, Odd columns -> up + for e in range(cfg.num_experts): + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[e, :, ::2].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[e, ::2].contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[e, :, 1::2].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.bias"] = gate_up_bias[e, 1::2].contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down_proj[e].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.bias"] = down_bias[e].contiguous() + + del gate_up_proj, gate_up_bias, down_proj, down_bias + return state_dict + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Load GPT-OSS-20B into TransformerLens") + parser.add_argument("--layers", type=int, default=24, + help="Number of layers to load (default: 24, use fewer to save memory)") + parser.add_argument("--prompt", type=str, default=None, + help="Custom prompt to test (default: built-in test prompts)") + args = parser.parse_args() + + print("=" * 60) + print("GPT-OSS-20B via TransformerLens (Direct SafeTensors)") + print("=" * 60) + + import psutil + ram = psutil.virtual_memory() + print(f"\nPyTorch: {torch.__version__}") + print(f"MPS: {torch.backends.mps.is_available()}") + print(f"RAM: {ram.total/1e9:.1f}GB total, {ram.available/1e9:.1f}GB available") + + n_layers = args.layers + if n_layers < 24: + print(f"\nLoading first {n_layers} of 24 layers (reduced memory mode)") + est_gb = 2.4 + n_layers * 1.64 + print(f"Estimated memory: ~{est_gb:.0f}GB") + else: + print(f"\nLoading all 24 layers (~42GB, will use swap on <40GB RAM machines)") + + model_path = get_model_path() + print(f"Model path: {model_path}") + + with open(model_path / "model.safetensors.index.json") as f: + index = json.load(f) + + # Create config + cfg = create_config(n_layers=n_layers) + + # Load tokenizer + print("\nLoading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + + # Create TransformerLens model (allocates parameter storage) + print("Creating TransformerLens model structure...") + model = HookedTransformer(cfg, tokenizer, move_to_device=False) + + # Load embeddings + print("\nLoading embeddings...") + embed_file = str(model_path / index["weight_map"]["model.embed_tokens.weight"]) + with safe_open(embed_file, framework="pt", device="cpu") as f: + embed_w = f.get_tensor("model.embed_tokens.weight") + model.load_state_dict({"embed.W_E": embed_w}, strict=False) + del embed_w + gc.collect() + + # Load layers one at a time + for l in range(n_layers): + print(f"\nLoading layer {l}/{n_layers-1}...") + layer_dict = load_layer_weights(l, cfg, index, model_path) + + # Load into model one key at a time to minimize peak memory + keys = list(layer_dict.keys()) + for key in keys: + model.load_state_dict({key: layer_dict[key]}, strict=False) + del layer_dict[key] + del layer_dict + gc.collect() + + ram = psutil.virtual_memory() + print(f" RAM: {ram.used/1e9:.1f}GB used, {ram.available/1e9:.1f}GB available") + + # Load final LayerNorm and unembed + print("\nLoading final layers...") + final_file = str(model_path / index["weight_map"]["model.norm.weight"]) + with safe_open(final_file, framework="pt", device="cpu") as f: + ln_w = f.get_tensor("model.norm.weight") + unembed_w = f.get_tensor("lm_head.weight").T + + model.load_state_dict({"ln_final.w": ln_w}, strict=False) + del ln_w + model.load_state_dict({"unembed.W_U": unembed_w}, strict=False) + del unembed_w + model.load_state_dict({"unembed.b_U": torch.zeros(cfg.d_vocab, dtype=cfg.dtype)}, strict=False) + gc.collect() + + print("\n" + "=" * 60) + print("Model loaded successfully!") + print(f"Architecture: {cfg.original_architecture}") + print(f"Layers: {cfg.n_layers}") + print(f"Experts: {cfg.num_experts}") + print(f"d_model: {cfg.d_model}") + + ram = psutil.virtual_memory() + print(f"RAM: {ram.used/1e9:.1f}GB used, {ram.available/1e9:.1f}GB available") + + # Test inference + if args.prompt: + prompts = [args.prompt] + else: + prompts = [ + "The capital of France is", + "2 + 2 =", + "The opposite of hot is", + ] + + for prompt in prompts: + print(f"\n{'='*60}") + print(f"Prompt: '{prompt}'") + tokens = model.to_tokens(prompt) + with torch.no_grad(): + logits = model(tokens) + pred = model.to_string(logits[0, -1].argmax()) + print(f"Prediction: '{pred}'") + + # Show top 5 predictions + probs = torch.softmax(logits[0, -1].float(), dim=-1) + top5 = probs.topk(5) + print("Top 5:") + for i in range(5): + token_str = model.to_string(top5.indices[i]) + print(f" {token_str!r}: {top5.values[i]:.4f}") + + print(f"\n{'='*60}") + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/transformer_lens/components/mlps/gpt_oss_moe.py b/transformer_lens/components/mlps/gpt_oss_moe.py new file mode 100644 index 000000000..401d7ac1f --- /dev/null +++ b/transformer_lens/components/mlps/gpt_oss_moe.py @@ -0,0 +1,122 @@ +"""GPT-OSS Mixture of Experts implementation for TransformerLens. + +GPT-OSS uses a unique MoE architecture: +- Merged expert weights (gate_up_proj with interleaved gate/up columns) +- Custom GLU activation: gate * sigmoid(gate * 1.702) * (up + 1), with clamping +- Router with bias, softmax applied AFTER top-k selection +- Expert projections have biases +""" + +from typing import Dict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from jaxtyping import Float + +from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + +GPT_OSS_ALPHA = 1.702 +GPT_OSS_LIMIT = 7.0 + + +class GptOssExpert(nn.Module): + """Single GPT-OSS expert with custom GLU activation. + + The activation differs from standard SiLU: + gate = clamp(x @ W_gate + b_gate, max=7.0) + up = clamp(x @ W_in + b_in, min=-7.0, max=7.0) + glu = gate * sigmoid(gate * 1.702) + out = (up + 1) * glu + result = out @ W_out + b_out + """ + + def __init__(self, cfg: HookedTransformerConfig): + super().__init__() + self.cfg = cfg + assert cfg.d_mlp is not None + + self.W_gate = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True, dtype=cfg.dtype) + self.W_in = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True, dtype=cfg.dtype) + self.W_out = nn.Linear(cfg.d_mlp, cfg.d_model, bias=True, dtype=cfg.dtype) + + self.hook_gate = HookPoint() + self.hook_pre = HookPoint() + self.hook_post = HookPoint() + + def forward(self, x: Float[torch.Tensor, "pos d_model"]) -> Float[torch.Tensor, "pos d_model"]: + gate = self.hook_gate(self.W_gate(x)) + up = self.hook_pre(self.W_in(x)) + + # GPT-OSS custom activation + gate = gate.clamp(max=GPT_OSS_LIMIT) + up = up.clamp(min=-GPT_OSS_LIMIT, max=GPT_OSS_LIMIT) + glu = gate * torch.sigmoid(gate * GPT_OSS_ALPHA) + post = self.hook_post((up + 1) * glu) + + return self.W_out(post) + + +class GptOssMoE(CanBeUsedAsMLP): + """GPT-OSS Mixture of Experts layer. + + Differences from standard TransformerLens MoE (Mixtral): + - Router has bias + - Softmax applied AFTER top-k selection (not before) + - Experts use custom GLU activation (not SiLU) + - Expert projections have biases + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__(cfg) + + assert self.cfg.num_experts is not None + assert self.cfg.experts_per_token is not None + + self.num_experts: int = self.cfg.num_experts + self.experts_per_token: int = self.cfg.experts_per_token + + self.experts = nn.ModuleList([GptOssExpert(self.cfg) for _ in range(self.num_experts)]) + # GPT-OSS router has bias (unlike Mixtral) + self.W_gate = nn.Linear(self.cfg.d_model, self.cfg.num_experts, bias=True, dtype=self.cfg.dtype) + + self.hook_expert_weights = HookPoint() + self.hook_expert_indices = HookPoint() + + def forward( + self, x: Float[torch.Tensor, "batch pos d_model"] + ) -> Float[torch.Tensor, "batch pos d_model"]: + batch, pos, d_model = x.shape + x = x.view(-1, d_model) + + # GPT-OSS routing: softmax AFTER top-k (differs from Mixtral) + gate_logits = self.W_gate(x) + top_values, expert_indices = torch.topk(gate_logits, self.experts_per_token, dim=-1) + # Softmax over just the selected experts + top_weights = F.softmax(top_values, dim=-1, dtype=torch.float) + + # Build full routing weights tensor for hooks (num_tokens, num_experts) + routing_weights = torch.zeros_like(gate_logits, dtype=torch.float) + routing_weights.scatter_(1, expert_indices, top_weights) + + routing_weights = self.hook_expert_weights(routing_weights) + expert_indices = self.hook_expert_indices(expert_indices) + routing_weights = routing_weights.to(x.dtype) + + results = torch.zeros((batch * pos, d_model), dtype=x.dtype, device=x.device) + expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.numel() == 0: + continue + + current_state = x[top_x] + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, expert_idx, None] + results.index_add_(0, top_x, current_hidden_states.to(x.dtype)) + + return results.reshape(batch, pos, d_model) diff --git a/transformer_lens/factories/mlp_factory.py b/transformer_lens/factories/mlp_factory.py index fe4dbbab7..273c9c85c 100644 --- a/transformer_lens/factories/mlp_factory.py +++ b/transformer_lens/factories/mlp_factory.py @@ -6,6 +6,7 @@ from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP from transformer_lens.components.mlps.gated_mlp import GatedMLP from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit +from transformer_lens.components.mlps.gpt_oss_moe import GptOssMoE from transformer_lens.components.mlps.mlp import MLP from transformer_lens.components.mlps.moe import MoE from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -15,6 +16,8 @@ class MLPFactory: @staticmethod def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP: if cfg.num_experts: + if cfg.original_architecture == "GptOssForCausalLM": + return GptOssMoE(cfg) return MoE(cfg) elif cfg.gated_mlp: return GatedMLP(cfg) if not cfg.load_in_4bit else GatedMLP4Bit(cfg) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 1a8ef9ddc..2603e8e06 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -28,6 +28,7 @@ convert_bloom_weights, convert_coder_weights, convert_gemma_weights, + convert_gpt_oss_weights, convert_gpt2_weights, convert_gptj_weights, convert_llama_weights, @@ -192,6 +193,7 @@ "mistralai/Mistral-Nemo-Base-2407", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1", + "openai/gpt-oss-20b", "bigscience/bloom-560m", "bigscience/bloom-1b1", "bigscience/bloom-1b7", @@ -663,6 +665,7 @@ "mixtral-instruct", "mixtral-8x7b-instruct", ], + "openai/gpt-oss-20b": ["gpt-oss-20b", "gpt-oss"], "bigscience/bloom-560m": ["bloom-560m"], "bigscience/bloom-1b1": ["bloom-1b1"], "bigscience/bloom-1b7": ["bloom-1b7"], @@ -1283,6 +1286,28 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "num_experts": hf_config.num_local_experts, "experts_per_token": hf_config.num_experts_per_tok, } + elif architecture == "GptOssForCausalLM": + cfg_dict = { + "dtype": torch.bfloat16, + "d_model": hf_config.hidden_size, + "d_head": hf_config.head_dim, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_base": hf_config.rope_theta, + "eps": hf_config.rms_norm_eps, + "n_key_value_heads": hf_config.num_key_value_heads, + "gated_mlp": True, + "use_local_attn": False, + "rotary_dim": hf_config.head_dim, + "num_experts": hf_config.num_local_experts, + "experts_per_token": hf_config.num_experts_per_tok, + } elif architecture == "BloomForCausalLM": cfg_dict = { "d_model": hf_config.hidden_size, @@ -2379,6 +2404,8 @@ def get_pretrained_state_dict( state_dict = convert_mistral_weights(hf_model, cfg) elif cfg.original_architecture == "MixtralForCausalLM": state_dict = convert_mixtral_weights(hf_model, cfg) + elif cfg.original_architecture == "GptOssForCausalLM": + state_dict = convert_gpt_oss_weights(hf_model, cfg) elif cfg.original_architecture == "BloomForCausalLM": state_dict = convert_bloom_weights(hf_model, cfg) elif cfg.original_architecture == "GPT2LMHeadCustomModel": diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..b58cce705 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .openai import convert_gpt_oss_weights diff --git a/transformer_lens/pretrained/weight_conversions/openai.py b/transformer_lens/pretrained/weight_conversions/openai.py new file mode 100644 index 000000000..90f4e6f4e --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/openai.py @@ -0,0 +1,93 @@ +"""Weight conversion for OpenAI GPT-OSS models. + +GPT-OSS has a unique MoE architecture: +- GptOssExperts stores all expert weights in merged tensors (not individual modules) +- gate_up_proj: (num_experts, hidden_size, 2*expert_dim) with interleaved gate/up columns +- down_proj: (num_experts, expert_dim, hidden_size) +- Router (GptOssTopKRouter) uses weight + bias +""" + +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_gpt_oss_weights(gpt_oss, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.n_key_value_heads is not None + assert cfg.d_mlp is not None + assert cfg.num_experts is not None + + state_dict["embed.W_E"] = gpt_oss.model.embed_tokens.weight + + for l in range(cfg.n_layers): + layer = gpt_oss.model.layers[l] + + # LayerNorms + state_dict[f"blocks.{l}.ln1.w"] = layer.input_layernorm.weight + state_dict[f"blocks.{l}.ln2.w"] = layer.post_attention_layernorm.weight + + # Attention + W_Q = einops.rearrange(layer.self_attn.q_proj.weight, "(n h) m -> n m h", n=cfg.n_heads) + W_K = einops.rearrange(layer.self_attn.k_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads) + W_V = einops.rearrange(layer.self_attn.v_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn._W_K"] = W_K + state_dict[f"blocks.{l}.attn._W_V"] = W_V + + if layer.self_attn.q_proj.bias is not None: + state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( + layer.self_attn.q_proj.bias, "(n h) -> n h", n=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn._b_K"] = einops.rearrange( + layer.self_attn.k_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads + ) + state_dict[f"blocks.{l}.attn._b_V"] = einops.rearrange( + layer.self_attn.v_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads + ) + else: + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype) + + W_O = einops.rearrange(layer.self_attn.o_proj.weight, "m (n h) -> n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + if hasattr(layer.self_attn.o_proj, "bias") and layer.self_attn.o_proj.bias is not None: + state_dict[f"blocks.{l}.attn.b_O"] = layer.self_attn.o_proj.bias + else: + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + # MoE - Router (GPT-OSS uses 'router' with bias) + state_dict[f"blocks.{l}.mlp.W_gate.weight"] = layer.mlp.router.weight + state_dict[f"blocks.{l}.mlp.W_gate.bias"] = layer.mlp.router.bias + + # MoE - Experts + # GPT-OSS stores all experts in merged tensors: + # gate_up_proj: (num_experts, hidden_size, 2*expert_dim) - interleaved gate/up + # down_proj: (num_experts, expert_dim, hidden_size) + experts = layer.mlp.experts + gate_up_proj = experts.gate_up_proj # (num_experts, hidden_size, 2*expert_dim) + gate_up_bias = experts.gate_up_proj_bias # (num_experts, 2*expert_dim) + down_proj = experts.down_proj # (num_experts, expert_dim, hidden_size) + down_bias = experts.down_proj_bias # (num_experts, hidden_size) + + for e in range(cfg.num_experts): + # Split interleaved gate_up_proj into separate gate and up (in) projections + # Even columns → gate path, Odd columns → up/in path + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[e, :, ::2].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[e, ::2].contiguous() + + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[e, :, 1::2].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.bias"] = gate_up_bias[e, 1::2].contiguous() + + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down_proj[e].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.bias"] = down_bias[e].contiguous() + + state_dict["ln_final.w"] = gpt_oss.model.norm.weight + state_dict["unembed.W_U"] = gpt_oss.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict From 9d72d1b94e1e4fa21cfe93af17d71cf9c62b6c95 Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Fri, 6 Mar 2026 17:09:12 -0800 Subject: [PATCH 2/6] Add openai/gpt-oss-20b to Colab_Compatibility notebook Register the new model in the incompatible_models list (too large for Colab) and update the model count from 231 to 232. Co-Authored-By: Claude Opus 4.6 --- demos/Colab_Compatibility.ipynb | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 71425d1bf..068dd4d24 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -651,7 +651,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "TransformerLens currently supports 231 models out of the box.\n" + "TransformerLens currently supports 232 models out of the box.\n" ] } ], @@ -1086,6 +1086,7 @@ " \"meta-llama/Meta-Llama-3-70B-Instruct\",\n", " \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", " \"mistralai/Mixtral-8x7B-v0.1\",\n", + " \"openai/gpt-oss-20b\",\n", " \"Qwen/Qwen2.5-32B\",\n", " \"Qwen/Qwen2.5-32B-Instruct\",\n", " \"Qwen/Qwen2.5-72B\",\n", @@ -1097,7 +1098,7 @@ "mark_models_as_tested(incompatible_models)" ], "outputs": [], - "execution_count": 32 + "execution_count": null }, { "cell_type": "code", From c1e85360f4e2c1e4eb1af2a21ac7f74e776a5964 Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Fri, 6 Mar 2026 19:45:22 -0800 Subject: [PATCH 3/6] Fix import sorting in loading_from_pretrained.py Co-Authored-By: Claude Opus 4.6 --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 2603e8e06..c4870e40d 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -28,8 +28,8 @@ convert_bloom_weights, convert_coder_weights, convert_gemma_weights, - convert_gpt_oss_weights, convert_gpt2_weights, + convert_gpt_oss_weights, convert_gptj_weights, convert_llama_weights, convert_mingpt_weights, From ac6be3e878715895acd16115cd2c3b4d29f90ff2 Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Fri, 6 Mar 2026 19:52:28 -0800 Subject: [PATCH 4/6] Apply black formatting to GPT-OSS files Co-Authored-By: Claude Opus 4.6 --- run_gpt_oss.py | 43 ++++++++++++++----- .../components/mlps/gpt_oss_moe.py | 8 +++- .../pretrained/weight_conversions/openai.py | 42 ++++++++++++------ 3 files changed, 68 insertions(+), 25 deletions(-) diff --git a/run_gpt_oss.py b/run_gpt_oss.py index 1590e5111..2c43473bd 100644 --- a/run_gpt_oss.py +++ b/run_gpt_oss.py @@ -40,6 +40,7 @@ def get_model_path(): # Not cached — download print("Model not found in cache. Downloading...") from huggingface_hub import snapshot_download + return Path(snapshot_download("openai/gpt-oss-20b")) @@ -100,8 +101,12 @@ def gt(name): o_w = gt(f"{prefix}.self_attn.o_proj.weight") state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(q_w, "(n h) m -> n m h", n=cfg.n_heads) - state_dict[f"blocks.{l}.attn._W_K"] = einops.rearrange(k_w, "(n h) m -> n m h", n=cfg.n_key_value_heads) - state_dict[f"blocks.{l}.attn._W_V"] = einops.rearrange(v_w, "(n h) m -> n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn._W_K"] = einops.rearrange( + k_w, "(n h) m -> n m h", n=cfg.n_key_value_heads + ) + state_dict[f"blocks.{l}.attn._W_V"] = einops.rearrange( + v_w, "(n h) m -> n m h", n=cfg.n_key_value_heads + ) state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(o_w, "m (n h) -> n h m", n=cfg.n_heads) del q_w, k_w, v_w, o_w @@ -119,8 +124,12 @@ def gt(name): ) else: state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) - state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype) - state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) o_bias_key = f"{prefix}.self_attn.o_proj.bias" if o_bias_key in wmap: @@ -157,9 +166,13 @@ def gt(name): # gate_up_proj shape: [num_experts, hidden_size, 2*expert_dim] # Even columns -> gate, Odd columns -> up for e in range(cfg.num_experts): - state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[e, :, ::2].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[ + e, :, ::2 + ].T.contiguous() state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[e, ::2].contiguous() - state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[e, :, 1::2].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[ + e, :, 1::2 + ].T.contiguous() state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.bias"] = gate_up_bias[e, 1::2].contiguous() state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down_proj[e].T.contiguous() state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.bias"] = down_bias[e].contiguous() @@ -170,11 +183,20 @@ def gt(name): def main(): import argparse + parser = argparse.ArgumentParser(description="Load GPT-OSS-20B into TransformerLens") - parser.add_argument("--layers", type=int, default=24, - help="Number of layers to load (default: 24, use fewer to save memory)") - parser.add_argument("--prompt", type=str, default=None, - help="Custom prompt to test (default: built-in test prompts)") + parser.add_argument( + "--layers", + type=int, + default=24, + help="Number of layers to load (default: 24, use fewer to save memory)", + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help="Custom prompt to test (default: built-in test prompts)", + ) args = parser.parse_args() print("=" * 60) @@ -182,6 +204,7 @@ def main(): print("=" * 60) import psutil + ram = psutil.virtual_memory() print(f"\nPyTorch: {torch.__version__}") print(f"MPS: {torch.backends.mps.is_available()}") diff --git a/transformer_lens/components/mlps/gpt_oss_moe.py b/transformer_lens/components/mlps/gpt_oss_moe.py index 401d7ac1f..bfc575280 100644 --- a/transformer_lens/components/mlps/gpt_oss_moe.py +++ b/transformer_lens/components/mlps/gpt_oss_moe.py @@ -80,7 +80,9 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): self.experts = nn.ModuleList([GptOssExpert(self.cfg) for _ in range(self.num_experts)]) # GPT-OSS router has bias (unlike Mixtral) - self.W_gate = nn.Linear(self.cfg.d_model, self.cfg.num_experts, bias=True, dtype=self.cfg.dtype) + self.W_gate = nn.Linear( + self.cfg.d_model, self.cfg.num_experts, bias=True, dtype=self.cfg.dtype + ) self.hook_expert_weights = HookPoint() self.hook_expert_indices = HookPoint() @@ -116,7 +118,9 @@ def forward( continue current_state = x[top_x] - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, expert_idx, None] + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, expert_idx, None] + ) results.index_add_(0, top_x, current_hidden_states.to(x.dtype)) return results.reshape(batch, pos, d_model) diff --git a/transformer_lens/pretrained/weight_conversions/openai.py b/transformer_lens/pretrained/weight_conversions/openai.py index 90f4e6f4e..17b9b40dd 100644 --- a/transformer_lens/pretrained/weight_conversions/openai.py +++ b/transformer_lens/pretrained/weight_conversions/openai.py @@ -31,8 +31,12 @@ def convert_gpt_oss_weights(gpt_oss, cfg: HookedTransformerConfig): # Attention W_Q = einops.rearrange(layer.self_attn.q_proj.weight, "(n h) m -> n m h", n=cfg.n_heads) - W_K = einops.rearrange(layer.self_attn.k_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads) - W_V = einops.rearrange(layer.self_attn.v_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads) + W_K = einops.rearrange( + layer.self_attn.k_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads + ) + W_V = einops.rearrange( + layer.self_attn.v_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads + ) state_dict[f"blocks.{l}.attn.W_Q"] = W_Q state_dict[f"blocks.{l}.attn._W_K"] = W_K state_dict[f"blocks.{l}.attn._W_V"] = W_V @@ -48,9 +52,15 @@ def convert_gpt_oss_weights(gpt_oss, cfg: HookedTransformerConfig): layer.self_attn.v_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads ) else: - state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) - state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype) - state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) W_O = einops.rearrange(layer.self_attn.o_proj.weight, "m (n h) -> n h m", n=cfg.n_heads) state_dict[f"blocks.{l}.attn.W_O"] = W_O @@ -69,18 +79,24 @@ def convert_gpt_oss_weights(gpt_oss, cfg: HookedTransformerConfig): # gate_up_proj: (num_experts, hidden_size, 2*expert_dim) - interleaved gate/up # down_proj: (num_experts, expert_dim, hidden_size) experts = layer.mlp.experts - gate_up_proj = experts.gate_up_proj # (num_experts, hidden_size, 2*expert_dim) - gate_up_bias = experts.gate_up_proj_bias # (num_experts, 2*expert_dim) - down_proj = experts.down_proj # (num_experts, expert_dim, hidden_size) - down_bias = experts.down_proj_bias # (num_experts, hidden_size) + gate_up_proj = experts.gate_up_proj # (num_experts, hidden_size, 2*expert_dim) + gate_up_bias = experts.gate_up_proj_bias # (num_experts, 2*expert_dim) + down_proj = experts.down_proj # (num_experts, expert_dim, hidden_size) + down_bias = experts.down_proj_bias # (num_experts, hidden_size) for e in range(cfg.num_experts): # Split interleaved gate_up_proj into separate gate and up (in) projections # Even columns → gate path, Odd columns → up/in path - state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[e, :, ::2].T.contiguous() - state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[e, ::2].contiguous() - - state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[e, :, 1::2].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[ + e, :, ::2 + ].T.contiguous() + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[ + e, ::2 + ].contiguous() + + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[ + e, :, 1::2 + ].T.contiguous() state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.bias"] = gate_up_bias[e, 1::2].contiguous() state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down_proj[e].T.contiguous() From 62ba2240ae90044244fde43c56dbf4289df7c52e Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Sat, 7 Mar 2026 07:39:56 -0800 Subject: [PATCH 5/6] Restore execution_count in Colab_Compatibility notebook Co-Authored-By: Claude Opus 4.6 --- demos/Colab_Compatibility.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 068dd4d24..46b42c2ba 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -1098,7 +1098,7 @@ "mark_models_as_tested(incompatible_models)" ], "outputs": [], - "execution_count": null + "execution_count": 32 }, { "cell_type": "code", From b7b38837d0b38cbb360719b02c5259ab971fbb75 Mon Sep 17 00:00:00 2001 From: Carl Gross Date: Sat, 7 Mar 2026 08:42:16 -0800 Subject: [PATCH 6/6] Replace run_gpt_oss.py with demos/GPT_OSS_Demo.ipynb Convert the standalone script to a Jupyter notebook to match the existing demo format. Adds sections for inference, caching, expert routing analysis, logit lens, attention patterns, and activation patching. Co-Authored-By: Claude Opus 4.6 --- demos/GPT_OSS_Demo.ipynb | 512 +++++++++++++++++++++++++++++++++++++++ run_gpt_oss.py | 319 ------------------------ 2 files changed, 512 insertions(+), 319 deletions(-) create mode 100644 demos/GPT_OSS_Demo.ipynb delete mode 100644 run_gpt_oss.py diff --git a/demos/GPT_OSS_Demo.ipynb b/demos/GPT_OSS_Demo.ipynb new file mode 100644 index 000000000..74de6c04c --- /dev/null +++ b/demos/GPT_OSS_Demo.ipynb @@ -0,0 +1,512 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GPT-OSS-20B Demo\n", + "\n", + "This notebook loads OpenAI's [GPT-OSS-20B](https://huggingface.co/openai/gpt-oss-20b) into TransformerLens for mechanistic interpretability.\n", + "\n", + "GPT-OSS-20B is a Mixture of Experts (MoE) model with:\n", + "- 24 layers, d_model=2880, 32 experts, 4 experts per token\n", + "- MXFP4 quantized weights on HuggingFace (dequantized to BF16 during loading)\n", + "- Custom GLU activation and post-top-k softmax routing\n", + "\n", + "**Memory:** The full model is ~40GB in BF16. This notebook loads directly from safetensors, bypassing the HuggingFace model pipeline to keep peak memory manageable. Use `N_LAYERS` to load fewer layers if needed.\n", + "\n", + "**Requirements:** `transformers`, `safetensors`, `einops`, `psutil`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gc\n", + "import json\n", + "from pathlib import Path\n", + "\n", + "import einops\n", + "import torch\n", + "from safetensors import safe_open\n", + "from transformers import AutoTokenizer\n", + "from transformers.integrations.mxfp4 import convert_moe_packed_tensors\n", + "\n", + "from transformer_lens import HookedTransformer\n", + "from transformer_lens.HookedTransformerConfig import HookedTransformerConfig" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration\n", + "\n", + "Set `N_LAYERS` to control how many layers to load. Each layer is ~1.6GB. Use fewer layers to save memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "N_LAYERS = 24 # Full model. Set to 3-6 for quick testing." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Loading Utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_model_path():\n", + " \"\"\"Get the cached model path, downloading if necessary.\"\"\"\n", + " cache_path = Path.home() / \".cache/huggingface/hub/models--openai--gpt-oss-20b\"\n", + " snapshots = cache_path / \"snapshots\"\n", + "\n", + " if snapshots.exists():\n", + " snapshot_dirs = list(snapshots.iterdir())\n", + " if snapshot_dirs:\n", + " return snapshot_dirs[0]\n", + "\n", + " print(\"Model not found in cache. Downloading...\")\n", + " from huggingface_hub import snapshot_download\n", + "\n", + " return Path(snapshot_download(\"openai/gpt-oss-20b\"))\n", + "\n", + "\n", + "def create_config(n_layers=24):\n", + " \"\"\"Create TransformerLens config for GPT-OSS-20B.\"\"\"\n", + " return HookedTransformerConfig(\n", + " n_layers=n_layers,\n", + " d_model=2880,\n", + " d_head=64,\n", + " n_heads=64,\n", + " d_mlp=2880,\n", + " n_ctx=4096,\n", + " d_vocab=201088,\n", + " act_fn=\"silu\",\n", + " normalization_type=\"RMS\",\n", + " positional_embedding_type=\"rotary\",\n", + " rotary_base=150000,\n", + " eps=1e-5,\n", + " n_key_value_heads=8,\n", + " gated_mlp=True,\n", + " use_local_attn=False,\n", + " rotary_dim=64,\n", + " num_experts=32,\n", + " experts_per_token=4,\n", + " dtype=torch.bfloat16,\n", + " device=\"cpu\",\n", + " original_architecture=\"GptOssForCausalLM\",\n", + " model_name=\"openai/gpt-oss-20b\",\n", + " )\n", + "\n", + "\n", + "_open_files = {}\n", + "\n", + "\n", + "def _get_tensor(hf_name, wmap, model_path):\n", + " \"\"\"Load a single tensor from the correct safetensors shard.\"\"\"\n", + " st_file = wmap[hf_name]\n", + " filepath = str(model_path / st_file)\n", + " if filepath not in _open_files:\n", + " _open_files[filepath] = safe_open(filepath, framework=\"pt\", device=\"cpu\")\n", + " return _open_files[filepath].get_tensor(hf_name)\n", + "\n", + "\n", + "def load_layer_weights(l, cfg, index, model_path):\n", + " \"\"\"Load and convert weights for one transformer layer from safetensors.\"\"\"\n", + " state_dict = {}\n", + " wmap = index[\"weight_map\"]\n", + " prefix = f\"model.layers.{l}\"\n", + "\n", + " def gt(name):\n", + " return _get_tensor(name, wmap, model_path)\n", + "\n", + " # LayerNorms\n", + " state_dict[f\"blocks.{l}.ln1.w\"] = gt(f\"{prefix}.input_layernorm.weight\")\n", + " state_dict[f\"blocks.{l}.ln2.w\"] = gt(f\"{prefix}.post_attention_layernorm.weight\")\n", + "\n", + " # Attention weights\n", + " q_w = gt(f\"{prefix}.self_attn.q_proj.weight\")\n", + " k_w = gt(f\"{prefix}.self_attn.k_proj.weight\")\n", + " v_w = gt(f\"{prefix}.self_attn.v_proj.weight\")\n", + " o_w = gt(f\"{prefix}.self_attn.o_proj.weight\")\n", + "\n", + " state_dict[f\"blocks.{l}.attn.W_Q\"] = einops.rearrange(q_w, \"(n h) m -> n m h\", n=cfg.n_heads)\n", + " state_dict[f\"blocks.{l}.attn._W_K\"] = einops.rearrange(\n", + " k_w, \"(n h) m -> n m h\", n=cfg.n_key_value_heads\n", + " )\n", + " state_dict[f\"blocks.{l}.attn._W_V\"] = einops.rearrange(\n", + " v_w, \"(n h) m -> n m h\", n=cfg.n_key_value_heads\n", + " )\n", + " state_dict[f\"blocks.{l}.attn.W_O\"] = einops.rearrange(\n", + " o_w, \"m (n h) -> n h m\", n=cfg.n_heads\n", + " )\n", + " del q_w, k_w, v_w, o_w\n", + "\n", + " # Attention biases\n", + " q_bias_key = f\"{prefix}.self_attn.q_proj.bias\"\n", + " if q_bias_key in wmap:\n", + " state_dict[f\"blocks.{l}.attn.b_Q\"] = einops.rearrange(\n", + " gt(q_bias_key), \"(n h) -> n h\", n=cfg.n_heads\n", + " )\n", + " state_dict[f\"blocks.{l}.attn._b_K\"] = einops.rearrange(\n", + " gt(f\"{prefix}.self_attn.k_proj.bias\"), \"(n h) -> n h\", n=cfg.n_key_value_heads\n", + " )\n", + " state_dict[f\"blocks.{l}.attn._b_V\"] = einops.rearrange(\n", + " gt(f\"{prefix}.self_attn.v_proj.bias\"), \"(n h) -> n h\", n=cfg.n_key_value_heads\n", + " )\n", + " else:\n", + " state_dict[f\"blocks.{l}.attn.b_Q\"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)\n", + " state_dict[f\"blocks.{l}.attn._b_K\"] = torch.zeros(\n", + " cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype\n", + " )\n", + " state_dict[f\"blocks.{l}.attn._b_V\"] = torch.zeros(\n", + " cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype\n", + " )\n", + "\n", + " o_bias_key = f\"{prefix}.self_attn.o_proj.bias\"\n", + " if o_bias_key in wmap:\n", + " state_dict[f\"blocks.{l}.attn.b_O\"] = gt(o_bias_key)\n", + " else:\n", + " state_dict[f\"blocks.{l}.attn.b_O\"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)\n", + "\n", + " # Router\n", + " state_dict[f\"blocks.{l}.mlp.W_gate.weight\"] = gt(f\"{prefix}.mlp.router.weight\")\n", + " state_dict[f\"blocks.{l}.mlp.W_gate.bias\"] = gt(f\"{prefix}.mlp.router.bias\")\n", + "\n", + " # Expert weights - dequantize MXFP4 to BF16\n", + " gate_up_blocks = gt(f\"{prefix}.mlp.experts.gate_up_proj_blocks\")\n", + " gate_up_scales = gt(f\"{prefix}.mlp.experts.gate_up_proj_scales\")\n", + " gate_up_bias = gt(f\"{prefix}.mlp.experts.gate_up_proj_bias\")\n", + "\n", + " print(f\" Dequantizing layer {l} gate_up_proj...\", end=\"\", flush=True)\n", + " gate_up_proj = convert_moe_packed_tensors(gate_up_blocks, gate_up_scales)\n", + " del gate_up_blocks, gate_up_scales\n", + " print(\" done\")\n", + "\n", + " down_blocks = gt(f\"{prefix}.mlp.experts.down_proj_blocks\")\n", + " down_scales = gt(f\"{prefix}.mlp.experts.down_proj_scales\")\n", + " down_bias = gt(f\"{prefix}.mlp.experts.down_proj_bias\")\n", + "\n", + " print(f\" Dequantizing layer {l} down_proj...\", end=\"\", flush=True)\n", + " down_proj = convert_moe_packed_tensors(down_blocks, down_scales)\n", + " del down_blocks, down_scales\n", + " print(\" done\")\n", + "\n", + " # Split merged expert tensors into per-expert weights\n", + " # Even columns -> gate, Odd columns -> up\n", + " for e in range(cfg.num_experts):\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_gate.weight\"] = gate_up_proj[\n", + " e, :, ::2\n", + " ].T.contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_gate.bias\"] = gate_up_bias[\n", + " e, ::2\n", + " ].contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_in.weight\"] = gate_up_proj[\n", + " e, :, 1::2\n", + " ].T.contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_in.bias\"] = gate_up_bias[\n", + " e, 1::2\n", + " ].contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_out.weight\"] = down_proj[e].T.contiguous()\n", + " state_dict[f\"blocks.{l}.mlp.experts.{e}.W_out.bias\"] = down_bias[e].contiguous()\n", + "\n", + " del gate_up_proj, gate_up_bias, down_proj, down_bias\n", + " return state_dict" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = get_model_path()\n", + "print(f\"Model path: {model_path}\")\n", + "\n", + "with open(model_path / \"model.safetensors.index.json\") as f:\n", + " index = json.load(f)\n", + "\n", + "cfg = create_config(n_layers=N_LAYERS)\n", + "tokenizer = AutoTokenizer.from_pretrained(str(model_path))\n", + "model = HookedTransformer(cfg, tokenizer, move_to_device=False)\n", + "\n", + "# Load embeddings\n", + "wmap = index[\"weight_map\"]\n", + "model.load_state_dict(\n", + " {\"embed.W_E\": _get_tensor(\"model.embed_tokens.weight\", wmap, model_path)},\n", + " strict=False,\n", + ")\n", + "gc.collect()\n", + "\n", + "# Load layers one at a time\n", + "for l in range(N_LAYERS):\n", + " print(f\"Loading layer {l}/{N_LAYERS-1}...\")\n", + " layer_dict = load_layer_weights(l, cfg, index, model_path)\n", + " for key in list(layer_dict.keys()):\n", + " model.load_state_dict({key: layer_dict[key]}, strict=False)\n", + " del layer_dict[key]\n", + " del layer_dict\n", + " gc.collect()\n", + "\n", + "# Load final LayerNorm and unembed\n", + "model.load_state_dict(\n", + " {\"ln_final.w\": _get_tensor(\"model.norm.weight\", wmap, model_path)},\n", + " strict=False,\n", + ")\n", + "model.load_state_dict(\n", + " {\"unembed.W_U\": _get_tensor(\"lm_head.weight\", wmap, model_path).T},\n", + " strict=False,\n", + ")\n", + "model.load_state_dict(\n", + " {\"unembed.b_U\": torch.zeros(cfg.d_vocab, dtype=cfg.dtype)},\n", + " strict=False,\n", + ")\n", + "gc.collect()\n", + "\n", + "print(f\"\\nModel loaded! Layers: {cfg.n_layers}, Experts: {cfg.num_experts}, d_model: {cfg.d_model}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\n", + " \"The capital of France is\",\n", + " \"2 + 2 =\",\n", + " \"The opposite of hot is\",\n", + "]\n", + "\n", + "for prompt in prompts:\n", + " tokens = model.to_tokens(prompt)\n", + " with torch.no_grad():\n", + " logits = model(tokens)\n", + " probs = torch.softmax(logits[0, -1].float(), dim=-1)\n", + " top5 = probs.topk(5)\n", + " print(f\"Prompt: '{prompt}'\")\n", + " for i in range(5):\n", + " token_str = model.to_string(top5.indices[i])\n", + " print(f\" {token_str!r}: {top5.values[i]:.4f}\")\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Caching and Residual Stream" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokens = model.to_tokens(\"The cat sat on the mat\")\n", + "with torch.no_grad():\n", + " logits, cache = model.run_with_cache(tokens)\n", + "\n", + "print(\"Cached activation keys (first 10):\")\n", + "for key in list(cache.keys())[:10]:\n", + " print(f\" {key}: {cache[key].shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Expert Routing Analysis\n", + "\n", + "GPT-OSS routes each token to 4 of 32 experts. We can inspect which experts are selected and their weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "str_tokens = model.to_str_tokens(\"The cat sat on the mat\")\n", + "\n", + "for layer in range(min(N_LAYERS, 3)):\n", + " expert_weights = cache[f\"blocks.{layer}.mlp.hook_expert_weights\"]\n", + " expert_indices = cache[f\"blocks.{layer}.mlp.hook_expert_indices\"]\n", + "\n", + " print(f\"\\nLayer {layer} expert routing:\")\n", + " for t in range(len(str_tokens)):\n", + " indices = expert_indices[t].tolist()\n", + " weights = [expert_weights[t, idx].item() for idx in indices]\n", + " pairs = \", \".join(f\"E{idx}({w:.2f})\" for idx, w in zip(indices, weights))\n", + " print(f\" {str_tokens[t]:>10s} -> {pairs}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Logit Lens\n", + "\n", + "Apply the unembedding matrix to intermediate residual streams to see how the model's prediction evolves across layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"The capital of France is\"\n", + "tokens = model.to_tokens(prompt)\n", + "with torch.no_grad():\n", + " logits, cache = model.run_with_cache(tokens)\n", + "\n", + "print(f\"Prompt: '{prompt}'\")\n", + "print(f\"{'Layer':>6s} {'Top prediction':>20s} {'Prob':>6s}\")\n", + "print(\"-\" * 38)\n", + "\n", + "for layer in range(N_LAYERS):\n", + " resid = cache[f\"blocks.{layer}.hook_resid_post\"]\n", + " normed = model.ln_final(resid)\n", + " layer_logits = normed @ model.unembed.W_U + model.unembed.b_U\n", + " probs = torch.softmax(layer_logits[0, -1].float(), dim=-1)\n", + " top_idx = probs.argmax().item()\n", + " top_word = model.to_string(torch.tensor(top_idx))\n", + " print(f\"{layer:>6d} {top_word:>20s} {probs[top_idx]:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attention Patterns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"The cat sat on the mat\"\n", + "tokens = model.to_tokens(prompt)\n", + "str_tokens = model.to_str_tokens(prompt)\n", + "\n", + "with torch.no_grad():\n", + " logits, cache = model.run_with_cache(tokens)\n", + "\n", + "# Show attention pattern for layer 0, head 0\n", + "pattern = cache[\"blocks.0.attn.hook_pattern\"][0, 0] # [seq, seq]\n", + "print(\"Layer 0, Head 0 attention pattern:\")\n", + "print(f\"{'':>12s}\", \" \".join(f\"{t:>6s}\" for t in str_tokens))\n", + "for i, src_token in enumerate(str_tokens):\n", + " row = \" \".join(f\"{pattern[i, j]:.3f}\" for j in range(len(str_tokens)))\n", + " print(f\"{src_token:>12s} {row}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Activation Patching\n", + "\n", + "Patch the residual stream from a clean run into a corrupted run to measure causal effects." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clean_prompt = \"The capital of France is\"\n", + "corrupt_prompt = \"The capital of Germany is\"\n", + "\n", + "clean_tokens = model.to_tokens(clean_prompt)\n", + "corrupt_tokens = model.to_tokens(corrupt_prompt)\n", + "\n", + "# Get clean activations\n", + "captured_clean = {}\n", + "\n", + "def save_clean(tensor, hook):\n", + " captured_clean[hook.name] = tensor.detach().clone()\n", + "\n", + "with torch.no_grad():\n", + " clean_logits = model.run_with_hooks(\n", + " clean_tokens,\n", + " fwd_hooks=[(f\"blocks.{l}.hook_resid_post\", save_clean) for l in range(N_LAYERS)],\n", + " )\n", + "\n", + "with torch.no_grad():\n", + " corrupt_logits = model(corrupt_tokens)\n", + "\n", + "clean_pred = model.to_string(clean_logits[0, -1].argmax())\n", + "corrupt_pred = model.to_string(corrupt_logits[0, -1].argmax())\n", + "print(f\"Clean prediction: '{clean_pred}'\")\n", + "print(f\"Corrupt prediction: '{corrupt_pred}'\")\n", + "\n", + "# Patch each layer's residual stream and measure effect\n", + "print(f\"\\n{'Layer':>6s} {'Patched prediction':>20s} {'Logit diff':>10s}\")\n", + "print(\"-\" * 42)\n", + "\n", + "for layer in range(N_LAYERS):\n", + " hook_name = f\"blocks.{layer}.hook_resid_post\"\n", + "\n", + " def patch_hook(tensor, hook, clean_act=captured_clean[hook_name]):\n", + " return clean_act\n", + "\n", + " with torch.no_grad():\n", + " patched_logits = model.run_with_hooks(\n", + " corrupt_tokens,\n", + " fwd_hooks=[(hook_name, patch_hook)],\n", + " )\n", + "\n", + " patched_pred = model.to_string(patched_logits[0, -1].argmax())\n", + " diff = (patched_logits[0, -1] - corrupt_logits[0, -1]).abs().sum().item()\n", + " print(f\"{layer:>6d} {patched_pred:>20s} {diff:>10.1f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/run_gpt_oss.py b/run_gpt_oss.py deleted file mode 100644 index 2c43473bd..000000000 --- a/run_gpt_oss.py +++ /dev/null @@ -1,319 +0,0 @@ -"""Load GPT-OSS-20B directly from safetensors into TransformerLens. - -Bypasses the HuggingFace model loading pipeline to avoid doubling memory usage. -The model is ~40GB in BF16 — loading via HF would require ~80GB peak (HF model + state dict). - -Instead, we: -1. Create the TransformerLens model structure (~40GB, filled with empty tensors) -2. Load weights from safetensors one layer at a time -3. Dequantize MXFP4 expert weights on the fly using HF's convert_moe_packed_tensors -4. Copy directly into TL model parameters, freeing temp data immediately - -Peak memory: ~42GB (model + one layer's temp data). Works on a 38.7GB Mac via swap. -""" - -import gc -import json -from pathlib import Path - -import einops -import torch -from safetensors import safe_open -from transformers import AutoTokenizer -from transformers.integrations.mxfp4 import convert_moe_packed_tensors - -from transformer_lens import HookedTransformer -from transformer_lens.HookedTransformerConfig import HookedTransformerConfig - - -def get_model_path(): - """Get the cached model path, downloading if necessary.""" - cache_path = Path.home() / ".cache/huggingface/hub/models--openai--gpt-oss-20b" - snapshots = cache_path / "snapshots" - - if snapshots.exists(): - # Use the first (usually only) snapshot - snapshot_dirs = list(snapshots.iterdir()) - if snapshot_dirs: - return snapshot_dirs[0] - - # Not cached — download - print("Model not found in cache. Downloading...") - from huggingface_hub import snapshot_download - - return Path(snapshot_download("openai/gpt-oss-20b")) - - -def create_config(n_layers=24): - """Create TransformerLens config for GPT-OSS-20B.""" - return HookedTransformerConfig( - n_layers=n_layers, - d_model=2880, - d_head=64, - n_heads=64, - d_mlp=2880, - n_ctx=4096, # Reduced from 131072 to save memory - d_vocab=201088, - act_fn="silu", - normalization_type="RMS", - positional_embedding_type="rotary", - rotary_base=150000, - eps=1e-5, - n_key_value_heads=8, - gated_mlp=True, - use_local_attn=False, - rotary_dim=64, - num_experts=32, - experts_per_token=4, - dtype=torch.bfloat16, - device="cpu", - original_architecture="GptOssForCausalLM", - model_name="openai/gpt-oss-20b", - ) - - -def _get_tensor(hf_name, wmap, model_path, _open_files={}): - """Load a single tensor from the correct safetensors shard.""" - st_file = wmap[hf_name] - filepath = str(model_path / st_file) - if filepath not in _open_files: - _open_files[filepath] = safe_open(filepath, framework="pt", device="cpu") - return _open_files[filepath].get_tensor(hf_name) - - -def load_layer_weights(l, cfg, index, model_path): - """Load and convert weights for one transformer layer from safetensors.""" - state_dict = {} - wmap = index["weight_map"] - prefix = f"model.layers.{l}" - - def gt(name): - return _get_tensor(name, wmap, model_path) - - # LayerNorms - state_dict[f"blocks.{l}.ln1.w"] = gt(f"{prefix}.input_layernorm.weight") - state_dict[f"blocks.{l}.ln2.w"] = gt(f"{prefix}.post_attention_layernorm.weight") - - # Attention weights - q_w = gt(f"{prefix}.self_attn.q_proj.weight") - k_w = gt(f"{prefix}.self_attn.k_proj.weight") - v_w = gt(f"{prefix}.self_attn.v_proj.weight") - o_w = gt(f"{prefix}.self_attn.o_proj.weight") - - state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(q_w, "(n h) m -> n m h", n=cfg.n_heads) - state_dict[f"blocks.{l}.attn._W_K"] = einops.rearrange( - k_w, "(n h) m -> n m h", n=cfg.n_key_value_heads - ) - state_dict[f"blocks.{l}.attn._W_V"] = einops.rearrange( - v_w, "(n h) m -> n m h", n=cfg.n_key_value_heads - ) - state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(o_w, "m (n h) -> n h m", n=cfg.n_heads) - del q_w, k_w, v_w, o_w - - # Attention biases - q_bias_key = f"{prefix}.self_attn.q_proj.bias" - if q_bias_key in wmap: - state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( - gt(q_bias_key), "(n h) -> n h", n=cfg.n_heads - ) - state_dict[f"blocks.{l}.attn._b_K"] = einops.rearrange( - gt(f"{prefix}.self_attn.k_proj.bias"), "(n h) -> n h", n=cfg.n_key_value_heads - ) - state_dict[f"blocks.{l}.attn._b_V"] = einops.rearrange( - gt(f"{prefix}.self_attn.v_proj.bias"), "(n h) -> n h", n=cfg.n_key_value_heads - ) - else: - state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) - state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( - cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype - ) - state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( - cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype - ) - - o_bias_key = f"{prefix}.self_attn.o_proj.bias" - if o_bias_key in wmap: - state_dict[f"blocks.{l}.attn.b_O"] = gt(o_bias_key) - else: - state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) - - # Router - state_dict[f"blocks.{l}.mlp.W_gate.weight"] = gt(f"{prefix}.mlp.router.weight") - state_dict[f"blocks.{l}.mlp.W_gate.bias"] = gt(f"{prefix}.mlp.router.bias") - - # Expert weights — dequantize MXFP4 to BF16 - gate_up_blocks = gt(f"{prefix}.mlp.experts.gate_up_proj_blocks") - gate_up_scales = gt(f"{prefix}.mlp.experts.gate_up_proj_scales") - gate_up_bias = gt(f"{prefix}.mlp.experts.gate_up_proj_bias") - - # Dequantize gate_up_proj: [32, 5760, 90, 16] + [32, 5760, 90] -> [32, 2880, 5760] - print(f" Dequantizing layer {l} gate_up_proj...", end="", flush=True) - gate_up_proj = convert_moe_packed_tensors(gate_up_blocks, gate_up_scales) - del gate_up_blocks, gate_up_scales - print(" done") - - down_blocks = gt(f"{prefix}.mlp.experts.down_proj_blocks") - down_scales = gt(f"{prefix}.mlp.experts.down_proj_scales") - down_bias = gt(f"{prefix}.mlp.experts.down_proj_bias") - - # Dequantize down_proj: [32, 2880, 90, 16] + [32, 2880, 90] -> [32, 2880, 2880] - print(f" Dequantizing layer {l} down_proj...", end="", flush=True) - down_proj = convert_moe_packed_tensors(down_blocks, down_scales) - del down_blocks, down_scales - print(" done") - - # Split merged expert tensors into per-expert weights - # gate_up_proj shape: [num_experts, hidden_size, 2*expert_dim] - # Even columns -> gate, Odd columns -> up - for e in range(cfg.num_experts): - state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[ - e, :, ::2 - ].T.contiguous() - state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[e, ::2].contiguous() - state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[ - e, :, 1::2 - ].T.contiguous() - state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.bias"] = gate_up_bias[e, 1::2].contiguous() - state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down_proj[e].T.contiguous() - state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.bias"] = down_bias[e].contiguous() - - del gate_up_proj, gate_up_bias, down_proj, down_bias - return state_dict - - -def main(): - import argparse - - parser = argparse.ArgumentParser(description="Load GPT-OSS-20B into TransformerLens") - parser.add_argument( - "--layers", - type=int, - default=24, - help="Number of layers to load (default: 24, use fewer to save memory)", - ) - parser.add_argument( - "--prompt", - type=str, - default=None, - help="Custom prompt to test (default: built-in test prompts)", - ) - args = parser.parse_args() - - print("=" * 60) - print("GPT-OSS-20B via TransformerLens (Direct SafeTensors)") - print("=" * 60) - - import psutil - - ram = psutil.virtual_memory() - print(f"\nPyTorch: {torch.__version__}") - print(f"MPS: {torch.backends.mps.is_available()}") - print(f"RAM: {ram.total/1e9:.1f}GB total, {ram.available/1e9:.1f}GB available") - - n_layers = args.layers - if n_layers < 24: - print(f"\nLoading first {n_layers} of 24 layers (reduced memory mode)") - est_gb = 2.4 + n_layers * 1.64 - print(f"Estimated memory: ~{est_gb:.0f}GB") - else: - print(f"\nLoading all 24 layers (~42GB, will use swap on <40GB RAM machines)") - - model_path = get_model_path() - print(f"Model path: {model_path}") - - with open(model_path / "model.safetensors.index.json") as f: - index = json.load(f) - - # Create config - cfg = create_config(n_layers=n_layers) - - # Load tokenizer - print("\nLoading tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(str(model_path)) - - # Create TransformerLens model (allocates parameter storage) - print("Creating TransformerLens model structure...") - model = HookedTransformer(cfg, tokenizer, move_to_device=False) - - # Load embeddings - print("\nLoading embeddings...") - embed_file = str(model_path / index["weight_map"]["model.embed_tokens.weight"]) - with safe_open(embed_file, framework="pt", device="cpu") as f: - embed_w = f.get_tensor("model.embed_tokens.weight") - model.load_state_dict({"embed.W_E": embed_w}, strict=False) - del embed_w - gc.collect() - - # Load layers one at a time - for l in range(n_layers): - print(f"\nLoading layer {l}/{n_layers-1}...") - layer_dict = load_layer_weights(l, cfg, index, model_path) - - # Load into model one key at a time to minimize peak memory - keys = list(layer_dict.keys()) - for key in keys: - model.load_state_dict({key: layer_dict[key]}, strict=False) - del layer_dict[key] - del layer_dict - gc.collect() - - ram = psutil.virtual_memory() - print(f" RAM: {ram.used/1e9:.1f}GB used, {ram.available/1e9:.1f}GB available") - - # Load final LayerNorm and unembed - print("\nLoading final layers...") - final_file = str(model_path / index["weight_map"]["model.norm.weight"]) - with safe_open(final_file, framework="pt", device="cpu") as f: - ln_w = f.get_tensor("model.norm.weight") - unembed_w = f.get_tensor("lm_head.weight").T - - model.load_state_dict({"ln_final.w": ln_w}, strict=False) - del ln_w - model.load_state_dict({"unembed.W_U": unembed_w}, strict=False) - del unembed_w - model.load_state_dict({"unembed.b_U": torch.zeros(cfg.d_vocab, dtype=cfg.dtype)}, strict=False) - gc.collect() - - print("\n" + "=" * 60) - print("Model loaded successfully!") - print(f"Architecture: {cfg.original_architecture}") - print(f"Layers: {cfg.n_layers}") - print(f"Experts: {cfg.num_experts}") - print(f"d_model: {cfg.d_model}") - - ram = psutil.virtual_memory() - print(f"RAM: {ram.used/1e9:.1f}GB used, {ram.available/1e9:.1f}GB available") - - # Test inference - if args.prompt: - prompts = [args.prompt] - else: - prompts = [ - "The capital of France is", - "2 + 2 =", - "The opposite of hot is", - ] - - for prompt in prompts: - print(f"\n{'='*60}") - print(f"Prompt: '{prompt}'") - tokens = model.to_tokens(prompt) - with torch.no_grad(): - logits = model(tokens) - pred = model.to_string(logits[0, -1].argmax()) - print(f"Prediction: '{pred}'") - - # Show top 5 predictions - probs = torch.softmax(logits[0, -1].float(), dim=-1) - top5 = probs.topk(5) - print("Top 5:") - for i in range(5): - token_str = model.to_string(top5.indices[i]) - print(f" {token_str!r}: {top5.values[i]:.4f}") - - print(f"\n{'='*60}") - print("Done!") - - -if __name__ == "__main__": - main()