Skip to content

Commit d19c792

Browse files
authored
Implement beam search for CopyNet (#18)
1 parent 399d27c commit d19c792

File tree

17 files changed

+800
-132
lines changed

17 files changed

+800
-132
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,7 @@ venv.bak/
102102

103103
# mypy
104104
.mypy_cache/
105+
106+
# scratch files
107+
scratch*
108+
tmp*

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
debug = 0
12
test = modules
23
COVERAGE := $(addprefix --cov=, $(test))
34
PYTHONPATH = allennlp
@@ -12,7 +13,11 @@ EXPERIMENTS := $(wildcard $(EXPERIMENTDIR)/**/*.json)
1213

1314
.PHONY : train
1415
train :
16+
ifeq ($(debug),0)
1517
./scripts/train.sh
18+
else
19+
CUDA_LAUNCH_BLOCKING=1 ./scripts/train.sh
20+
endif
1621

1722
# Need this to force targets to build, even when the target file exists.
1823
.PHONY : phony-target
@@ -47,7 +52,7 @@ lint :
4752
.PHONY : unit-test
4853
unit-test :
4954
@echo "Unit tests: pytest"
50-
ifeq ($(suffix $(test)),.py)
55+
ifneq ($(findstring test,$(test)),)
5156
PYTHONPATH=$(PYTHONPATH) python -m pytest -v --color=yes $(test)
5257
else
5358
PYTHONPATH=$(PYTHONPATH) python -m pytest -v --cov-config .coveragerc $(COVERAGE) --color=yes $(test)

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@ After AllenNLP is installed, you can define your own experiments with an AllenNL
2020

2121
## Models implemented
2222

23-
- (WIP) [CopyNet](https://arxiv.org/abs/1603.06393): A sequence-to-sequence model that incorporates a copying mechanism, which enables the model to copy tokens from the source sentence into the target sentence even if they are not part of the target vocabulary. This architecture has shown promising results on machine translation and semantic parsing tasks.
23+
- [CopyNet](https://arxiv.org/abs/1603.06393): A sequence-to-sequence model that incorporates a copying mechanism, which enables the model to copy tokens from the source sentence into the target sentence even if they are not part of the target vocabulary. This architecture has shown promising results on machine translation and semantic parsing tasks.
2424

2525
## Datasets
2626

27+
- Greetings: A simple made-up dataset of greetings (the source sentences) and replies (the target sentences). The greetings are things like "Hi, my name is Jon Snow" and the replies are in the format "Nice to meet you, Jon Snow!". This is completely artificial and is just meant to show the usefullness of the copy mechanism in CopyNet.
2728
- [NL2Bash](http://arxiv.org/abs/1802.08979): A challenging dataset that consists of bash one-liners along with corresponding expert descriptions. The goal is to translate the natural language descriptions into the bash commands.
2829

2930
## Experiments
3031

32+
- Greetings dataset with CopyNet: run `make experiments/greetings/copynet.json` to train.
3133
- (WIP) [NL2Bash with CopyNet](./experiments/nl2bash/copynet.json): run `make experiments/nl2bash/copynet.json` to train.
3234

3335
## TODO
3436

35-
- Implement beam search for CopyNet
37+
- Implement custom metrics for NL2Bash.

data/greetings.tar.gz

62.1 KB
Binary file not shown.

data/names.tar.gz

254 KB
Binary file not shown.

experiments/greetings/copynet.json

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
{
2+
"dataset_reader": {
3+
"target_namespace": "target_tokens",
4+
"type": "copynet",
5+
"source_token_indexers": {
6+
"tokens": {
7+
"type": "single_id",
8+
"namespace": "source_tokens"
9+
},
10+
"token_characters": {
11+
"type": "characters"
12+
}
13+
},
14+
"target_token_indexers": {
15+
"tokens": {
16+
"namespace": "target_tokens"
17+
}
18+
}
19+
},
20+
"vocabulary": {
21+
"min_count": {
22+
"source_tokens": 4,
23+
"target_tokens": 4
24+
},
25+
"tokens_to_add": {
26+
"target_tokens": ["@COPY@"]
27+
}
28+
},
29+
"train_data_path": "data/greetings/train.tsv",
30+
"validation_data_path": "data/greetings/validation.tsv",
31+
"model": {
32+
"type": "copynet",
33+
"source_embedder": {
34+
"tokens": {
35+
"type": "embedding",
36+
"vocab_namespace": "source_tokens",
37+
"embedding_dim": 25,
38+
"trainable": true
39+
},
40+
"token_characters": {
41+
"type": "character_encoding",
42+
"embedding": {
43+
"embedding_dim": 10
44+
},
45+
"encoder": {
46+
"type": "lstm",
47+
"input_size": 10,
48+
"hidden_size": 10,
49+
"num_layers": 2,
50+
"dropout": 0,
51+
"bidirectional": true
52+
}
53+
}
54+
},
55+
"encoder": {
56+
"type": "lstm",
57+
"input_size": 45,
58+
"hidden_size": 100,
59+
"num_layers": 2,
60+
"dropout": 0,
61+
"bidirectional": true
62+
},
63+
"attention": {
64+
"type": "bilinear",
65+
"vector_dim": 200,
66+
"matrix_dim": 200
67+
},
68+
"target_embedding_dim": 10,
69+
"beam_size": 3,
70+
"max_decoding_steps": 20
71+
},
72+
"iterator": {
73+
"type": "bucket",
74+
"padding_noise": 0.0,
75+
"batch_size" : 32,
76+
"sorting_keys": [["source_tokens", "num_tokens"]]
77+
},
78+
"trainer": {
79+
"optimizer": {
80+
"type": "sgd",
81+
"lr": 0.015
82+
},
83+
"learning_rate_scheduler": {
84+
"type": "cosine",
85+
"t_initial": 5,
86+
"t_mul": 1.5,
87+
"eta_mul": 0.9
88+
},
89+
"num_epochs": 80,
90+
"cuda_device": 0,
91+
"should_log_learning_rate": true,
92+
"should_log_parameter_statistics": false
93+
}
94+
}

experiments/nl2bash/copynet.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"dataset_reader": {
3+
"target_namespace": "target_tokens",
34
"type": "nl2bash",
45
"source_token_indexers": {
56
"tokens": {
@@ -20,6 +21,9 @@
2021
"min_count": {
2122
"source_tokens": 4,
2223
"target_tokens": 4
24+
},
25+
"tokens_to_add": {
26+
"target_tokens": ["@COPY@"]
2327
}
2428
},
2529
"train_data_path": "data/nl2bash/train.tsv",
@@ -62,6 +66,8 @@
6266
"matrix_dim": 200
6367
},
6468
"target_embedding_dim": 10
69+
"beam_size": 5,
70+
"max_decoding_steps": 50
6571
},
6672
"iterator": {
6773
"type": "bucket",

modules/data/dataset_readers/copynet.py

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,23 @@
44
import numpy as np
55
from overrides import overrides
66

7+
from allennlp.common.checks import ConfigurationError
8+
from allennlp.common.file_utils import cached_path
79
from allennlp.common.util import START_SYMBOL, END_SYMBOL
810
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
9-
from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader
1011
from allennlp.data.fields import TextField, ArrayField
1112
from allennlp.data.instance import Instance
12-
from allennlp.data.tokenizers import Token, Tokenizer
13-
from allennlp.data.token_indexers import TokenIndexer
13+
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
14+
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
15+
16+
from modules.data.fields import CopyMapField
1417

1518

1619
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
1720

1821

1922
@DatasetReader.register("copynet")
20-
class CopyNetDatasetReader(Seq2SeqDatasetReader):
23+
class CopyNetDatasetReader(DatasetReader):
2124
"""
2225
Read a tsv file containing paired sequences, and create a dataset suitable for a
2326
``CopyNet`` model, or any model with a matching API.
@@ -28,6 +31,9 @@ class CopyNetDatasetReader(Seq2SeqDatasetReader):
2831
2932
Parameters
3033
----------
34+
target_namespace : ``str``, required
35+
The vocab namespace for the targets. This needs to be passed to the dataset reader
36+
in order to construct the CopyMapField.
3137
source_tokenizer : ``Tokenizer``, optional
3238
Tokenizer to use to split the input sequences into words or other kinds of tokens. Defaults
3339
to ``WordTokenizer()``.
@@ -43,20 +49,32 @@ class CopyNetDatasetReader(Seq2SeqDatasetReader):
4349
"""
4450

4551
def __init__(self,
52+
target_namespace: str,
4653
source_tokenizer: Tokenizer = None,
4754
target_tokenizer: Tokenizer = None,
4855
source_token_indexers: Dict[str, TokenIndexer] = None,
4956
target_token_indexers: Dict[str, TokenIndexer] = None,
5057
lazy: bool = False) -> None:
51-
# The only reason we override __init__ is so that we can ensure `source_add_start_token`
52-
# is True. This is because the CopyNet model always assumes the start token
53-
# will be part of the source sentence.
54-
super().__init__(source_tokenizer=source_tokenizer,
55-
target_tokenizer=target_tokenizer,
56-
source_token_indexers=source_token_indexers,
57-
target_token_indexers=target_token_indexers,
58-
source_add_start_token=True,
59-
lazy=lazy)
58+
super().__init__(lazy)
59+
self._target_namespace = target_namespace
60+
self._source_tokenizer = source_tokenizer or WordTokenizer()
61+
self._target_tokenizer = target_tokenizer or self._source_tokenizer
62+
self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()}
63+
self._target_token_indexers = target_token_indexers or self._source_token_indexers
64+
65+
@overrides
66+
def _read(self, file_path):
67+
with open(cached_path(file_path), "r") as data_file:
68+
logger.info("Reading instances from lines in file at: %s", file_path)
69+
for line_num, line in enumerate(data_file):
70+
line = line.strip("\n")
71+
if not line:
72+
continue
73+
line_parts = line.split('\t')
74+
if len(line_parts) != 2:
75+
raise ConfigurationError("Invalid line format: %s (line number %d)" % (line, line_num + 1))
76+
source_sequence, target_sequence = line_parts
77+
yield self.text_to_instance(source_sequence, target_sequence)
6078

6179
def _preprocess_source(self, source_string: str) -> str: # pylint: disable=no-self-use
6280
"""
@@ -72,6 +90,27 @@ def _preprocess_target(self, target_string: str) -> str: # pylint: disable=no-s
7290
"""
7391
return target_string
7492

93+
@staticmethod
94+
def _create_copy_indicator_array(tokenized_source: List[Token],
95+
tokenized_target: List[Token]) -> np.array:
96+
copy_indicator_array: List[List[int]] = []
97+
for target_token in tokenized_target[1:-1]:
98+
source_index_list: List[int] = [int(target_token.text.lower() == source_token.text.lower())
99+
for source_token in tokenized_source[1:-1]]
100+
copy_indicator_array.append(source_index_list)
101+
copy_indicator_array.insert(0, [0] * len(tokenized_source[1:-1]))
102+
copy_indicator_array.append([0] * len(tokenized_source[1:-1]))
103+
return np.array(copy_indicator_array)
104+
105+
@staticmethod
106+
def _create_source_duplicates_array(tokenized_source: List[Token]) -> np.array:
107+
out_array: List[List[int]] = []
108+
for token in tokenized_source[1:-1]:
109+
array_slice: List[int] = [int(token.text.lower() == other.text.lower())
110+
for other in tokenized_source[1:-1]]
111+
out_array.append(array_slice)
112+
return np.array(out_array)
113+
75114
@overrides
76115
def text_to_instance(self, source_string: str, target_string: str = None) -> Instance: # type: ignore
77116
# pylint: disable=arguments-differ
@@ -80,6 +119,24 @@ def text_to_instance(self, source_string: str, target_string: str = None) -> Ins
80119
tokenized_source.insert(0, Token(START_SYMBOL))
81120
tokenized_source.append(Token(END_SYMBOL))
82121
source_field = TextField(tokenized_source, self._source_token_indexers)
122+
123+
# For token in the source sentence, we store a sparse array containing
124+
# indicators for each other source token that matches. This gives us
125+
# a matrix of shape `(source_length, source_length)` where the (i,j)th entry
126+
# is a 1 if the ith token matches the jth token.
127+
source_duplicates_array = self._create_source_duplicates_array(tokenized_source)
128+
source_duplicates_field = ArrayField(source_duplicates_array)
129+
130+
# For each token in the source sentence, we keep track of the matching token
131+
# in the target sentence (which will be the OOV symbol if there is no match).
132+
target_pointer_field = CopyMapField(tokenized_source[1:-1], self._target_namespace)
133+
134+
fields_dict = {
135+
"source_tokens": source_field,
136+
"source_duplicates": source_duplicates_field,
137+
"target_pointers": target_pointer_field,
138+
}
139+
83140
if target_string is not None:
84141
target_string = self._preprocess_target(target_string)
85142
tokenized_target = self._target_tokenizer.tokenize(target_string)
@@ -89,22 +146,12 @@ def text_to_instance(self, source_string: str, target_string: str = None) -> Ins
89146

90147
# For each token in the target sentence, we keep track of the index
91148
# of every token in the source sentence that matches.
92-
source_index_array: List[List[int]] = []
93-
for tgt_tok in tokenized_target[1:-1]:
94-
source_index_list: List[int] = []
95-
for src_tok in tokenized_source[1:-1]:
96-
if tgt_tok.text == src_tok.text:
97-
source_index_list.append(1)
98-
else:
99-
source_index_list.append(0)
100-
source_index_array.append(source_index_list)
101-
source_index_array.insert(0, [0] * len(tokenized_source[1:-1]))
102-
source_index_array.append([0] * len(tokenized_source[1:-1]))
103-
source_index_field = ArrayField(np.array(source_index_array))
149+
copy_indicator_array = self._create_copy_indicator_array(tokenized_source,
150+
tokenized_target)
104151
# shape: (target_length, source_length)
152+
copy_indicator_field = ArrayField(copy_indicator_array)
153+
154+
fields_dict["target_tokens"] = target_field
155+
fields_dict["copy_indicators"] = copy_indicator_field
105156

106-
return Instance({"source_tokens": source_field,
107-
"target_tokens": target_field,
108-
"source_indices": source_index_field})
109-
else:
110-
return Instance({'source_tokens': source_field})
157+
return Instance(fields_dict)

modules/data/dataset_readers/nl2bash.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,16 @@ class NL2BashDatasetReader(CopyNetDatasetReader):
9595
prompt_finder = re.compile(r"^(\$|#)\s?")
9696

9797
def __init__(self,
98+
target_namespace: str,
9899
source_tokenizer: Tokenizer = None,
99100
target_tokenizer: Tokenizer = None,
100101
source_token_indexers: Dict[str, TokenIndexer] = None,
101102
target_token_indexers: Dict[str, TokenIndexer] = None,
102103
lazy: bool = False) -> None:
103104
source_tokenizer = source_tokenizer or WordTokenizer(word_splitter=NL2BashWordSplitter())
104105
target_tokenizer = target_tokenizer or source_tokenizer
105-
super().__init__(source_tokenizer=source_tokenizer,
106+
super().__init__(target_namespace,
107+
source_tokenizer=source_tokenizer,
106108
target_tokenizer=target_tokenizer,
107109
source_token_indexers=source_token_indexers,
108110
target_token_indexers=target_token_indexers,

modules/data/fields/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from modules.data.fields.copy_map_field import CopyMapField

modules/data/fields/copy_map_field.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Dict, List, Optional
2+
3+
from overrides import overrides
4+
import torch
5+
6+
from allennlp.common.util import pad_sequence_to_length
7+
from allennlp.data import Vocabulary
8+
from allennlp.data.tokenizers import Token
9+
from allennlp.data import Field
10+
11+
12+
class CopyMapField(Field[torch.Tensor]):
13+
14+
def __init__(self,
15+
source_tokens: List[Token],
16+
target_namespace: str) -> None:
17+
self._source_tokens = source_tokens
18+
self._target_namespace = target_namespace
19+
self._mapping_array: Optional[List[List[int]]] = None
20+
21+
@overrides
22+
def index(self, vocab: Vocabulary):
23+
self._mapping_array = [vocab.get_token_index(x.text, self._target_namespace)
24+
for x in self._source_tokens]
25+
26+
@overrides
27+
def get_padding_lengths(self) -> Dict[str, int]:
28+
return {"num_tokens": len(self._source_tokens)}
29+
30+
@overrides
31+
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
32+
desired_length = padding_lengths["num_tokens"]
33+
padded_tokens = pad_sequence_to_length(self._mapping_array, desired_length)
34+
tensor = torch.LongTensor(padded_tokens)
35+
return tensor
36+
37+
@overrides
38+
def empty_field(self) -> 'CopyMapField':
39+
return CopyMapField([], self._target_namespace)

0 commit comments

Comments
 (0)