Skip to content

Commit bceaf6a

Browse files
committed
Update README.md and added configuration files
1 parent 9668d5b commit bceaf6a

File tree

120 files changed

+4136
-130
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

120 files changed

+4136
-130
lines changed

Diff for: .gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__pycache__
22
.ipynb_checkpoints
33
cache
4-
wikitext-103
54
.idea/
5+
results

Diff for: README.md

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# MicroNet: Team MIT-HAN-Lab
2+
3+
## News
4+
Our work has been accepted by PMLR!
5+
6+
Hanrui and Zhongxia gave a [talk](https://slideslive.com/38922007/competition-track-day-13)(Start from 33:35) on the challenge in NeurIPS 2019, Vancouver.
7+
8+
<img src="neurips_micronet_challenge/talk_photo.png" alt="drawing" width="240"/>
9+
10+
## Introduction
11+
This codebase provides the code, configurations, and commands for our submission to PMLR for representing our work in the NeurIPS 2019 MicroNet Challenge on the [WikiText-103 Language Modeling task](https://micronet-challenge.github.io/index.html). The information for our submission to NeurIPS 2019 MicroNet Challenge can be found [here](https://github.com/mit-han-lab/neurips-micronet/tree/master/neurips_micronet_challenge).
12+
13+
Team Member: Zhongxia Yan, Hanrui Wang, Demi Guo, Song Han.
14+
15+
Our work implements or make improvements to the following methods, integrating them to create an efficient language model for the Wikitext-103 task
16+
* [Transformer-XL](https://arxiv.org/abs/1901.02860)
17+
* [Adaptive embedding and softmax](http://arxiv.org/abs/1809.10853)
18+
* [Non-parametric cache](http://arxiv.org/abs/1612.04426)
19+
* [Hebbian softmax](http://arxiv.org/abs/1803.10049)
20+
* [Knowledge distillation with teacher annealing](https://arxiv.org/abs/1907.04829)
21+
* Pruning: we use the [Distiller implementation](https://nervanasystems.github.io/distiller/algo_pruning.html#automated-gradual-pruner-agp) of Automated Gradual Pruning
22+
* Quantization: we use the [Distiller implementation](https://nervanasystems.github.io/distiller/algo_quantization.html) of the quantization aware training of the symmetric range-based linear quantizer.
23+
24+
## MicroNet
25+
The MicroNet challenge website can be found [here](https://micronet-challenge.github.io/). Our best model discussed below gets a MicroNet score of `0.0387`. Note that this is better than our score `0.0475` on the MicroNet website, since we previously had an evaluation error which miscalculated the number of math operations.
26+
27+
## Pipeline
28+
We show our pipeline with incremental performance results. Each row in each column is an ablation on our best configuration [`quantize_prune0.358_distill_8.3M_cache2000_hebbian_step175000_cache3000_bits9`](results/quantize_prune0.358_distill_8.3M_cache2000_hebbian_step175000_cache3000_bits9/config.yaml). The left column does not use compression techniques, while the right column does. From top to bottom, each stack displays the progression of techniques. Each row displays associated metrics: parameters (top left), operations (top right), validation perplexity (bottom left), and estimated processing time (bottom right). Metrics are displayed when changed from the previous row, with green for desirable change and red for undesirable. Red rows represent Core LM techniques, blue rows represent compression techniques, and gray rows represent cache search; joined rows represent joint training.
29+
<p align="center"><img align="center" src="figures/pipeline.svg" width="500"/></p>
30+
31+
## Installation
32+
We run our code on Python 3.6.8 and PyTorch 1.1.0+. We set up our environment using a mixture of Conda and Pip, though in theory Conda shouldn't be necessary. Our code has submodules, so make sure to use `--recursive` while cloning.
33+
```bash
34+
git clone --recursive https://github.com/mit-han-lab/neurips-micronet.git
35+
36+
# If you need to install conda first, follow the instructions from https://docs.conda.io/en/latest/miniconda.html
37+
conda create -n micronet python=3.6
38+
conda activate micronet
39+
40+
# Install distiller (pruning and quantization) requirements
41+
pip install -r distiller/requirements.txt
42+
43+
# For using mixed precision training with https://github.com/NVIDIA/apex
44+
# This is not necessary but some experiments may benefit larger batch sizes with mixed precision training
45+
# Depending on what CUDA version your PyTorch uses, you may have to change the CUDA_HOME environment
46+
# variable in the command below
47+
cd apex && CUDA_HOME=/usr/local/cuda-10.0 pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
48+
49+
pip install enlighten future gitpython==3.1.2
50+
```
51+
52+
## Dataset
53+
Our code in [`setup_data.ipynb`](setup_data.ipynb) directly downloads and preprocesses the [Wikitext-103 dataset](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/). Just run all the cells!
54+
55+
## Models and Configurations
56+
All of our configurations are already in the [`results/`](results/) directory. For example, our best configuration evaluated by the MicroNet criteria is [`results/quantize_prune0.358_distill_8.3M_cache2000_hebbian_step175000_cache3000_bits9`](results/quantize_prune0.358_distill_8.3M_cache2000_hebbian_step175000_cache3000_bits9/config.yaml). You can download our trained models from [here](https://www.dropbox.com/sh/8b37zkfvuyog4tu/AAB1wH9GgQgVgO1b7Lh0Pap4a?dl=0) into your directory. The necessary files to evaluate a particular configuration are
57+
```bash
58+
results/<configuration_name>/config.yaml # this contains readable hyperparameters that we use
59+
results/<configuration_name>/cache_step<searched_checkpoint_step>_n<search_cache_size>.yaml # this contains local searched cache parameters for the particular checkpoint step and local search cache size
60+
results/<configuration_name>/models/model-<step>.pth # this contains the step number, saved weights of the network, and saved weights of any optimizer
61+
```
62+
63+
Configuration names are mostly intuitive. We use `attn129` and `attn257` to denote `C = 129` and `C = 257`, respectively, otherwise the default is `C = 97`. For quantization configurations, for example `quantize_prune0.358_distill_8.3M_cache2000_hebbian_step175000_cache3000_bits9`, the first `cache<size>` refers to the training cache size, the second `cache<size>` refers to the local search cache size, and `step<number>` refers to the checkpoint step of the model before quantization.
64+
65+
### Evaluation
66+
To evaluation our trained model, make sure to download it as mentioned above, then go to [`micronet_evaluate.ipynb`](micronet_evaluate.ipynb) and substitute in the configuration name. This gives the validation and test perplexities as well as the number of parameters and math operations after pruning and quantization.
67+
68+
### Running the Pipeline
69+
In general, we run the pipeline from the desired configuration directory (e.g. [`results/quantize_prune0.358_distill_8.3M_cache2000_hebbian_step175000_cache3000_bits9/`](results/quantize_prune0.358_distill_8.3M_cache2000_hebbian_step175000_cache3000_bits9/)).
70+
71+
#### Training the Core LM with or without Distillation
72+
You can refer to [`results/`](results/) for examples of configurations for training with / without adaptive softmax, training cache, Hebbian softmax, or distillation. You may also use this to train a teacher model. To start training, you can use an existing configuration directory or create a new one, then run
73+
```bash
74+
cd <configuration_directory>
75+
# Make sure your directory has the correct config.yaml file
76+
77+
CUDA_VISIBLE_DEVICES=<device> python ../../main.py .
78+
```
79+
We recommend that you modify the `train_batch` hyperparameter in the `config.yaml` to be the maximum that fits in memory. If you'd like to use mixed precision training with [Apex](https://github.com/NVIDIA/apex) Amp, make sure to install Apex as in the instruction, then `opt_level=O1` as an argument to the command. Note that sometimes the cache or Hebbian softmax have instabilities with using mixed precision training. In that case you could try to debug the code or just use full precision.
80+
81+
#### Generating Soft Labels for Distillation
82+
If you'd like to use distillation and you have already trained a teacher model, you may generate the top 30 soft labels for the training set tokens by running this from the teacher's `<configuration_directory>`
83+
```bash
84+
cd <configuration_directory>
85+
CUDA_VISIBLE_DEVICES=<device> python ../../gen_soft_labels.py .
86+
```
87+
Note that this takes around 40G of disk storage and may take several hours.
88+
89+
#### Pruning
90+
After you train a model for a configuration `<configuration_directory>` with the above instructions, you can run
91+
```bash
92+
cd <configuration_directory>
93+
CUDA_VISIBLE_DEVICES=<device> python ../../setup_prune.py .
94+
```
95+
This will automatically create a new configuration directory `<prune_configuration_directory>` for you. Just follow the printed instruction to run pruning. Note that by default this uses the pruning configuration for sparsity `0.358`. You can also use pruninf configuration for sparsity `0.239` and `0.477` by replacing the `distiller_prune.yaml` in your `<prune_configuration_directory>` with [`distiller_prune0.239.yaml`](distiller_prune0.239.yaml) or [`distiller_prune0.477.yaml`](distiller_prune0.477.yaml).
96+
97+
#### Cache Local Search
98+
You may run local search on either a pruned model or unpruned (but trained) model. This will generate a new `<configuration_directory>/cache_step<searched_checkpoint_step>_n<search_cache_size>.yaml` with your searched cache configuration.
99+
```bash
100+
cd <configuration_directory>
101+
CUDA_VISIBLE_DEVICES=<device> python ../../cache_search.py . n_cache=<search_cache_size>
102+
```
103+
This will by default run search on the checkpoint trained for the largest number of steps. If you want to run local search on an arbitrary saved checkpoint step, then add the argument `step=<trained_step_that_you_want_to_search>`.
104+
105+
#### Quantization
106+
You may quantize a trained model, pruned model, or a local-searched model.
107+
```bash
108+
cd <configuration_directory>
109+
CUDA_VISIBLE_DEVICES=<device> python ../../setup_quantize.py . bits=<bits_to_quantize_to>
110+
```
111+
This will by default quantize the largest checkpoint with the `n_cache` value in `<configuration_directory>/config.yaml`. If you'd like to quantize with a different cache size (possibly with local searched parameters), you can add the argument `n_cache=<cache_size>`. If you'd like to quantize a different step, add the argument `step=<trained_step_that_you_want_to_quantize>`, this will by default use your cache parameters if you performed local search with that step previously.
112+
113+
## Common Questions / Errors
114+
>AssertionError: Training already exists
115+
116+
Usually this happens if you try to run training again after your training crashes. This is because we have a guard against accidentally running multiple trainings as the same time. Just remove the `<configuration_directory>/is_training` guard file and you should be fine.
117+
118+
>Why is the perplexity `nan`?
119+
120+
To prevent obscenely large perplexities at the beginning of training, we set perplexities greater than `e ** 5` to be `nan`.
121+
122+
## FAQ
123+
If you have any further questions about our submission, please don't hesitate reaching out to us through Github Issues :)
124+
Thanks!

Diff for: cache/sorted_vocab.npy

46 MB
Binary file not shown.

Diff for: cache_analysis.ipynb

+11-23
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,19 @@
55
"execution_count": 1,
66
"metadata": {
77
"ExecuteTime": {
8-
"end_time": "2020-05-14T03:54:21.752556Z",
9-
"start_time": "2020-05-14T03:54:08.639491Z"
8+
"end_time": "2020-05-14T23:29:14.676749Z",
9+
"start_time": "2020-05-14T23:28:59.194229Z"
1010
}
1111
},
12-
"outputs": [],
12+
"outputs": [
13+
{
14+
"name": "stdout",
15+
"output_type": "stream",
16+
"text": [
17+
"The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed',)).History will not be written to the database.\n"
18+
]
19+
}
20+
],
1321
"source": [
1422
"from u import *\n",
1523
"from model import Transformer\n",
@@ -21,26 +29,6 @@
2129
"vocab = (Cache / 'sorted_vocab.npy').load()"
2230
]
2331
},
24-
{
25-
"cell_type": "code",
26-
"execution_count": 7,
27-
"metadata": {
28-
"ExecuteTime": {
29-
"end_time": "2020-05-10T20:41:36.058609Z",
30-
"start_time": "2020-05-10T20:41:29.462290Z"
31-
}
32-
},
33-
"outputs": [],
34-
"source": [
35-
"for dir in Res.ls()[0]:\n",
36-
" conf = (dir / 'config.yaml').load()\n",
37-
" if 4096 % conf['n_seq'] == 0:\n",
38-
" conf['eval_chunk'] = 4096\n",
39-
"# for k in 'hebbian', 'hebbian_T', 'hebbian_gamma', 'step', 'out_res':\n",
40-
"# conf.pop(k, None)\n",
41-
" (dir / 'config.yaml').save(conf)"
42-
]
43-
},
4432
{
4533
"cell_type": "code",
4634
"execution_count": 2,

Diff for: data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __iter__(self):
4949
yield batch_i.astype(np.int64)
5050

5151
def __len__(self):
52-
return len(range(0, self.span_i, c.eval_chunk))
52+
return len(range(0, self.span_i, self.c.eval_chunk))
5353

5454
class DistillationSampleIterator:
5555
def __init__(self, c, batch_size):

Diff for: figures/pipeline.svg

+2
Loading

0 commit comments

Comments
 (0)