-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgenerate_words.py
68 lines (56 loc) · 2.23 KB
/
generate_words.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os, sys
import collections
import numpy as np
import re
data_folder = 'words_data'
train_x = os.path.join(data_folder, 'train_x.txt')
train_y = os.path.join(data_folder, 'train_y.txt')
val_x = os.path.join(data_folder, 'val_x.txt')
val_y = os.path.join(data_folder, 'val_y.txt')
SAMPLE_SIZE = 80
VOCAB_SIZE = 10
TOP = 5
stopwords = set(open('stop_words.txt').read().split(','))
all_words = re.findall('[a-z]{2,}', open('pride-and-prejudice.txt').read().lower())
words = list(set([w for w in all_words if w not in stopwords]))
def generate_pair(with_counts):
# Grab a slice of the input file of size SAMPLE_SIZE
index = np.random.randint(0, len(all_words) - SAMPLE_SIZE)
querytmp = all_words[index:index+SAMPLE_SIZE]
# Replace unknown words with known ones
query = querytmp
for i, w in enumerate(querytmp):
if w not in words[:VOCAB_SIZE] and query[i] == w:
# Replace ALL occurrences in query with the same replacement word
other = words[np.random.randint(0, VOCAB_SIZE/2)]
query = [other if v == w else v for v in query]
counts = collections.Counter(query)
top = counts.most_common()
if not with_counts:
ans = list(list(zip(*top))[0])
else:
ans = [t[0] + " " + str(t[1]) for t in top]
return query, ans
def generate_data(ntrain, nval, vocab_size, with_counts):
if not os.path.exists(data_folder):
os.makedirs(data_folder)
global VOCAB_SIZE
VOCAB_SIZE = vocab_size
with open(train_x, 'w') as fx, open(train_y, 'w') as fy:
for _ in range(0, ntrain):
query, ans = generate_pair(with_counts)
fx.write(' '.join(query) + '\n')
fy.write(','.join(ans) + '\n')
with open(val_x, 'w') as fx, open(val_y, 'w') as fy:
for _ in range(0, nval):
query, ans = generate_pair(with_counts)
fx.write(' '.join(query) + '\n')
fy.write(','.join(ans) + '\n')
def main():
# [1]: number of samples in training set
# [2]: number of samples in validation set
# [3]: vocabulary size
# [4]: output with (1) or without (0) counts
generate_data(int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]), bool(int(sys.argv[4])))
if __name__ == "__main__":
main()