Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion aeon/transformations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
__maintainer__ = ["MatthewMiddlehurst", "TonyBagnall"]
__all__ = ["BaseTransformer"]

from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import final

import numpy as np
import pandas as pd

from aeon.base import BaseAeonEstimator
from aeon.transformations.collection import BaseCollectionTransformer
from aeon.transformations.series import BaseSeriesTransformer


class BaseTransformer(BaseAeonEstimator):
Expand Down Expand Up @@ -112,3 +115,88 @@ def _check_y(self, y, n_cases=None):
f"Mismatch in number of cases. Number in X = {n_cases} nos in y = "
f"{n_labels}"
)


class InverseTransformerMixin(ABC):
"""Mixin for transformers that support inverse transformation."""

_tags = {
"capability:inverse_transform": True,
}

@final
def inverse_transform(self, X, y=None, axis=1):
"""Inverse transform X and return an inverse transformed version.

Currently it is assumed that only transformers with tags
"input_data_type"="Series", "output_data_type"="Series",
can have an inverse_transform.

State required:
Requires state to be "fitted".

Accesses in self:
_is_fitted : must be True
fitted model attributes (ending in "_") : accessed by _inverse_transform

Parameters
----------
X : Series or Collection, any supported type
Data to fit transform to, of python type as follows:
Series: 2D np.ndarray shape (n_channels, n_timepoints)
Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints)
or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i)
y : Series, default=None
Additional data, e.g., labels for transformation.
axis : int, default = 1
Axis of time in the input series.
If ``axis == 0``, it is assumed each column is a time series and each row is
a time point. i.e. the shape of the data is ``(n_timepoints,
n_channels)``.
``axis == 1`` indicates the time series are in rows, i.e. the shape of
the data is ``(n_channels, n_timepoints)`.``axis is None`` indicates
that the axis of X is the same as ``self.axis``.

Only relevant for ``aeon.transformations.series`` transformers.

Returns
-------
inverse transformed version of X
of the same type as X
"""
# check whether is fitted
self._check_is_fitted()

# input check and conversion for X/y
if isinstance(self, BaseCollectionTransformer):
X_inner = self._preprocess_collection(X, store_metadata=False)
Xt = self._inverse_transform(X=X_inner, y=y)
return Xt
elif isinstance(self, BaseSeriesTransformer):
self._check_is_fitted()
X = self._preprocess_series(X, axis=axis, store_metadata=False)
Xt = self._inverse_transform(X=X, y=y)
return self._postprocess_series(Xt, axis=axis)

@abstractmethod
def _inverse_transform(self, X, y=None):
"""Inverse transform X and return an inverse transformed version.

private _inverse_transform containing core logic, called from inverse_transform.

Parameters
----------
X : Series or Collection, any supported type
Data to fit transform to, of python type as follows:
Series: 2D np.ndarray shape (n_channels, n_timepoints)
Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints)
or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i)
y : Series, default=None
Additional data, e.g., labels for transformation.

Returns
-------
inverse transformed version of X
of the same type as X.
"""
...
3 changes: 2 additions & 1 deletion aeon/transformations/collection/_broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import numpy as np

from aeon.transformations.base import InverseTransformerMixin
from aeon.transformations.collection.base import BaseCollectionTransformer
from aeon.transformations.series.base import BaseSeriesTransformer
from aeon.utils.validation.collection import get_n_cases


class SeriesToCollectionBroadcaster(BaseCollectionTransformer):
class SeriesToCollectionBroadcaster(BaseCollectionTransformer, InverseTransformerMixin):
"""Broadcast a ``BaseSeriesTransformer`` over a collection of time series.

Uses the ``BaseSeriesTransformer`` passed in the constructor. If the
Expand Down
78 changes: 0 additions & 78 deletions aeon/transformations/collection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,63 +211,6 @@ def fit_transform(self, X, y=None):
self.is_fitted = True
return Xt

@final
def inverse_transform(self, X, y=None):
"""Inverse transform X and return an inverse transformed version.

Currently it is assumed that only transformers with tags
"input_data_type"="Series", "output_data_type"="Series",
can have an inverse_transform.

State required:
Requires state to be "fitted".

Accesses in self:
_is_fitted : must be True
fitted model attributes (ending in "_") : accessed by _inverse_transform

