Skip to content

Commit

Permalink
XLNet support and overhaul/cleanup of BERT support (#845)
Browse files Browse the repository at this point in the history
* Rename namespaces to suppress warnings.

* Revert "Rename namespaces to suppress warnings."

This reverts commit 0cf7b23.

* Initial working-ish attempt.

* Intermediate check-in...

* More partial progress.

* Another pass...

* Fix sep/cls handling, cleanup.

* Further cleanup.

* Keyword name fix.

* Another flag fix.

* Pull debug print.

* Line length cleanup.

* WiC fix.

* Two task setup bugs.

* BoolQ typo

* Improved segment handling.

* Delete unused is_pair_task, other cleanup/fixes.

* Fix deleted path from merge.

* Fix cache path.

* Address (spurious?) tokenization warning.

* Select pool_type automatically to match model.

h/t Haokun Liu

* Config updates.

* Path fix

* Fix XLNet UNK handling.

* Internal temporary MNLI alternate.

* Revert "Internal temporary MNLI alternate."

This reverts commit 455792a.

* Add helper fn tests

* Finish merge

* Remove unused argument.

* Possible ReCoRD bug fix

* Cleanup

* Fix merge issues.

* Revert "Remove unused argument."

This reverts commit 96a7c37.

* Assorted responses to Alex's commenst.

* Further ReCoRD fix.

* @iftenney's comments.

* Fix/simplify segment logic.

* @W4ngatang's comments

* Cleanup.

* Cleanup

* Fix issues with alternative embeddings_mode settings, max_layer.

* More mix cleanup.

* Masking fix.

* Address (most of) @iftenney's comments

* Tidying.

* Misc cleanup.

* Comment.
  • Loading branch information
sleepinyourhat authored Aug 7, 2019
1 parent 23ad1a7 commit a1e9abf
Show file tree
Hide file tree
Showing 38 changed files with 940 additions and 577 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ user_config.sh
.idea
.ipynb_checkpoints/
perluniprops/
.DS_Store
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ A few things you might want to know about `jiant`:
- `jiant` is configuration-driven. You can run an enormous variety of experiments by simply writing configuration files. Of course, if you need to add any major new features, you can also easily edit or extend the code.
- `jiant` contains implementations of strong baselines for the [GLUE](https://gluebenchmark.com) and [SuperGLUE](https://super.gluebenchmark.com/) benchmarks, and it's the recommended starting point for work on these benchmarks.
- `jiant` was developed at [the 2018 JSALT Workshop](https://www.clsp.jhu.edu/workshops/18-workshop/) by [the General-Purpose Sentence Representation Learning](https://jsalt18-sentence-repl.github.io/) team and is maintained by [the NYU Machine Learning for Language Lab](https://wp.nyu.edu/ml2/people/), with help from [many outside collaborators](https://github.com/nyu-mll/jiant/graphs/contributors) (especially Google AI Language's [Ian Tenney](https://ai.google/research/people/IanTenney)).
- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-pretrained-BERT) of BERT and GPT.
- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-transformers) of GPT, BERT, and XLNet.
- The name `jiant` doesn't mean much. The 'j' stands for JSALT. That's all the acronym we have.

## Getting Started
Expand Down Expand Up @@ -84,10 +84,10 @@ This package is released under the [MIT License](LICENSE.md). The material in th

## Acknowledgments

- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories.
- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories.
- This work was made possible in part by a donation to NYU from Eric and Wendy Schmidt made
by recommendation of the Schmidt Futures program.
- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work.
- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work.
- Developer Alex Wang is supported by the National Science Foundation Graduate Research Fellowship Program under Grant
No. DGE 1342536. Any opinions, findings, and conclusions or recommendations expressed in this
material are those of the author(s) and do not necessarily reflect the views of the National Science
Expand Down
14 changes: 10 additions & 4 deletions cola_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@

from jiant.models import build_model
from jiant.preprocess import build_indexers, build_tasks
from jiant.tasks.tasks import process_sentence, sentence_to_text_field
from jiant.tasks.tasks import tokenize_and_truncate, sentence_to_text_field
from jiant.utils import config
from jiant.utils.data_loaders import load_tsv
from jiant.utils.utils import check_arg_name, load_model_state
from jiant.utils.utils import check_arg_name, load_model_state, select_pool_type

log.basicConfig(format="%(asctime)s: %(message)s", datefmt="%m/%d %I:%M:%S %p", level=log.INFO)

Expand Down Expand Up @@ -121,6 +121,7 @@ def main(cl_arguments):
cl_args = handle_arguments(cl_arguments)
args = config.params_from_file(cl_args.config_file, cl_args.overrides)
check_arg_name(args)

assert args.target_tasks == "cola", "Currently only supporting CoLA. ({})".format(
args.target_tasks
)
Expand All @@ -138,6 +139,11 @@ def main(cl_arguments):
)
args.cuda = -1

