Skip to content

Commit fa65d8a

Browse files
author
Lukasz Kaiser
committed
Adding Neural GPU code.
1 parent 15f82d2 commit fa65d8a

File tree

5 files changed

+936
-0
lines changed

5 files changed

+936
-0
lines changed

neural_gpu/BUILD

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
py_library(
2+
name = "data_utils",
3+
srcs = [
4+
"data_utils.py",
5+
],
6+
deps = [
7+
"//file/colossus/public:cns",
8+
"//third_party/py/numpy",
9+
"//third_party/py/tensorflow",
10+
],
11+
)
12+
13+
py_library(
14+
name = "neural_gpu",
15+
srcs = [
16+
"neural_gpu.py",
17+
],
18+
deps = [
19+
":data_utils",
20+
"//third_party/py/numpy",
21+
"//third_party/py/tensorflow",
22+
],
23+
)
24+
25+
py_binary(
26+
name = "neural_gpu_trainer",
27+
srcs = [
28+
"neural_gpu_trainer.py",
29+
],
30+
launcher = "//devtools/python/launcher",
31+
malloc = "//tcmalloc:tcmalloc_or_debug",
32+
deps = [
33+
":neural_gpu",
34+
"//file/colossus/public:cns",
35+
"//net/proto2/python/public:use_fast_cpp_protos",
36+
"//third_party/py/Tkinter",
37+
"//third_party/py/matplotlib",
38+
"//third_party/py/numpy",
39+
"//third_party/py/tensorflow",
40+
],
41+
)

