-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmla.py
66 lines (50 loc) · 1.87 KB
/
mla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""MLA - machine-learning-Algorithm
Usage:
mla.py train <dataset-dir> <model-file> [--vocab-size=<vocab-size>]
mla.py ask <model-file> <question>
mla.py (-h | --help)
Arguments:
<dataset-dir> Directory with dataset.
<model-file> Serialized model file.
<question> Text to be classified.
Options:
--vocab-size=<vocab-size> Vocabulary size. [default: 10000]
-h --help Show this screen.
"""
from docopt import docopt
##################################################################################
import os
from sklearn.metrics import classification_report
from mla import Model_RF, Dataset
def train_model(dataset_dir, model_file, vocab_size):
print(f'Training model from directory {dataset_dir}')
print(f'Vocabulary size: {vocab_size}')
train_dir = os.path.join(dataset_dir, 'train')
test_dir = os.path.join(dataset_dir, 'test')
dset = Dataset(train_dir, test_dir)
X, y = dset.get_train_set()
model = Model_RF(vocab_size=vocab_size)
model.train(X, y)
print(f'Storing model to {model_file}')
model.serialize(model_file)
X_test, y_test = dset.get_test_set()
y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred))
def ask_model(model_file, question):
print(f'Asking model {model_file} about "{question}"')
model = Model_RF.deserialize(model_file)
y_pred = model.predict_proba([question])
print(y_pred[0])
######################################################################
def main():
arguments = docopt(__doc__)
if arguments['train']:
train_model(arguments['<dataset-dir>'],
arguments['<model-file>'],
int(arguments['--vocab-size'])
)
elif arguments['ask']:
ask_model(arguments['<model-file>'],
arguments['<question>'])
if __name__ == '__main__':
main()