|
| 1 | +# anomaly_detection.py |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | +import numpy as np |
| 5 | +from sklearn.ensemble import IsolationForest |
| 6 | +from sklearn.svm import OneClassSVM |
| 7 | +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score |
| 8 | +from sklearn.preprocessing import StandardScaler |
| 9 | + |
| 10 | +class AnomalyDetector: |
| 11 | + def __init__(self, algorithm, contamination=0.1): |
| 12 | + self.algorithm = algorithm |
| 13 | + self.contamination = contamination |
| 14 | + self.model = self.select_algorithm(algorithm) |
| 15 | + |
| 16 | + def select_algorithm(self, algorithm): |
| 17 | + if algorithm == 'isolation_forest': |
| 18 | + return IsolationForest(contamination=self.contamination) |
| 19 | + elif algorithm == 'one_class_svm': |
| 20 | + return OneClassSVM(kernel='rbf', gamma=0.1, nu=0.1) |
| 21 | + else: |
| 22 | + raise ValueError('Invalid algorithm. Supported algorithms are isolation_forest and one_class_svm.') |
| 23 | + |
| 24 | + def fit(self, X): |
| 25 | + self.model.fit(X) |
| 26 | + |
| 27 | + def predict(self, X): |
| 28 | + return self.model.predict(X) |
| 29 | + |
| 30 | + def evaluate(self, y_true, y_pred): |
| 31 | + accuracy = accuracy_score(y_true, y_pred) |
| 32 | + precision = precision_score(y_true, y_pred) |
| 33 | + recall = recall_score(y_true, y_pred) |
| 34 | + f1 = f1_score(y_true, y_pred) |
| 35 | + return accuracy, precision, recall, f1 |
| 36 | + |
| 37 | + def detect_anomalies(self, data): |
| 38 | + X = data.drop(['label'], axis=1) |
| 39 | + y = data['label'] |
| 40 | + self.fit(X) |
| 41 | + y_pred = self.predict(X) |
| 42 | + accuracy, precision, recall, f1 = self.evaluate(y, y_pred) |
| 43 | + print(f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}') |
| 44 | + return y_pred |
| 45 | + |
| 46 | + def visualize_anomalies(self, data, y_pred): |
| 47 | + import matplotlib.pyplot as plt |
| 48 | + plt.scatter(data[:, 0], data[:, 1], c=y_pred) |
| 49 | + plt.xlabel('Feature 1') |
| 50 | + plt.ylabel('Feature 2') |
| 51 | + plt.title('Anomaly Detection') |
| 52 | + plt.show() |
| 53 | + |
| 54 | +# Example usage: |
| 55 | +data = pd.read_csv('anomaly_data.csv') |
| 56 | +anomaly_detector = AnomalyDetector(algorithm='isolation_forest') |
| 57 | +y_pred = anomaly_detector.detect_anomalies(data) |
| 58 | +anomaly_detector.visualize_anomalies(data.drop(['label'], axis=1), y_pred) |
0 commit comments