Skip to content

Commit 39d0206

Browse files
authored
add stack method (#24)
1 parent 7c44322 commit 39d0206

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

coreforecast/lag_transforms.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020

2121
import abc
2222
import copy
23-
from typing import Callable, Optional
23+
from typing import Callable, Optional, Sequence
2424

2525
import numpy as np
2626

2727
from .grouped_array import GroupedArray
2828

2929

3030
class _BaseLagTransform(abc.ABC):
31+
stats_: np.ndarray
32+
3133
@abc.abstractmethod
3234
def transform(self, ga: GroupedArray) -> np.ndarray:
3335
"""Apply the transformation by group.
@@ -53,6 +55,20 @@ def update(self, ga: GroupedArray) -> np.ndarray:
5355
def take(self, _idxs: np.ndarray) -> "_BaseLagTransform":
5456
return self
5557

58+
@staticmethod
59+
def stack(transforms: Sequence["_BaseLagTransform"]) -> "_BaseLagTransform":
60+
first_tfm = transforms[0]
61+
if not hasattr(first_tfm, "stats_"):
62+
# transform doesn't save state, we can return any of them
63+
return first_tfm
64+
out = copy.deepcopy(first_tfm)
65+
if first_tfm.stats_.ndim == 1:
66+
concat_fn = np.hstack
67+
else:
68+
concat_fn = np.vstack
69+
out.stats_ = concat_fn([tfm.stats_ for tfm in transforms])
70+
return out
71+
5672

5773
class Lag(_BaseLagTransform):
5874
"""Simple lag operator
@@ -396,15 +412,15 @@ def __init__(self, lag: int, alpha: float):
396412

397413
def transform(self, ga: GroupedArray) -> np.ndarray:
398414
out = ga._exponentially_weighted_transform("Mean", self.lag, self.alpha)
399-
self.ewm_ = out[ga.indptr[1:] - 1]
415+
self.stats_ = out[ga.indptr[1:] - 1]
400416
return out
401417

402418
def update(self, ga: GroupedArray) -> np.ndarray:
403419
x = ga._index_from_end(self.lag - 1)
404-
self.ewm_ = self.alpha * x + (1 - self.alpha) * self.ewm_
405-
return self.ewm_
420+
self.stats_ = self.alpha * x + (1 - self.alpha) * self.stats_
421+
return self.stats_
406422

407423
def take(self, idxs: np.ndarray) -> "ExponentiallyWeightedMean":
408424
out = copy.deepcopy(self)
409-
out.ewm_ = out.ewm_[idxs].copy()
425+
out.stats_ = out.stats_[idxs].copy()
410426
return out

coreforecast/scalers.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import ctypes
3-
from typing import List, Optional
3+
from typing import List, Optional, Sequence
44

55
import numpy as np
66

@@ -205,6 +205,12 @@ def take(self, idxs: np.ndarray) -> "_BaseLocalScaler":
205205
out.stats_ = self.stats_[idxs].copy()
206206
return out
207207

208+
@staticmethod
209+
def stack(scalers: Sequence["_BaseLocalScaler"]) -> "_BaseLocalScaler":
210+
out = copy.deepcopy(scalers[0])
211+
out.stats_ = np.vstack([sc.stats_ for sc in scalers])
212+
return out
213+
208214

209215
class LocalMinMaxScaler(_BaseLocalScaler):
210216
"""Scale each group to the [0, 1] interval"""
@@ -354,6 +360,12 @@ def take(self, idxs: np.ndarray) -> "Difference":
354360
out.tails_ = self.tails_[idxs].copy()
355361
return out
356362

363+
@staticmethod
364+
def stack(scalers: Sequence["Difference"]) -> "Difference":
365+
out = Difference(scalers[0].d)
366+
out.tails_ = np.hstack([sc.tails_ for sc in scalers])
367+
return out
368+
357369

358370
class AutoDifferences:
359371
"""Find and apply the optimal number of differences to each group.
@@ -447,6 +459,16 @@ def take(self, idxs: np.ndarray) -> "AutoDifferences":
447459
out.tails_ = [tail[idxs].copy() for tail in self.tails_]
448460
return out
449461

462+
@staticmethod
463+
def stack(scalers: Sequence["AutoDifferences"]) -> "AutoDifferences":
464+
out = copy.deepcopy(scalers[0])
465+
out.diffs_ = np.hstack([sc.diffs_ for sc in scalers])
466+
out.tails_ = [
467+
np.hstack([sc.tails_[i] for sc in scalers])
468+
for i in range(len(scalers[0].tails_))
469+
]
470+
return out
471+
450472

451473
class AutoSeasonalDifferences(AutoDifferences):
452474
"""Find and apply the optimal number of seasonal differences to each group.
@@ -604,3 +626,14 @@ def take(self, idxs: np.ndarray) -> "AutoSeasonalityAndDifferences":
604626
out.diffs_ = [diffs[idxs].copy() for diffs in self.diffs_]
605627
out.tails_ = [tail[idxs].copy() for tail in self.tails_]
606628
return out
629+
630+
@staticmethod
631+
def stack(
632+
scalers: Sequence["AutoSeasonalityAndDifferences"],
633+
) -> "AutoSeasonalityAndDifferences":
634+
out = AutoSeasonalityAndDifferences(
635+
scalers[0].max_season_length, scalers[0].max_diffs, scalers[0].n_seasons
636+
)
637+
out.diffs_ = [np.hstack([sc.diffs_[i] for sc in scalers]) for i in range(out.max_diffs)]
638+
out.tails_ = [np.hstack([sc.tails_[i] for sc in scalers]) for i in range(out.max_diffs)]
639+
return out

tests/test_lag_transforms.py

+6
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ def test_correctness(data, comb, dtype):
116116
wres = transform(data, indptr, True, lag - 1, wf, *args)
117117
cres = cobj.update(ga)
118118
np.testing.assert_allclose(wres, cres, atol=atol, rtol=rtol)
119+
# stack
120+
combined = cobj.stack([cobj, cobj])
121+
if hasattr(cobj, "stats_"):
122+
assert combined.stats_.shape[0] == 2 * cobj.stats_.shape[0]
123+
else:
124+
assert combined is cobj
119125

120126

121127
@pytest.mark.parametrize("window_type", ["rolling", "seasonal_rolling", "expanding"])

0 commit comments

Comments
 (0)