Parameters
----------
X : np.ndarray or list
Data to fit transform to, of valid collection type. Input data,
any number of channels, equal length series of shape ``(
n_cases, n_channels, n_timepoints)`` or list of numpy arrays (number
of channels, series length) of shape ``[n_cases]``, 2D np.array
``(n_channels, n_timepoints_i)``, where ``n_timepoints_i`` is length of
series ``i``. Other types are allowed and converted into one of the above.

Different estimators have different capabilities to handle different
types of input. If ``self.get_tag("capability:multivariate")`` is False,
they cannot handle multivariate series. If ``self.get_tag(
"capability:unequal_length")`` is False, they cannot handle unequal
length input. In both situations, a ``ValueError`` is raised if X has a
characteristic that the estimator does not have the capability to handle.
y : np.ndarray, default=None
1D np.array of float or str, of shape ``(n_cases)`` - class labels
(ground truth) for fitting indices corresponding to instance indices in X.
If None, no labels are used in fitting.

Returns
-------
inverse transformed version of X
of the same type as X
"""
if not self.get_tag("capability:inverse_transform"):
raise NotImplementedError(
f"{type(self)} does not implement inverse_transform"
)

# check whether is fitted
self._check_is_fitted()

# input check and conversion for X/y
X_inner = self._preprocess_collection(X, store_metadata=False)
y_inner = y

Xt = self._inverse_transform(X=X_inner, y=y_inner)

return Xt

def _fit(self, X, y=None):
"""Fit transformer to X and y.

Expand Down Expand Up @@ -331,24 +274,3 @@ def _fit_transform(self, X, y=None):
if not self.get_tag("fit_is_empty"):
self._fit(X, y)
return self._transform(X, y)

def _inverse_transform(self, X, y=None):
"""Inverse transform X and return an inverse transformed version.

private _inverse_transform containing core logic, called from inverse_transform.

Parameters
----------
X : Input data
Data to fit transform to, of valid collection type.
y : Target variable, default=None
Additional data, e.g., labels for transformation

Returns
-------
inverse transformed version of X
of the same type as X.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support inverse_transform"
)
4 changes: 2 additions & 2 deletions aeon/transformations/collection/compose/_identity.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Identity transformer."""

from aeon.transformations.base import InverseTransformerMixin
from aeon.transformations.collection import BaseCollectionTransformer
from aeon.utils.data_types import COLLECTIONS_DATA_TYPES


class CollectionId(BaseCollectionTransformer):
class CollectionId(BaseCollectionTransformer, InverseTransformerMixin):
"""Identity transformer, returns data unchanged in transform/inverse_transform."""

