Skip to content

Commit

Permalink
feat: Add float support in train_test_split (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Feb 22, 2024
1 parent 9c2f7fb commit 7ae519f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 11 deletions.
42 changes: 35 additions & 7 deletions functime/cross_validation.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,66 @@
from typing import Mapping, Optional, Tuple
from typing import Mapping, Optional, Tuple, Union

import numpy as np
import polars as pl


def train_test_split(
test_size: int, eager: bool = False
test_size: Union[int, float] = 0.25, eager: bool = False
) -> Tuple[pl.LazyFrame, pl.LazyFrame]:
"""Return a time-ordered train set and test set given `test_size`.
Parameters
----------
test_size : int
Number of test samples.
eager : bool
test_size : int | float, default=0.25
Number or fraction of test samples.
eager : bool, default=False
If True, evaluate immediately and returns tuple of train-test `DataFrame`.
Returns
-------
splitter : Callable[pl.LazyFrame, Tuple[pl.LazyFrame, pl.LazyFrame]]
Function that takes a panel LazyFrame and returns tuple of train / test LazyFrames.
"""
if isinstance(test_size, float):
if test_size < 0 or test_size > 1:
raise ValueError("`test_size` must be between 0 and 1")
elif isinstance(test_size, int):
if test_size < 0:
raise ValueError("`test_size` must be greater than 0")
else:
raise TypeError("`test_size` must be int or float")

def split(X: pl.LazyFrame) -> pl.LazyFrame:
X = X.lazy() # Defensive
entity_col = X.columns[0]

max_size = (
X.group_by(entity_col)
.agg(pl.count())
.select(pl.min("count"))
.collect()
.item()
)
if isinstance(test_size, int) and test_size > max_size:
raise ValueError(
"`test_size` must be less than the number of samples of the smallest entity"
)

train_length = (
pl.count() - test_size
if isinstance(test_size, int)
else (pl.count() * (1 - test_size)).cast(int)
)
test_length = pl.count() - train_length

train_split = (
X.group_by(entity_col)
.agg(pl.all().slice(0, pl.count() - test_size))
.agg(pl.all().slice(offset=0, length=train_length))
.explode(pl.all().exclude(entity_col))
)
test_split = (
X.group_by(entity_col)
.agg(pl.all().slice(-1 * test_size, test_size))
.agg(pl.all().slice(offset=train_length, length=test_length))
.explode(pl.all().exclude(entity_col))
)
if eager:
Expand Down
55 changes: 51 additions & 4 deletions tests/test_cross_validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import nullcontext as does_not_raise

import polars as pl
import pytest

Expand All @@ -23,7 +25,7 @@ def step_size(request):
return request.param


def test_train_test_split(test_size, pl_y, benchmark):
def test_train_test_split_int_size(test_size, pl_y, benchmark):
def _split(y):
y_train, y_test = train_test_split(test_size)(y)
return pl.collect_all([y_train, y_test])
Expand All @@ -44,12 +46,57 @@ def _split(y):
pl.col(time_col).count()
)
assert (
(ts_lengths.select("time") - train_lengths.select("time")) == test_size
).select(pl.all().all())[0, 0]
((ts_lengths.select("time") - train_lengths.select("time")) == test_size)
.to_series()
.all()
)

# Check test window lengths
test_lengths = y_test.group_by(entity_col).agg(pl.col(time_col).count())
assert (test_lengths.select("time") == test_size).select(pl.all().all())[0, 0]
assert (test_lengths.select("time") == test_size).to_series().all()


@pytest.mark.parametrize(
"float_test_size,context",
[
(0.1, does_not_raise()),
(0.5, does_not_raise()),
(1.1, pytest.raises(ValueError)),
(-0.1, pytest.raises(ValueError)),
],
)
def test_train_test_split_float_size(pl_y, float_test_size, context):
with context as exc_info:
y_train, y_test = train_test_split(float_test_size)(pl_y)

if exc_info:
assert "`test_size` must be between 0 and 1" in str(exc_info.value)

else:
entity_col, time_col = pl_y.columns[:2]
assert y_train.columns == y_test.columns

# Check train window lengths
ts_lengths = (
pl_y.group_by(entity_col, maintain_order=True)
.agg(pl.col(time_col).count())
.collect()
)

test_lengths = (
y_test.group_by(entity_col, maintain_order=True)
.agg(pl.col(time_col).count())
.collect()
)

assert (
(
(test_lengths.select("time") / ts_lengths.select("time"))
== float_test_size
)
.to_series()
.all()
)


def test_expanding_window_split(test_size, n_splits, step_size, pl_y, benchmark):
Expand Down

0 comments on commit 7ae519f

Please sign in to comment.