-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.cc
152 lines (139 loc) · 4.34 KB
/
training.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include <iostream>
#include <vector>
#include <iterator>
#include <algorithm>
#include "helper.h"
using namespace std;
map<string, pair<int, int>> read_csv(ifstream& filename);
void cleanInfrequentClassifications(map<string, pair<int, int>>& words);
void outputProbability(map<string, pair<int, int>>& words, ofstream& hamFile, ofstream& spamFile);
int main(int argc, char** argv)
{
if (argc < 4 && (argc != 7 || argc != 4)) {
cerr << "Inccorect argument count" << endl;
return -1;
}
string i, os, oh;
if (argc == 4) { // No optional arguments provided.
i = string(argv[1]); // Assume first argument is input.
os = string(argv[2]); // Assume second argument is spam output.
oh = string(argv[3]); // Assume third argument is ham output.
}
else {
string x[6];
for (int j = 1; j < 7; j++) {
x[j-1] = string(argv[j]);
}
if (x[0][0] == '-' && x[2][0] == '-' && x[4][0] == '-') {
for (int k = 0; k <= 4; k += 2) {
if (x[k][1] == 'i') {
i = x[k+1];
}
else if (x[k][1] == 'o' && x[k][2] == 's') {
os = x[k+1];
}
else if (x[k][1] == 'o' && x[k][2] == 'h') {
oh = x[k+1];
}
else {
cerr << "Unknown argument " << x[k] << endl;
return -1;
}
}
}
else {
cerr << "Incorrect argument structure" << endl;
return -1;
}
}
if (i.empty()) {
cerr << "Input argument not satisfied" << endl;
return -1;
}
if (os.empty()) {
cerr << "Output spam argument not satisfied" << endl;
return -1;
}
if (oh.empty()) {
cerr << "Output ham argument not satisfied" << endl;
return -1;
}
ifstream inputFile(i, ifstream::in); // Input File Name
ofstream spamFile (os, ofstream::out); // Output Spam File
ofstream hamFile (oh, ofstream::out); // Output Ham File
map<string, pair<int, int>> trained_model = read_csv(inputFile); // <word: <spam_count, ham_count>>
cleanInfrequentClassifications(trained_model); // Remove classifcations where spam_count + ham_count < 3
removeStopWords(trained_model); // Remove words that likely play no part to the meaning of a message.
outputProbability(trained_model, spamFile, hamFile); // Output probability files for classifcation
return 0;
}
map<string, pair<int, int>> read_csv(ifstream& filename)
{
string line, word;
map<string, pair<int, int>> words;
if( filename.good() ) {
while(getline(filename, line)) {
for(int index = 0; index < line.size() && line[index] == ','; ++index) line.replace(index, 1, ""); // Remove commas at the beginning of row (if any)
if (line[0] == 's' || line[0] == 'S') { // Mark words for spam
line.erase(0, 5); // Remove spam, from line
cleanLine(line);
toLower(line);
for (int i = 0; i < line.size(); i++) {
if (line[i] != ' ') word += line[i];
else if (word.size()) {
if (words.count(word) == 0) { // Word does not exist in map.
words.insert(pair<string, pair<int, int>>(word, make_pair(1, 0)));
}
else { // Word exists in map.
get<0>(words.find(word)->second) += 1; // Increment spam count
}
word.erase(word.begin(), word.end());
}
}
}
else if (line[0] == 'h' || line[0] == 'H') { // Mark words for not spam.
line.erase(0, 4); // Remove ham, from line
cleanLine(line);
toLower(line);
for (int i = 0; i < line.size(); i++) {
if (line[i] != ' ') word += line[i];
else if (word.size()) {
if (words.count(word) == 0) { // Word does not exist in map.
words.insert(pair<string, pair<int, int>>(word, make_pair(0, 1)));
}
else {// Word exists in map.
get<1>(words.find(word)->second) += 1; // Increment ham count
}
word.erase(word.begin(), word.end());
}
}
}
}
}
else {
cerr << " Cannot read file ! " << endl;
}
filename.close();
return words;
}
void cleanInfrequentClassifications(map<string, pair<int, int>>& words)
{
auto it = words.begin();
auto temp = words.begin();
while (it != words.end()) {
if ((get<0>(it->second) + get<1>(it->second)) < 3) {
temp = it;
it++;
words.erase(temp);
}
else it++;
}
}
void outputProbability(map<string, pair<int, int>>& words, ofstream& hamFile, ofstream& spamFile)
{
for (auto it = words.begin(); it != words.end(); ++it) {
hamFile << it->first << " " << get<0>(it->second) << endl;
spamFile << it->first << " " << get<1>(it->second) << endl;
}
hamFile.close(); spamFile.close();
}