The nnsight package enables interpreting and manipulating the internals of deep learned models. Read our paper!
pip install nnsightfrom nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='auto', dispatch=True)
with model.trace('The Eiffel Tower is in the city of'):
# Access and save hidden states
hidden_states = model.transformer.h[-1].output[0].save()
# Intervene on activations
model.transformer.h[0].output[0][:] = 0
# Get model output
output = model.output.save()
print(model.tokenizer.decode(output.logits.argmax(dim=-1)[0]))with model.trace("The Eiffel Tower is in the city of"):
# Access attention output
attn_output = model.transformer.h[0].attn.output[0].save()
# Access MLP output
mlp_output = model.transformer.h[0].mlp.output.save()
# Access any layer's output (access in execution order)
layer_output = model.transformer.h[5].output[0].save()
# Access final logits
logits = model.lm_head.output.save()Note: GPT-2 transformer layers return tuples where index 0 contains the hidden states.
with model.trace("Hello"):
# Zero out all activations
model.transformer.h[0].output[0][:] = 0
# Modify specific positions
model.transformer.h[0].output[0][:, -1, :] = 0 # Last token onlywith model.trace("Hello"):
# Add noise to activations
hs = model.transformer.h[-1].mlp.output.clone()
noise = 0.01 * torch.randn(hs.shape)
model.transformer.h[-1].mlp.output = hs + noise
result = model.transformer.h[-1].mlp.output.save()Process multiple inputs in one forward pass. Each invoke runs its code in a separate worker thread:
- Threads execute serially (no race conditions)
- Each thread waits for values via
.output,.input, etc. - Invokes run in the order they're defined
- Cross-invoke references work because threads run sequentially
- Within an invoke, access modules in execution order only
with model.trace() as tracer:
# First invoke: worker thread 1
with tracer.invoke("The Eiffel Tower is in"):
embeddings = model.transformer.wte.output # Thread waits here
output1 = model.lm_head.output.save()
# Second invoke: worker thread 2 (runs after thread 1 completes)
with tracer.invoke("_ _ _ _ _ _"):
model.transformer.wte.output = embeddings # Uses value from thread 1
output2 = model.lm_head.output.save()Use .invoke() with no arguments to operate on the entire batch:
with model.trace() as tracer:
with tracer.invoke("Hello"):
out1 = model.lm_head.output[:, -1].save()
with tracer.invoke(["World", "Test"]):
out2 = model.lm_head.output[:, -1].save()
# No-arg invoke: operates on ALL 3 inputs
with tracer.invoke():
out_all = model.lm_head.output[:, -1].save() # Shape: [3, vocab]Use .generate() for autoregressive generation:
with model.generate("The Eiffel Tower is in", max_new_tokens=3) as tracer:
output = model.generator.output.save()
print(model.tokenizer.decode(output[0]))
# "The Eiffel Tower is in the city of Paris"with model.generate("Hello", max_new_tokens=5) as tracer:
logits = list().save()
# Iterate over all generation steps
with tracer.iter[:]:
logits.append(model.lm_head.output[0][-1].argmax(dim=-1))
print(model.tokenizer.batch_decode(logits))with model.generate("Hello", max_new_tokens=5) as tracer:
outputs = list().save()
with tracer.iter[:] as step_idx:
if step_idx == 2:
model.transformer.h[0].output[0][:] = 0 # Only on step 2
outputs.append(model.transformer.h[-1].output[0])Gradients are accessed on tensors (not modules), only inside a with tensor.backward(): context:
with model.trace("Hello"):
hs = model.transformer.h[-1].output[0]
hs.requires_grad_(True)
logits = model.lm_head.output
loss = logits.sum()
with loss.backward():
grad = hs.grad.save()
print(grad.shape)Create persistent model modifications:
# Create edited model (non-destructive)
with model.edit() as model_edited:
model.transformer.h[0].output[0][:] = 0
# Original model unchanged
with model.trace("Hello"):
out1 = model.transformer.h[0].output[0].save()
# Edited model has modification
with model_edited.trace("Hello"):
out2 = model_edited.transformer.h[0].output[0].save()
assert not torch.all(out1 == 0)
assert torch.all(out2 == 0)Get shapes without running the full model:
with model.scan("Hello"):
dim = model.transformer.h[0].output[0].shape[-1]
print(dim) # 768Automatically cache outputs from modules:
with model.trace("Hello") as tracer:
cache = tracer.cache()
# Access cached values
layer0_out = cache['model.transformer.h.0'].output
print(cache.model.transformer.h[0].output[0].shape)Group multiple traces for efficiency:
with model.session() as session:
with model.trace("Hello"):
hs1 = model.transformer.h[0].output[0].save()
with model.trace("World"):
model.transformer.h[0].output[0][:] = hs1 # Use value from first trace
hs2 = model.transformer.h[0].output[0].save()Run on NDIF's remote infrastructure:
from nnsight import CONFIG
CONFIG.set_default_api_key("YOUR_API_KEY")
model = LanguageModel("meta-llama/Meta-Llama-3.1-8B")
with model.trace("Hello", remote=True):
hidden_states = model.model.layers[-1].output.save()Check available models at nnsight.net/status
High-performance inference with vLLM:
from nnsight.modeling.vllm import VLLM
model = VLLM("gpt2", tensor_parallel_size=1, dispatch=True)
with model.trace("Hello", temperature=0.0, max_tokens=5) as tracer:
logits = list().save()
with tracer.iter[:]:
logits.append(model.logits.output)Use NNsight for arbitrary PyTorch models:
from nnsight import NNsight
import torch
net = torch.nn.Sequential(
torch.nn.Linear(5, 10),
torch.nn.Linear(10, 2)
)
model = NNsight(net)
with model.trace(torch.rand(1, 5)):
layer1_out = model[0].output.save()
output = model.output.save()Access intermediate operations inside a module's forward pass. .source rewrites the forward method to hook into all operations:
# Discover available operations
print(model.transformer.h[0].attn.source)
# Shows forward method with operation names like:
# attention_interface_0 -> 66 attn_output, attn_weights = attention_interface(...)
# self_c_proj_0 -> 79 attn_output = self.c_proj(attn_output)
# Access operation values
with model.trace("Hello"):
attn_out = model.transformer.h[0].attn.source.attention_interface_0.output.save()Apply modules out of their normal execution order:
with model.trace("The Eiffel Tower is in the city of"):
# Get intermediate hidden states
hidden_states = model.transformer.h[-1].output[0]
# Apply lm_head to get "logit lens" view
logits = model.lm_head(model.transformer.ln_f(hidden_states))
tokens = logits.argmax(dim=-1).save()
print(model.tokenizer.decode(tokens[0]))NNsight uses deferred execution with thread-based synchronization:
- Code extraction: When you enter a
with model.trace(...)block, nnsight captures your code (via AST) and immediately exits the block - Thread execution: Your code runs in a separate worker thread
- Value waiting: When you access
.output, the thread waits until the model provides that value - Hook-based injection: The model uses PyTorch hooks to provide values to waiting threads
with model.trace("Hello"):
# Code runs in a worker thread
# Thread WAITS here until layer output is available
hs = model.transformer.h[-1].output[0]
# .save() marks the value to persist after the context exits
hs = hs.save()
# After exiting, hs contains the actual tensor
print(hs.shape) # torch.Size([1, 2, 768])Key insight: Your code runs directly. When you access .output, you get the real tensor - your thread just waits for it to be available.
Important: Within an invoke, you must access modules in execution order. Accessing layer 5's output before layer 2's output will cause a deadlock (layer 2 has already been executed).
Every module has these special properties. Accessing them causes the worker thread to wait for the value:
| Property | Description |
|---|---|
.output |
Module's forward pass output (thread waits) |
.input |
First positional argument to the module |
.inputs |
All inputs as (args_tuple, kwargs_dict) |
Note: .grad is accessed on tensors (not modules), only inside a with tensor.backward(): context.
Print the model to see its structure:
print(model)
# GPT2LMHeadModel(
# (transformer): GPT2Model(
# (h): ModuleList(
# (0-11): 12 x GPT2Block(
# (attn): GPT2Attention(...)
# (mlp): GPT2MLP(...)
# )
# )
# )
# (lm_head): Linear(...)
# )Find more examples and tutorials at nnsight.net
If you use nnsight in your research, please cite:
@article{fiottokaufman2024nnsightndifdemocratizingaccess,
title={NNsight and NDIF: Democratizing Access to Foundation Model Internals},
author={Jaden Fiotto-Kaufman and Alexander R Loftus and Eric Todd and Jannik Brinkmann and Caden Juang and Koyena Pal and Can Rager and Aaron Mueller and Samuel Marks and Arnab Sen Sharma and Francesca Lucchetti and Michael Ripa and Adam Belfki and Nikhil Prakash and Sumeet Multani and Carla Brodley and Arjun Guha and Jonathan Bell and Byron Wallace and David Bau},
year={2024},
eprint={2407.14561},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.14561},
}