-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathsettings.py
83 lines (70 loc) · 2.69 KB
/
settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from os.path import join
CUDA = bool(os.getenv('ENABLE_CUDA_DL4NLP', False))
ROOT_DIR = os.getenv('DL4NLPROOT', None)
if ROOT_DIR is None:
ROOT_DIR = os.getenv('HOME', '.')
DATA_DIR = join(ROOT_DIR, 'data')
DAY1_DIR = join(ROOT_DIR, 'day1')
DAY2_DIR = join(ROOT_DIR, 'day2')
ZOO_DIR = join(ROOT_DIR, 'modelzoo')
GLOVE_FILENAME = join(DATA_DIR, 'glove.6B.100d.txt')
FIRSTNAMES_CSV = join(DATA_DIR, 'firstnames.csv')
SURNAMES_CSV = join(DATA_DIR, 'surnames.csv')
TRUMP_FILENAME = join(DATA_DIR, 'trump.csv')
AMAZON_FILENAME = join(DATA_DIR, 'amazon_train_small.csv')
SNLI_TRAIN_JSON = join(DATA_DIR, 'snli_1.0', 'snli_1.0_train.jsonl')
SNLI_DEV_JSON = join(DATA_DIR, 'snli_1.0', 'snli_1.0_dev.jsonl')
SNLI_TEST_JSON = join(DATA_DIR, 'snli_1.0', 'snli_1.0_test.jsonl')
ZHNEWS_CSV = join(DATA_DIR, 'zhnews.csv')
START_TOKEN = "^"
END_TOKEN= "_"
# used by pytorch backends to ignore non-classes in loss compuations
# best for sequences where not all things that go into the loss computation
# should be included in the loss
IGNORE_INDEX_VALUE = -1
class ZOO:
charnn_surname_classifer = {
'filename': join(ZOO_DIR,
'charnn_emb16_hid64_surnames_classify.state'),
'vocab': join(ZOO_DIR, 'surnames_classify.vocab'),
'comments': 'pre-trained surname classifier',
'date': '09-14-2017',
'parameters': {
'embedding_size': 16,
'hidden_size': 64
}
}
charnn_surname_predicter = {
'filename': join(ZOO_DIR,
'charnn_emb16_hid64_surnames_predict.state'),
'vocab': join(ZOO_DIR, 'surnames_classify.vocab'),
'comments': 'pre-trained surname sequence prediction (& generation model)',
'date': '09-14-2017',
'parameters': {
'embedding_size': 16,
'hidden_size': 64
}
}
charnn_surname_conditioned_predicter = {
'filename': join(ZOO_DIR,
'charnn_emb16_hid64_surnames_conditionally_predict.state'),
'vocab': join(ZOO_DIR, 'surnames_classify.vocab'),
'comments': 'pre-trained surname conditioned sequence prediction (& conditioned generation)',
'date': '09-14-2017',
'parameters': {
'embedding_size': 16,
'hidden_size': 64
}
}
wordrnn_trump_tweet_predicter = {
'filename': join(ZOO_DIR,
'wordrnn_emb100_hid64_trump_tweets_predict.state'),
'vocab': join(ZOO_DIR, 'trump_twitter.vocab'),
'comments': 'pre-trained trump sequence prediction (& generation)',
'date': '09-14-2017',
'parameters': {
'embedding_size': 100,
'hidden_size': 64
}
}