Skip to content

Commit

Permalink
improve end point handling in data augmentation index shift
Browse files Browse the repository at this point in the history
  • Loading branch information
paucablop committed Jan 26, 2025
1 parent 028bb27 commit f005cb3
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 8 deletions.
85 changes: 82 additions & 3 deletions chemotools/augmentation/index_shift.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from typing import Literal, Optional

import numpy as np
from numpy.polynomial import polynomial as poly
from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin
from sklearn.utils.validation import check_is_fitted, validate_data

Expand Down Expand Up @@ -35,8 +36,14 @@ class IndexShift(TransformerMixin, OneToOneFeatureMixin, BaseEstimator):
Transform the input data by shifting the spectrum.
"""

def __init__(self, shift: int = 0, random_state: Optional[int] = None):
def __init__(
self,
shift: int = 0,
fill_method: Literal["constant", "linear", "quadratic"] = "constant",
random_state: Optional[int] = None,
):
self.shift = shift
self.fill_method = fill_method
self.random_state = random_state

def fit(self, X: np.ndarray, y=None) -> "IndexShift":
Expand Down Expand Up @@ -111,10 +118,82 @@ def transform(self, X: np.ndarray, y=None) -> np.ndarray:

# Calculate the standard normal variate
for i, x in enumerate(X_):
X_[i] = self._shift_spectrum(x)
X_[i] = self._shift_vector(x)

return X_.reshape(-1, 1) if X_.ndim == 1 else X_

def _shift_spectrum(self, x) -> np.ndarray:
shift_amount = self._rng.integers(-self.shift, self.shift, endpoint=True)
return np.roll(x, shift_amount)

def _shift_vector(
self,
x: np.ndarray,
) -> np.ndarray:
"""
Shift vector with option to fill missing values.
Args:
arr: Input numpy array
shift: Number of positions to shift
fill_method: Method to fill missing values
'constant': fill with first/last value
'linear': fill using linear regression
'quadratic': fill using quadratic regression
Returns:
Shifted numpy array
"""
shift = self._rng.integers(-self.shift, self.shift, endpoint=True)

result = np.roll(x, shift)

if self.fill_method == "constant":
if shift > 0:
result[:shift] = x[0]
elif shift < 0:
result[shift:] = x[-1]

elif self.fill_method == "linear":
if shift > 0:
x_ = np.arange(5)
coeffs = poly.polyfit(x_, x[:5], 1)

extrapolate_x = np.arange(-shift, 0)
extrapolated_values = poly.polyval(extrapolate_x, coeffs)

result[:shift] = extrapolated_values

elif shift < 0:
x_ = np.arange(5)
coeffs = poly.polyfit(x_, x[-5:], 1)

extrapolate_x = np.arange(len(x_), len(x_) - shift)
extrapolated_values = poly.polyval(extrapolate_x, coeffs)

result[shift:] = extrapolated_values

elif self.fill_method == "quadratic":
if shift > 0:
# Use first 3 values for quadratic regression
x_ = np.arange(5)
coeffs = poly.polyfit(x_, x[:5], 2)

# Extrapolate to fill shifted region
extrapolate_x = np.arange(-shift, 0)
extrapolated_values = poly.polyval(extrapolate_x, coeffs)

result[:shift] = extrapolated_values

elif shift < 0:
# Use last 3 values for quadratic regression
x_ = np.arange(5)
coeffs = poly.polyfit(x_, x[-5:], 2)

# Extrapolate to fill shifted region
extrapolate_x = np.arange(len(x_), len(x_) - shift)
extrapolated_values = poly.polyval(extrapolate_x, coeffs)

result[shift:] = extrapolated_values

return result
47 changes: 42 additions & 5 deletions tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,18 +285,55 @@ def test_index_selector_with_wavenumbers_and_dataframe():
assert np.allclose(spectrum_corrected.values[0], expected, atol=1e-8)


def test_index_shift():
def test_index_shift_constant_fill():
# Arrange
spectrum = np.array([[1, 1, 1, 1, 1, 2, 1, 1, 1, 1]])
spectrum_shift = IndexShift(shift=1, random_state=42)
spectrum = np.array([[5, 4, 3, 2, 1, 2, 1, 2, 3, 4, 5]])
spectrum_positive_shift = IndexShift(shift=1, fill_method="constant", random_state=44)
spectrum_negative_shift = IndexShift(shift=1, fill_method="constant", random_state=42)

# Act
spectrum_corrected = spectrum_shift.fit_transform(spectrum)
spectrum_positive_shifted = spectrum_positive_shift.fit_transform(spectrum)
spectrum_negative_shifted = spectrum_negative_shift.fit_transform(spectrum)

# Assert
assert spectrum_corrected[0][4] == 2
assert spectrum_positive_shifted[0][6] == 2
assert spectrum_negative_shifted[0][4] == 2
assert spectrum_positive_shifted[0][0] == 5
assert spectrum_negative_shifted[0][-1] == 5


def test_index_shift_linear_fill():
# Arrange
spectrum = np.array([[5, 4, 3, 2, 1, 2, 1, 2, 3, 4, 5]])
spectrum_positive_shift = IndexShift(shift=1, fill_method="linear", random_state=44)
spectrum_negative_shift = IndexShift(shift=1, fill_method="linear", random_state=42)

# Act
spectrum_positive_shifted = spectrum_positive_shift.fit_transform(spectrum)
spectrum_negative_shifted = spectrum_negative_shift.fit_transform(spectrum)

# Assert
assert spectrum_positive_shifted[0][6] == 2
assert spectrum_negative_shifted[0][4] == 2
assert np.isclose(spectrum_positive_shifted[0][0], 6.0, atol=1e-6)
assert np.isclose(spectrum_negative_shifted[0][-1], 6.0, atol=1e-6)

def test_index_shift_quadratic_fill():
# Arrange
spectrum = np.array([[5, 4, 3, 2, 1, 2, 1, 4, 9, 16, 25]])
spectrum_positive_shift = IndexShift(shift=1, fill_method="quadratic", random_state=44)
spectrum_negative_shift = IndexShift(shift=1, fill_method="quadratic", random_state=42)

# Act
spectrum_positive_shifted = spectrum_positive_shift.fit_transform(spectrum)
spectrum_negative_shifted = spectrum_negative_shift.fit_transform(spectrum)

# Assert
assert spectrum_positive_shifted[0][6] == 2
assert spectrum_negative_shifted[0][4] == 2
assert np.isclose(spectrum_positive_shifted[0][0], 6.0, atol=1e-6)
assert np.isclose(spectrum_negative_shifted[0][-1], 36.0, atol=1e-6)

def test_l1_norm(spectrum):
# Arrange
norm = 1
Expand Down

0 comments on commit f005cb3

Please sign in to comment.