Skip to content

Commit d27e206

Browse files
committed
BUG fixed bugs
1 parent 96682e0 commit d27e206

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed
File renamed without changes.

textcat/train.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ class TokenStatistics():
1212
''' Helper class that stores the tf, idf and number of documents
1313
for a token. '''
1414
def __init__(self):
15-
self.tf_dict = {} # Key, value = category, tf
16-
self.num_docs_with_token = 0 # Number of documents containing token
17-
self.idf = 0 # idf
15+
self.tf_dict = defaultdict(lambda: 0) # Key, value = category, tf
16+
self.num_docs_with_token = 0 # Number of documents iwth token
17+
self.idf = 0 # idf
1818

1919

2020
class InvertedIndex():
@@ -44,7 +44,7 @@ def compute_tfidfs(self, train_labels_filename):
4444
self.inverted_index[token].tf_dict[category] += 1
4545

4646
for token in set(token_list):
47-
self.inverted_index[token].doc_count += 1
47+
self.inverted_index[token].num_docs_with_token += 1
4848

4949
self.num_documents += 1
5050

@@ -83,12 +83,18 @@ def save(self, filename):
8383

8484
def tokenize(file_path):
8585
''' Takes article path, and returns list of tokens. '''
86-
return ['foobar']
86+
tokens = []
87+
88+
with open(file_path, 'r') as f:
89+
for line in f:
90+
tokens += line.split()
91+
92+
return tokens
8793

8894

8995
if __name__ == '__main__':
90-
train_labels_filename = input('Train labels file:')
91-
model_filename = input('Train labels file:')
96+
train_labels_filename = input('Train labels file:\t')
97+
model_filename = input('Model checkpoint file:\t')
9298

9399
print('Training text categorizer...')
94100

0 commit comments

Comments
 (0)