|
| 1 | +""" |
| 2 | +@ref: [An Experiment with the Edited Nearest-Neighbor Rule](https://ieeexplore.ieee.org/ielx5/21/4309513/04309523.pdf) |
| 3 | +""" |
| 4 | + |
1 | 5 | from sklearn.model_selection import KFold
|
2 | 6 | from sklearn.neighbors import KNeighborsClassifier
|
3 | 7 | import numpy as np
|
|
9 | 13 | class AENN:
|
10 | 14 | def __init__(self, max_neighbours: int = 5, n_jobs: int = -1):
|
11 | 15 | self.max_neighbours = max_neighbours
|
12 |
| - self.filter = Filter(parameters = {"max_neighbours": self.max_neighbours}) |
| 16 | + self.filter = Filter(parameters={"max_neighbours": self.max_neighbours}) |
13 | 17 | self.n_jobs = n_jobs
|
14 | 18 |
|
15 | 19 | def __call__(self, data: t.Sequence, classes: t.Sequence) -> Filter:
|
16 | 20 | self.isNoise = np.array([False] * len(classes))
|
17 | 21 | for n_neigh in range(1, self.max_neighbours + 1):
|
18 |
| - self.clf = KNeighborsClassifier(n_neighbors = n_neigh, algorithm = 'kd_tree', n_jobs = self.n_jobs) |
| 22 | + self.clf = KNeighborsClassifier(n_neighbors=n_neigh, algorithm='kd_tree', n_jobs=self.n_jobs) |
19 | 23 | for indx in np.argwhere(np.invert(self.isNoise)):
|
20 |
| - self.clf.fit(np.delete(data, indx, axis = 0), np.delete(classes, indx, axis = 0)) |
| 24 | + self.clf.fit(np.delete(data, indx, axis=0), np.delete(classes, indx, axis=0)) |
21 | 25 | pred = self.clf.predict(data[indx])
|
22 | 26 | self.isNoise[indx] = pred != classes[indx]
|
| 27 | + print(f"n_neigh: {n_neigh}, is_noise count:, {sum(self.isNoise)}, total: {len(self.isNoise)}") |
23 | 28 | self.filter.rem_indx = np.argwhere(self.isNoise)
|
24 | 29 | notNoise = np.invert(self.isNoise)
|
25 | 30 | self.filter.set_cleanData(data[notNoise], classes[notNoise])
|
26 | 31 | return self.filter
|
| 32 | + |
| 33 | + |
| 34 | +if __name__ == "__main__": |
| 35 | + from sklearn.datasets import load_iris |
| 36 | + |
| 37 | + data = load_iris() |
| 38 | + rm_feature_id = data.feature_names.index("petal length (cm)") |
| 39 | + features = data.data[:, [idx for idx in range(len(data.feature_names)) if idx != rm_feature_id]] |
| 40 | + labels = data.target |
| 41 | + filter = AENN()(features, labels) |
| 42 | + print(filter.rem_indx.shape, filter.rem_indx[:, 0]) |
0 commit comments