This is the code repository accompanying the paper
Lange, R. D., Rolnick, D. S., and Kording, K. (2022) "Clustering units in neural networks: upstream vs downstream information." TMLR. https://openreview.net/forum?id=Euf7KofunK
PyTorch models are defined in models/mnist.py and models/cifar10.py. Models are wrapped by a Pytorch Lightning
module models.LitWrapper, which handles loading a specific model or dataset. Training is done by train.py, which
is called for a range of hyperparameter configurations by train.sh.
Training needs to be run before moving on to step 2.
As detailed in the paper, we analyze "modularity" of a set of units (e.g. all units in a layer) by
- computing pairwise similarity scores of units
- clustering units together by maximizing the Q score from Newman (2006).
Step 1 is done by functions in associations.py and step 2 is done by functions in modularity.py.
Running eval.py does the following:
- loads a model from a checkpoint
- computes a variety of performance statistics such as validation accuracy, weight norms, etc
- computes a variety of modularity statistics by calling functions from
associations.pyandmodularity.py - saves results back into the same checkpoint file
The file eval.sh is a shell script that demonstrates how we call eval.py for each checkpoint in a directory.
As mentioned above, eval.py loads a checkpoint, computes a variety of statistics including modules (clusters), and
saves the result back into the checkpoint file. This means that eval.sh needs to be run on a set of checkpoints before
notebooks can be run to plot the results. The file analysis.py handles the process of loading statistics computed by
eval.py into a pandas DataFrame.
The notebook notebooks/analysis_sandbox.ipynb was used to generate most figures in the paper. This notebook's structure
primarily involves calling analysis.load_data_as_table() to load precomputed information from a set of checkpoints
into a DataFrame, then the rest is a variety of ways of slicing and plotting the results.