Skip to content

Commit fbad9af

Browse files
committed
BERT data preparation script
1 parent 63e19aa commit fbad9af

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

neuralmonkey/dataset.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ def itergen():
389389
for s_name, (preprocessor, source) in prep_sl.items():
390390
if source not in iterators:
391391
raise ValueError(
392-
"Source series {} for series-level preprocessor nonexistent: "
393-
"Preprocessed series '', source series ''".format(source))
392+
"Source series for series-level preprocessor nonexistent: "
393+
"Preprocessed series '{}', source series '{}'")
394394
iterators[s_name] = _make_sl_iterator(source, preprocessor)
395395

396396
# Finally, dataset-level preprocessors.
@@ -443,6 +443,8 @@ def __init__(self,
443443
Arguments:
444444
name: The name for the dataset.
445445
iterators: A series-iterator generator mapping.
446+
lazy: If False, load the data from iterators to a list and store
447+
the list in memory.
446448
buffer_size: Use this tuple as a minimum and maximum buffer size
447449
for pre-loading data. This should be (a few times) larger than
448450
the batch size used for mini-batching. When the buffer size
@@ -638,7 +640,9 @@ def itergen():
638640
buf.append(item)
639641

640642
if self.shuffled:
641-
random.shuffle(buf) # type: ignore
643+
lbuf = list(buf)
644+
random.shuffle(lbuf)
645+
buf = deque(lbuf)
642646

643647
if not self.batching.drop_remainder:
644648
for bucket in buckets:

scripts/preprocess_bert.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/usr/bin/env python3
2+
# Creates training data for the BERT network training
3+
# (noisified + masked gold predictions) using the input corpus
4+
# TODO: add support for other NM vocabularies (aside from t2t)
5+
6+
import argparse
7+
import os
8+
9+
import numpy as np
10+
11+
from neuralmonkey.logging import log as _log
12+
from neuralmonkey.vocabulary import (
13+
Vocabulary, PAD_TOKEN, UNK_TOKEN, from_wordlist)
14+
15+
16+
def log(message: str, color: str = "blue") -> None:
17+
_log(message, color)
18+
19+
20+
def main() -> None:
21+
parser = argparse.ArgumentParser(description=__doc__)
22+
parser.add_argument("--input_file", type=str, default="/dev/stdin")
23+
parser.add_argument("--vocabulary", type=str, required=True)
24+
parser.add_argument("--output_prefix", type=str, default=None)
25+
parser.add_argument("--mask_token", type=str, default=UNK_TOKEN)
26+
parser.add_argument("--coverage", type=float, default=0.15)
27+
parser.add_argument("--mask_prob", type=float, default=0.8)
28+
parser.add_argument("--replace_prob", type=float, default=0.1)
29+
parser.add_argument("--vocab_contains_header", type=bool, default=True)
30+
parser.add_argument("--vocab_contains_frequencies",
31+
type=bool, default=True)
32+
args = parser.parse_args()
33+
34+
assert (args.coverage <= 1 and args.coverage >= 0)
35+
assert (args.mask_prob <= 1 and args.mask_prob >= 0)
36+
assert (args.replace_prob <= 1 and args.replace_prob >= 0)
37+
38+
log("Loading vocabulary.")
39+
vocabulary = from_wordlist(
40+
args.vocabulary,
41+
contains_header=args.vocab_contains_header,
42+
contains_frequencies=args.vocab_contains_freqeuencies)
43+
44+
# Tuple[keep_prob
45+
mask_prob = args.mask_prob
46+
replace_prob = args.replace_prob
47+
keep_prob = 1 - mask_prob - replace_prob
48+
sample_probs = (keep_prob, mask_prob, replace_prob)
49+
50+
output_prefix = args.output_prefix
51+
if output_prefix is None:
52+
output_prefix = args.input_file
53+
out_f_noise = "{}.noisy".format(output_prefix)
54+
out_f_mask = "{}.mask".format(output_prefix)
55+
56+
out_noise_h = open(out_f_noise, "w", encoding="utf-8")
57+
out_mask_h = open(out_f_mask, "w", encoding="utf-8")
58+
log("Processing data.")
59+
with open(args.input_file, "r", encoding="utf-8") as input_h:
60+
# TODO: performance optimizations
61+
for line in input_h:
62+
line = line.strip().split(" ")
63+
num_samples = int(args.coverage * len(line))
64+
sampled_indices = np.random.choice(len(line), num_samples, False)
65+
66+
output_noisy = list(line)
67+
output_masked = [PAD_TOKEN] * len(line)
68+
for i in sampled_indices:
69+
random_token = np.random.choice(vocabulary.index_to_word[4:])
70+
new_token = np.random.choice(
71+
[line[i], args.mask_token, random_token], p=sample_probs)
72+
output_noisy[i] = new_token
73+
output_masked[i] = line[i]
74+
out_noise_h.write(str(" ".join(output_noisy)) + "\n")
75+
out_mask_h.write(str(" ".join(output_masked)) + "\n")
76+
77+
78+
if __name__ == "__main__":
79+
main()

0 commit comments

Comments
 (0)