|
| 1 | +""" |
| 2 | +Title: Text generation with a miniature GPT |
| 3 | +Author: [Apoorv Nandan](https://twitter.com/NandanApoorv) |
| 4 | +Date created: 2020/05/29 |
| 5 | +Last modified: 2020/05/29 |
| 6 | +Description: Implement a miniature version of GPT and train it to generate text. |
| 7 | +Accelerator: GPU |
| 8 | +""" |
| 9 | +""" |
| 10 | +## Introduction |
| 11 | +
|
| 12 | +This example demonstrates how to implement an autoregressive language model |
| 13 | +using a miniature version of the GPT model. |
| 14 | +The model consists of a single Transformer block with causal masking |
| 15 | +in its attention layer. |
| 16 | +We use the text from the IMDB sentiment classification dataset for training |
| 17 | +and generate new movie reviews for a given prompt. |
| 18 | +When using this script with your own dataset, make sure it has at least |
| 19 | +1 million words. |
| 20 | +
|
| 21 | +This example should be run with `tf-nightly>=2.3.0-dev20200531` or |
| 22 | +with TensorFlow 2.3 or higher. |
| 23 | +
|
| 24 | +**References:** |
| 25 | +
|
| 26 | +- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035) |
| 27 | +- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe) |
| 28 | +- [GPT-3](https://arxiv.org/abs/2005.14165) |
| 29 | +""" |
| 30 | +""" |
| 31 | +## Setup |
| 32 | +""" |
| 33 | +# We set the backend to TensorFlow. The code works with |
| 34 | +# both `tensorflow` and `torch`. It does not work with JAX |
| 35 | +# due to the behavior of `jax.numpy.tile` in a jit scope |
| 36 | +# (used in `causal_attention_mask()`: `tile` in JAX does |
| 37 | +# not support a dynamic `reps` argument. |
| 38 | +# You can make the code work in JAX by wrapping the |
| 39 | +# inside of the `causal_attention_mask` function in |
| 40 | +# a decorator to prevent jit compilation: |
| 41 | +# `with jax.ensure_compile_time_eval():`. |
| 42 | +import os |
| 43 | +os.environ['KERAS_BACKEND'] = 'tensorflow' |
| 44 | + |
| 45 | +import keras_core as keras |
| 46 | +from keras_core import layers |
| 47 | +from keras_core import ops |
| 48 | +from keras_core.layers import TextVectorization |
| 49 | +import numpy as np |
| 50 | +import os |
| 51 | +import string |
| 52 | +import random |
| 53 | +import tensorflow |
| 54 | +import tensorflow.data as tf_data |
| 55 | +import tensorflow.strings as tf_strings |
| 56 | + |
| 57 | + |
| 58 | +""" |
| 59 | +## Implement a Transformer block as a layer |
| 60 | +""" |
| 61 | + |
| 62 | + |
| 63 | +def causal_attention_mask(batch_size, n_dest, n_src, dtype): |
| 64 | + """ |
| 65 | + Mask the upper half of the dot product matrix in self attention. |
| 66 | + This prevents flow of information from future tokens to current token. |
| 67 | + 1's in the lower triangle, counting from the lower right corner. |
| 68 | + """ |
| 69 | + i = ops.arange(n_dest)[:, None] |
| 70 | + j = ops.arange(n_src) |
| 71 | + m = i >= j - n_src + n_dest |
| 72 | + mask = ops.cast(m, dtype) |
| 73 | + mask = ops.reshape(mask, [1, n_dest, n_src]) |
| 74 | + mult = ops.concatenate( |
| 75 | + [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0 |
| 76 | + ) |
| 77 | + return ops.tile(mask, mult) |
| 78 | + |
| 79 | + |
| 80 | +class TransformerBlock(layers.Layer): |
| 81 | + def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): |
| 82 | + super().__init__() |
| 83 | + self.att = layers.MultiHeadAttention(num_heads, embed_dim) |
| 84 | + self.ffn = keras.Sequential( |
| 85 | + [ |
| 86 | + layers.Dense(ff_dim, activation="relu"), |
| 87 | + layers.Dense(embed_dim), |
| 88 | + ] |
| 89 | + ) |
| 90 | + self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) |
| 91 | + self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) |
| 92 | + self.dropout1 = layers.Dropout(rate) |
| 93 | + self.dropout2 = layers.Dropout(rate) |
| 94 | + |
| 95 | + def call(self, inputs): |
| 96 | + input_shape = ops.shape(inputs) |
| 97 | + batch_size = input_shape[0] |
| 98 | + seq_len = input_shape[1] |
| 99 | + causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool") |
| 100 | + attention_output = self.att(inputs, inputs, attention_mask=causal_mask) |
| 101 | + attention_output = self.dropout1(attention_output) |
| 102 | + out1 = self.layernorm1(inputs + attention_output) |
| 103 | + ffn_output = self.ffn(out1) |
| 104 | + ffn_output = self.dropout2(ffn_output) |
| 105 | + return self.layernorm2(out1 + ffn_output) |
| 106 | + |
| 107 | + |
| 108 | +""" |
| 109 | +## Implement an embedding layer |
| 110 | +
|
| 111 | +Create two separate embedding layers: one for tokens and one for token index |
| 112 | +(positions). |
| 113 | +""" |
| 114 | + |
| 115 | + |
| 116 | +class TokenAndPositionEmbedding(layers.Layer): |
| 117 | + def __init__(self, maxlen, vocab_size, embed_dim): |
| 118 | + super().__init__() |
| 119 | + self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim) |
| 120 | + self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim) |
| 121 | + |
| 122 | + def call(self, x): |
| 123 | + maxlen = ops.shape(x)[-1] |
| 124 | + positions = ops.arange(0, maxlen, 1) |
| 125 | + positions = self.pos_emb(positions) |
| 126 | + x = self.token_emb(x) |
| 127 | + return x + positions |
| 128 | + |
| 129 | + |
| 130 | +""" |
| 131 | +## Implement the miniature GPT model |
| 132 | +""" |
| 133 | +vocab_size = 20000 # Only consider the top 20k words |
| 134 | +maxlen = 80 # Max sequence size |
| 135 | +embed_dim = 256 # Embedding size for each token |
| 136 | +num_heads = 2 # Number of attention heads |
| 137 | +feed_forward_dim = 256 # Hidden layer size in feed forward network inside transformer |
| 138 | + |
| 139 | + |
| 140 | +def create_model(): |
| 141 | + inputs = layers.Input(shape=(maxlen,), dtype="int32") |
| 142 | + embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim) |
| 143 | + x = embedding_layer(inputs) |
| 144 | + transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim) |
| 145 | + x = transformer_block(x) |
| 146 | + outputs = layers.Dense(vocab_size)(x) |
| 147 | + model = keras.Model(inputs=inputs, outputs=[outputs, x]) |
| 148 | + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) |
| 149 | + model.compile( |
| 150 | + "adam", |
| 151 | + loss=[loss_fn, None], |
| 152 | + ) # No loss and optimization based on word embeddings from transformer block |
| 153 | + return model |
| 154 | + |
| 155 | + |
| 156 | +""" |
| 157 | +## Prepare the data for word-level language modelling |
| 158 | +
|
| 159 | +Download the IMDB dataset and combine training and validation sets for a text |
| 160 | +generation task. |
| 161 | +""" |
| 162 | + |
| 163 | +"""shell |
| 164 | +curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz |
| 165 | +tar -xf aclImdb_v1.tar.gz |
| 166 | +""" |
| 167 | + |
| 168 | + |
| 169 | +batch_size = 128 |
| 170 | + |
| 171 | +# The dataset contains each review in a separate text file |
| 172 | +# The text files are present in four different folders |
| 173 | +# Create a list all files |
| 174 | +filenames = [] |
| 175 | +directories = [ |
| 176 | + "aclImdb/train/pos", |
| 177 | + "aclImdb/train/neg", |
| 178 | + "aclImdb/test/pos", |
| 179 | + "aclImdb/test/neg", |
| 180 | +] |
| 181 | +for dir in directories: |
| 182 | + for f in os.listdir(dir): |
| 183 | + filenames.append(os.path.join(dir, f)) |
| 184 | + |
| 185 | +print(f"{len(filenames)} files") |
| 186 | + |
| 187 | +# Create a dataset from text files |
| 188 | +random.shuffle(filenames) |
| 189 | +text_ds = tf_data.TextLineDataset(filenames) |
| 190 | +text_ds = text_ds.shuffle(buffer_size=256) |
| 191 | +text_ds = text_ds.batch(batch_size) |
| 192 | + |
| 193 | + |
| 194 | +def custom_standardization(input_string): |
| 195 | + """Remove html line-break tags and handle punctuation""" |
| 196 | + lowercased = tf_strings.lower(input_string) |
| 197 | + stripped_html = tf_strings.regex_replace(lowercased, "<br />", " ") |
| 198 | + return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1") |
| 199 | + |
| 200 | + |
| 201 | +# Create a vectorization layer and adapt it to the text |
| 202 | +vectorize_layer = TextVectorization( |
| 203 | + standardize=custom_standardization, |
| 204 | + max_tokens=vocab_size - 1, |
| 205 | + output_mode="int", |
| 206 | + output_sequence_length=maxlen + 1, |
| 207 | +) |
| 208 | +vectorize_layer.adapt(text_ds) |
| 209 | +vocab = vectorize_layer.get_vocabulary() # To get words back from token indices |
| 210 | + |
| 211 | + |
| 212 | +def prepare_lm_inputs_labels(text): |
| 213 | + """ |
| 214 | + Shift word sequences by 1 position so that the target for position (i) is |
| 215 | + word at position (i+1). The model will use all words up till position (i) |
| 216 | + to predict the next word. |
| 217 | + """ |
| 218 | + text = tensorflow.expand_dims(text, -1) |
| 219 | + tokenized_sentences = vectorize_layer(text) |
| 220 | + x = tokenized_sentences[:, :-1] |
| 221 | + y = tokenized_sentences[:, 1:] |
| 222 | + return x, y |
| 223 | + |
| 224 | + |
| 225 | +text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE) |
| 226 | +text_ds = text_ds.prefetch(tf_data.AUTOTUNE) |
| 227 | + |
| 228 | + |
| 229 | +""" |
| 230 | +## Implement a Keras callback for generating text |
| 231 | +""" |
| 232 | + |
| 233 | + |
| 234 | +class TextGenerator(keras.callbacks.Callback): |
| 235 | + """A callback to generate text from a trained model. |
| 236 | + 1. Feed some starting prompt to the model |
| 237 | + 2. Predict probabilities for the next token |
| 238 | + 3. Sample the next token and add it to the next input |
| 239 | +
|
| 240 | + Arguments: |
| 241 | + max_tokens: Integer, the number of tokens to be generated after prompt. |
| 242 | + start_tokens: List of integers, the token indices for the starting prompt. |
| 243 | + index_to_word: List of strings, obtained from the TextVectorization layer. |
| 244 | + top_k: Integer, sample from the `top_k` token predictions. |
| 245 | + print_every: Integer, print after this many epochs. |
| 246 | + """ |
| 247 | + |
| 248 | + def __init__( |
| 249 | + self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1 |
| 250 | + ): |
| 251 | + self.max_tokens = max_tokens |
| 252 | + self.start_tokens = start_tokens |
| 253 | + self.index_to_word = index_to_word |
| 254 | + self.print_every = print_every |
| 255 | + self.k = top_k |
| 256 | + |
| 257 | + def sample_from(self, logits): |
| 258 | + logits, indices = ops.top_k(logits, k=self.k, sorted=True) |
| 259 | + indices = np.asarray(indices).astype("int32") |
| 260 | + preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0] |
| 261 | + preds = np.asarray(preds).astype("float32") |
| 262 | + return np.random.choice(indices, p=preds) |
| 263 | + |
| 264 | + def detokenize(self, number): |
| 265 | + return self.index_to_word[number] |
| 266 | + |
| 267 | + def on_epoch_end(self, epoch, logs=None): |
| 268 | + start_tokens = [_ for _ in self.start_tokens] |
| 269 | + if (epoch + 1) % self.print_every != 0: |
| 270 | + return |
| 271 | + num_tokens_generated = 0 |
| 272 | + tokens_generated = [] |
| 273 | + while num_tokens_generated <= self.max_tokens: |
| 274 | + pad_len = maxlen - len(start_tokens) |
| 275 | + sample_index = len(start_tokens) - 1 |
| 276 | + if pad_len < 0: |
| 277 | + x = start_tokens[:maxlen] |
| 278 | + sample_index = maxlen - 1 |
| 279 | + elif pad_len > 0: |
| 280 | + x = start_tokens + [0] * pad_len |
| 281 | + else: |
| 282 | + x = start_tokens |
| 283 | + x = np.array([x]) |
| 284 | + y, _ = self.model.predict(x) |
| 285 | + sample_token = self.sample_from(y[0][sample_index]) |
| 286 | + tokens_generated.append(sample_token) |
| 287 | + start_tokens.append(sample_token) |
| 288 | + num_tokens_generated = len(tokens_generated) |
| 289 | + txt = " ".join( |
| 290 | + [self.detokenize(_) for _ in self.start_tokens + tokens_generated] |
| 291 | + ) |
| 292 | + print(f"generated text:\n{txt}\n") |
| 293 | + |
| 294 | + |
| 295 | +# Tokenize starting prompt |
| 296 | +word_to_index = {} |
| 297 | +for index, word in enumerate(vocab): |
| 298 | + word_to_index[word] = index |
| 299 | + |
| 300 | +start_prompt = "this movie is" |
| 301 | +start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()] |
| 302 | +num_tokens_generated = 40 |
| 303 | +text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab) |
| 304 | + |
| 305 | + |
| 306 | +""" |
| 307 | +## Train the model |
| 308 | +
|
| 309 | +Note: This code should preferably be run on GPU. |
| 310 | +""" |
| 311 | + |
| 312 | +model = create_model() |
| 313 | + |
| 314 | +model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback]) |
0 commit comments