diff --git a/chemotools/augmentation/index_shift.py b/chemotools/augmentation/index_shift.py index b628b0d..25342f9 100644 --- a/chemotools/augmentation/index_shift.py +++ b/chemotools/augmentation/index_shift.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Literal, Optional import numpy as np +from numpy.polynomial import polynomial as poly from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin from sklearn.utils.validation import check_is_fitted, validate_data @@ -35,8 +36,14 @@ class IndexShift(TransformerMixin, OneToOneFeatureMixin, BaseEstimator): Transform the input data by shifting the spectrum. """ - def __init__(self, shift: int = 0, random_state: Optional[int] = None): + def __init__( + self, + shift: int = 0, + fill_method: Literal["constant", "linear", "quadratic"] = "constant", + random_state: Optional[int] = None, + ): self.shift = shift + self.fill_method = fill_method self.random_state = random_state def fit(self, X: np.ndarray, y=None) -> "IndexShift": @@ -111,10 +118,82 @@ def transform(self, X: np.ndarray, y=None) -> np.ndarray: # Calculate the standard normal variate for i, x in enumerate(X_): - X_[i] = self._shift_spectrum(x) + X_[i] = self._shift_vector(x) return X_.reshape(-1, 1) if X_.ndim == 1 else X_ def _shift_spectrum(self, x) -> np.ndarray: shift_amount = self._rng.integers(-self.shift, self.shift, endpoint=True) return np.roll(x, shift_amount) + + def _shift_vector( + self, + x: np.ndarray, + ) -> np.ndarray: + """ + Shift vector with option to fill missing values. + + Args: + arr: Input numpy array + shift: Number of positions to shift + fill_method: Method to fill missing values + 'constant': fill with first/last value + 'linear': fill using linear regression + 'quadratic': fill using quadratic regression + + Returns: + Shifted numpy array + """ + shift = self._rng.integers(-self.shift, self.shift, endpoint=True) + + result = np.roll(x, shift) + + if self.fill_method == "constant": + if shift > 0: + result[:shift] = x[0] + elif shift < 0: + result[shift:] = x[-1] + + elif self.fill_method == "linear": + if shift > 0: + x_ = np.arange(5) + coeffs = poly.polyfit(x_, x[:5], 1) + + extrapolate_x = np.arange(-shift, 0) + extrapolated_values = poly.polyval(extrapolate_x, coeffs) + + result[:shift] = extrapolated_values + + elif shift < 0: + x_ = np.arange(5) + coeffs = poly.polyfit(x_, x[-5:], 1) + + extrapolate_x = np.arange(len(x_), len(x_) - shift) + extrapolated_values = poly.polyval(extrapolate_x, coeffs) + + result[shift:] = extrapolated_values + + elif self.fill_method == "quadratic": + if shift > 0: + # Use first 3 values for quadratic regression + x_ = np.arange(5) + coeffs = poly.polyfit(x_, x[:5], 2) + + # Extrapolate to fill shifted region + extrapolate_x = np.arange(-shift, 0) + extrapolated_values = poly.polyval(extrapolate_x, coeffs) + + result[:shift] = extrapolated_values + + elif shift < 0: + # Use last 3 values for quadratic regression + x_ = np.arange(5) + coeffs = poly.polyfit(x_, x[-5:], 2) + + # Extrapolate to fill shifted region + extrapolate_x = np.arange(len(x_), len(x_) - shift) + extrapolated_values = poly.polyval(extrapolate_x, coeffs) + + result[shift:] = extrapolated_values + + return result diff --git a/tests/test_functionality.py b/tests/test_functionality.py index af9610c..fde9798 100644 --- a/tests/test_functionality.py +++ b/tests/test_functionality.py @@ -285,18 +285,55 @@ def test_index_selector_with_wavenumbers_and_dataframe(): assert np.allclose(spectrum_corrected.values[0], expected, atol=1e-8) -def test_index_shift(): +def test_index_shift_constant_fill(): # Arrange - spectrum = np.array([[1, 1, 1, 1, 1, 2, 1, 1, 1, 1]]) - spectrum_shift = IndexShift(shift=1, random_state=42) + spectrum = np.array([[5, 4, 3, 2, 1, 2, 1, 2, 3, 4, 5]]) + spectrum_positive_shift = IndexShift(shift=1, fill_method="constant", random_state=44) + spectrum_negative_shift = IndexShift(shift=1, fill_method="constant", random_state=42) # Act - spectrum_corrected = spectrum_shift.fit_transform(spectrum) + spectrum_positive_shifted = spectrum_positive_shift.fit_transform(spectrum) + spectrum_negative_shifted = spectrum_negative_shift.fit_transform(spectrum) # Assert - assert spectrum_corrected[0][4] == 2 + assert spectrum_positive_shifted[0][6] == 2 + assert spectrum_negative_shifted[0][4] == 2 + assert spectrum_positive_shifted[0][0] == 5 + assert spectrum_negative_shifted[0][-1] == 5 +def test_index_shift_linear_fill(): + # Arrange + spectrum = np.array([[5, 4, 3, 2, 1, 2, 1, 2, 3, 4, 5]]) + spectrum_positive_shift = IndexShift(shift=1, fill_method="linear", random_state=44) + spectrum_negative_shift = IndexShift(shift=1, fill_method="linear", random_state=42) + + # Act + spectrum_positive_shifted = spectrum_positive_shift.fit_transform(spectrum) + spectrum_negative_shifted = spectrum_negative_shift.fit_transform(spectrum) + + # Assert + assert spectrum_positive_shifted[0][6] == 2 + assert spectrum_negative_shifted[0][4] == 2 + assert np.isclose(spectrum_positive_shifted[0][0], 6.0, atol=1e-6) + assert np.isclose(spectrum_negative_shifted[0][-1], 6.0, atol=1e-6) + +def test_index_shift_quadratic_fill(): + # Arrange + spectrum = np.array([[5, 4, 3, 2, 1, 2, 1, 4, 9, 16, 25]]) + spectrum_positive_shift = IndexShift(shift=1, fill_method="quadratic", random_state=44) + spectrum_negative_shift = IndexShift(shift=1, fill_method="quadratic", random_state=42) + + # Act + spectrum_positive_shifted = spectrum_positive_shift.fit_transform(spectrum) + spectrum_negative_shifted = spectrum_negative_shift.fit_transform(spectrum) + + # Assert + assert spectrum_positive_shifted[0][6] == 2 + assert spectrum_negative_shifted[0][4] == 2 + assert np.isclose(spectrum_positive_shifted[0][0], 6.0, atol=1e-6) + assert np.isclose(spectrum_negative_shifted[0][-1], 36.0, atol=1e-6) + def test_l1_norm(spectrum): # Arrange norm = 1