_tags = {
"X_inner_type": COLLECTIONS_DATA_TYPES,
"fit_is_empty": True,
"capability:inverse_transform": True,
"capability:multivariate": True,
"capability:unequal_length": True,
"capability:missing_values": True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class BORF(BaseCollectionTransformer):

_tags = {
"X_inner_type": "numpy3D",
"capability:inverse_transform": False,
"capability:missing_values": True,
"capability:multivariate": True,
"capability:multithreading": True,
Expand Down
4 changes: 2 additions & 2 deletions aeon/transformations/series/_boxcox.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from scipy.special import boxcox, inv_boxcox
from scipy.stats import boxcox_llf, distributions, variation

from aeon.transformations.base import InverseTransformerMixin
from aeon.transformations.series.base import BaseSeriesTransformer


Expand Down Expand Up @@ -38,7 +39,7 @@ def _calc_uniform_order_statistic_medians(n):
return v


class BoxCoxTransformer(BaseSeriesTransformer):
class BoxCoxTransformer(BaseSeriesTransformer, InverseTransformerMixin):
r"""Box-Cox power transform.

Box-Cox transformation is a power transformation that is used to
Expand Down Expand Up @@ -106,7 +107,6 @@ class BoxCoxTransformer(BaseSeriesTransformer):
"X_inner_type": "np.ndarray",
"fit_is_empty": False,
"capability:multivariate": False,
"capability:inverse_transform": True,
}

def __init__(self, bounds=None, method="mle", sp=None):
Expand Down
4 changes: 2 additions & 2 deletions aeon/transformations/series/_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import numpy as np

from aeon.transformations.base import InverseTransformerMixin
from aeon.transformations.series.base import BaseSeriesTransformer


class LogTransformer(BaseSeriesTransformer):
class LogTransformer(BaseSeriesTransformer, InverseTransformerMixin):
"""Natural logarithm transformation.

The Natural logarithm transformation can be used to make the data more normally
Expand Down Expand Up @@ -41,7 +42,6 @@ class LogTransformer(BaseSeriesTransformer):
"X_inner_type": "np.ndarray",
"fit_is_empty": True,
"capability:multivariate": True,
"capability:inverse_transform": True,
}

def __init__(self, offset=0, scale=1):
Expand Down
4 changes: 2 additions & 2 deletions aeon/transformations/series/_scaled_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import numpy as np

from aeon.transformations.base import InverseTransformerMixin
from aeon.transformations.series.base import BaseSeriesTransformer


class ScaledLogitSeriesTransformer(BaseSeriesTransformer):
class ScaledLogitSeriesTransformer(BaseSeriesTransformer, InverseTransformerMixin):
r"""Scaled logit transform or Log transform.

If both lower_bound and upper_bound are not None, a scaled logit transform is
Expand Down Expand Up @@ -59,7 +60,6 @@ class ScaledLogitSeriesTransformer(BaseSeriesTransformer):
"X_inner_type": "np.ndarray",
"fit_is_empty": True,
"capability:multivariate": True,
"capability:inverse_transform": True,
}

def __init__(self, lower_bound=None, upper_bound=None):
Expand Down
59 changes: 0 additions & 59 deletions aeon/transformations/series/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,44 +165,6 @@ def fit_transform(self, X, y=None, axis=1):
self.is_fitted = True
return self._postprocess_series(Xt, axis=axis)

@final
def inverse_transform(self, X, y=None, axis=1):
"""Inverse transform X and return an inverse transformed version.

State required:
Requires state to be "fitted".

Parameters
----------
X : Input data
Data to fit transform to, of valid collection type.
y : Target variable, default=None
Additional data, e.g., labels for transformation
axis : int, default = 1
Axis of time in the input series.
If ``axis == 0``, it is assumed each column is a time series and each row is
a time point. i.e. the shape of the data is ``(n_timepoints,
n_channels)``.
``axis == 1`` indicates the time series are in rows, i.e. the shape of
the data is ``(n_channels, n_timepoints)`.``axis is None`` indicates
that the axis of X is the same as ``self.axis``.

Returns
-------
inverse transformed version of X
of the same type as X
"""
if not self.get_tag("capability:inverse_transform"):
raise NotImplementedError(
f"{type(self)} does not implement inverse_transform"
)

# check whether is fitted
self._check_is_fitted()
X = self._preprocess_series(X, axis=axis, store_metadata=False)
Xt = self._inverse_transform(X=X, y=y)
return self._postprocess_series(Xt, axis=axis)

@final
def update(self, X, y=None, update_params=True, axis=1):
"""Update transformer with X, optionally y.
Expand Down Expand Up @@ -287,27 +249,6 @@ def _fit_transform(self, X, y=None):
self._fit(X, y)
return self._transform(X, y)

def _inverse_transform(self, X, y=None):
"""Inverse transform X and return an inverse transformed version.

private _inverse_transform containing core logic, called from inverse_transform.

Parameters
----------
X : Input data
Time series to fit transform to, of valid collection type.
y : Target variable, default=None
Additional data, e.g., labels for transformation

Returns
-------
inverse transformed version of X
of the same type as X.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support inverse_transform"
)

def _update(self, X, y=None, update_params=True):
# standard behaviour: no update takes place, new data is ignored
return self
Expand Down
4 changes: 2 additions & 2 deletions aeon/transformations/series/compose/_identity.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Identity transformer."""

from aeon.transformations.base import InverseTransformerMixin
from aeon.transformations.series import BaseSeriesTransformer
from aeon.utils.data_types import VALID_SERIES_INNER_TYPES


class SeriesId(BaseSeriesTransformer):
class SeriesId(BaseSeriesTransformer, InverseTransformerMixin):
"""Identity transformer, returns data unchanged in transform/inverse_transform."""

_tags = {
"X_inner_type": VALID_SERIES_INNER_TYPES,
"fit_is_empty": True,
"capability:inverse_transform": True,
"capability:multivariate": True,
"capability:missing_values": True,
}
Expand Down
Loading