Skip to content

Commit

Permalink
Merge branch 'main' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
Myles Bartlett committed Mar 16, 2022
2 parents 9435518 + af18439 commit be93dfa
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 11 deletions.
176 changes: 168 additions & 8 deletions ranzen/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,114 @@
from abc import abstractmethod
from enum import Enum, auto
import math
from typing import Iterator, Sequence, Sized
from typing import Generic, Iterator, Sequence, Sized, TypeVar, cast, overload

from attr import dataclass
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset, Sampler
from torch.utils.data.dataset import Subset, random_split
from typing_extensions import Final
from torch.utils.data import Sampler
from typing_extensions import Final, Literal, Protocol, runtime_checkable

from ranzen import implements
from ranzen.misc import str_to_enum

__all__ = [
"GreedyCoreSetSampler",
"SequentialBatchSampler",
"SizedDataset",
"StratifiedBatchSampler",
"Subset",
"TrainTestSplit",
"TrainingMode",
"prop_random_split",
"prop_stratified_split",
]

T_co = TypeVar("T_co", covariant=True)


@runtime_checkable
class SizedDataset(Protocol[T_co]):
def __getitem__(self, index: int) -> T_co:
...

def __len__(self) -> int:
...


class Subset(SizedDataset[T_co]):
r"""
Subset of a dataset at specified indices.
"""
dataset: SizedDataset[T_co]
indices: Sequence[int]

def __init__(self, dataset: SizedDataset[T_co], indices: Sequence[int]) -> None:
"""
:param dataset: The whole Dataset.
:param indices: Indices in the whole set selected for subset.
"""
self.dataset = dataset
self.indices = indices

@implements(SizedDataset)
def __getitem__(self, idx: int) -> T_co:
return self.dataset[self.indices[idx]]

@implements(SizedDataset)
def __len__(self) -> int:
return len(self.indices)


D = TypeVar("D", bound=SizedDataset)


@overload
def prop_random_split(
dataset: Dataset, *, props: Sequence[float] | float, seed: int | None = None
) -> list[Subset]:
"""Splits a dataset based on proportions rather than on absolute sizes."""
dataset: D,
*,
props: Sequence[float] | float,
as_indices: Literal[False] = ...,
seed: int | None = ...,
) -> list[Subset[D]]:
...


@overload
def prop_random_split(
dataset: SizedDataset,
*,
props: Sequence[float] | float,
as_indices: Literal[True] = ...,
seed: int | None = ...,
) -> list[int]:
...


