Skip to content

Commit 6a71fe2

Browse files
committed
reinit commit
1 parent a7e4c9f commit 6a71fe2

37 files changed

+13952
-0
lines changed

data/README.md

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Data Folder
2+
3+
There are several files you will need. you can download them in zip format
4+
from [here](https://drive.google.com/file/d/0B2hg7DTHpfLsZW44aTRVd2FrbEE/view?usp=sharing).
5+
6+
7+
The files you need are:
8+
9+
- surnames.csv
10+
- trump.csv
11+
- glove.6B.100d.txt
12+
- zhnews.csv
13+
- firstnames.csv
14+
- amazon_train_small.csv
15+

datautils/__init__.py

Whitespace-only changes.
107 Bytes
Binary file not shown.
107 Bytes
Binary file not shown.
8.81 KB
Binary file not shown.
8.18 KB
Binary file not shown.

datautils/vocabulary.py

+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
from collections import Counter
2+
3+
import numpy as np
4+
from torch.utils.data import Dataset
5+
import six
6+
7+
import json
8+
9+
10+
class Vocabulary(object):
11+
"""
12+
An implementation that manages the interface between a token dataset and the
13+
machine learning algorithm.
14+
"""
15+
16+
def __init__(self, use_unks=False, unk_token="<UNK>",
17+
use_mask=False, mask_token="<MASK>", use_start_end=False,
18+
start_token="<START>", end_token="<END>"):
19+
"""
20+
Args:
21+
use_unks (bool): The vocabulary will output UNK tokens for out of
22+
vocabulary items.
23+
[default=False]
24+
unk_token (str): The token used for unknown tokens.
25+
If `use_unks` is True, this will be added to the vocabulary.
26+
[default='<UNK>']
27+
use_mask (bool): The vocabulary will reserve the 0th index for a mask token.
28+
This is used to handle variable lengths in sequence models.
29+
[default=False]
30+
mask_token (str): The token used for the mask.
31+
Note: mostly a placeholder; it's unlikely the token will be seen.
32+
[default='<MASK>']
33+
use_start_end (bool): The vocabulary will reserve indices for two tokens
34+
that represent the start and end of a sequence.
35+
[default=False]
36+
start_token: The token used to indicate the start of a sequence.
37+
If `use_start_end` is True, this will be added to the vocabulary.
38+
[default='<START>']
39+
end_token: The token used to indicate the end of a sequence
40+
If `use_start_end` is True, this will be added to the vocabulary.
41+
[default='<END>']
42+
"""
43+
44+
self._mapping = {} # str -> int
45+
self._flip = {} # int -> str;
46+
self._counts = Counter() # int -> int; count occurrences
47+
self._forced_unks = set() # force tokens to unk (e.g. if < 5 occurrences)
48+
self._i = 0
49+
self._frozen = False
50+
self._frequency_threshold = -1
51+
52+
# mask token for use in masked recurrent networks
53+
# usually need to be the 0th index
54+
self.use_mask = use_mask
55+
self.mask_token = mask_token
56+
if self.use_mask:
57+
self.add(self.mask_token)
58+
59+
# unk token for out of vocabulary tokens
60+
self.use_unks = use_unks
61+
self.unk_token = unk_token
62+
if self.use_unks:
63+
self.add(self.unk_token)
64+
65+
# start token for sequence models
66+
self.use_start_end = use_start_end
67+
self.start_token = start_token
68+
self.end_token = end_token
69+
if self.use_start_end:
70+
self.add(self.start_token)
71+
self.add(self.end_token)
72+
73+
def iterkeys(self):
74+
for k in self._mapping.keys():
75+
if k == self.unk_token or k == self.mask_token:
76+
continue
77+
else:
78+
yield k
79+
80+
def keys(self):
81+
return list(self.iterkeys())
82+
83+
def iteritems(self):
84+
for key, value in self._mapping.items():
85+
if key == self.unk_token or key == self.mask_token:
86+
continue
87+
yield key, value
88+
89+
def items(self):
90+
return list(self.iteritems())
91+
92+
def values(self):
93+
return [value for _, value in self.iteritems()]
94+
95+
def __getitem__(self, k):
96+
if self._frozen:
97+
if k in self._mapping:
98+
out_index = self._mapping[k]
99+
elif self.use_unks:
100+
out_index = self.unk_index
101+
else: # case: frozen, don't want unks, raise exception
102+
raise VocabularyException("Vocabulary is frozen. " +
103+
"Key '{}' not found.".format(k))
104+
if out_index in self._forced_unks:
105+
out_index = self.unk_index
106+
elif k in self._mapping: # case: normal
107+
out_index = self._mapping[k]
108+
self._counts[out_index] += 1
109+
else:
110+
out_index = self._mapping[k] = self._i
111+
self._i += 1
112+
self._flip[out_index] = k
113+
self._counts[out_index] = 1
114+
115+
return out_index
116+
117+
def add(self, k):
118+
return self.__getitem__(k)
119+
120+
def add_many(self, x):
121+
return [self.add(k) for k in x]
122+
123+
def lookup(self, i):
124+
try:
125+
return self._flip[i]
126+
except KeyError:
127+
raise VocabularyException("Key {} not in Vocabulary".format(i))
128+
129+
def lookup_many(self, x):
130+
for k in x:
131+
yield self.lookup(k)
132+
133+
def map(self, sequence, include_start_end=False):
134+
if include_start_end:
135+
yield self.start_index
136+
137+
for item in sequence:
138+
yield self[item]
139+
140+
if include_start_end:
141+
yield self.end_index
142+
143+
def freeze(self, use_unks=False, frequency_cutoff=-1):
144+
self.use_unks = use_unks
145+
self._frequency_cutoff = frequency_cutoff
146+
147+
if use_unks and self.unk_token not in self:
148+
self.add(self.unk_token)
149+
150+
if self._frequency_cutoff > 0:
151+
for token, count in self._counts.items():
152+
if count < self._frequency_cutoff:
153+
self._forced_unks.add(token)
154+
155+
self._frozen = True
156+
157+
def unfreeze(self):
158+
self._frozen = False
159+
160+
def get_counts(self):
161+
return {self._flip[i]: count for i, count in self._counts.items()}
162+
163+
def get_count(self, token=None, index=None):
164+
if token is None and index is None:
165+
return None
166+
elif token is not None and index is not None:
167+
print("Cannot do two things at once; choose one")
168+
elif token is not None:
169+
return self._counts[self[token]]
170+
elif index is not None:
171+
return self._counts[index]
172+
else:
173+
raise Exception("impossible condition")
174+
175+
@property
176+
def unk_index(self):
177+
if self.unk_token not in self:
178+
return None
179+
return self._mapping[self.unk_token]
180+
181+
@property
182+
def mask_index(self):
183+
if self.mask_token not in self:
184+
return None
185+
return self._mapping[self.mask_token]
186+
187+
@property
188+
def start_index(self):
189+
if self.start_token not in self:
190+
return None
191+
return self._mapping[self.start_token]
192+
193+
@property
194+
def end_index(self):
195+
if self.end_token not in self:
196+
return None
197+
return self._mapping[self.end_token]
198+
199+
def __contains__(self, k):
200+
return k in self._mapping
201+
202+
def __len__(self):
203+
return len(self._mapping)
204+
205+
def __repr__(self):
206+
return "<Vocabulary(size={},frozen={})>".format(len(self), self._frozen)
207+
208+
209+
def get_serializable_contents(self):
210+
"""
211+
Creats a dict containing the necessary information to recreate this instance
212+
"""
213+
config = {"_mapping": self._mapping,
214+
"_flip": self._flip,
215+
"_frozen": self._frozen,
216+
"_i": self._i,
217+
"_counts": list(self._counts.items()),
218+
"_frequency_threshold": self._frequency_threshold,
219+
"use_unks": self.use_unks,
220+
"unk_token": self.unk_token,
221+
"use_mask": self.use_mask,
222+
"mask_token": self.mask_token,
223+
"use_start_end": self.use_start_end,
224+
"start_token": self.start_token,
225+
"end_token": self.end_token}
226+
return config
227+
228+
@classmethod
229+
def deserialize_from_contents(cls, content):
230+
"""
231+
Recreate a Vocabulary instance; expect same dict as output in `serialize`
232+
"""
233+
try:
234+
_mapping = content.pop("_mapping")
235+
_flip = content.pop("_flip")
236+
_i = content.pop("_i")
237+
_frozen = content.pop("_frozen")
238+
_counts = content.pop("_counts")
239+
_frequency_threshold = content.pop("_frequency_threshold")
240+
except KeyError:
241+
raise Exception("unable to deserialize vocabulary")
242+
if isinstance(list(_flip.keys())[0], six.string_types):
243+
_flip = {int(k): v for k, v in _flip.items()}
244+
out = cls(**content)
245+
out._mapping = _mapping
246+
out._flip = _flip
247+
out._i = _i
248+
out._counts = Counter(dict(_counts))
249+
out._frequency_threshold = _frequency_threshold
250+
251+
if _frozen:
252+
out.freeze(out.use_unks)
253+
254+
return out
255+

day_1/.ipynb_checkpoints/3_classify_names_with_MLP-checkpoint.ipynb

+1,319
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)