Skip to content

Commit 76f567d

Browse files
knathanieltuckermartinwicke
authored andcommitted
add the namignizer model (tensorflow#147)
1 parent dc7791d commit 76f567d

File tree

5 files changed

+602
-0
lines changed

5 files changed

+602
-0
lines changed

namignizer/.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Remove the pyc files
2+
*.pyc
3+
4+
# Ignore the model and the data
5+
model/
6+
data/

namignizer/README.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Namignizer
2+
3+
Use a variation of the [PTB](https://www.tensorflow.org/versions/r0.8/tutorials/recurrent/index.html#recurrent-neural-networks) model to recognize and generate names using the [Kaggle Baby Name Database](https://www.kaggle.com/kaggle/us-baby-names).
4+
5+
### API
6+
Namignizer is implemented in Tensorflow 0.8r and uses the python package `pandas` for some data processing.
7+
8+
#### How to use
9+
Download the data from Kaggle and place it in your data directory (or use the small training data provided). The example data looks like so:
10+
11+
```
12+
Id,Name,Year,Gender,Count
13+
1,Mary,1880,F,7065
14+
2,Anna,1880,F,2604
15+
3,Emma,1880,F,2003
16+
4,Elizabeth,1880,F,1939
17+
5,Minnie,1880,F,1746
18+
6,Margaret,1880,F,1578
19+
7,Ida,1880,F,1472
20+
8,Alice,1880,F,1414
21+
9,Bertha,1880,F,1320
22+
```
23+
24+
But any data with the two columns: `Name` and `Count` will work.
25+
26+
With the data, we can then train the model:
27+
28+
```python
29+
train("data/SmallNames.txt", "model/namignizer", SmallConfig)
30+
```
31+
32+
And you will get the output:
33+
34+
```
35+
Reading Name data in data/SmallNames.txt
36+
Epoch: 1 Learning rate: 1.000
37+
0.090 perplexity: 18.539 speed: 282 lps
38+
...
39+
0.890 perplexity: 1.478 speed: 285 lps
40+
0.990 perplexity: 1.477 speed: 284 lps
41+
Epoch: 13 Train Perplexity: 1.477
42+
```
43+
44+
This will as a side effect write model checkpoints to the `model` directory. With this you will be able to determine the perplexity your model will give you for any arbitrary set of names like so:
45+
46+
```python
47+
namignize(["mary", "ida", "gazorpazorp", "houyhnhnms", "bob"],
48+
tf.train.latest_checkpoint("model"), SmallConfig)
49+
```
50+
You will provide the same config and the same checkpoint directory. This will allow you to use a the model you just trained. You will then get a perplexity output for each name like so:
51+
52+
```
53+
Name mary gives us a perplexity of 1.03105580807
54+
Name ida gives us a perplexity of 1.07770049572
55+
Name gazorpazorp gives us a perplexity of 175.940353394
56+
Name houyhnhnms gives us a perplexity of 9.53870773315
57+
Name bob gives us a perplexity of 6.03938627243
58+
```
59+
60+
Finally, you will also be able generate names using the model like so:
61+
62+
```python
63+
namignator(tf.train.latest_checkpoint("model"), SmallConfig)
64+
```
65+
66+
Again, you will need to provide the same config and the same checkpoint directory. This will allow you to use a the model you just trained. You will then get a single generated name. Examples of output that I got when using the provided data are:
67+
68+
```
69+
['b', 'e', 'r', 't', 'h', 'a', '`']
70+
['m', 'a', 'r', 'y', '`']
71+
['a', 'n', 'n', 'a', '`']
72+
['m', 'a', 'r', 'y', '`']
73+
['b', 'e', 'r', 't', 'h', 'a', '`']
74+
['a', 'n', 'n', 'a', '`']
75+
['e', 'l', 'i', 'z', 'a', 'b', 'e', 't', 'h', '`']
76+
```
77+
78+
Notice that each name ends with a backtick. This marks the end of the name.
79+
80+
### Contact Info
81+
82+
Feel free to reach out to me at knt(at google) or k.nathaniel.tucker(at gmail)

namignizer/data_utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Utilities for parsing Kaggle baby names files."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import collections
21+
import os
22+
23+
import numpy as np
24+
import tensorflow as tf
25+
import pandas as pd
26+
27+
# the default end of name rep will be zero
28+
_EON = 0
29+
30+
31+
def read_names(names_path):
32+
"""read data from downloaded file. See SmallNames.txt for example format
33+
or go to https://www.kaggle.com/kaggle/us-baby-names for full lists
34+
35+
Args:
36+
names_path: path to the csv file similar to the example type
37+
Returns:
38+
Dataset: a namedtuple of two elements: deduped names and their associated
39+
counts. The names contain only 26 chars and are all lower case
40+
"""
41+
names_data = pd.read_csv(names_path)
42+
names_data.Name = names_data.Name.str.lower()
43+
44+
name_data = names_data.groupby(by=["Name"])["Count"].sum()
45+
name_counts = np.array(name_data.tolist())
46+
names_deduped = np.array(name_data.index.tolist())
47+
48+
Dataset = collections.namedtuple('Dataset', ['Name', 'Count'])
49+
return Dataset(names_deduped, name_counts)
50+
51+
52+
def _letter_to_number(letter):
53+
"""converts letters to numbers between 1 and 27"""
54+
# ord of lower case 'a' is 97
55+
return ord(letter) - 96
56+
57+
58+
def namignizer_iterator(names, counts, batch_size, num_steps, epoch_size):
59+
"""Takes a list of names and counts like those output from read_names, and
60+
makes an iterator yielding a batch_size by num_steps array of random names
61+
separated by an end of name token. The names are choosen randomly according
62+
to their counts. The batch may end mid-name
63+
64+
Args:
65+
names: a set of lowercase names composed of 26 characters
66+
counts: a list of the frequency of those names
67+
batch_size: int
68+
num_steps: int
69+
epoch_size: number of batches to yield
70+
Yields:
71+
(x, y): a batch_size by num_steps array of ints representing letters, where
72+
x will be the input and y will be the target
73+
"""
74+
name_distribution = counts / counts.sum()
75+
76+
for i in range(epoch_size):
77+
data = np.zeros(batch_size * num_steps + 1)
78+
samples = np.random.choice(names, size=batch_size * num_steps // 2,
79+
replace=True, p=name_distribution)
80+
81+
data_index = 0
82+
for sample in samples:
83+
if data_index >= batch_size * num_steps:
84+
break
85+
for letter in map(_letter_to_number, sample) + [_EON]:
86+
if data_index >= batch_size * num_steps:
87+
break
88+
data[data_index] = letter
89+
data_index += 1
90+
91+
x = data[:batch_size * num_steps].reshape((batch_size, num_steps))
92+
y = data[1:batch_size * num_steps + 1].reshape((batch_size, num_steps))
93+
94+
yield (x, y)
95+
96+
97+
def name_to_batch(name, batch_size, num_steps):
98+
""" Takes a single name and fills a batch with it
99+
100+
Args:
101+
name: lowercase composed of 26 characters
102+
batch_size: int
103+
num_steps: int
104+
Returns:
105+
x, y: a batch_size by num_steps array of ints representing letters, where
106+
x will be the input and y will be the target. The array is filled up
107+
to the length of the string, the rest is filled with zeros
108+
"""
109+
data = np.zeros(batch_size * num_steps + 1)
110+
111+
data_index = 0
112+
for letter in map(_letter_to_number, name) + [_EON]:
113+
data[data_index] = letter
114+
data_index += 1
115+
116+
x = data[:batch_size * num_steps].reshape((batch_size, num_steps))
117+
y = data[1:batch_size * num_steps + 1].reshape((batch_size, num_steps))
118+
119+
return x, y

namignizer/model.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""RNN model with embeddings"""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
23+
class NamignizerModel(object):
24+
"""The Namignizer model ~ strongly based on PTB"""
25+
26+
def __init__(self, is_training, config):
27+
self.batch_size = batch_size = config.batch_size
28+
self.num_steps = num_steps = config.num_steps
29+
size = config.hidden_size
30+
# will always be 27
31+
vocab_size = config.vocab_size
32+
33+
# placeholders for inputs
34+
self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
35+
self._targets = tf.placeholder(tf.int32, [batch_size, num_steps])
36+
# weights for the loss function
37+
self._weights = tf.placeholder(tf.float32, [batch_size * num_steps])
38+
39+
# lstm for our RNN cell (GRU supported too)
40+
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.0)
41+
if is_training and config.keep_prob < 1:
42+
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
43+
lstm_cell, output_keep_prob=config.keep_prob)
44+
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)
45+
46+
self._initial_state = cell.zero_state(batch_size, tf.float32)
47+
48+
with tf.device("/cpu:0"):
49+
embedding = tf.get_variable("embedding", [vocab_size, size])
50+
inputs = tf.nn.embedding_lookup(embedding, self._input_data)
51+
52+
if is_training and config.keep_prob < 1:
53+
inputs = tf.nn.dropout(inputs, config.keep_prob)
54+
55+
outputs = []
56+
state = self._initial_state
57+
with tf.variable_scope("RNN"):
58+
for time_step in range(num_steps):
59+
if time_step > 0:
60+
tf.get_variable_scope().reuse_variables()
61+
(cell_output, state) = cell(inputs[:, time_step, :], state)
62+
outputs.append(cell_output)
63+
64+
output = tf.reshape(tf.concat(1, outputs), [-1, size])
65+
softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
66+
softmax_b = tf.get_variable("softmax_b", [vocab_size])
67+
logits = tf.matmul(output, softmax_w) + softmax_b
68+
loss = tf.nn.seq2seq.sequence_loss_by_example(
69+
[logits],
70+
[tf.reshape(self._targets, [-1])],
71+
[self._weights])
72+
self._loss = loss
73+
self._cost = cost = tf.reduce_sum(loss) / batch_size
74+
self._final_state = state
75+
76+
# probabilities of each letter
77+
self._activations = tf.nn.softmax(logits)
78+
79+
# ability to save the model
80+
self.saver = tf.train.Saver(tf.all_variables())
81+
82+
if not is_training:
83+
return
84+
85+
self._lr = tf.Variable(0.0, trainable=False)
86+
tvars = tf.trainable_variables()
87+
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
88+
config.max_grad_norm)
89+
optimizer = tf.train.GradientDescentOptimizer(self.lr)
90+
self._train_op = optimizer.apply_gradients(zip(grads, tvars))
91+
92+
def assign_lr(self, session, lr_value):
93+
session.run(tf.assign(self.lr, lr_value))
94+
95+
@property
96+
def input_data(self):
97+
return self._input_data
98+
99+
@property
100+
def targets(self):
101+
return self._targets
102+
103+
@property
104+
def activations(self):
105+
return self._activations
106+
107+
@property
108+
def weights(self):
109+
return self._weights
110+
111+
@property
112+
def initial_state(self):
113+
return self._initial_state
114+
115+
@property
116+
def cost(self):
117+
return self._cost
118+
119+
@property
120+
def loss(self):
121+
return self._loss
122+
123+
@property
124+
def final_state(self):
125+
return self._final_state
126+
127+
@property
128+
def lr(self):
129+
return self._lr
130+
131+
@property
132+
def train_op(self):
133+
return self._train_op

0 commit comments

Comments
 (0)