This is an implementation of the paper "Model Agnostic Meta Learning" (MAML) paper. I also plan to implement "How to train your MAML" (MAML++).
MAML - Model Agnostic Meta Learning approaches the Few Shot Learning (FSL) problem by using the "prior knowledge about learning approach".
- There is a meta learner and task specific learners. The task specific learner at each forward pass clone the parameters of the meta learner.
- Then each of the task specific learners update their paramaters by undergoing some steps of gradient descent (usually 4-5); learning from using examples from the
support set
. This is calledacquiring fast knowledge
. - Then each of the task specific learner give predictions for the
query set
of the specific task and hence a loss for each task is calculated by using predictions from the task specific models(θ*)
. - This loss is then used to update the parameter of the meta learner
(θ)
which hence learners task agnostic knowledge and is calledslow task agnostic learning
Figure 1: MAML which optimizes for a representation θ that can quickly adapt to new tasks. Taken from "Model Agnostic Meta Learning"
You must agree to terms and conditions in the LICENSE file to use this code. If you choose to use the Mini-Imagenet dataset, you must abide by the terms and conditions in the ImageNet LICENSE
I recommend using the conda package management library. More specifically, miniconda3, as it is lightweight and fast to install. If you have an existing miniconda3 installation please start at step 2:
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
- Go through the installation and activate conda
make conda-update
conda activate maml_pytorch_lightning
make pip-tools
- add
export PYTHONPATH=.:$PYTHONPATH
to your~/.bashrc
andsource ~/.bashrc
Using Mini ImageNet from Antreas Antoniou GitHub. The download and setup are done via code.
Note: By downloading and using the mini-imagenet datasets, you accept terms and conditions found in imagenet_license.md
Dataset
||
___||_________
| | |
Train Val Test
|_________________________
| | |
class_0 class_1 ... class_N
| |___________________
| |
samples for class_0 samples for class_1
Code Structure is inspired from fsdl course.
-
few_shot_image_classification
: This folder consists ofmodel
,data
andtrainer
using pytorch lightning.a.
models
: Any new model architecture should be defined here.b.
data
: Data downloading, processing, loading, batching etc. code related to any new dataset must be here. We usepl.LightningDataModule
for organised code.c.
lit_models
: Training related code using pytorch lightning is written here. This involves defining training_step, optimizers etc. We usepl.LightningModule
for harnessing all benefits of Lightning. -
training
: Code related to configuring and running experiments sit here. -
tasks
: Tasks like Linting and tests are setup for running during CI build. -
requirements
: All dev and prod dependencies are specified in.in
files and used bypip-tools
. -
Additonal files related to
wandb
sweeps ,CircleCI
setup andLinting
is also present.
-
Setup the environment following the above instructions
-
Run
python training/run_experiment.py --gpus=-1 --wandb --data_class=MiniImagenet --model_class=ResNetClassifier --num_workers=6 --accelerator=ddp --val_check_interval=5 --batch_size=3 --inner_loop_steps=1 --support_samples=5 --track_grad_norm=2
-
Change
data_class
andmodel_class
for new data sources and models respectively. -
Training was run on
Standard_NC12_Promo
. Keep in mind the following params if you encounterCUDA OOM ERROR
.a.
inner_loop_steps
: Since we backprop through all the stages the inner model updates during training, the memory increases ~linearly with this.b.
batch_size , support_samples and query_samples
: The number of images in a batch arebatch_size * episode_classes * (support_samples + query_samples)
which can be typically large. So hence consider batch_size as number of tasks you want to train MAML across.
The initial version of MAML suffers from gradient explosion when the internal loop optimizer is a complex optimizer like Adam involving exponential gradients. On using a simple gradient descent for inner loop training, the gradient explosion is solved. However the loss curves clearly show overfitting.
Figure 2: Loss curves which clearly point to overfitting.
Figure 3: As clearly visible from gradient histogram over time. The gradients have become zero and model has overfitted on training set.
- Solve over-fitting in the training of MAML and check for bugs in data preparation, batch_norm, train/eval flag etc.
- Include Adam like optimizer in the inner training loop of MAML.
- Include first order approximation and other improvements from MAML++.