Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
6 changes: 6 additions & 0 deletions ntm/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,16 @@ def __init__(self, vector_length, hidden_size):
nn.init.uniform_(p, -stdev, stdev)

def forward(self, x, state):
# LSTM configured to accept : sequence_length * batch_size * input || 1 * 1 * input_representation_size
output, state = self.layer(x.unsqueeze(0), state)
# Final outputs : time_step * hidden_representation || i.e output at each unrolling, with batch size one. Thus, the squeezing
# Assumption below : Squeeze the sequence-length dimension, if the first-dimension (sequence-length) is 1.
return output.squeeze(0), state

def get_initial_state(self, batch_size):
# For multiple Batches, clone the same state
# Currently, Batch_Size is 1. WHY ? .. Gotta Come back in a while
# as we want minimal training .. Right ?
lstm_h = self.lstm_h_state.clone().repeat(1, batch_size, 1)
lstm_c = self.lstm_c_state.clone().repeat(1, batch_size, 1)
return lstm_h, lstm_c
Expand Down
7 changes: 7 additions & 0 deletions ntm/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def get_initial_state(self, batch_size):
return F.softmax(self._initial_state, dim=1).repeat(batch_size, 1)

def get_head_weight(self, x, previous_state, memory_read):
'''
Outputs Weight necessary for Read and Write Head
'''
breakpoint()
k = self.k_layer(x)
beta = F.softplus(self.beta_layer(x))
g = F.sigmoid(self.g_layer(x))
Expand All @@ -50,8 +54,11 @@ def shift(self, w_g, s):

class ReadHead(Head):
def forward(self, x, previous_state):
# NOTE : memory shape : batch_size * (128 * 20)
memory_read = self.memory.read()
# Weight across : 120 rows (1 * 120)
w = self.get_head_weight(x, previous_state, memory_read)
# Add dimension to apply same weight across different Batch of inputs
return torch.matmul(w.unsqueeze(1), memory_read).squeeze(1), w


Expand Down
2 changes: 1 addition & 1 deletion ntm/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, memory_size):
initial_state = torch.ones(memory_size) * 1e-6
self.register_buffer('initial_state', initial_state.data)

# Initial read vector is a learnt parameter
# Initial read vector is a learnt parameter |
self.initial_read = Parameter(torch.randn(1, self._memory_size[1]) * 0.01)

def get_size(self):
Expand Down
11 changes: 11 additions & 0 deletions ntm/ntm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,24 @@ def get_initial_state(self, batch_size=1):
return (read, read_head_state, write_head_state, controller_state)

def forward(self, x, previous_state):

## Whether is LSTM or Feedforward or any network of choice, we take in inputs at only one time step
## This input is processed by network to create feature
## This feature is then, used to update (read and write ) the Memory-Matrix
## So at each time step, we keep on retrieving contents and updating the memory matrix
## The retrieved content + Controller input -> A fully Connected Layer -> Output at each time-step

previous_read, previous_read_head_state, previous_write_head_state, previous_controller_state = previous_state
controller_input = torch.cat([x, previous_read], dim=1)
## LSTM will take in : 1 (seq_length) * 1 (batch_size) * 29 (input_repr_space)
## is time-step 1 a Design Choice. NEED to look at other NTM implementations
## and output : seq_len * batch_size * 100 (hidden repr) => with first dimension squeezed, if 1
controller_output, controller_state = self.controller(controller_input, previous_controller_state)
# Read
read_head_output, read_head_state = self.read_head(controller_output, previous_read_head_state)
# Write
write_head_state = self.write_head(controller_output, previous_write_head_state)
#
fc_input = torch.cat((controller_output, read_head_output), dim=1)
state = (read_head_output, read_head_state, write_head_state, controller_state)
return F.sigmoid(self.fc(fc_input)), state
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch
numpy
matplotlib
tensorboard