Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

Latest commit

 

History

History
154 lines (130 loc) · 8.3 KB

README.md

File metadata and controls

154 lines (130 loc) · 8.3 KB

I-JEPA

Official PyTorch codebase for I-JEPA (the Image-based Joint-Embedding Predictive Architecture) published @ CVPR-23. [arXiv] [JEPAs] [blogpost]

Method

I-JEPA is a method for self-supervised learning. At a high level, I-JEPA predicts the representations of part of an image from the representations of other parts of the same image. Notably, this approach learns semantic image features:

  1. without relying on pre-specified invariances to hand-crafted data transformations, which tend to be biased for particular downstream tasks,
  2. and without having the model fill in pixel-level details, which tend to result in learning less semantically meaningful representations.

ijepa

Visualizations

As opposed to generative methods that have a pixel decoder, I-JEPA has a predictor that makes predictions in latent space. The predictor in I-JEPA can be seen as a primitive (and restricted) world-model that is able to model spatial uncertainty in a static image from a partially observable context. This world model is semantic in the sense that it predicts high level information about unseen regions in the image, rather than pixel-level details.

We trained a stochastic decoder that maps the I-JEPA predicted representations back in pixel space as sketches. The model correctly captures positional uncertainty and produces high-level object parts with the correct pose (e.g., dog’s head, wolf’s front legs).

ijepa-predictor-sketch Caption: Illustrating how the predictor learns to model the semantics of the world. For each image, the portion outside of the blue box is encoded and given to the predictor as context. The predictor outputs a representation for what it expects to be in the region within the blue box. To visualize the prediction, we train a generative model that produces a sketch of the contents represented by the predictor output, and we show a sample output within the blue box. The predictor recognizes the semantics of what parts should be filled in (the top of the dog’s head, the bird’s leg, the wolf’s legs, the other side of the building).

Evaluations

I-JEPA pretraining is also computationally efficient. It does not involve any overhead associated with applying more computationally intensive data augmentations to produce multiple views. Only one view of the image needs to be processed by the target encoder, and only the context blocks need to be processed by the context encoder. Empirically, I-JEPA learns strong off-the-shelf semantic representations without the use of hand-crafted view augmentations.

1percenteval lineareval

Pretrained models

arch. patch size resolution epochs data download
ViT-H 14x14 224x224 300 ImageNet-1K full checkpoint logs configs
ViT-H 16x16 448x448 300 ImageNet-1K full checkpoint logs configs
ViT-H 14x14 224x224 66 ImageNet-22K full checkpoint logs configs
ViT-g 16x16 224x224 44 ImageNet-22K full checkpoint logs configs

Code Structure

.
├── configs                   # directory in which all experiment '.yaml' configs are stored
├── src                       # the package
│   ├── train.py              #   the I-JEPA training loop
│   ├── helper.py             #   helper functions for init of models & opt/loading checkpoint
│   ├── transforms.py         #   pre-train data transforms
│   ├── datasets              #   datasets, data loaders, ...
│   ├── models                #   model definitions
│   ├── masks                 #   mask collators, masking utilities, ...
│   └── utils                 #   shared utilities
├── main_distributed.py       # entrypoint for launch distributed I-JEPA pretraining on SLURM cluster
└── main.py                   # entrypoint for launch I-JEPA pretraining locally on your machine

Config files: Note that all experiment parameters are specified in config files (as opposed to command-line-arguments). See the configs/ directory for example config files.

Launching I-JEPA pretraining

Single-GPU training

This implementation starts from the main.py, which parses the experiment config file and runs the pre-training locally on a multi-GPU (or single-GPU) machine. For example, to run I-JEPA pretraining on GPUs "0","1", and "2" on a local machine using the config configs/in1k_vith14_ep300.yaml, type the command:

python main.py \
  --fname configs/in1k_vith14_ep300.yaml \
  --devices cuda:0 cuda:1 cuda:2

Note: This example is just used for illustrative purposes, as the ViT-H/14 config should be run on 16 A100 80G GPUs for an effective batch-size of 2048, in order to reproduce our results.

Multi-GPU training

In the multi-GPU setting, the implementation starts from main_distributed.py, which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source submitit tool and provide examples for a SLURM cluster.

For example, to pre-train on 16 A100 80G GPUs using the pre-training experiment configs specificed inside configs/in1k_vith14_ep300.yaml, type the command:

python main_distributed.py \
  --fname configs/in1k_vith14_ep300.yaml \
  --folder $path_to_save_submitit_logs \
  --partition $slurm_partition \
  --nodes 2 --tasks-per-node 8 \
  --time 1000

Requirements

  • Python 3.8 (or newer)
  • PyTorch 2.0
  • torchvision
  • Other dependencies: pyyaml, numpy, opencv, submitit

License

See the LICENSE file for details about the license under which this code is made available.

Citation

If you find this repository useful in your research, please consider giving a star ⭐ and a citation

@article{assran2023self,
  title={Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture},
  author={Assran, Mahmoud and Duval, Quentin and Misra, Ishan and Bojanowski, Piotr and Vincent, Pascal and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas},
  journal={arXiv preprint arXiv:2301.08243},
  year={2023}
}