Skip to content

Commit 18085ff

Browse files
committed
Add lstm and gradient check
1 parent 1fd8a6b commit 18085ff

File tree

4 files changed

+319
-16
lines changed

4 files changed

+319
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
*.pyc
33
*.gz
44
*.pkz
5+
.DS_Store

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Implemented algorithms:
7373
* Convolutional layer with vectorized img2col and col2img
7474
* Recurrent neural network
7575
* Backpropagation through time (BPTT)
76+
* Long short-term memory
7677
* Generative Adversarial Networks (GAN)
7778

7879
* Optimization Algorithms (See implementations in MLP or Regression)
@@ -101,7 +102,6 @@ Implemented algorithms:
101102
* Prediction by Viterbi
102103

103104
Work in progress:
104-
* Long short-term memory
105105
* Deep Q-Network (Reinforcement learning)
106106

107107
Feel free to use the code. Please contact me if you have any question: xiecng [at] gmail.com

lstm.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import numpy as np
2+
import requests
3+
import re
4+
# TODO add sentence tokenizer
5+
6+
7+
def sigmoid(x):
8+
return 1 / (1 + np.exp(-x))
9+
10+
def tanh(x):
11+
return np.tanh(x)
12+
13+
def dsigmoid(grad_a, act):
14+
return grad_a * (act - np.square(act))
15+
16+
def dtanh(grad_a, act):
17+
return grad_a * (1 - np.square(act))
18+
19+
def softmax(x):
20+
eps = 1e-20
21+
out = np.exp(x - np.max(x, axis=1).reshape(-1, 1))
22+
return out / (np.sum(out, axis=1).reshape(-1, 1) + eps)
23+
24+
def cross_entropy(pred, y):
25+
return -(np.multiply(y, np.log(pred + 1e-20))).sum()
26+
27+
28+
class LSTM(object):
29+
def __init__(self, n_input, n_hidden, n_label, n_t):
30+
self.loss = cross_entropy
31+
self.n_hidden, self.n_label = n_hidden, n_label
32+
self.lr, self.batch_size, self.epochs = 1, 32, 200
33+
self.eps = 1e-20
34+
self.n_t = n_t
35+
36+
self.w_f, self.w_i, self.w_c, self.w_o = [np.random.randn(n_input, self.n_hidden) / n_input for _ in range(4)]
37+
self.u_f, self.u_i, self.u_c, self.u_o = [np.random.randn(self.n_hidden, self.n_hidden) / self.n_hidden for _ in range(4)]
38+
self.b_f, self.b_i, self.b_c, self.b_o = [np.random.randn(1, self.n_hidden) for _ in range(4)]
39+
self.u_v, self.b_v = np.random.randn(self.n_hidden, self.n_label) / self.n_hidden, np.random.randn(1, self.n_label)
40+
41+
self.param_list = [
42+
self.w_f, self.w_i, self.w_c, self.w_o,
43+
self.u_f, self.u_i, self.u_c, self.u_o, self.u_v,
44+
self.b_f, self.b_i, self.b_c, self.b_o, self.b_v
45+
]
46+
self.mom_list = [np.zeros_like(param) for param in self.param_list]
47+
self.cache_list = [np.zeros_like(param) for param in self.param_list]
48+
49+
def fit(self, x, label):
50+
b_size = self.batch_size
51+
n_t, n_data, n_input = x.shape
52+
y = np.zeros((n_t * n_data, self.n_label))
53+
y[np.arange(n_t * n_data), label.flatten()] = 1
54+
y = y.reshape((n_t, n_data, self.n_label))
55+
constant = np.ones((1, self.batch_size*n_t))
56+
57+
for epoch in range(self.epochs):
58+
permut=np.random.permutation(n_data//b_size*b_size).reshape(-1, b_size)
59+
for b_idx in range(permut.shape[0]):
60+
x_batch = x[:, permut[b_idx, :]].reshape(n_t * b_size, n_input)
61+
y_batch = y[:, permut[b_idx, :]].reshape(n_t * b_size, self.n_label)
62+
h, f, i, c, o, c_bar, grad_f, grad_i, grad_o, grad_c, grad_c_bar = [
63+
np.zeros((n_t * b_size, self.n_hidden)) for _ in range(11)
64+
]
65+
66+
# forward pass
67+
for t in range(n_t):
68+
t_idx = np.arange(t * b_size, (t + 1) * b_size)
69+
t_idx_prev = t_idx - b_size if t > 0 else t_idx
70+
71+
xt_batch, ht_prev = x_batch[t_idx], h[t_idx_prev]
72+
73+
f[t_idx] = sigmoid(xt_batch @ self.w_f + ht_prev @ self.u_f + self.b_f)
74+
i[t_idx] = sigmoid(xt_batch @ self.w_i + ht_prev @ self.u_i + self.b_i)
75+
o[t_idx] = sigmoid(xt_batch @ self.w_o + ht_prev @ self.u_o + self.b_o)
76+
c_bar[t_idx] = tanh(xt_batch @ self.w_c + ht_prev @ self.u_c + self.b_c)
77+
c[t_idx] = f[t_idx] * c[t_idx_prev] + i[t_idx] * c_bar[t_idx]
78+
h[t_idx] = o[t_idx] * tanh(c[t_idx])
79+
80+
c_prev = np.zeros(c.shape)
81+
c_prev[b_size:, :] = c[:-b_size, :]
82+
h_prev = np.zeros(h.shape)
83+
h_prev[b_size:, :] = h[:-b_size, :]
84+
85+
# back propagation through time
86+
grad_v = softmax(h @ self.u_v + self.b_v) - y_batch
87+
grad_h = grad_v @ self.u_v.T
88+
89+
for t in reversed(range(0, n_t)):
90+
t_idx = np.arange(t * b_size, (t + 1) * b_size)
91+
if t < n_t - 1:
92+
grad_h[t_idx] += (
93+
dsigmoid(grad_f[t_idx + b_size], f[t_idx + b_size]) @ self.u_f.T +
94+
dsigmoid(grad_i[t_idx + b_size], i[t_idx + b_size]) @ self.u_i.T +
95+
dsigmoid(grad_o[t_idx + b_size], o[t_idx + b_size]) @ self.u_o.T +
96+
dtanh(grad_c_bar[t_idx + b_size], c_bar[t_idx + b_size]) @ self.u_c.T
97+
)
98+
grad_c[t_idx] = o[t_idx] * grad_h[t_idx] * (1 - np.square(np.tanh(c[t_idx])))
99+
if t < n_t - 1:
100+
grad_c[t_idx] += f[t_idx + b_size] * grad_c[t_idx + b_size]
101+
grad_f[t_idx] = grad_c[t_idx] * c_prev[t_idx]
102+
grad_i[t_idx] = grad_c[t_idx] * c_bar[t_idx]
103+
grad_o[t_idx] = grad_h[t_idx] * tanh(c[t_idx])
104+
grad_c_bar[t_idx] = grad_c[t_idx] * i[t_idx]
105+
106+
self.adam(
107+
grad_list=[
108+
x_batch.T @ dsigmoid(grad_f, f), x_batch.T @ dsigmoid(grad_i, i), x_batch.T @ dtanh(grad_c_bar, c_bar), x_batch.T @ dsigmoid(grad_o, o),
109+
h_prev.T @ dsigmoid(grad_f, f), h_prev.T @ dsigmoid(grad_i, i), h_prev.T @ dtanh(grad_c_bar, c_bar), h_prev.T @ dsigmoid(grad_o, o), h.T @ grad_v,
110+
constant @ dsigmoid(grad_f, f), constant @ dsigmoid(grad_i, i), constant @ dtanh(grad_c_bar, c_bar), constant @ dsigmoid(grad_o, o), constant @ grad_v
111+
]
112+
)
113+
self.regularization()
114+
if hasattr(self, 'ix_to_word'):
115+
print(self.sample(np.random.randint(n_input), np.random.randn(1, self.n_hidden), np.random.randn(1, self.n_hidden), n_t * 4))
116+
print(self.loss(self.predict(x).reshape(n_t * n_data, self.n_label), y.reshape(n_t * n_data, self.n_label)))
117+
118+
def gradient_check(self, x, label):
119+
n_t, n_data, n_input = x.shape
120+
y = np.zeros((n_t * n_data, self.n_label))
121+
y[np.arange(n_t * n_data), label.flatten()] = 1
122+
x_batch = x.reshape(n_t * n_data, n_input)
123+
h, f, i, c, o, c_bar, grad_f, grad_i, grad_o, grad_c, grad_c_bar = [
124+
np.zeros((n_t * n_data, self.n_hidden)) for _ in range(11)
125+
]
126+
127+
constant = np.ones((1, n_data*n_t))
128+
# forward pass
129+
for t in range(n_t):
130+
t_idx = np.arange(t * n_data, (t + 1) * n_data)
131+
t_idx_prev = t_idx - n_data if t > 0 else t_idx
132+
133+
xt_batch, ht_prev = x_batch[t_idx], h[t_idx_prev]
134+
f[t_idx] = sigmoid(xt_batch @ self.w_f + ht_prev @ self.u_f + self.b_f)
135+
i[t_idx] = sigmoid(xt_batch @ self.w_i + ht_prev @ self.u_i + self.b_i)
136+
o[t_idx] = sigmoid(xt_batch @ self.w_o + ht_prev @ self.u_o + self.b_o)
137+
c_bar[t_idx] = tanh(xt_batch @ self.w_c + ht_prev @ self.u_c + self.b_c)
138+
c[t_idx] = f[t_idx] * c[t_idx_prev] + i[t_idx] * c_bar[t_idx]
139+
h[t_idx] = o[t_idx] * tanh(c[t_idx])
140+
141+
c_prev = np.zeros(c.shape)
142+
c_prev[n_data:, :] = c[:-n_data, :]
143+
h_prev = np.zeros(h.shape)
144+
h_prev[n_data:, :] = h[:-n_data, :]
145+
146+
# back propagation through time
147+
grad_v = softmax(h @ self.u_v + self.b_v) - y
148+
grad_h = grad_v @ self.u_v.T
149+
150+
for t in reversed(range(0, n_t)):
151+
t_idx = np.arange(t * n_data, (t + 1) * n_data)
152+
if t < n_t - 1:
153+
grad_h[t_idx] += (
154+
dsigmoid(grad_f[t_idx + n_data], f[t_idx + n_data]) @ self.u_f.T +
155+
dsigmoid(grad_i[t_idx + n_data], i[t_idx + n_data]) @ self.u_i.T +
156+
dsigmoid(grad_o[t_idx + n_data], o[t_idx + n_data]) @ self.u_o.T +
157+
dtanh(grad_c_bar[t_idx + n_data], c_bar[t_idx + n_data]) @ self.u_c.T
158+
)
159+
grad_c[t_idx] = o[t_idx] * grad_h[t_idx] * (1 - np.square(np.tanh(c[t_idx])))
160+
if t < n_t - 1:
161+
grad_c[t_idx] += f[t_idx + n_data] * grad_c[t_idx + n_data]
162+
grad_f[t_idx] = grad_c[t_idx] * c_prev[t_idx]
163+
grad_i[t_idx] = grad_c[t_idx] * c_bar[t_idx]
164+
grad_o[t_idx] = grad_h[t_idx] * tanh(c[t_idx])
165+
grad_c_bar[t_idx] = grad_c[t_idx] * i[t_idx]
166+
167+
index = (0, 1)
168+
eps = 1e-4
169+
for j, grad in enumerate([
170+
x_batch.T @ dsigmoid(grad_f, f), x_batch.T @ dsigmoid(grad_i, i), x_batch.T @ dtanh(grad_c_bar, c_bar), x_batch.T @ dsigmoid(grad_o, o),
171+
h_prev.T @ dsigmoid(grad_f, f), h_prev.T @ dsigmoid(grad_i, i), h_prev.T @ dtanh(grad_c_bar, c_bar), h_prev.T @ dsigmoid(grad_o, o), h.T @ grad_v,
172+
constant @ dsigmoid(grad_f, f), constant @ dsigmoid(grad_i, i), constant @ dtanh(grad_c_bar, c_bar), constant @ dsigmoid(grad_o, o), constant @ grad_v
173+
]):
174+
params_a = [param.copy() for param in self.param_list]
175+
params_b = [param.copy() for param in self.param_list]
176+
params_a[j][index]+=eps
177+
params_b[j][index]-=eps
178+
179+
w_f_a, w_i_a, w_c_a, w_o_a, u_f_a, u_i_a, u_c_a, u_o_a, u_v_a, b_f_a, b_i_a, b_c_a, b_o_a, b_v_a = params_a
180+
w_f_b, w_i_b, w_c_b, w_o_b, u_f_b, u_i_b, u_c_b, u_o_b, u_v_b, b_f_b, b_i_b, b_c_b, b_o_b, b_v_b = params_b
181+
h_a, f_a, i_a, c_a, o_a, c_bar_a, h_b, f_b, i_b, c_b, o_b, c_bar_b = [
182+
np.zeros((n_t * n_data, self.n_hidden)) for _ in range(12)
183+
]
184+
185+
for t in range(n_t):
186+
t_idx = np.arange(t * n_data, (t + 1) * n_data)
187+
t_idx_prev = t_idx - n_data if t > 0 else t_idx
188+
189+
xt_batch, ht_prev_a, ht_prev_b = x_batch[t_idx], h_a[t_idx_prev], h_b[t_idx_prev]
190+
f_a[t_idx] = sigmoid(xt_batch @ w_f_a + ht_prev_a @ u_f_a + b_f_a)
191+
i_a[t_idx] = sigmoid(xt_batch @ w_i_a + ht_prev_a @ u_i_a + b_i_a)
192+
o_a[t_idx] = sigmoid(xt_batch @ w_o_a + ht_prev_a @ u_o_a + b_o_a)
193+
c_bar_a[t_idx] = tanh(xt_batch @ w_c_a + ht_prev_a @ u_c_a + b_c_a)
194+
c_a[t_idx] = f_a[t_idx] * c_a[t_idx_prev] + i_a[t_idx] * c_bar_a[t_idx]
195+
h_a[t_idx] = o_a[t_idx] * tanh(c_a[t_idx])
196+
197+
f_b[t_idx] = sigmoid(xt_batch @ w_f_b + ht_prev_b @ u_f_b + b_f_b)
198+
i_b[t_idx] = sigmoid(xt_batch @ w_i_b + ht_prev_b @ u_i_b + b_i_b)
199+
o_b[t_idx] = sigmoid(xt_batch @ w_o_b + ht_prev_b @ u_o_b + b_o_b)
200+
c_bar_b[t_idx] = tanh(xt_batch @ w_c_b + ht_prev_b @ u_c_b + b_c_b)
201+
c_b[t_idx] = f_b[t_idx] * c_b[t_idx_prev] + i_b[t_idx] * c_bar_b[t_idx]
202+
h_b[t_idx] = o_b[t_idx] * tanh(c_b[t_idx])
203+
204+
pred_a = cross_entropy(softmax(h_a @ u_v_a + b_v_a), y)
205+
pred_b = cross_entropy(softmax(h_b @ u_v_b + b_v_b), y)
206+
print('gradient_check', j, ((pred_a - pred_b) / eps / 2 - grad[index])/eps/eps)
207+
208+
209+
def sgd(self, grad_list):
210+
alpha = self.lr / self.batch_size / self.n_t
211+
for params, grads in zip(self.param_list, grad_list):
212+
params -= alpha * grads
213+
214+
def adam(self, grad_list):
215+
beta1 = 0.9
216+
beta2 = 0.999
217+
alpha = self.lr / self.batch_size / self.n_t
218+
for params, grads, mom, cache in zip(
219+
self.param_list, grad_list, self.mom_list, self.cache_list
220+
):
221+
mom += (beta1 - 1) * mom + (1 - beta1) * grads
222+
cache += (beta2 - 1) * cache + (1 - beta2) * np.square(grads)
223+
params -= alpha * mom / (np.sqrt(cache) + self.eps)
224+
225+
def regularization(self):
226+
lbd = 1e-5
227+
for params in self.param_list:
228+
params -= lbd * params
229+
230+
def predict(self, x):
231+
n_t, n_data, n_input = x.shape
232+
h, f, i, c, o = [np.zeros((n_t * n_data, self.n_hidden)) for _ in range(5)]
233+
# forward pass
234+
for t in range(n_t):
235+
t_idx = np.arange(t * n_data, (t + 1) * n_data)
236+
t_idx_prev = t_idx - n_data if t > 0 else t_idx
237+
f[t_idx] = sigmoid(x[t] @ self.w_f + h[t_idx_prev] @ self.u_f + self.b_f)
238+
i[t_idx] = sigmoid(x[t] @ self.w_i + h[t_idx_prev] @ self.u_i + self.b_i)
239+
o[t_idx] = sigmoid(x[t] @ self.w_o + h[t_idx_prev] @ self.u_o + self.b_o)
240+
c[t_idx] = f[t_idx] * c[t_idx_prev] + i[t_idx] * tanh(x[t] @ self.w_c + h[t_idx_prev] @ self.u_c + self.b_c)
241+
h[t_idx] = o[t_idx] * tanh(c[t_idx])
242+
return softmax(h @ self.u_v + self.b_v).reshape(n_t, n_data, self.n_label)
243+
244+
def sample(self, x_idx, h, c, seq_length):
245+
n_input = self.w_f.shape[0]
246+
seq = [x_idx]
247+
for t in range(seq_length):
248+
x = np.zeros((1, n_input))
249+
x[0, seq[-1]] = 1
250+
251+
f = sigmoid(x @ self.w_f + h @ self.u_f + self.b_f)
252+
i = sigmoid(x @ self.w_i + h @ self.u_i + self.b_i)
253+
o = sigmoid(x @ self.w_o + h @ self.u_o + self.b_o)
254+
c = f * c + i * tanh(x @ self.w_c + h @ self.u_c + self.b_c)
255+
h = o * tanh(c)
256+
y = softmax(h @ self.u_v + self.b_v)
257+
seq.append(np.random.choice(range(n_input), p=y.flatten()))
258+
return ''.join(np.vectorize(self.ix_to_word.get)(np.array(seq)).tolist())
259+
260+
261+
def text_generation(use_word=True):
262+
text = requests.get('http://www.gutenberg.org/cache/epub/11/pg11.txt').text
263+
if use_word:
264+
text = [word+' ' for word in re.sub("[^a-zA-Z]", " ", text).lower().split()]
265+
266+
words = sorted(list(set(text)))
267+
text_size, vocab_size = len(text), len(words)
268+
269+
print(f'text has {text_size} characters, {vocab_size} unique.')
270+
word_to_ix = {word:i for i, word in enumerate(words)}
271+
ix_to_word = {i:word for i, word in enumerate(words)}
272+
273+
seq_length = 25
274+
indices = np.vectorize(word_to_ix.get)(np.array(list(text)))
275+
data = np.zeros((text_size, vocab_size))
276+
data[np.arange(text_size), indices] = 1
277+
n_text = (text_size - 1) // seq_length
278+
x = data[:n_text * seq_length].reshape(n_text, seq_length, vocab_size).transpose(1,0,2)
279+
y = indices[1: n_text * seq_length + 1].reshape(n_text, seq_length).T
280+
281+
test_ratio = 0.2
282+
test_split = np.random.uniform(0, 1, x.shape[1])
283+
train_x, test_x = x[:,test_split >= test_ratio,:] , x[:,test_split < test_ratio,:]
284+
train_y, test_y = y[:,test_split >= test_ratio], y[:,test_split < test_ratio]
285+
286+
lstm = LSTM(vocab_size, 500, vocab_size, seq_length)
287+
lstm.ix_to_word = ix_to_word
288+
lstm.gradient_check(train_x[:,np.arange(32),:], train_y[:,np.arange(32)])
289+
lstm.fit(train_x, train_y)
290+
print('train loss', (np.argmax(lstm.predict(train_x), axis=2)==train_y).sum()/(train_y.shape[0] * train_y.shape[1]))
291+
print('test loss', (np.argmax(lstm.predict(test_x), axis=2)==test_y).sum()/(test_y.shape[0] * test_y.shape[1]))
292+
293+
294+
def main():
295+
text_generation(use_word=False)
296+
297+
298+
if __name__ == "__main__":
299+
main()

0 commit comments

Comments
 (0)