diff --git a/ranzen/torch/data.py b/ranzen/torch/data.py index 438bf5f5..6df7b86e 100644 --- a/ranzen/torch/data.py +++ b/ranzen/torch/data.py @@ -2,14 +2,14 @@ from abc import abstractmethod from enum import Enum, auto import math -from typing import Iterator, Sequence, Sized +from typing import Generic, Iterator, Sequence, Sized, TypeVar, cast, overload +from attr import dataclass import numpy as np import torch from torch import Tensor -from torch.utils.data import Dataset, Sampler -from torch.utils.data.dataset import Subset, random_split -from typing_extensions import Final +from torch.utils.data import Sampler +from typing_extensions import Final, Literal, Protocol, runtime_checkable from ranzen import implements from ranzen.misc import str_to_enum @@ -17,16 +17,99 @@ __all__ = [ "GreedyCoreSetSampler", "SequentialBatchSampler", + "SizedDataset", "StratifiedBatchSampler", + "Subset", + "TrainTestSplit", "TrainingMode", "prop_random_split", + "prop_stratified_split", ] +T_co = TypeVar("T_co", covariant=True) + +@runtime_checkable +class SizedDataset(Protocol[T_co]): + def __getitem__(self, index: int) -> T_co: + ... + + def __len__(self) -> int: + ... + + +class Subset(SizedDataset[T_co]): + r""" + Subset of a dataset at specified indices. + """ + dataset: SizedDataset[T_co] + indices: Sequence[int] + + def __init__(self, dataset: SizedDataset[T_co], indices: Sequence[int]) -> None: + """ + :param dataset: The whole Dataset. + :param indices: Indices in the whole set selected for subset. + """ + self.dataset = dataset + self.indices = indices + + @implements(SizedDataset) + def __getitem__(self, idx: int) -> T_co: + return self.dataset[self.indices[idx]] + + @implements(SizedDataset) + def __len__(self) -> int: + return len(self.indices) + + +D = TypeVar("D", bound=SizedDataset) + + +@overload def prop_random_split( - dataset: Dataset, *, props: Sequence[float] | float, seed: int | None = None -) -> list[Subset]: - """Splits a dataset based on proportions rather than on absolute sizes.""" + dataset: D, + *, + props: Sequence[float] | float, + as_indices: Literal[False] = ..., + seed: int | None = ..., +) -> list[Subset[D]]: + ... + + +@overload +def prop_random_split( + dataset: SizedDataset, + *, + props: Sequence[float] | float, + as_indices: Literal[True] = ..., + seed: int | None = ..., +) -> list[int]: + ... + + +def prop_random_split( + dataset: D, + *, + props: Sequence[float] | float, + as_indices: bool = False, + seed: int | None = None, +) -> list[Subset[D]] | list[int]: + """Splits a dataset based on proportions rather than on absolute sizes + + :param dataset: Dataset to split. + :param props: The fractional size of each subset into which to randomly split the data. + Elements must be non-negative and sum to 1 or less; if less then the size of the final + split will be computed by complement. + + :param as_indices: If ``True`` the raw indices are returned instead of subsets constructed + from them. + + :param seed: The PRNG used for determining the random splits. + + :returns: Random subsets of the data of the requested proportions. + + :raises ValueError: If the dataset does not have a ``__len__`` method or sum(props) > 1. + """ if not hasattr(dataset, "__len__"): raise ValueError( "Split proportions can only be computed for datasets with __len__ defined." @@ -41,7 +124,84 @@ def prop_random_split( if sum_ < 1: section_sizes.append(len_ - sum(section_sizes)) generator = torch.default_generator if seed is None else torch.Generator().manual_seed(seed) - return random_split(dataset, section_sizes, generator=generator) + indices = torch.randperm(sum(section_sizes), generator=generator).tolist() + splits = [] + for offset, length in zip(np.cumsum(section_sizes), section_sizes): + split = indices[offset - length : offset] + if not as_indices: + split = Subset(dataset, indices=split) + splits.append(split) + return splits + + +S = TypeVar("S") + + +@dataclass(frozen=True) +class TrainTestSplit(Generic[S]): + + train: S + test: S + + def __iter__(self) -> Iterator[S]: + yield from (self.train, self.test) + + +def prop_stratified_split( + labels: Tensor, + *, + default_train_prop: float, + train_props: dict[int, float] | None = None, + seed: int | None = None, +) -> TrainTestSplit[list[int]]: + """Splits the data into train/test sets conditional on super- and sub-class labels. + + :param labels: Tensor encoding the label associated with each sample. + :param default_train_prop: Proportion of samples for a given to sample for + the training set for those y-s combinations not specified in ``train_props``. + + :param train_props: Proportion of each group to sample for the training set. + If ``None`` then the function reduces to a simple random split of the data. + + :param seed: PRNG seed to use for sampling. + + :returns: Train-test split. + + :raises ValueError: If a value in ``train_props`` is not in the range [0, 1] or if a key is not + present in ``group_ids``. + """ + # Initialise the random-number generator + generator = torch.default_generator if seed is None else torch.Generator().manual_seed(seed) + groups, label_counts = labels.unique(return_counts=True) + train_props_all = dict.fromkeys(groups.tolist(), default_train_prop) + + if train_props is not None: + for label, train_prop in train_props.items(): + if not 0 <= train_prop <= 1: + raise ValueError( + "All splitting proportions specified in 'train_props' must be in the " + "range [0, 1]." + ) + if label not in groups: + raise ValueError(f"No samples belonging to the group in 'group_ids'.") + train_props_all[label] = train_prop + + # Shuffle the samples before sampling + perm_inds = torch.randperm(len(labels), generator=generator) + labels_perm = labels[perm_inds] + + sort_inds = labels_perm.sort(dim=0, stable=True).indices + thresholds = cast( + Tensor, (torch.as_tensor(tuple(train_props_all.values())) * label_counts).round().long() + ) + thresholds = torch.stack([thresholds, label_counts], dim=-1) + thresholds[1:] += label_counts.cumsum(0)[:-1].unsqueeze(-1) + + train_test_inds = sort_inds.tensor_split(thresholds.flatten()[:-1], dim=0) + train_inds = perm_inds[torch.cat(train_test_inds[0::2])].tolist() + test_inds = perm_inds[torch.cat(train_test_inds[1::2])].tolist() + + return TrainTestSplit(train=train_inds, test=test_inds) class TrainingMode(Enum): diff --git a/tests/torch_data_test.py b/tests/torch_data_test.py index 276f57c6..95385ed9 100644 --- a/tests/torch_data_test.py +++ b/tests/torch_data_test.py @@ -5,6 +5,7 @@ from torch.utils.data import TensorDataset from ranzen.torch import prop_random_split +from ranzen.torch.data import prop_stratified_split @pytest.fixture(scope="module") @@ -12,19 +13,51 @@ def dummy_ds() -> TensorDataset: # type: ignore[no-any-unimported] return TensorDataset(torch.randn(100)) +@pytest.mark.parametrize("as_indices", [False, True]) @pytest.mark.parametrize("props", [0.5, [-0.2, 0.5], [0.1, 0.3, 0.4], [0.5, 0.6]]) def test_prop_random_split( - dummy_ds: TensorDataset, props: float | list[float] + dummy_ds: TensorDataset, props: float | list[float], as_indices: bool ) -> None: # type: ignore[no-any-unimported] sum_ = props if isinstance(props, float) else sum(props) props_ls = [props] if isinstance(props, float) else props if sum_ > 1 or any(not (0 <= prop <= 1) for prop in props_ls): with pytest.raises(ValueError): - splits = prop_random_split(dataset=dummy_ds, props=props) + splits = prop_random_split(dataset=dummy_ds, props=props, as_indices=as_indices) # type: ignore else: - splits = prop_random_split(dataset=dummy_ds, props=props) + splits = prop_random_split(dataset=dummy_ds, props=props, as_indices=as_indices) # type: ignore sizes = [len(split) for split in splits] sum_sizes = sum(sizes) assert len(splits) == (len(props_ls) + 1) assert sum_sizes == len(dummy_ds) assert sizes[-1] == (len(dummy_ds) - (round(sum_ * len(dummy_ds)))) + + +def test_prop_stratified_split(): + labels = torch.randint(low=0, high=4, size=(50,)) + train_inds, test_inds = prop_stratified_split(labels=labels, default_train_prop=0.5) + labels_tr = labels[train_inds] + labels_te = labels[test_inds] + counts_tr = labels_tr.unique(return_counts=True)[1] + counts_te = labels_te.unique(return_counts=True)[1] + assert torch.isclose(counts_tr, counts_te, atol=1).all() + + train_props = {0: 0.25, 1: 0.45} + train_inds, test_inds = prop_stratified_split( + labels=labels, + default_train_prop=0.5, + train_props=train_props, + ) + labels_tr = labels[train_inds] + labels_te = labels[test_inds] + + for label, train_prop in train_props.items(): + train_m = labels_tr == label + test_m = labels_te == label + all_m = labels == label + + n_train = train_m.count_nonzero().item() + n_test = test_m.count_nonzero().item() + n_all = all_m.count_nonzero().item() + + assert n_train == pytest.approx(train_prop * n_all, abs=1) + assert n_test == pytest.approx((1 - train_prop) * n_all, abs=1)