Skip to content

Commit 3bc3544

Browse files
authored
Convert "text generation with miniature gpt" to Keras Core (#676)
* copy text_generation_with_miniature_gpt.py from keras io * make it work with keras_core * add notes for JAX backend * mv text_generation_with_miniature_gpt to backend agnostic folder
1 parent db9d734 commit 3bc3544

File tree

1 file changed

+314
-0
lines changed

1 file changed

+314
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
"""
2+
Title: Text generation with a miniature GPT
3+
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
4+
Date created: 2020/05/29
5+
Last modified: 2020/05/29
6+
Description: Implement a miniature version of GPT and train it to generate text.
7+
Accelerator: GPU
8+
"""
9+
"""
10+
## Introduction
11+
12+
This example demonstrates how to implement an autoregressive language model
13+
using a miniature version of the GPT model.
14+
The model consists of a single Transformer block with causal masking
15+
in its attention layer.
16+
We use the text from the IMDB sentiment classification dataset for training
17+
and generate new movie reviews for a given prompt.
18+
When using this script with your own dataset, make sure it has at least
19+
1 million words.
20+
21+
This example should be run with `tf-nightly>=2.3.0-dev20200531` or
22+
with TensorFlow 2.3 or higher.
23+
24+
**References:**
25+
26+
- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)
27+
- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)
28+
- [GPT-3](https://arxiv.org/abs/2005.14165)
29+
"""
30+
"""
31+
## Setup
32+
"""
33+
# We set the backend to TensorFlow. The code works with
34+
# both `tensorflow` and `torch`. It does not work with JAX
35+
# due to the behavior of `jax.numpy.tile` in a jit scope
36+
# (used in `causal_attention_mask()`: `tile` in JAX does
37+
# not support a dynamic `reps` argument.
38+
# You can make the code work in JAX by wrapping the
39+
# inside of the `causal_attention_mask` function in
40+
# a decorator to prevent jit compilation:
41+
# `with jax.ensure_compile_time_eval():`.
42+
import os
43+
os.environ['KERAS_BACKEND'] = 'tensorflow'
44+
45+
import keras_core as keras
46+
from keras_core import layers
47+
from keras_core import ops
48+
from keras_core.layers import TextVectorization
49+
import numpy as np
50+
import os
51+
import string
52+
import random
53+
import tensorflow
54+
import tensorflow.data as tf_data
55+
import tensorflow.strings as tf_strings
56+
57+
58+
"""
59+
## Implement a Transformer block as a layer
60+
"""
61+
62+
63+
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
64+
"""
65+
Mask the upper half of the dot product matrix in self attention.
66+
This prevents flow of information from future tokens to current token.
67+
1's in the lower triangle, counting from the lower right corner.
68+
"""
69+
i = ops.arange(n_dest)[:, None]
70+
j = ops.arange(n_src)
71+
m = i >= j - n_src + n_dest
72+
mask = ops.cast(m, dtype)
73+
mask = ops.reshape(mask, [1, n_dest, n_src])
74+
mult = ops.concatenate(
75+
[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0
76+
)
77+
return ops.tile(mask, mult)
78+
79+
80+
class TransformerBlock(layers.Layer):
81+
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
82+
super().__init__()
83+
self.att = layers.MultiHeadAttention(num_heads, embed_dim)
84+
self.ffn = keras.Sequential(
85+
[
86+
layers.Dense(ff_dim, activation="relu"),
87+
layers.Dense(embed_dim),
88+
]
89+
)
90+
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
91+
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
92+
self.dropout1 = layers.Dropout(rate)
93+
self.dropout2 = layers.Dropout(rate)
94+
95+
def call(self, inputs):
96+
input_shape = ops.shape(inputs)
97+
batch_size = input_shape[0]
98+
seq_len = input_shape[1]
99+
causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")
100+
attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
101+
attention_output = self.dropout1(attention_output)
102+
out1 = self.layernorm1(inputs + attention_output)
103+
ffn_output = self.ffn(out1)
104+
ffn_output = self.dropout2(ffn_output)
105+
return self.layernorm2(out1 + ffn_output)
106+
107+
108+
"""
109+
## Implement an embedding layer
110+
111+
Create two separate embedding layers: one for tokens and one for token index
112+
(positions).
113+
"""
114+
115+
116+
class TokenAndPositionEmbedding(layers.Layer):
117+
def __init__(self, maxlen, vocab_size, embed_dim):
118+
super().__init__()
119+
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
120+
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
121+
122+
def call(self, x):
123+
maxlen = ops.shape(x)[-1]
124+
positions = ops.arange(0, maxlen, 1)
125+
positions = self.pos_emb(positions)
126+
x = self.token_emb(x)
127+
return x + positions
128+
129+
130+
"""
131+
## Implement the miniature GPT model
132+
"""
133+
vocab_size = 20000 # Only consider the top 20k words
134+
maxlen = 80 # Max sequence size
135+
embed_dim = 256 # Embedding size for each token
136+
num_heads = 2 # Number of attention heads
137+
feed_forward_dim = 256 # Hidden layer size in feed forward network inside transformer
138+
139+
140+
def create_model():
141+
inputs = layers.Input(shape=(maxlen,), dtype="int32")
142+
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
143+
x = embedding_layer(inputs)
144+
transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
145+
x = transformer_block(x)
146+
outputs = layers.Dense(vocab_size)(x)
147+
model = keras.Model(inputs=inputs, outputs=[outputs, x])
148+
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
149+
model.compile(
150+
"adam",
151+
loss=[loss_fn, None],
152+
) # No loss and optimization based on word embeddings from transformer block
153+
return model
154+
155+
156+
"""
157+
## Prepare the data for word-level language modelling
158+
159+
Download the IMDB dataset and combine training and validation sets for a text
160+
generation task.
161+
"""
162+
163+
"""shell
164+
curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
165+
tar -xf aclImdb_v1.tar.gz
166+
"""
167+
168+
169+
batch_size = 128
170+
171+
# The dataset contains each review in a separate text file
172+
# The text files are present in four different folders
173+
# Create a list all files
174+
filenames = []
175+
directories = [
176+
"aclImdb/train/pos",
177+
"aclImdb/train/neg",
178+
"aclImdb/test/pos",
179+
"aclImdb/test/neg",
180+
]
181+
for dir in directories:
182+
for f in os.listdir(dir):
183+
filenames.append(os.path.join(dir, f))
184+
185+
print(f"{len(filenames)} files")
186+
187+
# Create a dataset from text files
188+
random.shuffle(filenames)
189+
text_ds = tf_data.TextLineDataset(filenames)
190+
text_ds = text_ds.shuffle(buffer_size=256)
191+
text_ds = text_ds.batch(batch_size)
192+
193+
194+
def custom_standardization(input_string):
195+
"""Remove html line-break tags and handle punctuation"""
196+
lowercased = tf_strings.lower(input_string)
197+
stripped_html = tf_strings.regex_replace(lowercased, "<br />", " ")
198+
return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
199+
200+
201+
# Create a vectorization layer and adapt it to the text
202+
vectorize_layer = TextVectorization(
203+
standardize=custom_standardization,
204+
max_tokens=vocab_size - 1,
205+
output_mode="int",
206+
output_sequence_length=maxlen + 1,
207+
)
208+
vectorize_layer.adapt(text_ds)
209+
vocab = vectorize_layer.get_vocabulary() # To get words back from token indices
210+
211+
212+
def prepare_lm_inputs_labels(text):
213+
"""
214+
Shift word sequences by 1 position so that the target for position (i) is
215+
word at position (i+1). The model will use all words up till position (i)
216+
to predict the next word.
217+
"""
218+
text = tensorflow.expand_dims(text, -1)
219+
tokenized_sentences = vectorize_layer(text)
220+
x = tokenized_sentences[:, :-1]
221+
y = tokenized_sentences[:, 1:]
222+
return x, y
223+
224+
225+
text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)
226+
text_ds = text_ds.prefetch(tf_data.AUTOTUNE)
227+
228+
229+
"""
230+
## Implement a Keras callback for generating text
231+
"""
232+
233+
234+
class TextGenerator(keras.callbacks.Callback):
235+
"""A callback to generate text from a trained model.
236+
1. Feed some starting prompt to the model
237+
2. Predict probabilities for the next token
238+
3. Sample the next token and add it to the next input
239+
240+
Arguments:
241+
max_tokens: Integer, the number of tokens to be generated after prompt.
242+
start_tokens: List of integers, the token indices for the starting prompt.
243+
index_to_word: List of strings, obtained from the TextVectorization layer.
244+
top_k: Integer, sample from the `top_k` token predictions.
245+
print_every: Integer, print after this many epochs.
246+
"""
247+
248+
def __init__(
249+
self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
250+
):
251+
self.max_tokens = max_tokens
252+
self.start_tokens = start_tokens
253+
self.index_to_word = index_to_word
254+
self.print_every = print_every
255+
self.k = top_k
256+
257+
def sample_from(self, logits):
258+
logits, indices = ops.top_k(logits, k=self.k, sorted=True)
259+
indices = np.asarray(indices).astype("int32")
260+
preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]
261+
preds = np.asarray(preds).astype("float32")
262+
return np.random.choice(indices, p=preds)
263+
264+
def detokenize(self, number):
265+
return self.index_to_word[number]
266+
267+
def on_epoch_end(self, epoch, logs=None):
268+
start_tokens = [_ for _ in self.start_tokens]
269+
if (epoch + 1) % self.print_every != 0:
270+
return
271+
num_tokens_generated = 0
272+
tokens_generated = []
273+
while num_tokens_generated <= self.max_tokens:
274+
pad_len = maxlen - len(start_tokens)
275+
sample_index = len(start_tokens) - 1
276+
if pad_len < 0:
277+
x = start_tokens[:maxlen]
278+
sample_index = maxlen - 1
279+
elif pad_len > 0:
280+
x = start_tokens + [0] * pad_len
281+
else:
282+
x = start_tokens
283+
x = np.array([x])
284+
y, _ = self.model.predict(x)
285+
sample_token = self.sample_from(y[0][sample_index])
286+
tokens_generated.append(sample_token)
287+
start_tokens.append(sample_token)
288+
num_tokens_generated = len(tokens_generated)
289+
txt = " ".join(
290+
[self.detokenize(_) for _ in self.start_tokens + tokens_generated]
291+
)
292+
print(f"generated text:\n{txt}\n")
293+
294+
295+
# Tokenize starting prompt
296+
word_to_index = {}
297+
for index, word in enumerate(vocab):
298+
word_to_index[word] = index
299+
300+
start_prompt = "this movie is"
301+
start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]
302+
num_tokens_generated = 40
303+
text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)
304+
305+
306+
"""
307+
## Train the model
308+
309+
Note: This code should preferably be run on GPU.
310+
"""
311+
312+
model = create_model()
313+
314+
model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])

0 commit comments

Comments
 (0)