Skip to content

Commit 2a77f99

Browse files
authored
Create anomaly_detection.py
1 parent 99fbb5e commit 2a77f99

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# anomaly_detection.py
2+
3+
import pandas as pd
4+
import numpy as np
5+
from sklearn.ensemble import IsolationForest
6+
from sklearn.svm import OneClassSVM
7+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
8+
from sklearn.preprocessing import StandardScaler
9+
10+
class AnomalyDetector:
11+
def __init__(self, algorithm, contamination=0.1):
12+
self.algorithm = algorithm
13+
self.contamination = contamination
14+
self.model = self.select_algorithm(algorithm)
15+
16+
def select_algorithm(self, algorithm):
17+
if algorithm == 'isolation_forest':
18+
return IsolationForest(contamination=self.contamination)
19+
elif algorithm == 'one_class_svm':
20+
return OneClassSVM(kernel='rbf', gamma=0.1, nu=0.1)
21+
else:
22+
raise ValueError('Invalid algorithm. Supported algorithms are isolation_forest and one_class_svm.')
23+
24+
def fit(self, X):
25+
self.model.fit(X)
26+
27+
def predict(self, X):
28+
return self.model.predict(X)
29+
30+
def evaluate(self, y_true, y_pred):
31+
accuracy = accuracy_score(y_true, y_pred)
32+
precision = precision_score(y_true, y_pred)
33+
recall = recall_score(y_true, y_pred)
34+
f1 = f1_score(y_true, y_pred)
35+
return accuracy, precision, recall, f1
36+
37+
def detect_anomalies(self, data):
38+
X = data.drop(['label'], axis=1)
39+
y = data['label']
40+
self.fit(X)
41+
y_pred = self.predict(X)
42+
accuracy, precision, recall, f1 = self.evaluate(y, y_pred)
43+
print(f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}')
44+
return y_pred
45+
46+
def visualize_anomalies(self, data, y_pred):
47+
import matplotlib.pyplot as plt
48+
plt.scatter(data[:, 0], data[:, 1], c=y_pred)
49+
plt.xlabel('Feature 1')
50+
plt.ylabel('Feature 2')
51+
plt.title('Anomaly Detection')
52+
plt.show()
53+
54+
# Example usage:
55+
data = pd.read_csv('anomaly_data.csv')
56+
anomaly_detector = AnomalyDetector(algorithm='isolation_forest')
57+
y_pred = anomaly_detector.detect_anomalies(data)
58+
anomaly_detector.visualize_anomalies(data.drop(['label'], axis=1), y_pred)

0 commit comments

Comments
 (0)