|
1 |
| -# fewshot-egnn |
| 1 | +# fewshot-egnn |
| 2 | + |
| 3 | +PyTorch implementation of the following paper: |
| 4 | + |
| 5 | + "Edge-labeling Graph Neural Network for Few-shot Learning", CVPR 2019 [arXiv link] |
| 6 | + |
| 7 | +# Platform |
| 8 | +- pytorch 0.4.1, python 3 |
| 9 | + |
| 10 | +## Setting |
| 11 | +- In ```data.py```, replace the dataset root directory with your own: |
| 12 | + |
| 13 | + root_dir = '/mnt/hdd/jmkim/maml_pytorch/asset/data/miniImagenet/' |
| 14 | + |
| 15 | +- For resnet experiment, download the pre-trained 64-way cls models from the following link: |
| 16 | + https://drive.google.com/open?id=1pic_LWnRUP1IaGJLvujF-0k9WtSHPW_Y |
| 17 | + |
| 18 | + and place it under ./asset/pre-trained/ |
| 19 | + |
| 20 | +## Supervised few-shot classification |
| 21 | +``` |
| 22 | +# miniImagenet, 5way 1shot, non-transductive |
| 23 | +$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive False |
| 24 | +# miniImagenet, 5way 1shot, transductive |
| 25 | +$ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive True |
| 26 | +
|
| 27 | +# miniImagenet, 5way 5shot, non-transductive |
| 28 | +$ python3 trainer.py --dataset mini --num_ways 5 --num_shots 5 --trainsductive False |
| 29 | +# miniImagenet, 5way 5shot, transductive |
| 30 | +$ python3 trainer.py --dataset mini --num_ways 5 --num_shots 5 --trainsductive True |
| 31 | +# miniImagenet, 10way 5shot, transductive |
| 32 | +$ python3 trainer.py --dataset mini --num_ways 10 --num_shots 5 --meta_batch_size 20 --trainsductive True |
| 33 | +
|
| 34 | +# tieredImagenet, 5way 1shot, non-transductive |
| 35 | +$ python3 train.py --dataset tiered --num_ways 5 --num_shots 1 --meta_batch_size 100 --transductive False |
| 36 | +# miniImagenet, 5way 1shot, transductive |
| 37 | +$ python3 train.py --dataset tiered --num_ways 5 --num_shots 1 --meta_batch_size 100 --transductive True |
| 38 | +
|
| 39 | +# tieredImagenet, 5way 5shot, non-transductive |
| 40 | +$ python3 trainer.py --dataset tiered --num_ways 5 --num_shots 5 --trainsductive False |
| 41 | +# miniImagenet, 5way 5shot, transductive |
| 42 | +$ python3 trainer.py --dataset tiered --num_ways 5 --num_shots 5 --trainsductive True |
| 43 | +
|
| 44 | +``` |
| 45 | + |
| 46 | +## Semi-supervsied |
| 47 | +``` |
| 48 | +# miniImagenet, 5way 5shot, 20% labeled, transductive |
| 49 | +$ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive True |
| 50 | +
|
| 51 | +``` |
| 52 | + |
| 53 | +### Adapt Metric_NN, while Enc_NN is updated only in outer-loop |
| 54 | +``` |
| 55 | +# 5-way 5-shot, initilized with 5-way 5-shot pre-trained model (enc_nn + metric_nn) |
| 56 | +$ python3 trainer.py --config asset/config/mini-gnn-maml-N5S5-N5S5init-joint.ini --reinit 1 |
| 57 | +``` |
| 58 | + |
| 59 | +## Training (resnet-18, resnet-50) |
| 60 | +### Adapt Metric_NN, while Enc_NN is fixed |
| 61 | +``` |
| 62 | +# 5-way 5-shot, initilized with 64-way cls pre-trained enc_nn model (metric_nn is trained from scratch!) |
| 63 | +$ python3 trainer.py --config asset/config/mini-resnet18-gnn-maml-N5S5-64wayinit-scratch.ini --reinit 1 |
| 64 | +
|
| 65 | +# TODO: 5-way 5-shot, initilized with 5-way 5-shot pre-trained model (enc_nn + metric_nn) |
| 66 | +
|
| 67 | +``` |
| 68 | +### Adapt Metric_NN, while Enc_NN is updated only in outer-loop |
| 69 | +``` |
| 70 | +# 5-way 5-shot, initilized with 64-way cls pre-trained enc_nn model (metric_nn is trained from scratch!) |
| 71 | +$ python3 trainer.py --config asset/config/mini-resnet18-gnn-maml-N5S5-64wayinit-scratch-joint.ini --reinit 1 |
| 72 | +
|
| 73 | +# TODO: 5-way 5-shot, initilized with 5-way 5-shot pre-trained model (enc_nn + metric_nn) |
| 74 | +``` |
| 75 | + |
| 76 | +## Result |
| 77 | +- MiniImagenet, 5-way, 4convblock |
| 78 | + |
| 79 | +| Model | | | | |5-way Acc.| | | | | |
| 80 | +|-------------------------------------|----|---|-----|---|----------|--|--|-----|-----| |
| 81 | +| | |1-shot| || 2-shot|| |5-shot | | |
| 82 | +| |train|val|test|train|val|test|train|val|test| |
| 83 | +| MAML | -|-|48.70 | -|-|-| -|-|63.11 | |
| 84 | +| GNN | -|-|50.33 | -|-|-| -|-|66.41 | |
| 85 | +| MAML (our implementation) | 51.29|45.24|44.58 | 63.93|52.57|52.55| 74.50|60.99|61.97 | |
| 86 | +| GNN (our implementation) | 62.80|47.62|44.64 | 76.40|54.48|51.41| 82.00|60.37|60.45 | |
| 87 | +| GNN + MAML (N5S1init) | -|-|- | 69.64|52.59|50.63| 73.58|57.86|56.78 | |
| 88 | +| GNN + MAML (N5S2init) | -|-|- | 74.25|54.36|51.05| -|-|- | |
| 89 | +| GNN + MAML (N5S5init) | -|-|- | -|-|- | 81.83|60.13|59.23 | |
| 90 | + |
0 commit comments