A simple and efficient PyTorch implementation of Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture (I-JEPA).
The model was pre-trained on 100,000 unlabeled images from the STL-10 dataset. For evaluation, I trained and tested logistic regression on frozen features obtained from 5k train images and evaluated on 8k test images, also from the STL-10 dataset.
Linear probing was used for evaluating on features extracted from encoders using the scikit LogisticRegression model. Image resolution was 96x96.
More detailed evaluation steps and results for STL10 can be found in the notebooks directory.
| Dataset | Approach | Encoder | Emb. dim | Patch size | Num. targets | Batch size | Epochs | Top 1% |
|---|---|---|---|---|---|---|---|---|
| STL10 | I-JEPA | VisionTransformer | 512 | 8 | 4 | 256 | 100 | 77.07 |
All experiments were done using a very small and shallow VisionTransformer (only 11M params) with following parameters:
- embbeding dimension -
512 - depth (number of transformers layers) -
6 - number of heads -
6 - mlp dim -
2 * embedding dimension - patch size -
8 - number of targets -
4
The mask generator is inspired by the original paper, but sligthly simplified.
To setup the code, clone the repository, optionally create a venv and install requirements:
git clone git@github.com:filipbasara0/simple-ijepa.git- create virtual environment:
virtualenv -p python3.10 env - activate virtual environment:
source env/bin/activate - install requirements:
pip install .
STL-10 model was trained with this command:
python run_training.py --fp16_precision --log_every_n_steps 200 --num_epochs 100 --batch_size 256
Once the code is setup, run the following command with optinos listed below:
python run_training.py [args...]⬇️
I-JEPA
options:
-h, --help show this help message and exit
--dataset_path DATASET_PATH
Path where datasets will be saved
--dataset_name {stl10}
Dataset name
-save_model_dir SAVE_MODEL_DIR
Path where models
--num_epochs NUM_EPOCHS
Number of epochs for training
-b BATCH_SIZE, --batch_size BATCH_SIZE
Batch size
-lr LEARNING_RATE, --learning_rate LEARNING_RATE
-wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY
--fp16_precision Whether to use 16-bit precision for GPU training
--emb_dim EMB_DIM Transofmer embedding dimm
--log_every_n_steps LOG_EVERY_N_STEPS
Log every n steps
--gamma GAMMA Initial EMA coefficient
--update_gamma_after_step UPDATE_GAMMA_AFTER_STEP
Update EMA gamma after this step
--update_gamma_every_n_steps UPDATE_GAMMA_EVERY_N_STEPS
Update EMA gamma after this many steps
--ckpt_path CKPT_PATH
Specify path to ijepa_model.pth to resume training
@misc{assran2023selfsupervisedlearningimagesjointembedding,
title={Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture},
author={Mahmoud Assran and Quentin Duval and Ishan Misra and Piotr Bojanowski and Pascal Vincent and Michael Rabbat and Yann LeCun and Nicolas Ballas},
year={2023},
eprint={2301.08243},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2301.08243},
}