Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: fix type hints of train_test_split #161

Closed
wants to merge 3 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions functime/cross_validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import Mapping, Optional, Tuple
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, overload

import numpy as np
import polars as pl

if TYPE_CHECKING:
from typing import Callable, Literal, Mapping, Optional, Tuple


def train_test_split(
test_size: int, eager: bool = False
) -> Tuple[pl.LazyFrame, pl.LazyFrame]:
test_size: int, *, eager: bool = False
) -> Callable[
[pl.LazyFrame | pl.DataFrame],
Tuple[pl.LazyFrame, pl.LazyFrame] | Tuple[pl.DataFrame, pl.DataFrame],
]:
"""Return a time-ordered train set and test set given `test_size`.

Parameters
Expand All @@ -18,12 +27,37 @@ def train_test_split(

Returns
-------
splitter : Callable[pl.LazyFrame, Tuple[pl.LazyFrame, pl.LazyFrame]]
Function that takes a panel LazyFrame and returns tuple of train / test LazyFrames.
splitter : Callable[pl.LazyFrame | pl.DataFrame, Tuple[pl.LazyFrame, pl.LazyFrame]]
baggiponte marked this conversation as resolved.
Show resolved Hide resolved
Function that takes a panel DataFrame or LazyFrame and returns tuple of train/test DataFrame or LazyFrame.
"""

def split(X: pl.LazyFrame) -> pl.LazyFrame:
X = X.lazy() # Defensive
@overload
def splitter(
X: pl.LazyFrame | pl.DataFrame,
test_size: int,
*,
eager: Literal[False],
) -> Tuple[pl.LazyFrame, pl.LazyFrame]:
...

@overload
def splitter(
X: pl.LazyFrame | pl.DataFrame,
test_size: int,
*,
eager: Literal[True],
) -> Tuple[pl.DataFrame, pl.DataFrame]:
...

def splitter(
X: pl.LazyFrame | pl.DataFrame,
test_size: int,
*,
eager: bool,
) -> Tuple[pl.LazyFrame, pl.LazyFrame] | Tuple[pl.DataFrame, pl.DataFrame]:
if isinstance(X, pl.DataFrame):
X = X.lazy() # Defensive

entity_col = X.columns[0]
train_split = (
X.group_by(entity_col)
Expand All @@ -37,9 +71,10 @@ def split(X: pl.LazyFrame) -> pl.LazyFrame:
)
if eager:
train_split, test_split = pl.collect_all([train_split, test_split])
return train_split, test_split
return train_split, test_split

return split
return partial(splitter, **{"test_size": test_size, "eager": eager})


def _window_split(
Expand Down