This repository contains a Pytorch implementation of the paper The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.
pip3 install -r requirements.txt
python3 main.py --prune_type=lt --arch_type=fc1 --dataset=mnist --prune_percent=10 --prune_iterations=35
--prune_type: Type of pruning- Options :
lt- Lottery Ticket Hypothesis,reinit- Random reinitialization - Default :
lt
- Options :
--arch_type: Type of architecture- Options :
fc1- Simple fully connected network,lenet5- LeNet5,AlexNet- AlexNet,resnet18- Resnet18,vgg16- VGG16 - Default :
fc1
- Options :
--dataset: Choice of dataset- Options :
mnist,fashionmnist,cifar10,cifar100 - Default :
mnist
- Options :
--prune_percent: Percentage of weight to be pruned after each cycle.- Default :
10
- Default :
--prune_iterations: Number of cycle of pruning that should be done.- Default :
35
- Default :
--lr: Learning rate- Default :
1.2e-3
- Default :
--batch_size: Batch size- Default :
60
- Default :
--end_iter: Number of Epochs- Default :
100
- Default :
--print_freq: Frequency for printing accuracy and loss- Default :
1
- Default :
--valid_freq: Frequency for Validation- Default :
1
- Default :
--gpu: Decide Which GPU the program should use- Default :
0
- Default :
- Adding a new architecture :
- For example, if you want to add an architecture named
new_modelwithmnistdataset compatibility.- Go to
/archs/mnist/directory and create a filenew_model.py. - Now paste your Pytorch compatible model inside
new_model.py. - IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your
new_model.pymatches with the corresponding dataset that you are adding (in this case, it ismnist). - Now open
main.pyand go toline 36and look for the comment# Data Loader. Now find your corresponding dataset (in this case,mnist) and addnew_modelat the end of the linefrom archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet. - Now go to
line 82and add the following to it :Here,elif args.arch_type == "new_model": model = new_model.new_model_name().to(device)new_model_name()is the name of the model that you have given insidenew_model.py.
- Go to
- For example, if you want to add an architecture named
- Adding a new dataset :
- For example, if you want to add a dataset named
new_datasetwithfc1architecture compatibility.- Go to
/archsand create a directory namednew_dataset. - Now go to /archs/new_dataset/
and add a file namedfc1.py` or copy paste it from existing dataset folder. - IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your
new_model.pymatches with the corresponding dataset that you are adding (in this case, it isnew_dataset). - Now open
main.pyand gotoline 58and add the following to it :Note that as of now, you can only add dataset that are natively available in Pytorch.elif args.dataset == "cifar100": traindataset = datasets.new_dataset('../data', train=True, download=True, transform=transform) testdataset = datasets.new_dataset('../data', train=False, transform=transform)from archs.new_dataset import fc1
- Go to
- For example, if you want to add a dataset named
- Go to
combine_plots.pyand add/remove the datasets/archs who's combined plot you want to generate (Assuming that you have already executed themain.pycode for those dataset/archs and produced the weights). - Run
python3 combine_plots.py. - Go to
/plots/lt/combined_plots/to see the graphs.
Kindly raise an issue if you have any problem with the instructions.
| fc1 | LeNet5 | AlexNet | VGG16 | Resnet18 | |
|---|---|---|---|---|---|
| MNIST | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| CIFAR10 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| FashionMNIST | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| CIFAR100 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
Lottery-Ticket-Hypothesis-in-Pytorch
├── archs
│ ├── cifar10
│ │ ├── AlexNet.py
│ │ ├── densenet.py
│ │ ├── fc1.py
│ │ ├── LeNet5.py
│ │ ├── resnet.py
│ │ └── vgg.py
│ ├── cifar100
│ │ ├── AlexNet.py
│ │ ├── fc1.py
│ │ ├── LeNet5.py
│ │ ├── resnet.py
│ │ └── vgg.py
│ └── mnist
│ ├── AlexNet.py
│ ├── fc1.py
│ ├── LeNet5.py
│ ├── resnet.py
│ └── vgg.py
├── combine_plots.py
├── dumps
├── main.py
├── plots
├── README.md
├── requirements.txt
├── saves
└── utils.py
Parts of code were borrowed from ktkth5.
Open a new issue or do a pull request incase you are facing any difficulty with the code base or if you want to contribute to it.