neural_gpu/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# NeuralGPU
2+
Code for the Neural GPU model as described
3+
in [[http://arxiv.org/abs/1511.08228]].
4+

neural_gpu/data_utils.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""Convolutional Gated Recurrent Networks for Algorithm Learning."""
2+
3+
import math
4+
import random
5+
import sys
6+
import time
7+
8+
import google3
9+
10+
import numpy as np
11+
import tensorflow as tf
12+
13+
from google3.third_party.tensorflow.python.platform import gfile
14+
15+
FLAGS = tf.app.flags.FLAGS
16+
17+
bins = [8, 16, 32, 64, 128]
18+
all_tasks = ["sort", "id", "rev", "incr", "left", "right", "left-shift", "add",
19+
"right-shift", "bmul", "dup", "badd", "qadd"]
20+
forward_max = 128
21+
log_filename = ""
22+
23+
24+
def pad(l):
25+
for b in bins:
26+
if b >= l: return b
27+
return forward_max
28+
29+
30+
train_set = {}
31+
test_set = {}
32+
for some_task in all_tasks:
33+
train_set[some_task] = []
34+
test_set[some_task] = []
35+
for all_max_len in xrange(10000):
36+
train_set[some_task].append([])
37+
test_set[some_task].append([])
38+
39+
40+
def add(n1, n2, base=10):
41+
"""Add two numbers represented as lower-endian digit lists."""
42+
k = max(len(n1), len(n2)) + 1
43+
d1 = n1 + [0 for _ in xrange(k - len(n1))]
44+
d2 = n2 + [0 for _ in xrange(k - len(n2))]
45+
res = []
46+
carry = 0
47+
for i in xrange(k):
48+
if d1[i] + d2[i] + carry < base:
49+
res.append(d1[i] + d2[i] + carry)
50+
carry = 0
51+
else:
52+
res.append(d1[i] + d2[i] + carry - base)
53+
carry = 1
54+
while res and res[-1] == 0:
55+
res = res[:-1]
56+
if res: return res
57+
return [0]
58+
59+
60+
def init_data(task, length, nbr_cases, nclass):
61+
"""Data initialization."""
62+
def rand_pair(l, task):
63+
"""Random data pair for a task. Total length should be <= l."""
64+
k = (l-1)/2
65+
base = 10
66+
if task[0] == "b": base = 2
67+
if task[0] == "q": base = 4
68+
d1 = [np.random.randint(base) for _ in xrange(k)]
69+
d2 = [np.random.randint(base) for _ in xrange(k)]
70+
if task in ["add", "badd", "qadd"]:
71+
res = add(d1, d2, base)
72+
elif task in ["bmul"]:
73+
d1n = sum([d * (base ** i) for i, d in enumerate(d1)])
74+
d2n = sum([d * (base ** i) for i, d in enumerate(d2)])
75+
res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
76+
else:
77+
sys.exit()
78+
sep = [12]
79+
if task in ["add", "badd", "qadd"]: sep = [11]
80+
inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2]
81+
return inp, [r + 1 for r in res]
82+
83+
def rand_dup_pair(l):
84+
"""Random data pair for duplication task. Total length should be <= l."""
85+
k = l/2
86+
x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)]
87+
inp = x + [0 for _ in xrange(l - k)]
88+
res = x + x + [0 for _ in xrange(l - 2*k)]
89+
return inp, res
90+
91+
def spec(inp):
92+
"""Return the target given the input for some tasks."""
93+
if task == "sort":
94+
return sorted(inp)
95+
elif task == "id":
96+
return inp
97+
elif task == "rev":
98+
return [i for i in reversed(inp)]
99+
elif task == "incr":
100+
carry = 1
101+
res = []
102+
for i in xrange(len(inp)):
103+
if inp[i] + carry < nclass:
104+
res.append(inp[i] + carry)
105+
carry = 0
106+
else:
107+
res.append(1)
108+
carry = 1
109+
return res
110+
elif task == "left":
111+
return [inp[0]]
112+
elif task == "right":
113+
return [inp[-1]]
114+
elif task == "left-shift":
115+
return [inp[l-1] for l in xrange(len(inp))]
116+
elif task == "right-shift":
117+
return [inp[l+1] for l in xrange(len(inp))]
118+
else:
119+
print_out("Unknown spec for task " + str(task))
120+
sys.exit()
121+
122+
l = length
123+
cur_time = time.time()
124+
total_time = 0.0
125+
for case in xrange(nbr_cases):
126+
total_time += time.time() - cur_time
127+
cur_time = time.time()
128+
if l > 10000 and case % 100 == 1:
129+
print_out(" avg gen time %.4f s" % (total_time / float(case)))
130+
if task in ["add", "badd", "qadd", "bmul"]:
131+
i, t = rand_pair(l, task)
132+
train_set[task][len(i)].append([i, t])
133+
i, t = rand_pair(l, task)
134+
test_set[task][len(i)].append([i, t])
135+
elif task == "dup":
136+
i, t = rand_dup_pair(l)
137+
train_set[task][len(i)].append([i, t])
138+
i, t = rand_dup_pair(l)
139+
test_set[task][len(i)].append([i, t])
140+
else:
141+
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
142+
target = spec(inp)
143+
train_set[task][l].append([inp, target])
144+
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
145+
target = spec(inp)
146+
test_set[task][l].append([inp, target])
147+
148+
149+
def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None):
150+
"""Get a batch of data, training or testing."""
151+
inputs = []
152+
targets = []
153+
length = max_length
154+
if preset is None:
155+
cur_set = test_set[task]
156+
if do_train: cur_set = train_set[task]
157+
while not cur_set[length]:
158+
length -= 1
159+
pad_length = pad(length)
160+
for b in xrange(batch_size):
161+
if preset is None:
162+
elem = random.choice(cur_set[length])
163+
if offset is not None and offset + b < len(cur_set[length]):
164+
elem = cur_set[length][offset + b]
165+
else:
166+
elem = preset
167+
inp, target = elem[0], elem[1]
168+
assert len(inp) == length
169+
inputs.append(inp + [0 for l in xrange(pad_length - len(inp))])
170+
targets.append(target + [0 for l in xrange(pad_length - len(target))])
171+
res_input = []
172+
res_target = []
173+
for l in xrange(pad_length):
174+
new_input = np.array([inputs[b][l] for b in xrange(batch_size)],
175+
dtype=np.int32)
176+
new_target = np.array([targets[b][l] for b in xrange(batch_size)],
177+
dtype=np.int32)
178+
res_input.append(new_input)
179+
res_target.append(new_target)
180+
return res_input, res_target
181+
182+
183+
def print_out(s, newline=True):
184+
"""Print a message out and log it to file."""
185+
if log_filename:
186+
try:
187+
with gfile.GFile(log_filename, mode="a") as f:
188+
f.write(s + ("\n" if newline else ""))
189+
# pylint: disable=bare-except
190+
except:
191+
sys.stdout.write("Error appending to %s\n" % log_filename)
192+
sys.stdout.write(s + ("\n" if newline else ""))
193+
sys.stdout.flush()
194+
195+
196+
def decode(output):
197+
return [np.argmax(o, axis=1) for o in output]
198+
199+
200+
def accuracy(inpt, output, target, batch_size, nprint):
201+
"""Calculate output accuracy given target."""
202+
assert nprint < batch_size + 1
203+
def task_print(inp, output, target):
204+
stop_bound = 0
205+
print_len = 0
206+
while print_len < len(target) and target[print_len] > stop_bound:
207+
print_len += 1
208+
print_out(" i: " + " ".join([str(i - 1) for i in inp if i > 0]))
209+
print_out(" o: " +
210+
" ".join([str(output[l] - 1) for l in xrange(print_len)]))
211+
print_out(" t: " +
212+
" ".join([str(target[l] - 1) for l in xrange(print_len)]))
213+
decoded_target = target
214+
decoded_output = decode(output)
215+
total = 0
216+
errors = 0
217+
seq = [0 for b in xrange(batch_size)]
218+
for l in xrange(len(decoded_output)):
219+
for b in xrange(batch_size):
220+
if decoded_target[l][b] > 0:
221+
total += 1
222+
if decoded_output[l][b] != decoded_target[l][b]:
223+
seq[b] = 1
224+
errors += 1
225+
e = 0 # Previous error index
226+
for _ in xrange(min(nprint, sum(seq))):
227+
while seq[e] == 0:
228+
e += 1
229+
task_print([inpt[l][e] for l in xrange(len(inpt))],
230+
[decoded_output[l][e] for l in xrange(len(decoded_target))],
231+
[decoded_target[l][e] for l in xrange(len(decoded_target))])
232+
e += 1
233+
for b in xrange(nprint - errors):
234+
task_print([inpt[l][b] for l in xrange(len(inpt))],
235+
[decoded_output[l][b] for l in xrange(len(decoded_target))],
236+
[decoded_target[l][b] for l in xrange(len(decoded_target))])
237+
return errors, total, sum(seq)
238+
239+
240+
def safe_exp(x):
241+
perp = 10000
242+
if x < 100: perp = math.exp(x)
243+
if perp > 10000: return 10000
244+
return perp

0 commit comments

Comments
 (0)