-
Notifications
You must be signed in to change notification settings - Fork 216
[ENH] Implement of ESMOTE for imbalanced classification problems #2971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
bcfcd92
Esmote implement for aeon
LinGinQiu 43e44c9
Automatic `pre-commit` fixes
LinGinQiu d87e9f4
Merge branch 'aeon-toolkit:main' into main
LinGinQiu 6433f08
Merge branch 'aeon-toolkit:main' into main
LinGinQiu 81f7063
Merge branch 'aeon-toolkit:main' into main
LinGinQiu 0203c84
Merge branch 'aeon-toolkit:main' into main
LinGinQiu c9106fc
Merge branch 'main' into main
TonyBagnall e171a41
Merge branch 'main' into main
TonyBagnall 17b3311
use aeon's distamce in smote, adasyn, esmote
LinGinQiu c936712
delete the unused import lines
LinGinQiu 28e9fb0
add "capability:multithreading": True, in the tags
LinGinQiu d9e8465
add self._n_jobs = check_n_jobs(self.n_jobs)
LinGinQiu 10bb46d
1.rename utils to signle class knn
LinGinQiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
"""Supervised transformers to rebalance colelctions of time series.""" | ||
|
||
__all__ = ["ADASYN", "SMOTE", "OHIT"] | ||
__all__ = ["ADASYN", "SMOTE", "OHIT", "ESMOTE"] | ||
|
||
from aeon.transformations.collection.imbalance._adasyn import ADASYN | ||
from aeon.transformations.collection.imbalance._esmote import ESMOTE | ||
from aeon.transformations.collection.imbalance._ohit import OHIT | ||
from aeon.transformations.collection.imbalance._smote import SMOTE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
from collections import OrderedDict | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
from sklearn.utils import check_random_state | ||
|
||
from aeon.clustering.averaging._ba_utils import _get_alignment_path | ||
from aeon.transformations.collection import BaseCollectionTransformer | ||
from aeon.transformations.collection.imbalance._single_class_knn import Single_Class_KNN | ||
from aeon.utils.validation import check_n_jobs | ||
|
||
__all__ = ["ESMOTE"] | ||
|
||
|
||
class ESMOTE(BaseCollectionTransformer): | ||
""" | ||
Elastic Synthetic Minority Over-sampling Technique (ESMOTE). | ||
|
||
Parameters | ||
---------- | ||
n_neighbors : int, default=5 | ||
LinGinQiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The number of nearest neighbors used to define the neighborhood of samples | ||
to use to generate the synthetic time series. | ||
distance : str or callable, default="twe" | ||
The distance metric to use for the nearest neighbors search and alignment path | ||
of synthetic time series. | ||
weights : str or callable, default = 'uniform' | ||
Mechanism for weighting a vote one of: ``'uniform'``, ``'distance'``, | ||
or a callable | ||
function. | ||
random_state : int, RandomState instance or None, default=None | ||
If `int`, random_state is the seed used by the random number generator; | ||
If `RandomState` instance, random_state is the random number generator; | ||
If `None`, the random number generator is the `RandomState` instance used | ||
by `np.random`. | ||
|
||
See Also | ||
-------- | ||
ADASYN | ||
|
||
References | ||
---------- | ||
.. [1] Chawla et al. SMOTE: synthetic minority over-sampling technique, Journal | ||
of Artificial Intelligence Research 16(1): 321–357, 2002. | ||
https://dl.acm.org/doi/10.5555/1622407.1622416 | ||
""" | ||
|
||
_tags = { | ||
"capability:multivariate": False, | ||
"capability:unequal_length": False, | ||
"capability:multithreading": True, | ||
"requires_y": True, | ||
} | ||
|
||
def __init__( | ||
self, | ||
n_neighbors=5, | ||
distance: Union[str, callable] = "twe", | ||
LinGinQiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
distance_params: Optional[dict] = None, | ||
weights: Union[str, callable] = "uniform", | ||
n_jobs: int = 1, | ||
random_state=None, | ||
): | ||
self.random_state = random_state | ||
self.n_neighbors = n_neighbors | ||
self.distance = distance | ||
self.weights = weights | ||
self.distance_params = distance_params | ||
self.n_jobs = n_jobs | ||
|
||
self._random_state = None | ||
self._distance_params = distance_params or {} | ||
|
||
self.nn_ = None | ||
super().__init__() | ||
|
||
def _fit(self, X, y=None): | ||
self._random_state = check_random_state(self.random_state) | ||
self._n_jobs = check_n_jobs(self.n_jobs) | ||
self.nn_ = Single_Class_KNN( | ||
n_neighbors=self.n_neighbors + 1, | ||
distance=self.distance, | ||
distance_params=self._distance_params, | ||
weights=self.weights, | ||
n_jobs=self.n_jobs, | ||
) | ||
|
||
# generate sampling target by targeting all classes except the majority | ||
unique, counts = np.unique(y, return_counts=True) | ||
target_stats = dict(zip(unique, counts)) | ||
n_sample_majority = max(target_stats.values()) | ||
class_majority = max(target_stats, key=target_stats.get) | ||
sampling_strategy = { | ||
key: n_sample_majority - value | ||
for (key, value) in target_stats.items() | ||
if key != class_majority | ||
} | ||
self.sampling_strategy_ = OrderedDict(sorted(sampling_strategy.items())) | ||
return self | ||
|
||
def _transform(self, X, y=None): | ||
X_resampled = [X.copy()] | ||
y_resampled = [y.copy()] | ||
|
||
# got the minority class label and the number needs to be generated | ||
for class_sample, n_samples in self.sampling_strategy_.items(): | ||
if n_samples == 0: | ||
continue | ||
target_class_indices = np.flatnonzero(y == class_sample) | ||
X_class = X[target_class_indices] | ||
y_class = y[target_class_indices] | ||
|
||
self.nn_.fit(X_class, y_class) | ||
nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:] | ||
X_new, y_new = self._make_samples( | ||
X_class, | ||
y.dtype, | ||
class_sample, | ||
X_class, | ||
nns, | ||
n_samples, | ||
1.0, | ||
n_jobs=self.n_jobs, | ||
) | ||
X_resampled.append(X_new) | ||
y_resampled.append(y_new) | ||
X_synthetic = np.vstack(X_resampled) | ||
y_synthetic = np.hstack(y_resampled) | ||
|
||
return X_synthetic, y_synthetic | ||
|
||
def _make_samples( | ||
self, X, y_dtype, y_type, nn_data, nn_num, n_samples, step_size=1.0, n_jobs=1 | ||
): | ||
samples_indices = self._random_state.randint( | ||
low=0, high=nn_num.size, size=n_samples | ||
) | ||
|
||
steps = ( | ||
step_size | ||
* self._random_state.uniform(low=0, high=1, size=n_samples)[:, np.newaxis] | ||
) | ||
rows = np.floor_divide(samples_indices, nn_num.shape[1]) | ||
cols = np.mod(samples_indices, nn_num.shape[1]) | ||
X_new = np.zeros((len(rows), *X.shape[1:]), dtype=X.dtype) | ||
for count in range(len(rows)): | ||
i = rows[count] | ||
j = cols[count] | ||
nn_ts = nn_data[nn_num[i, j]] | ||
X_new[count] = self._generate_sample_use_elastic_distance( | ||
X[i], | ||
nn_ts, | ||
distance=self.distance, | ||
step=steps[count], | ||
) | ||
|
||
y_new = np.full(n_samples, fill_value=y_type, dtype=y_dtype) | ||
return X_new, y_new | ||
|
||
def _generate_sample_use_elastic_distance( | ||
self, | ||
curr_ts, | ||
nn_ts, | ||
distance, | ||
step, | ||
window: Union[float, None] = None, | ||
g: float = 0.0, | ||
epsilon: Union[float, None] = None, | ||
nu: float = 0.001, | ||
lmbda: float = 1.0, | ||
independent: bool = True, | ||
c: float = 1.0, | ||
descriptor: str = "identity", | ||
reach: int = 15, | ||
warp_penalty: float = 1.0, | ||
transformation_precomputed: bool = False, | ||
transformed_x: Optional[np.ndarray] = None, | ||
transformed_y: Optional[np.ndarray] = None, | ||
return_bias=True, | ||
): | ||
""" | ||
Generate a single synthetic sample using soft distance. | ||
|
||
This is use soft distance to align the current time series with its nearest | ||
neighbor, and then generate a synthetic sample by subtracting the aligned | ||
nearest neighbor from the current time series. | ||
|
||
# shape: (c, l) or (l) | ||
# shape: (c, l) or (l) | ||
""" | ||
new_ts = curr_ts.copy() | ||
alignment, _ = _get_alignment_path( | ||
nn_ts, | ||
curr_ts, | ||
distance, | ||
window, | ||
g, | ||
epsilon, | ||
nu, | ||
lmbda, | ||
independent, | ||
c, | ||
descriptor, | ||
reach, | ||
warp_penalty, | ||
transformation_precomputed, | ||
transformed_x, | ||
transformed_y, | ||
) | ||
path_list = [[] for _ in range(curr_ts.shape[1])] | ||
for k, l in alignment: | ||
path_list[k].append(l) | ||
|
||
empty_of_array = np.zeros_like(curr_ts, dtype=float) # shape: (c, l) | ||
|
||
for k, l in enumerate(path_list): | ||
key = self._random_state.choice(l) | ||
# Compute difference for all channels at this time step | ||
empty_of_array[:, k] = curr_ts[:, k] - nn_ts[:, key] | ||
|
||
bias = step * empty_of_array | ||
if return_bias: | ||
return bias | ||
|
||
new_ts = new_ts - bias | ||
return new_ts |
24 changes: 24 additions & 0 deletions
24
aeon/transformations/collection/imbalance/_single_class_knn.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
"""Wrapper of KNeighborsTimeSeriesClassifier named Single_Class_KNN. | ||
|
||
It wraps the fit setup to ensure `_fit` is executed even when the dataset | ||
contains only a single class. | ||
""" | ||
|
||
from aeon.classification.distance_based import KNeighborsTimeSeriesClassifier | ||
|
||
__all__ = ["Single_Class_KNN"] | ||
|
||
|
||
class Single_Class_KNN(KNeighborsTimeSeriesClassifier): | ||
""" | ||
KNN classifier for time series data, adapted to work with SMOTE. | ||
|
||
This class is a wrapper around the original KNeighborsTimeSeriesClassifier | ||
to ensure compatibility with the Signal class. | ||
""" | ||
|
||
def _fit_setup(self, X, y): | ||
# KNN can support if all labels are the same so always return False for single | ||
# class problem so the fit will always run | ||
X, y, _ = super()._fit_setup(X, y) | ||
return X, y, False |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.