|
| 1 | +# **Neural Turing Machine** (NTM) & |
| 2 | +# **Differentiable Neural Computer** (DNC) with |
| 3 | +# **pytorch** & **visdom** |
| 4 | +******* |
| 5 | + |
| 6 | + |
| 7 | +* Sample on-line plotting while training(avg loss)/testing(write/read weights & memory) DNC on the repeat-copy task: |
| 8 | +<img src="/assets/dnc_repeat_copy_train.png" width="205"/> <img src="/assets/dnc_repeat_copy_test.png" width="600"/> |
| 9 | + |
| 10 | + |
| 11 | +* Sample loggings while training DNC on the repeat-copy task (we use ```WARNING``` as the logging level currently to get rid of the ```INFO``` printouts from visdom): |
| 12 | +```bash |
| 13 | +[WARNING ] (MainProcess) <===================================> |
| 14 | +[WARNING ] (MainProcess) bash$: python -m visdom.server |
| 15 | +[WARNING ] (MainProcess) http://localhost:8097/env/daim_17051000 |
| 16 | +[WARNING ] (MainProcess) <===================================> Agent: |
| 17 | +[WARNING ] (MainProcess) <-----------------------------======> Env: |
| 18 | +[WARNING ] (MainProcess) Creating {repeat-copy | } w/ Seed: 123 |
| 19 | +[WARNING ] (MainProcess) Word {length}: {4} |
| 20 | +[WARNING ] (MainProcess) Words # {min, max}: {1, 2} |
| 21 | +[WARNING ] (MainProcess) Repeats {min, max}: {1, 2} |
| 22 | +[WARNING ] (MainProcess) <-----------------------------======> Circuit: {Controller, Accessor} |
| 23 | +[WARNING ] (MainProcess) <--------------------------------===> Controller: |
| 24 | +[WARNING ] (MainProcess) LSTMController ( |
| 25 | + (in_2_hid): LSTMCell(70, 64, bias=1) |
| 26 | +) |
| 27 | +[WARNING ] (MainProcess) <--------------------------------===> Accessor: {WriteHead, ReadHead, Memory} |
| 28 | +[WARNING ] (MainProcess) <-----------------------------------> WriteHeads: {1 heads} |
| 29 | +[WARNING ] (MainProcess) DynamicWriteHead ( |
| 30 | + (hid_2_key): Linear (64 -> 16) |
| 31 | + (hid_2_beta): Linear (64 -> 1) |
| 32 | + (hid_2_alloc_gate): Linear (64 -> 1) |
| 33 | + (hid_2_write_gate): Linear (64 -> 1) |
| 34 | + (hid_2_erase): Linear (64 -> 16) |
| 35 | + (hid_2_add): Linear (64 -> 16) |
| 36 | +) |
| 37 | +[WARNING ] (MainProcess) <-----------------------------------> ReadHeads: {4 heads} |
| 38 | +[WARNING ] (MainProcess) DynamicReadHead ( |
| 39 | + (hid_2_key): Linear (64 -> 64) |
| 40 | + (hid_2_beta): Linear (64 -> 4) |
| 41 | + (hid_2_free_gate): Linear (64 -> 4) |
| 42 | + (hid_2_read_mode): Linear (64 -> 12) |
| 43 | +) |
| 44 | +[WARNING ] (MainProcess) <-----------------------------------> Memory: {16(batch_size) x 16(mem_hei) x 16(mem_wid)} |
| 45 | +[WARNING ] (MainProcess) <-----------------------------======> Circuit: {Overall Architecture} |
| 46 | +[WARNING ] (MainProcess) DNCCircuit ( |
| 47 | + (controller): LSTMController ( |
| 48 | + (in_2_hid): LSTMCell(70, 64, bias=1) |
| 49 | + ) |
| 50 | + (accessor): DynamicAccessor ( |
| 51 | + (write_heads): DynamicWriteHead ( |
| 52 | + (hid_2_key): Linear (64 -> 16) |
| 53 | + (hid_2_beta): Linear (64 -> 1) |
| 54 | + (hid_2_alloc_gate): Linear (64 -> 1) |
| 55 | + (hid_2_write_gate): Linear (64 -> 1) |
| 56 | + (hid_2_erase): Linear (64 -> 16) |
| 57 | + (hid_2_add): Linear (64 -> 16) |
| 58 | + ) |
| 59 | + (read_heads): DynamicReadHead ( |
| 60 | + (hid_2_key): Linear (64 -> 64) |
| 61 | + (hid_2_beta): Linear (64 -> 4) |
| 62 | + (hid_2_free_gate): Linear (64 -> 4) |
| 63 | + (hid_2_read_mode): Linear (64 -> 12) |
| 64 | + ) |
| 65 | + ) |
| 66 | + (hid_to_out): Linear (128 -> 5) |
| 67 | +) |
| 68 | +[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch. |
| 69 | +[WARNING ] (MainProcess) <===================================> Training ... |
| 70 | +[WARNING ] (MainProcess) Reporting @ Step: 500 | Elapsed Time: 30.609361887 |
| 71 | +[WARNING ] (MainProcess) Training Stats: avg_loss: 0.014866309287 |
| 72 | +[WARNING ] (MainProcess) Evaluating @ Step: 500 |
| 73 | +[WARNING ] (MainProcess) Evaluation Took: 1.6457400322 |
| 74 | +[WARNING ] (MainProcess) Iteration: 500; loss_avg: 0.0140423600748 |
| 75 | +[WARNING ] (MainProcess) Saving Model @ Step: 500: /home/zhang/ws/17_ws/pytorch-dnc/models/daim_17051000.pth ... |
| 76 | +[WARNING ] (MainProcess) Saved Model @ Step: 500: /home/zhang/ws/17_ws/pytorch-dnc/models/daim_17051000.pth. |
| 77 | +[WARNING ] (MainProcess) Resume Training @ Step: 500 |
| 78 | +... |
| 79 | +``` |
| 80 | +******* |
| 81 | + |
| 82 | + |
| 83 | +## What is included? |
| 84 | +This repo currently contains the following algorithms: |
| 85 | + |
| 86 | +- Neural Turing Machines (NTM) [[1]](https://arxiv.org/abs/1410.5401) |
| 87 | +- Differentiable Neural Computers (DNC) [[2]](http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html) |
| 88 | + |
| 89 | +Tasks: |
| 90 | +- copy |
| 91 | +- repeat-copy |
| 92 | + |
| 93 | +## Code structure & Naming conventions |
| 94 | +NOTE: we follow the exact code structure as [pytorch-rl](https://github.com/jingweiz/pytorch-rl) so as to make the code easily transplantable. |
| 95 | +* ```./utils/factory.py``` |
| 96 | +> We suggest the users refer to ```./utils/factory.py```, |
| 97 | + where we list all the integrated ```Env```, ```Circuit```, ```Agent``` into ```Dict```'s. |
| 98 | + All of the core classes are implemented in ```./core/```. |
| 99 | + The factory pattern in ```./utils/factory.py``` makes the code super clean, |
| 100 | + as no matter what type of ```Circuit``` you want to train, |
| 101 | + or which type of ```Env``` you want to train on, |
| 102 | + all you need to do is to simply modify some parameters in ```./utils/options.py```, |
| 103 | + then the ```./main.py``` will do it all (NOTE: this ```./main.py``` file never needs to be modified). |
| 104 | +* namings |
| 105 | +> To make the code more clean and readable, we name the variables using the following pattern: |
| 106 | +> * ```*_vb```: ```torch.autograd.Variable```'s or a list of such objects |
| 107 | +> * ```*_ts```: ```torch.Tensor```'s or a list of such objects |
| 108 | +> * otherwise: normal python datatypes |
| 109 | +
|
| 110 | + |
| 111 | +## Dependencies |
| 112 | +- Python 2.7 |
| 113 | +- [PyTorch](http://pytorch.org/) |
| 114 | +- [Visdom](https://github.com/facebookresearch/visdom) |
| 115 | +******* |
| 116 | + |
| 117 | + |
| 118 | +## How to run: |
| 119 | +You only need to modify some parameters in ```./utils/options.py``` to train a new configuration. |
| 120 | + |
| 121 | +* Configure your training in ```./utils/options.py```: |
| 122 | +> * ```line 12```: add an entry into ```CONFIGS``` to define your training (```agent_type```, ```env_type```, ```game```, ```circuit_type```) |
| 123 | +> * ```line 28```: choose the entry you just added |
| 124 | +> * ```line 24-25```: fill in your machine/cluster ID (```MACHINE```) and timestamp (```TIMESTAMP```) to define your training signature (```MACHINE_TIMESTAMP```), |
| 125 | + the corresponding model file and the log file of this training will be saved under this signature (```./models/MACHINE_TIMESTAMP.pth``` & ```./logs/MACHINE_TIMESTAMP.log``` respectively). |
| 126 | + Also the visdom visualization will be displayed under this signature (first activate the visdom server by type in bash: ```python -m visdom.server &```, then open this address in your browser: ```http://localhost:8097/env/MACHINE_TIMESTAMP```) |
| 127 | +> * ```line 28```: to train a model, set ```mode=1``` (training visualization will be under ```http://localhost:8097/env/MACHINE_TIMESTAMP```); to test the model of this current training, all you need to do is to set ```mode=2``` (testing visualization will be under ```http://localhost:8097/env/MACHINE_TIMESTAMP_test```). |
| 128 | +
|
| 129 | +* Run: |
| 130 | +> ```python main.py``` |
| 131 | +******* |
| 132 | + |
| 133 | + |
| 134 | +## Implementation Notes: |
| 135 | +The difference between ```NTM``` & ```DNC``` is stated as follows in the |
| 136 | +```DNC```[2] paper: |
| 137 | +> Comparison with the neural Turing machine. The neural Turing machine (NTM) was |
| 138 | +the predecessor to the DNC described in this work. It used a similar |
| 139 | +architecture of neural network controller with read–write access to a memory |
| 140 | +matrix, but differed in the access mechanism used to interface with the memory. |
| 141 | +In the NTM, content-based addressing was combined with location-based addressing |
| 142 | +to allow the network to iterate through memory locations in order of their |
| 143 | +indices (for example, location n followed by n+1 and so on). This allowed the |
| 144 | +network to store and retrieve temporal sequences in contiguous blocks of memory. |
| 145 | +However, there were several drawbacks. First, the NTM has no mechanism to ensure |
| 146 | +that blocks of allocated memory do not overlap and interfere—a basic problem of |
| 147 | +computer memory management. Interference is not an issue for the dynamic memory |
| 148 | +allocation used by DNCs, which provides single free locations at a time, |
| 149 | +irrespective of index, and therefore does not require contiguous blocks. Second, |
| 150 | +the NTM has no way of freeing locations that have already been written to and, |
| 151 | +hence, no way of reusing memory when processing long sequences. This problem is |
| 152 | +addressed in DNCs by the free gates used for de-allocation. Third, sequential |
| 153 | +information is preserved only as long as the NTM continues to iterate through |
| 154 | +consecutive locations; as soon as the write head jumps to a different part of |
| 155 | +the memory (using content-based addressing) the order of writes before and after |
| 156 | +the jump cannot be recovered by the read head. The temporal link matrix used by |
| 157 | +DNCs does not suffer from this problem because it tracks the order in which |
| 158 | +writes were made. |
| 159 | + |
| 160 | +We thus make some effort to put those two together in a combined codebase. |
| 161 | +The classes implemented have the following hierarchy: |
| 162 | +* Agent |
| 163 | + * Env |
| 164 | + * Circuit |
| 165 | + * Controller |
| 166 | + * Accessor |
| 167 | + * WriteHead |
| 168 | + * ReadHead |
| 169 | + * Memory |
| 170 | + |
| 171 | +The part where ```NTM``` & ```DNC``` differs is the ```Accessor```, where in the |
| 172 | +code ```NTM``` uses the ```StaticAccessor```(may not be an appropriate name but |
| 173 | +we use this to make the code more consistent) and ```DNC``` uses the |
| 174 | +```DynamicAccessor```. Both ```Accessor``` classes use ```_content_focus()``` |
| 175 | +and ```_location_focus()```(may not be an appropriate name for ```DNC``` but we |
| 176 | +use this to make the code more consistent). The ```_content_focus()``` is the |
| 177 | +same for both classes, but the ```_location_focus()``` for ```DNC``` is much |
| 178 | +more complicated as it uses ```dynamic allocation``` additionally for write and |
| 179 | +```temporal link``` additionally for read. Those focus (or attention) mechanisms |
| 180 | +are implemented in ```Head``` classes, and those focuses output a ```weight``` |
| 181 | +vector for each ```head``` (write/read). Those ```weight``` vectors are then used in |
| 182 | +```_access()``` to interact with the ```external memory```. |
| 183 | + |
| 184 | +## A side note: |
| 185 | +The sturcture for ```Env``` might look strange as this class was originally |
| 186 | +designed for ```reinforcement learning``` settings as in |
| 187 | +[pytorch-rl](https://github.com/jingweiz/pytorch-rl); here we use it for |
| 188 | +providing datasets for ```supervised learning```, so the ```reward```, |
| 189 | +```action``` and ```terminal``` are always left blank in this repo. |
| 190 | +******* |
| 191 | + |
| 192 | + |
| 193 | +## Repos we referred to during the development of this repo: |
| 194 | +* [deepmind/dnc](https://github.com/deepmind/dnc) |
| 195 | +* [ypxie/pytorch-NeuCom](https://github.com/ypxie/pytorch-NeuCom) |
| 196 | +* [bzcheeseman/pytorch-EMM](https://github.com/bzcheeseman/pytorch-EMM) |
| 197 | +* [DoctorTeeth/diffmem](https://github.com/DoctorTeeth/diffmem) |
| 198 | +* [kaishengtai/torch-ntm](https://github.com/kaishengtai/torch-ntm) |
| 199 | +* [Mostafa-Samir/DNC-tensorflow](https://github.com/Mostafa-Samir/DNC-tensorflow) |
0 commit comments