if args.tokenizer == "auto":
args.tokenizer = tokenizers.select_tokenizer(args)
if args.pool_type == "auto":
args.pool_type = select_pool_type(args)

# Prepare data #
_, target_tasks, vocab, word_embs = build_tasks(args)
tasks = sorted(set(target_tasks), key=lambda x: x.name)
Expand Down Expand Up @@ -185,7 +191,7 @@ def run_repl(model, vocab, indexers, task, args):
if input_string == "QUIT":
break

tokens = process_sentence(
tokens = tokenize_and_truncate(
tokenizer_name=task.tokenizer_name, sent=input_string, max_seq_len=args.max_seq_len
)
print("TOKENS:", " ".join("[{}]".format(tok) for tok in tokens))
Expand Down Expand Up @@ -282,7 +288,7 @@ def load_cola_data(input_path, task, input_format, max_seq_len):
with open(input_path, "r") as f_in:
sentences = f_in.readlines()
tokens = [
process_sentence(
tokenize_and_truncate(
tokenizer_name=task.tokenizer_name, sent=sentence, max_seq_len=max_seq_len
)
for sentence in sentences
Expand Down
2 changes: 0 additions & 2 deletions config/ccg_bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ include "defaults.conf"
pretrain_tasks = ccg
target_tasks = ccg
input_module = bert-base-uncased
tokenizer = ${input_module}
do_target_task_training = 0
transfer_paradigm = finetune

Expand All @@ -16,7 +15,6 @@ skip_embs = 1

// BERT-specific setup
classifier = log_reg // following BERT paper
pool_type = first

dropout = 0.1 // following BERT paper
optimizer = bert_adam
Expand Down
2 changes: 0 additions & 2 deletions config/copa_bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ do_full_eval = 1

// Typical BERT base setup
input_module = bert-base-uncased
tokenizer = bert-base-uncased
transfer_paradigm = finetune
classifier = log_reg
pool_type = first
optimizer = bert_adam
lr = 0.00001
sent_enc = none
Expand Down
116 changes: 60 additions & 56 deletions config/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ transfer_paradigm = "frozen" // How to use pretrained model parameters during ta
// "frozen" will train the downstream models on fixed
// representations from the encoder model.
// "finetune" will update the parameters of the encoders models as
// well as the downstream models.
// well as the downstream models. (This disables d_proj.)
load_target_train_checkpoint = none // If not "none", load the specified model_state checkpoint
// file when starting do_target_task_training.
// Supports * wildcards.
Expand Down Expand Up @@ -140,6 +140,9 @@ batch_size = 32 // Training batch size.
optimizer = adam // Optimizer. All valid AllenNLP options are available, including 'sgd'.
// Use 'bert_adam' for reproducing BERT experiments.
// 'adam' uses the newer AMSGrad variant.
// Warning: bert_adam is designed for cases where the number of epochs is known
// in advance, so it may not behave reasonably unless max_epochs is set to a
// reasonable positive value.
lr = 0.0001 // Initial learning rate.
min_lr = 0.000001 // Minimum learning rate. Training will stop when our explicit LR decay lowers
// the LR below this point or if any other stopping criterion applies.
Expand Down Expand Up @@ -221,42 +224,41 @@ max_targ_word_v_size = 20000 // Maximum target word vocab size for seq2seq task

// Input Handling //

input_module = "" // The word embedding or contextual word representation layer.
// Currently supported options:
// - scratch: Word embeddings trained from scratch.
// - glove: Leaded GloVe word embeddings. Typically used with
// tokenizer = MosesTokenizer. Note that this is not quite identical to the
// Stanford tokenizer used to train GloVe.
// - fastText: Leaded GloVe word embeddings. Use with
// tokenizer = MosesTokenizer.
// - elmo: AllenNLP's ELMo contextualized word vector model hidden states. Use
// with tokenizer = MosesTokenizer.
// - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's
// ELMo, but not ELMo's LSTM layer hidden states. Use with
// tokenizer = MosesTokenizer.
// - bert-base-uncased, etc.: Any BERT model specifier that is valid for
// pytorch-pretrained-bert may be specified here. Use with
// tokenizer = ${input_module}
// We support the newer bert-large-uncased-whole-word-masking and
// bert-large-cased-whole-word-masking cased models, but they require
// the git development version of pytorch-pretrained-bert. To use these
// models, follow the instructions under 'From source' here:
// https://github.com/huggingface/pytorch-pretrained-BERT
// Most of these options use MosesTokenizer tokenization, but
// BERT and GPT need more specific tokenization (tokenizer config
// parameter should be equal to input_module for BERT, and should be
// equal to 'OpenAI.BPE' if input_module = gpt).
// For ELMo, BERT, and GPT, there are additional config parameters below.

tokenizer = "MosesTokenizer" // The name of the tokenizer, passed to the Task constructor for
// appropriate handling during data loading. Currently supported
// options:
// - "": Split the input data on whitespace.
// - MosesTokenizer: Our standard word tokenizer. (Support for
// other NLTK tokenizers is pending.)
// - bert-uncased-base, etc.: Use the tokenizer supplied with
// pytorch-pretrained-bert that corresponds to that BERT model.
// - OpenAI.BPE: The tokenizer supplied with OpenAI GPT.
input_module = "" // The word embedding or contextual word representation layer.
// Currently supported options:
// - scratch: Word embeddings trained from scratch.
// - glove: Loaded GloVe word embeddings. Typically used with
// tokenizer = MosesTokenizer. Note that this is not quite identical to
// the Stanford tokenizer used to train GloVe.
// - fastText: Loaded fastText word embeddings. Use with
// tokenizer = MosesTokenizer.
// - elmo: AllenNLP's ELMo contextualized word vector model hidden states. Use
// with tokenizer = MosesTokenizer.
// - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's
// ELMo, but not ELMo's LSTM layer hidden states. Use with
// tokenizer = MosesTokenizer.
// - gpt: The OpenAI GPT language model encoder.
// Use with tokenizer = OpenAI.BPE.
// - bert-base-uncased, etc.: Any BERT model specifier that is valid for
// pytorch-pretrained-bert may be specified here. Use with
// tokenizer = ${input_module}
// We support the newer bert-large-uncased-whole-word-masking and
// bert-large-cased-whole-word-masking cased models, but they require
// the git development version of pytorch-pretrained-bert. To use these
// models, follow the instructions under 'From source' here:
// https://github.com/huggingface/pytorch-pretrained-BERT

tokenizer = auto // The name of the tokenizer, passed to the Task constructor for
// appropriate handling during data loading. Currently supported
// options:
// - auto: Select the tokenizer that matches the model specified in
// input_module above. Usually a safe default.
// - "": Split the input data on whitespace.
// - MosesTokenizer: Our standard word tokenizer. (Support for
// other NLTK tokenizers is pending.)
// - bert-uncased-base, etc.: Use the tokenizer supplied with
// pytorch-pretrained-bert that corresponds to that BERT model.
// - OpenAI.BPE: The tokenizer supplied with OpenAI GPT.

word_embs_file = ${WORD_EMBS_FILE} // Path to embeddings file, used with glove and fastText.
d_word = 300 // Dimension of word embeddings, used with scratch, glove, or fastText.
Expand All @@ -282,22 +284,21 @@ openai_embeddings_mode = "none" // How to handle the embedding layer of the Ope
// "mix" uses ELMo-style scalar mixing (with
// learned weights) across all layers.

bert_embeddings_mode = "none" // How to handle the embedding layer of the
// BERT model:
// "none" or "top" returns only top-layer activation,
// "cat" returns top-layer concatenated with
// lexical layer,
// "only" returns only lexical layer,
// "mix" uses ELMo-style scalar mixing (with
// learned weights) across all layers.
bert_max_layer = -1 // Maximum layer to return from BERT encoder. Layer 0 is
// wordpiece embeddings.
// bert_embeddings_mode will behave as if the BERT encoder
// is truncated at this layer, so 'top' will return this
// layer, and 'mix' will return a mix of all layers up to
// and including this layer.
// Set to -1 to use all layers.
// Used for probing experiments.
pytorch_transformers_output_mode = "none" // How to handle the embedding layer of the
// BERT/XLNet model:
// "none" or "top" returns only top-layer activation,
// "cat" returns top-layer concatenated with
// lexical layer,
// "only" returns only lexical layer,
// "mix" uses ELMo-style scalar mixing (with learned
// weights) across all layers.
pytorch_transformers_max_layer = -1 // Maximum layer to return from BERT etc. encoder. Layer 0 is
// wordpiece embeddings. pytorch_transformers_embeddings_mode
// will behave as if the is truncated at this layer, so 'top'
// will return this layer, and 'mix' will return a mix of all
// layers up to and including this layer.
// Set to -1 to use all layers.
// Used for probing experiments.

force_include_wsj_vocabulary = 0 // Set if using PTB parsing (grammar induction) task. Makes sure
// to include WSJ vocabulary.
Expand All @@ -320,7 +321,7 @@ n_layers_enc = 2 // Number of layers for a 'rnn' sent_enc.
skip_embs = 1 // If true, concatenate the sent_enc's input (ELMo/GPT/BERT output or
// embeddings) with the sent_enc's output.
sep_embs_for_skip = 0 // Whether the skip embedding uses the same embedder object as the original
//embedding (before skip).
// embedding (before skip).
// Only makes a difference if we are using ELMo weights, where it allows
// the four tuned ELMo scalars to vary separately for each target task.
n_layers_highway = 0 // Number of highway layers between the embedding layer and the sent_enc layer. [Deprecated.]
Expand Down Expand Up @@ -364,8 +365,11 @@ pair_attn = 1 // If true, use attn in sentence-pair classification/regression t
d_hid_attn = 512 // Post-attention LSTM state size.
shared_pair_attn = 0 // If true, share pair_attn parameters across all tasks that use it.
d_proj = 512 // Size of task-specific linear projection applied before before pooling.
pool_type = "max" // Type of pooling to reduce sequences of vectors into a single vector.
// Options: "max", "mean", "first", "final"
// Disabled when fine-tuning pytorch_transformers models.
pool_type = "auto" // Type of pooling to reduce sequences of vectors into a single vector.
// Options: "auto", "max", "mean", "first", "final"
// "auto" uses "first" for plain BERT (with no sent_enc), "final" for plain
// XLNet and GPT, and "max" in all other settings.
span_classifier_loss_fn = "softmax" // Classifier loss function. Used only in some tasks (notably
// span-related tasks), not mlp/fancy_mlp. Currently supports
// sigmoid and softmax.
Expand Down
2 changes: 0 additions & 2 deletions config/examples/copa_bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ do_full_eval = 1

// Typical BERT base setup
input_module = bert-base-uncased
tokenizer = bert-base-uncased
transfer_paradigm = finetune
classifier = log_reg
pool_type = first
optimizer = bert_adam
lr = 0.00001
sent_enc = none
Expand Down
4 changes: 1 addition & 3 deletions config/examples/stilts_example.conf
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ batch_size = 24
write_preds = "val,test"

//BERT-specific parameters
bert_embeddings_mode = "top"
pool_type = "first"
pytorch_transformers_output_mode = "top"
sep_embs_for_skip = 1
sent_enc = "none"
classifier = log_reg // following BERT paper
Expand All @@ -34,6 +33,5 @@ patience = 20
max_vals = 10000
transfer_paradigm = "finetune"

tokenizer = "bert-base-uncased"
input_module = "bert-base-uncased"

5 changes: 2 additions & 3 deletions config/superglue-bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ exp_name = "bert-large-cased"
// Data and preprocessing settings
max_seq_len = 256 // Mainly needed for MultiRC, to avoid over-truncating
// But not 512 as that is really hard to fit in memory.
tokenizer = "bert-large-cased"

// Model settings
input_module = "bert-large-cased"
bert_embeddings_mode = "top"
pool_type = "first"
pytorch_transformers_output_mode = "top"
pair_attn = 0 // shouldn't be needed but JIC
s2s = {
attention = none
Expand Down
6 changes: 6 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ dependencies:
- ftfy==5.4.1
- spacy==2.0.11

# Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_
# pytorch_transformers > 1.0. These are the same package, though the name changed between
# these two versions. AllenNLP requires 0.6 to support the BertAdam optimizer, and jiant
# directly requires 1.0 to support XLNet and WWM-BERT.
# This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067
- pytorch-transformers==1.0.0
4 changes: 2 additions & 2 deletions gcp/config/jiant_paths.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ export JIANT_PROJECT_PREFIX="$HOME/exp"

# pre-downloaded ELMo models
export ELMO_SRC_DIR="/nfs/jiant/share/elmo"
# cache for BERT models
export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/bert_cache"
# cache for BERT etc. models
export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/pytorch_transformers_cache"
# word embeddings
export WORD_EMBS_FILE="/nfs/jiant/share/wiki-news-300d-1M.vec"

1 change: 0 additions & 1 deletion gcp/kubernetes/run_batch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,3 @@ jsonnet -S -o "${YAML_FILE}" \
##
# Create the Kubernetes pod; this will actually launch the job.
kubectl ${KUBECTL_MODE} -f "${YAML_FILE}"

Loading

0 comments on commit a1e9abf

Please sign in to comment.