-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathspam_filter.py
69 lines (60 loc) · 2.46 KB
/
spam_filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import os
import joblib
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline
def load_data():
# Load training data
train_spam = []
train_not_spam = []
for filename in os.listdir('data/train/spam'):
with open(os.path.join('data/train/spam', filename), 'r') as file:
train_spam.append(file.read())
for filename in os.listdir('data/train/not_spam'):
with open(os.path.join('data/train/not_spam', filename), 'r') as file:
train_not_spam.append(file.read())
train_data = train_spam + train_not_spam
train_labels = [1] * len(train_spam) + [0] * len(train_not_spam)
# Load test data
test_spam = []
test_not_spam = []
for filename in os.listdir('data/test/spam'):
with open(os.path.join('data/test/spam', filename), 'r') as file:
test_spam.append(file.read())
for filename in os.listdir('data/test/not_spam'):
with open(os.path.join('data/test/not_spam', filename), 'r') as file:
test_not_spam.append(file.read())
test_data = test_spam + test_not_spam
test_labels = [1] * len(test_spam) + [0] * len(test_not_spam)
return (train_data, train_labels), (test_data, test_labels)
def train_model():
(train_data, train_labels), _ = load_data()
model = make_pipeline(TfidfVectorizer(), MultinomialNB())
model.fit(train_data, train_labels)
joblib.dump(model, 'spam_filter_model.pkl')
print("Model trained and saved as spam_filter_model.pkl")
def predict(email_path):
if not os.path.exists('spam_filter_model.pkl'):
print("Model not found. Please train the model first.")
return
model = joblib.load('spam_filter_model.pkl')
with open(email_path, 'r') as file:
email_content = file.read()
prediction = model.predict([email_content])
if prediction[0] == 1:
print("The email is spam.")
else:
print("The email is not spam.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Email Spam Filter")
parser.add_argument('--train', action='store_true', help="Train the spam filter model")
parser.add_argument('--email_path', type=str, help="Path to the email text file")
args = parser.parse_args()
if args.train:
train_model()
elif args.email_path:
predict(args.email_path)
else:
parser.print_help()