def prop_random_split(
dataset: D,
*,
props: Sequence[float] | float,
as_indices: bool = False,
seed: int | None = None,
) -> list[Subset[D]] | list[int]:
"""Splits a dataset based on proportions rather than on absolute sizes
:param dataset: Dataset to split.
:param props: The fractional size of each subset into which to randomly split the data.
Elements must be non-negative and sum to 1 or less; if less then the size of the final
split will be computed by complement.
:param as_indices: If ``True`` the raw indices are returned instead of subsets constructed
from them.
:param seed: The PRNG used for determining the random splits.
:returns: Random subsets of the data of the requested proportions.
:raises ValueError: If the dataset does not have a ``__len__`` method or sum(props) > 1.
"""
if not hasattr(dataset, "__len__"):
raise ValueError(
"Split proportions can only be computed for datasets with __len__ defined."
Expand All @@ -41,7 +124,84 @@ def prop_random_split(
if sum_ < 1:
section_sizes.append(len_ - sum(section_sizes))
generator = torch.default_generator if seed is None else torch.Generator().manual_seed(seed)
return random_split(dataset, section_sizes, generator=generator)
indices = torch.randperm(sum(section_sizes), generator=generator).tolist()
splits = []
for offset, length in zip(np.cumsum(section_sizes), section_sizes):
split = indices[offset - length : offset]
if not as_indices:
split = Subset(dataset, indices=split)
splits.append(split)
return splits


S = TypeVar("S")


@dataclass(frozen=True)
class TrainTestSplit(Generic[S]):

train: S
test: S

def __iter__(self) -> Iterator[S]:
yield from (self.train, self.test)


def prop_stratified_split(
labels: Tensor,
*,
default_train_prop: float,
train_props: dict[int, float] | None = None,
seed: int | None = None,
) -> TrainTestSplit[list[int]]:
"""Splits the data into train/test sets conditional on super- and sub-class labels.
:param labels: Tensor encoding the label associated with each sample.
:param default_train_prop: Proportion of samples for a given to sample for
the training set for those y-s combinations not specified in ``train_props``.
:param train_props: Proportion of each group to sample for the training set.
If ``None`` then the function reduces to a simple random split of the data.
:param seed: PRNG seed to use for sampling.
:returns: Train-test split.
:raises ValueError: If a value in ``train_props`` is not in the range [0, 1] or if a key is not
present in ``group_ids``.
"""
# Initialise the random-number generator
generator = torch.default_generator if seed is None else torch.Generator().manual_seed(seed)
groups, label_counts = labels.unique(return_counts=True)
train_props_all = dict.fromkeys(groups.tolist(), default_train_prop)

if train_props is not None:
for label, train_prop in train_props.items():
if not 0 <= train_prop <= 1:
raise ValueError(
"All splitting proportions specified in 'train_props' must be in the "
"range [0, 1]."
)
if label not in groups:
raise ValueError(f"No samples belonging to the group in 'group_ids'.")
train_props_all[label] = train_prop

# Shuffle the samples before sampling
perm_inds = torch.randperm(len(labels), generator=generator)
labels_perm = labels[perm_inds]

sort_inds = labels_perm.sort(dim=0, stable=True).indices
thresholds = cast(
Tensor, (torch.as_tensor(tuple(train_props_all.values())) * label_counts).round().long()
)
thresholds = torch.stack([thresholds, label_counts], dim=-1)
thresholds[1:] += label_counts.cumsum(0)[:-1].unsqueeze(-1)

train_test_inds = sort_inds.tensor_split(thresholds.flatten()[:-1], dim=0)
train_inds = perm_inds[torch.cat(train_test_inds[0::2])].tolist()
test_inds = perm_inds[torch.cat(train_test_inds[1::2])].tolist()

return TrainTestSplit(train=train_inds, test=test_inds)


class TrainingMode(Enum):
Expand Down
39 changes: 36 additions & 3 deletions tests/torch_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,59 @@
from torch.utils.data import TensorDataset

from ranzen.torch import prop_random_split
from ranzen.torch.data import prop_stratified_split


@pytest.fixture(scope="module")
def dummy_ds() -> TensorDataset: # type: ignore[no-any-unimported]
return TensorDataset(torch.randn(100))


@pytest.mark.parametrize("as_indices", [False, True])
@pytest.mark.parametrize("props", [0.5, [-0.2, 0.5], [0.1, 0.3, 0.4], [0.5, 0.6]])
def test_prop_random_split(
dummy_ds: TensorDataset, props: float | list[float]
dummy_ds: TensorDataset, props: float | list[float], as_indices: bool
) -> None: # type: ignore[no-any-unimported]
sum_ = props if isinstance(props, float) else sum(props)
props_ls = [props] if isinstance(props, float) else props
if sum_ > 1 or any(not (0 <= prop <= 1) for prop in props_ls):
with pytest.raises(ValueError):
splits = prop_random_split(dataset=dummy_ds, props=props)
splits = prop_random_split(dataset=dummy_ds, props=props, as_indices=as_indices) # type: ignore
else:
splits = prop_random_split(dataset=dummy_ds, props=props)
splits = prop_random_split(dataset=dummy_ds, props=props, as_indices=as_indices) # type: ignore
sizes = [len(split) for split in splits]
sum_sizes = sum(sizes)
assert len(splits) == (len(props_ls) + 1)
assert sum_sizes == len(dummy_ds)
assert sizes[-1] == (len(dummy_ds) - (round(sum_ * len(dummy_ds))))


def test_prop_stratified_split():
labels = torch.randint(low=0, high=4, size=(50,))
train_inds, test_inds = prop_stratified_split(labels=labels, default_train_prop=0.5)
labels_tr = labels[train_inds]
labels_te = labels[test_inds]
counts_tr = labels_tr.unique(return_counts=True)[1]
counts_te = labels_te.unique(return_counts=True)[1]
assert torch.isclose(counts_tr, counts_te, atol=1).all()

train_props = {0: 0.25, 1: 0.45}
train_inds, test_inds = prop_stratified_split(
labels=labels,
default_train_prop=0.5,
train_props=train_props,
)
labels_tr = labels[train_inds]
labels_te = labels[test_inds]

for label, train_prop in train_props.items():
train_m = labels_tr == label
test_m = labels_te == label
all_m = labels == label

n_train = train_m.count_nonzero().item()
n_test = test_m.count_nonzero().item()
n_all = all_m.count_nonzero().item()

assert n_train == pytest.approx(train_prop * n_all, abs=1)
assert n_test == pytest.approx((1 - train_prop) * n_all, abs=1)

0 comments on commit be93dfa

Please sign in to comment.