Skip to content

Commit 668d55c

Browse files
committed
* initial commit
0 parents  commit 668d55c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2468
-0
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.pyc
2+
./dnc/*.pyc
3+
./dnc/agents/*.pyc
4+
./utils/*.pyc

README.md

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# **Neural Turing Machine** (NTM) &
2+
# **Differentiable Neural Computer** (DNC) with
3+
# **pytorch** & **visdom**
4+
*******
5+
6+
7+
* Sample on-line plotting while training(avg loss)/testing(write/read weights & memory) DNC on the repeat-copy task:
8+
<img src="/assets/dnc_repeat_copy_train.png" width="205"/> <img src="/assets/dnc_repeat_copy_test.png" width="600"/>
9+
10+
11+
* Sample loggings while training DNC on the repeat-copy task (we use ```WARNING``` as the logging level currently to get rid of the ```INFO``` printouts from visdom):
12+
```bash
13+
[WARNING ] (MainProcess) <===================================>
14+
[WARNING ] (MainProcess) bash$: python -m visdom.server
15+
[WARNING ] (MainProcess) http://localhost:8097/env/daim_17051000
16+
[WARNING ] (MainProcess) <===================================> Agent:
17+
[WARNING ] (MainProcess) <-----------------------------======> Env:
18+
[WARNING ] (MainProcess) Creating {repeat-copy | } w/ Seed: 123
19+
[WARNING ] (MainProcess) Word {length}: {4}
20+
[WARNING ] (MainProcess) Words # {min, max}: {1, 2}
21+
[WARNING ] (MainProcess) Repeats {min, max}: {1, 2}
22+
[WARNING ] (MainProcess) <-----------------------------======> Circuit: {Controller, Accessor}
23+
[WARNING ] (MainProcess) <--------------------------------===> Controller:
24+
[WARNING ] (MainProcess) LSTMController (
25+
(in_2_hid): LSTMCell(70, 64, bias=1)
26+
)
27+
[WARNING ] (MainProcess) <--------------------------------===> Accessor: {WriteHead, ReadHead, Memory}
28+
[WARNING ] (MainProcess) <-----------------------------------> WriteHeads: {1 heads}
29+
[WARNING ] (MainProcess) DynamicWriteHead (
30+
(hid_2_key): Linear (64 -> 16)
31+
(hid_2_beta): Linear (64 -> 1)
32+
(hid_2_alloc_gate): Linear (64 -> 1)
33+
(hid_2_write_gate): Linear (64 -> 1)
34+
(hid_2_erase): Linear (64 -> 16)
35+
(hid_2_add): Linear (64 -> 16)
36+
)
37+
[WARNING ] (MainProcess) <-----------------------------------> ReadHeads: {4 heads}
38+
[WARNING ] (MainProcess) DynamicReadHead (
39+
(hid_2_key): Linear (64 -> 64)
40+
(hid_2_beta): Linear (64 -> 4)
41+
(hid_2_free_gate): Linear (64 -> 4)
42+
(hid_2_read_mode): Linear (64 -> 12)
43+
)
44+
[WARNING ] (MainProcess) <-----------------------------------> Memory: {16(batch_size) x 16(mem_hei) x 16(mem_wid)}
45+
[WARNING ] (MainProcess) <-----------------------------======> Circuit: {Overall Architecture}
46+
[WARNING ] (MainProcess) DNCCircuit (
47+
(controller): LSTMController (
48+
(in_2_hid): LSTMCell(70, 64, bias=1)
49+
)
50+
(accessor): DynamicAccessor (
51+
(write_heads): DynamicWriteHead (
52+
(hid_2_key): Linear (64 -> 16)
53+
(hid_2_beta): Linear (64 -> 1)
54+
(hid_2_alloc_gate): Linear (64 -> 1)
55+
(hid_2_write_gate): Linear (64 -> 1)
56+
(hid_2_erase): Linear (64 -> 16)
57+
(hid_2_add): Linear (64 -> 16)
58+
)
59+
(read_heads): DynamicReadHead (
60+
(hid_2_key): Linear (64 -> 64)
61+
(hid_2_beta): Linear (64 -> 4)
62+
(hid_2_free_gate): Linear (64 -> 4)
63+
(hid_2_read_mode): Linear (64 -> 12)
64+
)
65+
)
66+
(hid_to_out): Linear (128 -> 5)
67+
)
68+
[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.
69+
[WARNING ] (MainProcess) <===================================> Training ...
70+
[WARNING ] (MainProcess) Reporting @ Step: 500 | Elapsed Time: 30.609361887
71+
[WARNING ] (MainProcess) Training Stats: avg_loss: 0.014866309287
72+
[WARNING ] (MainProcess) Evaluating @ Step: 500
73+
[WARNING ] (MainProcess) Evaluation Took: 1.6457400322
74+
[WARNING ] (MainProcess) Iteration: 500; loss_avg: 0.0140423600748
75+
[WARNING ] (MainProcess) Saving Model @ Step: 500: /home/zhang/ws/17_ws/pytorch-dnc/models/daim_17051000.pth ...
76+
[WARNING ] (MainProcess) Saved Model @ Step: 500: /home/zhang/ws/17_ws/pytorch-dnc/models/daim_17051000.pth.
77+
[WARNING ] (MainProcess) Resume Training @ Step: 500
78+
...
79+
```
80+
*******
81+
82+
83+
## What is included?
84+
This repo currently contains the following algorithms:
85+
86+
- Neural Turing Machines (NTM) [[1]](https://arxiv.org/abs/1410.5401)
87+
- Differentiable Neural Computers (DNC) [[2]](http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html)
88+
89+
Tasks:
90+
- copy
91+
- repeat-copy
92+
93+
## Code structure & Naming conventions
94+
NOTE: we follow the exact code structure as [pytorch-rl](https://github.com/jingweiz/pytorch-rl) so as to make the code easily transplantable.
95+
* ```./utils/factory.py```
96+
> We suggest the users refer to ```./utils/factory.py```,
97+
where we list all the integrated ```Env```, ```Circuit```, ```Agent``` into ```Dict```'s.
98+
All of the core classes are implemented in ```./core/```.
99+
The factory pattern in ```./utils/factory.py``` makes the code super clean,
100+
as no matter what type of ```Circuit``` you want to train,
101+
or which type of ```Env``` you want to train on,
102+
all you need to do is to simply modify some parameters in ```./utils/options.py```,
103+
then the ```./main.py``` will do it all (NOTE: this ```./main.py``` file never needs to be modified).
104+
* namings
105+
> To make the code more clean and readable, we name the variables using the following pattern:
106+
> * ```*_vb```: ```torch.autograd.Variable```'s or a list of such objects
107+
> * ```*_ts```: ```torch.Tensor```'s or a list of such objects
108+
> * otherwise: normal python datatypes
109+
110+
111+
## Dependencies
112+
- Python 2.7
113+
- [PyTorch](http://pytorch.org/)
114+
- [Visdom](https://github.com/facebookresearch/visdom)
115+
*******
116+
117+
118+
## How to run:
119+
You only need to modify some parameters in ```./utils/options.py``` to train a new configuration.
120+
121+
* Configure your training in ```./utils/options.py```:
122+
> * ```line 12```: add an entry into ```CONFIGS``` to define your training (```agent_type```, ```env_type```, ```game```, ```circuit_type```)
123+
> * ```line 28```: choose the entry you just added
124+
> * ```line 24-25```: fill in your machine/cluster ID (```MACHINE```) and timestamp (```TIMESTAMP```) to define your training signature (```MACHINE_TIMESTAMP```),
125+
the corresponding model file and the log file of this training will be saved under this signature (```./models/MACHINE_TIMESTAMP.pth``` & ```./logs/MACHINE_TIMESTAMP.log``` respectively).
126+
Also the visdom visualization will be displayed under this signature (first activate the visdom server by type in bash: ```python -m visdom.server &```, then open this address in your browser: ```http://localhost:8097/env/MACHINE_TIMESTAMP```)
127+
> * ```line 28```: to train a model, set ```mode=1``` (training visualization will be under ```http://localhost:8097/env/MACHINE_TIMESTAMP```); to test the model of this current training, all you need to do is to set ```mode=2``` (testing visualization will be under ```http://localhost:8097/env/MACHINE_TIMESTAMP_test```).
128+
129+
* Run:
130+
> ```python main.py```
131+
*******
132+
133+
134+
## Implementation Notes:
135+
The difference between ```NTM``` & ```DNC``` is stated as follows in the
136+
```DNC```[2] paper:
137+
> Comparison with the neural Turing machine. The neural Turing machine (NTM) was
138+
the predecessor to the DNC described in this work. It used a similar
139+
architecture of neural network controller with read–write access to a memory
140+
matrix, but differed in the access mechanism used to interface with the memory.
141+
In the NTM, content-based addressing was combined with location-based addressing
142+
to allow the network to iterate through memory locations in order of their
143+
indices (for example, location n followed by n+1 and so on). This allowed the
144+
network to store and retrieve temporal sequences in contiguous blocks of memory.
145+
However, there were several drawbacks. First, the NTM has no mechanism to ensure
146+
that blocks of allocated memory do not overlap and interfere—a basic problem of
147+
computer memory management. Interference is not an issue for the dynamic memory
148+
allocation used by DNCs, which provides single free locations at a time,
149+
irrespective of index, and therefore does not require contiguous blocks. Second,
150+
the NTM has no way of freeing locations that have already been written to and,
151+
hence, no way of reusing memory when processing long sequences. This problem is
152+
addressed in DNCs by the free gates used for de-allocation. Third, sequential
153+
information is preserved only as long as the NTM continues to iterate through
154+
consecutive locations; as soon as the write head jumps to a different part of
155+
the memory (using content-based addressing) the order of writes before and after
156+
the jump cannot be recovered by the read head. The temporal link matrix used by
157+
DNCs does not suffer from this problem because it tracks the order in which
158+
writes were made.
159+
160+
We thus make some effort to put those two together in a combined codebase.
161+
The classes implemented have the following hierarchy:
162+
* Agent
163+
* Env
164+
* Circuit
165+
* Controller
166+
* Accessor
167+
* WriteHead
168+
* ReadHead
169+
* Memory
170+
171+
The part where ```NTM``` & ```DNC``` differs is the ```Accessor```, where in the
172+
code ```NTM``` uses the ```StaticAccessor```(may not be an appropriate name but
173+
we use this to make the code more consistent) and ```DNC``` uses the
174+
```DynamicAccessor```. Both ```Accessor``` classes use ```_content_focus()```
175+
and ```_location_focus()```(may not be an appropriate name for ```DNC``` but we
176+
use this to make the code more consistent). The ```_content_focus()``` is the
177+
same for both classes, but the ```_location_focus()``` for ```DNC``` is much
178+
more complicated as it uses ```dynamic allocation``` additionally for write and
179+
```temporal link``` additionally for read. Those focus (or attention) mechanisms
180+
are implemented in ```Head``` classes, and those focuses output a ```weight```
181+
vector for each ```head``` (write/read). Those ```weight``` vectors are then used in
182+
```_access()``` to interact with the ```external memory```.
183+
184+
## A side note:
185+
The sturcture for ```Env``` might look strange as this class was originally
186+
designed for ```reinforcement learning``` settings as in
187+
[pytorch-rl](https://github.com/jingweiz/pytorch-rl); here we use it for
188+
providing datasets for ```supervised learning```, so the ```reward```,
189+
```action``` and ```terminal``` are always left blank in this repo.
190+
*******
191+
192+
193+
## Repos we referred to during the development of this repo:
194+
* [deepmind/dnc](https://github.com/deepmind/dnc)
195+
* [ypxie/pytorch-NeuCom](https://github.com/ypxie/pytorch-NeuCom)
196+
* [bzcheeseman/pytorch-EMM](https://github.com/bzcheeseman/pytorch-EMM)
197+
* [DoctorTeeth/diffmem](https://github.com/DoctorTeeth/diffmem)
198+
* [kaishengtai/torch-ntm](https://github.com/kaishengtai/torch-ntm)
199+
* [Mostafa-Samir/DNC-tensorflow](https://github.com/Mostafa-Samir/DNC-tensorflow)

assets/dnc_repeat_copy_test.png

15.3 KB
Loading

assets/dnc_repeat_copy_train.png

20.7 KB
Loading

core/__init__.py

Whitespace-only changes.

core/accessor.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
import torch.nn as nn
5+
6+
class Accessor(nn.Module):
7+
def __init__(self, args):
8+
super(Accessor, self).__init__()
9+
# logging
10+
self.logger = args.logger
11+
# params
12+
self.use_cuda = args.use_cuda
13+
self.dtype = args.dtype
14+
15+
# params
16+
self.batch_size = args.batch_size
17+
self.hidden_dim = args.hidden_dim
18+
self.num_write_heads = args.num_write_heads
19+
self.num_read_heads = args.num_read_heads
20+
self.mem_hei = args.mem_hei
21+
self.mem_wid = args.mem_wid
22+
self.clip_value = args.clip_value
23+
24+
# functional components
25+
self.write_head_params = args.write_head_params
26+
self.read_head_params = args.read_head_params
27+
self.memory_params = args.memory_params
28+
29+
# fill in the missing values
30+
# write_heads
31+
self.write_head_params.num_heads = self.num_write_heads
32+
self.write_head_params.batch_size = self.batch_size
33+
self.write_head_params.hidden_dim = self.hidden_dim
34+
self.write_head_params.mem_hei = self.mem_hei
35+
self.write_head_params.mem_wid = self.mem_wid
36+
# read_heads
37+
self.read_head_params.num_heads = self.num_read_heads
38+
self.read_head_params.batch_size = self.batch_size
39+
self.read_head_params.hidden_dim = self.hidden_dim
40+
self.read_head_params.mem_hei = self.mem_hei
41+
self.read_head_params.mem_wid = self.mem_wid
42+
# memory
43+
self.memory_params.batch_size = self.batch_size
44+
self.memory_params.clip_value = self.clip_value
45+
self.memory_params.mem_hei = self.mem_hei
46+
self.memory_params.mem_wid = self.mem_wid
47+
48+
def _init_weights(self):
49+
raise NotImplementedError("not implemented in base calss")
50+
51+
def _reset_states(self):
52+
raise NotImplementedError("not implemented in base calss")
53+
54+
def _reset(self): # NOTE: should be called at each child's __init__
55+
raise NotImplementedError("not implemented in base calss")
56+
57+
def visual(self):
58+
self.write_heads.visual()
59+
self.read_heads.visual()
60+
self.memory.visual()
61+
62+
def forward(self, lstm_hidden_vb):
63+
raise NotImplementedError("not implemented in base calss")

core/accessors/__init__.py

Whitespace-only changes.

core/accessors/dynamic_accessor.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
import torch
5+
import torch.nn as nn
6+
from torch.autograd import Variable
7+
8+
from core.accessor import Accessor
9+
from core.heads.dynamic_write_head import DynamicWriteHead as WriteHead
10+
from core.heads.dynamic_read_head import DynamicReadHead as ReadHead
11+
from core.memory import External2DMemory as ExternalMemory
12+
13+
class DynamicAccessor(Accessor):
14+
def __init__(self, args):
15+
super(DynamicAccessor, self).__init__(args)
16+
# logging
17+
self.logger = args.logger
18+
# params
19+
self.use_cuda = args.use_cuda
20+
self.dtype = args.dtype
21+
# dynamic-accessor-specific params
22+
self.read_head_params.num_read_modes = self.write_head_params.num_heads * 2 + 1
23+
24+
self.logger.warning("<--------------------------------===> Accessor: {WriteHead, ReadHead, Memory}")
25+
26+
# functional components
27+
self.usage_vb = None # for dynamic allocation, init in _reset
28+
self.link_vb = None # for temporal link, init in _reset
29+
self.preced_vb = None # for temporal link, init in _reset
30+
self.write_heads = WriteHead(self.write_head_params)
31+
self.read_heads = ReadHead(self.read_head_params)
32+
self.memory = ExternalMemory(self.memory_params)
33+
34+
self._reset()
35+
36+
def _init_weights(self):
37+
pass
38+
39+
def _reset_states(self):
40+
# reset the usage (for dynamic allocation) & link (for temporal link)
41+
self.usage_vb = Variable(self.usage_ts).type(self.dtype)
42+
self.link_vb = Variable(self.link_ts).type(self.dtype)
43+
self.preced_vb = Variable(self.preced_ts).type(self.dtype)
44+
# we reset the write/read weights of heads
45+
self.write_heads._reset_states()
46+
self.read_heads._reset_states()
47+
# we also reset the memory to bias value
48+
self.memory._reset_states()
49+
50+
def _reset(self): # NOTE: should be called at __init__
51+
self._init_weights()
52+
self.type(self.dtype) # put on gpu if possible
53+
# reset internal states
54+
self.usage_ts = torch.zeros(self.batch_size, self.mem_hei)
55+
self.link_ts = torch.zeros(self.batch_size, self.write_head_params.num_heads, self.mem_hei, self.mem_hei)
56+
self.preced_ts = torch.zeros(self.batch_size, self.write_head_params.num_heads, self.mem_hei)
57+
self._reset_states()
58+
59+
def forward(self, hidden_vb):
60+
# 1. first we update the usage using the read/write weights from {t-1}
61+
self.usage_vb = self.write_heads._update_usage(self.usage_vb)
62+
self.usage_vb = self.read_heads._update_usage(hidden_vb, self.usage_vb)
63+
# 2. then write to memory_{t-1} to get memory_{t}
64+
self.memory.memory_vb = self.write_heads.forward(hidden_vb, self.memory.memory_vb, self.usage_vb)
65+
# 3. then we update the temporal link
66+
self.link_vb, self.preced_vb = self.write_heads._temporal_link(self.link_vb, self.preced_vb)
67+
# 4. then read from memory_{t} to get read_vec_{t}
68+
read_vec_vb = self.read_heads.forward(hidden_vb, self.memory.memory_vb, self.link_vb, self.write_head_params.num_heads)
69+
return read_vec_vb

core/accessors/static_accessor.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
import torch.nn as nn
5+
from torch.autograd import Variable
6+
7+
from core.accessor import Accessor
8+
from core.heads.static_write_head import StaticWriteHead as WriteHead
9+
from core.heads.static_read_head import StaticReadHead as ReadHead
10+
from core.memory import External2DMemory as ExternalMemory
11+
12+
class StaticAccessor(Accessor):
13+
def __init__(self, args):
14+
super(StaticAccessor, self).__init__(args)
15+
# logging
16+
self.logger = args.logger
17+
# params
18+
self.use_cuda = args.use_cuda
19+
self.dtype = args.dtype
20+
21+
self.logger.warning("<--------------------------------===> Accessor: {WriteHead, ReadHead, Memory}")
22+
23+
# functional components
24+
self.write_heads = WriteHead(self.write_head_params)
25+
self.read_heads = ReadHead(self.read_head_params)
26+
self.memory = ExternalMemory(self.memory_params)
27+
28+
self._reset()
29+
30+
def _init_weights(self):
31+
pass
32+
33+
def _reset_states(self):
34+
# we reset the write/read weights of heads
35+
self.write_heads._reset_states()
36+
self.read_heads._reset_states()
37+
# we also reset the memory to bias value
38+
self.memory._reset_states()
39+
40+
def _reset(self): # NOTE: should be called at __init__
41+
self._init_weights()
42+
self.type(self.dtype) # put on gpu if possible
43+
# reset internal states
44+
self._reset_states()
45+
46+
def forward(self, hidden_vb):
47+
# 1. first write to memory_{t-1} to get memory_{t}
48+
self.memory.memory_vb = self.write_heads.forward(hidden_vb, self.memory.memory_vb)
49+
# 2. then read from memory_{t} to get read_vec_{t}
50+
read_vec_vb = self.read_heads.forward(hidden_vb, self.memory.memory_vb)
51+
return read_vec_vb

0 commit comments

Comments
 (0)