From f48f59681dbf34d6740cf6c19f45a4b618984f07 Mon Sep 17 00:00:00 2001 From: Yada Pruksachatkun Date: Thu, 23 May 2019 07:30:22 +0700 Subject: [PATCH 1/3] Adding parameters for WSC to superglue configerations (#691) * renamed elmo skip embedding variable * Revert "renamed elmo skip embedding variable" This reverts commit 24ff66e8aeced1f2627919a86f85780de0e1ca66. * adding WSC configerations to match expeirments * adding WSC configerations to match expeirments --- config/superglue-bert.conf | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config/superglue-bert.conf b/config/superglue-bert.conf index 40061d6df..0116261e1 100644 --- a/config/superglue-bert.conf +++ b/config/superglue-bert.conf @@ -42,3 +42,6 @@ do_full_eval = 1 write_preds = "val,test" write_strict_glue_format = 1 +// For WSC +classifier_loss_fn = "softmax" +classifier_span_pooling = "attn" From 6b46b8d199e4abc651bc102e6db3e41a30ab585f Mon Sep 17 00:00:00 2001 From: Yada Pruksachatkun Date: Thu, 23 May 2019 14:59:11 +0700 Subject: [PATCH 2/3] Moving WSC span realignment to loading time (#690) * renamed elmo skip embedding variable * Revert "renamed elmo skip embedding variable" This reverts commit 24ff66e8aeced1f2627919a86f85780de0e1ca66. * making span realignment part of the loading flow * adding clasifier_fn and span pooling to config * reverting unecessary changes * revert * revert * revert reloads * nits * black styling, and changing tests * updated comment * adding to comments * got rid of unecessary space * adding load_data to winograd * black formatting * fix test * remove nit * nit --- config/superglue-bert.conf | 1 - src/tasks/tasks.py | 62 ++++++++++----------- src/trainer.py | 2 +- src/utils/data_loaders.py | 28 ++++++++++ src/utils/retokenize.py | 73 +++++++++++++++++++++++++ tests/test_preprocess_winograd.py | 89 ++++++++++++++++++------------- 6 files changed, 182 insertions(+), 73 deletions(-) diff --git a/config/superglue-bert.conf b/config/superglue-bert.conf index 0116261e1..a3b4db372 100644 --- a/config/superglue-bert.conf +++ b/config/superglue-bert.conf @@ -8,7 +8,6 @@ exp_name = "bert-large-cased" max_seq_len = 256 // Mainly needed for MultiRC, to avoid over-truncating // But not 512 as that is really hard to fit in memory. tokenizer = "bert-large-cased" - // Model settings bert_model_name = "bert-large-cased" bert_embeddings_mode = "top" diff --git a/src/tasks/tasks.py b/src/tasks/tasks.py index 80c006d54..4bdf9368c 100644 --- a/src/tasks/tasks.py +++ b/src/tasks/tasks.py @@ -27,7 +27,13 @@ from ..allennlp_mods.correlation import Correlation from ..allennlp_mods.numeric_field import NumericField from ..utils import utils -from ..utils.data_loaders import get_tag_list, load_diagnostic_tsv, load_tsv, process_sentence +from ..utils.data_loaders import ( + get_tag_list, + load_diagnostic_tsv, + load_span_data, + load_tsv, + process_sentence, +) from ..utils.tokenizers import get_tokenizer from .registry import register_task # global task registry @@ -1995,24 +2001,13 @@ def load_data(self): class SpanClassificationTask(Task): """ - Generic class for span tasks. + Generic class for span tasks. Acts as a classifier, but with multiple targets for each input text. Targets are of the form (span1, span2,..., span_n, label), where the spans are half-open token intervals [i, j). The number of spans is constant across examples. """ - @property - def _tokenizer_suffix(self): - """" - Suffix to make sure we use the correct source files, - based on the given tokenizer. - """ - if self.tokenizer_name: - return ".retokenized." + self.tokenizer_name - else: - return "" - def tokenizer_is_supported(self, tokenizer_name): """ Check if the tokenizer is supported for this task. """ # Assume all tokenizers supported; if retokenized data not found @@ -2049,8 +2044,7 @@ def __init__( assert label_file is not None assert files_by_split is not None self._files_by_split = { - split: os.path.join(path, fname) + self._tokenizer_suffix - for split, fname in files_by_split.items() + split: os.path.join(path, fname) for split, fname in files_by_split.items() } self.num_spans = num_spans self.max_seq_len = max_seq_len @@ -2089,15 +2083,6 @@ def _stream_records(self, filename): filename, ) - def load_data(self): - iters_by_split = collections.OrderedDict() - for split, filename in self._files_by_split.items(): - iter = list(self._stream_records(filename)) - iters_by_split[split] = iter - self._iters_by_split = iters_by_split - self.all_labels = list(utils.load_lines(self.label_file)) - self.n_classes = len(self.all_labels) - def get_split_text(self, split: str): """ Get split text as iterable of records. @@ -2139,19 +2124,15 @@ def make_instance(self, record, idx, indexers) -> Type[Instance]: for i in range(self.num_spans): example["span" + str(i + 1) + "s"] = ListField( - [ - self._make_span_field(t["span" + str(i + 1)], text_field, 1) - for t in record["targets"] - ] + [self._make_span_field(record["target"]["span" + str(i + 1)], text_field, 1)] ) - - labels = [utils.wrap_singleton_string(t["label"]) for t in record["targets"]] example["labels"] = ListField( [ MultiLabelField( - label_set, label_namespace=self._label_namespace, skip_indexing=False + [str(record["label"])], + label_namespace=self._label_namespace, + skip_indexing=False, ) - for label_set in labels ] ) return Instance(example) @@ -2536,14 +2517,29 @@ def __init__(self, path, **kw): self._files_by_split = { "train": "train.jsonl", "val": "val.jsonl", - "test": "test_with_labels.jsonl", + "test": "test.jsonl", } self.num_spans = 2 super().__init__( files_by_split=self._files_by_split, label_file="labels.txt", path=path, **kw ) + self.n_classes = 2 self.val_metric = "%s_acc" % self.name + def load_data(self): + iters_by_split = collections.OrderedDict() + for split, filename in self._files_by_split.items(): + if filename.endswith("test.jsonl"): + iters_by_split[split] = load_span_data( + self.tokenizer_name, filename, has_labels=False + ) + else: + iters_by_split[split] = load_span_data(self.tokenizer_name, filename) + self._iters_by_split = iters_by_split + + def get_all_labels(self): + return ["True", "False"] + def update_metrics(self, logits, labels, tagmask=None): logits, labels = logits.detach(), labels.detach() diff --git a/src/trainer.py b/src/trainer.py index 50686b116..b8cd23cc2 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -1080,7 +1080,7 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best_macro=Fals training_state, os.path.join( self._serialization_dir, - "pretraining_state_{}_epoch_{}{}.th".format(phase, epoch, best_str), + "metric_state_{}_epoch_{}{}.th".format(phase, epoch, best_str), ), ) diff --git a/src/utils/data_loaders.py b/src/utils/data_loaders.py index 964390237..56a53bc38 100644 --- a/src/utils/data_loaders.py +++ b/src/utils/data_loaders.py @@ -11,11 +11,39 @@ from allennlp.data import vocabulary from .tokenizers import get_tokenizer +from .retokenize import realign_spans BERT_CLS_TOK, BERT_SEP_TOK = "[CLS]", "[SEP]" SOS_TOK, EOS_TOK = "", "" +def load_span_data(tokenizer_name, file_name, label_fn=None, has_labels=True): + """ + Load a span-related task file in .jsonl format, does re-alignment of spans, and tokenizes the text. + Re-alignment of spans involves transforming the spans so that it matches the text after + tokenization. + For example, given the original text: [Mr., Porter, is, nice] and bert-base-cased tokenization, we get + [Mr, ., Por, ter, is, nice ]. If the original span indices was [0,2], under the new tokenization, + it becomes [0, 3]. + The task file should of be of the following form: + text: str, + label: bool + target: dict that contains the spans + Args: + tokenizer_name: str, + file_name: str, + label_fn: function that expects a row and outputs a transformed row with labels tarnsformed. + Returns: + List of dictionaries of the aligned spans and tokenized text. + """ + rows = pd.read_json(file_name, lines=True) + # realign spans + rows = rows.apply(lambda x: realign_spans(x, tokenizer_name), axis=1) + if has_labels is False: + rows["label"] = False + return list(rows.T.to_dict().values()) + + def load_tsv( tokenizer_name, data_file, diff --git a/src/utils/retokenize.py b/src/utils/retokenize.py index 229a8f1e2..835617c59 100644 --- a/src/utils/retokenize.py +++ b/src/utils/retokenize.py @@ -94,6 +94,79 @@ def _mat_from_spans_sparse(spans: Sequence[Tuple[int, int]], n_chars: int) -> Ma return sparse.csr_matrix((data, (ridxs, cidxs)), shape=(len(spans), n_chars)) +def realign_spans(record, tokenizer_name): + """ + Builds the indices alignment while also tokenizing the input + piece by piece. + Only BERT and Moses tokenization is supported currently. + + Parameters + ----------------------- + record: dict with the below fields + text: str + targets: list of dictionaries + label: bool + span1_index: int, start index of first span + span1_text: str, text of first span + span2_index: int, start index of second span + span2_text: str, text of second span + tokenizer_name: str + + Returns + ------------------------ + record: dict with the below fields: + text: str in tokenized form + targets: dictionary with the below fields + -label: bool + -span_1: (int, int) of token indices + -span1_text: str, the string + -span2: (int, int) of token indices + -span2_text: str, the string + """ + + # find span indices and text + text = record["text"].split() + span1 = record["target"]["span1_index"] + span1_text = record["target"]["span1_text"] + span2 = record["target"]["span2_index"] + span2_text = record["target"]["span2_text"] + + # construct end spans given span text space-tokenized length + span1 = [span1, span1 + len(span1_text.strip().split())] + span2 = [span2, span2 + len(span2_text.strip().split())] + indices = [span1, span2] + + sorted_indices = sorted(indices, key=lambda x: x[0]) + current_tokenization = [] + span_mapping = {} + + # align first span to tokenized text + aligner_fn = get_aligner_fn(tokenizer_name) + _, new_tokens = aligner_fn(" ".join(text[: sorted_indices[0][0]])) + current_tokenization.extend(new_tokens) + new_span1start = len(current_tokenization) + _, span_tokens = aligner_fn(" ".join(text[sorted_indices[0][0] : sorted_indices[0][1]])) + current_tokenization.extend(span_tokens) + new_span1end = len(current_tokenization) + span_mapping[sorted_indices[0][0]] = [new_span1start, new_span1end] + + # re-indexing second span + _, new_tokens = aligner_fn(" ".join(text[sorted_indices[0][1] : sorted_indices[1][0]])) + current_tokenization.extend(new_tokens) + new_span2start = len(current_tokenization) + _, span_tokens = aligner_fn(" ".join(text[sorted_indices[1][0] : sorted_indices[1][1]])) + current_tokenization.extend(span_tokens) + new_span2end = len(current_tokenization) + span_mapping[sorted_indices[1][0]] = [new_span2start, new_span2end] + + # save back into record + _, all_text = aligner_fn(" ".join(text)) + record["target"]["span1"] = span_mapping[record["target"]["span1_index"]] + record["target"]["span2"] = span_mapping[record["target"]["span2_index"]] + record["text"] = " ".join(all_text) + return record + + class TokenAligner(object): """Align two similiar tokenizations. diff --git a/tests/test_preprocess_winograd.py b/tests/test_preprocess_winograd.py index 7c686108e..06fb7e815 100644 --- a/tests/test_preprocess_winograd.py +++ b/tests/test_preprocess_winograd.py @@ -4,7 +4,7 @@ import shutil import tempfile import unittest -import scripts.winograd.preprocess_winograd as preprocess_winograd +import src.utils.retokenize as retokenize import json import copy @@ -24,15 +24,13 @@ def setUp(self): json.dumps( { "text": "Members of the House clapped their hands", - "targets": [ - { - "span1_index": 0, - "span1_text": "members", - "span2_index": 5, - "span2_text": "their", - "label": True, - } - ], + "target": { + "span1_index": 0, + "span1_text": "members", + "span2_index": 5, + "span2_text": "their", + "label": True, + }, } ) ) @@ -42,15 +40,13 @@ def setUp(self): json.dumps( { "text": "Mr. Ford told me to tell you to contact him", - "targets": [ - { - "span1_index": 0, - "span1_text": "Mr. Ford", - "span2_index": 9, - "span2_text": "him", - "label": True, - } - ], + "target": { + "span1_index": 0, + "span1_text": "Mr. Ford", + "span2_index": 9, + "span2_text": "him", + "label": True, + }, } ) ) @@ -60,48 +56,65 @@ def setUp(self): json.dumps( { "text": "I told you already, Mr. Ford!", - "targets": [ - { - "span1_index": 4, - "span1_text": "Mr. Ford", - "span2_index": 0, - "span2_text": "I", - "label": False, - } - ], + "target": { + "span1_index": 4, + "span1_text": "Mr. Ford", + "span2_index": 0, + "span2_text": "I", + "label": False, + }, } ) ) jsonfile.write("\n") + jsonfile.write( + json.dumps( + { + "text": "I look at Sarah's dog. It was cute.!", + "target": { + "span1_index": 3, + "span1_text": "Sarah's dog.", + "span2_index": 0, + "span2_text": "I", + "label": False, + }, + } + ) + ) def test_bert(self): records = list(pd.read_json(self.path, lines=True).T.to_dict().values()) orig_records = copy.deepcopy(records) - for rec in records: - preprocess_winograd.realign_spans(rec, "bert-large-cased") - print(records[0]) - print(orig_records[0]) + for rec in records[:-1]: + retokenize.realign_spans(rec, "bert-large-cased") + retokenize.realign_spans(records[-1], "MosesTokenizer") assert records[0]["text"] == orig_records[0]["text"] # the two below should be changed by tokenization assert records[1]["text"] != orig_records[1]["text"] assert records[2]["text"] != orig_records[2]["text"] - result_span1 = records[0]["targets"][0]["span1"] - result_span2 = records[0]["targets"][0]["span2"] + result_span1 = records[0]["target"]["span1"] + result_span2 = records[0]["target"]["span2"] assert result_span1 == [0, 1] assert result_span2 == [5, 6] - result_span1 = records[1]["targets"][0]["span1"] - result_span2 = records[1]["targets"][0]["span2"] + result_span1 = records[1]["target"]["span1"] + result_span2 = records[1]["target"]["span2"] assert result_span1 == [0, 3] assert result_span2 == [10, 11] - result_span1 = records[2]["targets"][0]["span1"] - result_span2 = records[2]["targets"][0]["span2"] + result_span1 = records[2]["target"]["span1"] + result_span2 = records[2]["target"]["span2"] assert result_span1 == [5, 9] assert result_span2 == [0, 1] + result_span1 = records[3]["target"]["span1"] + result_span2 = records[3]["target"]["span2"] + + assert result_span1 == [3, 7] + assert result_span2 == [0, 1] + def tearDown(self): shutil.rmtree(self.temp_dir) From b43fba6bce2fafb16a62ddd576f4898adf5d3428 Mon Sep 17 00:00:00 2001 From: Yada Pruksachatkun Date: Thu, 23 May 2019 15:56:31 +0700 Subject: [PATCH 3/3] Black formatting (#693) * renamed elmo skip embedding variable * Revert "renamed elmo skip embedding variable" This reverts commit 24ff66e8aeced1f2627919a86f85780de0e1ca66. * making span realignment part of the loading flow * adding clasifier_fn and span pooling to config * reverting unecessary changes * revert * revert * revert reloads * nits * black styling, and changing tests * updated comment * adding to comments * got rid of unecessary space * adding load_data to winograd * black formatting * fix test * remove nit * nit * black reformatting * black --- src/tasks/tasks.py | 6 +----- tests/test_preprocess_winograd.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/tasks/tasks.py b/src/tasks/tasks.py index 4bdf9368c..e6dc83291 100644 --- a/src/tasks/tasks.py +++ b/src/tasks/tasks.py @@ -2514,11 +2514,7 @@ def get_metrics(self, reset=False): @register_task("winograd-coreference", rel_path="winograd-coref") class WinogradCoreferenceTask(SpanClassificationTask): def __init__(self, path, **kw): - self._files_by_split = { - "train": "train.jsonl", - "val": "val.jsonl", - "test": "test.jsonl", - } + self._files_by_split = {"train": "train.jsonl", "val": "val.jsonl", "test": "test.jsonl"} self.num_spans = 2 super().__init__( files_by_split=self._files_by_split, label_file="labels.txt", path=path, **kw diff --git a/tests/test_preprocess_winograd.py b/tests/test_preprocess_winograd.py index 06fb7e815..9c8e402b2 100644 --- a/tests/test_preprocess_winograd.py +++ b/tests/test_preprocess_winograd.py @@ -112,7 +112,7 @@ def test_bert(self): result_span1 = records[3]["target"]["span1"] result_span2 = records[3]["target"]["span2"] - + assert result_span1 == [3, 7] assert result_span2 == [0, 1]