Skip to content

Commit

Permalink
fix issue #1597
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Feb 14, 2024
1 parent 524f70c commit d075a44
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 92 deletions.
1 change: 1 addition & 0 deletions avalanche/benchmarks/scenarios/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .dataset_scenario import *
from .exmodel_scenario import *
from .online import *
from .validation_scenario import *
86 changes: 0 additions & 86 deletions avalanche/benchmarks/scenarios/dataset_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import random
from avalanche.benchmarks.utils.data import AvalancheDataset
import torch
from itertools import tee
from typing import (
Callable,
Generator,
Expand Down Expand Up @@ -253,94 +252,9 @@ def __iter__(
yield self.split_strategy(new_experience.dataset)


def benchmark_with_validation_stream(
benchmark: CLScenario,
validation_size: Union[int, float] = 0.5,
shuffle: bool = False,
seed: Optional[int] = None,
split_strategy: Optional[
Callable[[AvalancheDataset], Tuple[AvalancheDataset, AvalancheDataset]]
] = None,
) -> CLScenario:
"""Helper to obtain a benchmark with a validation stream.
This generator accepts an existing benchmark instance and returns a version
of it in which the train stream has been split into training and validation
streams.
Each train/validation experience will be by splitting the original training
experiences. Patterns selected for the validation experience will be removed
from the training experiences.
The default splitting strategy is a random split as implemented by `split_validation_random`.
If you want to use class balancing you can use `split_validation_class_balanced`, or
use a custom `split_strategy`, as shown in the following example::
validation_size = 0.2
foo = lambda exp: split_dataset_class_balanced(validation_size, exp)
bm = benchmark_with_validation_stream(bm, custom_split_strategy=foo)
:param benchmark: The benchmark to split.
:param validation_size: The size of the validation experience, as an int
or a float between 0 and 1. Ignored if `custom_split_strategy` is used.
:param shuffle: If True, patterns will be allocated to the validation
stream randomly. This will use the default PyTorch random number
generator at its current state. Defaults to False. Ignored if
`custom_split_strategy` is used. If False, the first instances will be
allocated to the training dataset by leaving the last ones to the
validation dataset.
:param split_strategy: A function that implements a custom splitting
strategy. The function must accept an AvalancheDataset and return a tuple
containing the new train and validation dataset. By default, the splitting
strategy will split the data according to `validation_size` and `shuffle`).
A good starting to understand the mechanism is to look at the
implementation of the standard splitting function
:func:`random_validation_split_strategy`.
:return: A benchmark instance in which the validation stream has been added.
"""

if split_strategy is None:
if seed is None:
seed = random.randint(0, 1000000)

# functools.partial is a more compact option
# However, MyPy does not understand what a partial is -_-
def random_validation_split_strategy_wrapper(data):
return split_validation_random(validation_size, shuffle, seed, data)

split_strategy = random_validation_split_strategy_wrapper
else:
split_strategy = split_strategy

stream = benchmark.streams["train"]
if isinstance(stream, EagerCLStream): # eager split
train_exps, valid_exps = [], []

exp: DatasetExperience
for exp in stream:
train_data, valid_data = split_strategy(exp.dataset)
train_exps.append(DatasetExperience(dataset=train_data))
valid_exps.append(DatasetExperience(dataset=valid_data))
else: # Lazy splitting (based on a generator)
split_generator = LazyTrainValSplitter(split_strategy, stream)
train_exps = (DatasetExperience(dataset=a) for a, _ in split_generator)
valid_exps = (DatasetExperience(dataset=b) for _, b in split_generator)

train_stream = make_stream(name="train", exps=train_exps)
valid_stream = make_stream(name="valid", exps=valid_exps)
other_streams = benchmark.streams

del other_streams["train"]
return CLScenario(
streams=[train_stream, valid_stream] + list(other_streams.values())
)


__all__ = [
"_split_dataset_by_attribute",
"benchmark_from_datasets",
"DatasetExperience",
"split_validation_random",
"benchmark_with_validation_stream",
]
12 changes: 7 additions & 5 deletions avalanche/benchmarks/scenarios/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_attribute import DataAttribute
from .dataset_scenario import _split_dataset_by_attribute, DatasetExperience
from .. import CLScenario, CLStream, EagerCLStream
from .generic_scenario import CLScenario, CLStream, EagerCLStream


def class_incremental_benchmark(
Expand Down Expand Up @@ -399,12 +399,14 @@ def _decorate_stream(obj: CLStream):
new_exp = copy(exp)
curr_cls = exp.dataset.targets.uniques

new_exp.classes_in_this_experience = curr_cls
new_exp.previous_classes = set(prev_cls)
new_exp.classes_seen_so_far = curr_cls.union(prev_cls)
new_exp.classes_in_this_experience = list(curr_cls)
new_exp.previous_classes = list(set(prev_cls))
new_exp.classes_seen_so_far = list(curr_cls.union(prev_cls))
# TODO: future_classes ignores repetitions right now...
# implement and test scenario with repetitions
new_exp.future_classes = all_cls.difference(new_exp.classes_seen_so_far)
new_exp.future_classes = list(
all_cls.difference(new_exp.classes_seen_so_far)
)
new_stream.append(new_exp)

prev_cls = prev_cls.union(curr_cls)
Expand Down
116 changes: 116 additions & 0 deletions avalanche/benchmarks/scenarios/validation_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import (
Callable,
Generator,
Generic,
List,
Sequence,
TypeVar,
Union,
Tuple,
Optional,
Iterable,
Dict,
)

import random
from avalanche.benchmarks.utils.data import AvalancheDataset
from .generic_scenario import EagerCLStream, CLScenario, CLExperience, make_stream
from .dataset_scenario import (
LazyTrainValSplitter,
DatasetExperience,
split_validation_random,
)
from .supervised import with_classes_timeline


def benchmark_with_validation_stream(
benchmark: CLScenario,
validation_size: Union[int, float] = 0.5,
shuffle: bool = False,
seed: Optional[int] = None,
split_strategy: Optional[
Callable[[AvalancheDataset], Tuple[AvalancheDataset, AvalancheDataset]]
] = None,
) -> CLScenario:
"""Helper to obtain a benchmark with a validation stream.
This generator accepts an existing benchmark instance and returns a version
of it in which the train stream has been split into training and validation
streams.
Each train/validation experience will be by splitting the original training
experiences. Patterns selected for the validation experience will be removed
from the training experiences.
The default splitting strategy is a random split as implemented by `split_validation_random`.
If you want to use class balancing you can use `split_validation_class_balanced`, or
use a custom `split_strategy`, as shown in the following example::
validation_size = 0.2
foo = lambda exp: split_dataset_class_balanced(validation_size, exp)
bm = benchmark_with_validation_stream(bm, custom_split_strategy=foo)
:param benchmark: The benchmark to split.
:param validation_size: The size of the validation experience, as an int
or a float between 0 and 1. Ignored if `custom_split_strategy` is used.
:param shuffle: If True, patterns will be allocated to the validation
stream randomly. This will use the default PyTorch random number
generator at its current state. Defaults to False. Ignored if
`custom_split_strategy` is used. If False, the first instances will be
allocated to the training dataset by leaving the last ones to the
validation dataset.
:param split_strategy: A function that implements a custom splitting
strategy. The function must accept an AvalancheDataset and return a tuple
containing the new train and validation dataset. By default, the splitting
strategy will split the data according to `validation_size` and `shuffle`).
A good starting to understand the mechanism is to look at the
implementation of the standard splitting function
:func:`random_validation_split_strategy`.
:return: A benchmark instance in which the validation stream has been added.
"""

if split_strategy is None:
if seed is None:
seed = random.randint(0, 1000000)

# functools.partial is a more compact option
# However, MyPy does not understand what a partial is -_-
def random_validation_split_strategy_wrapper(data):
return split_validation_random(validation_size, shuffle, seed, data)

split_strategy = random_validation_split_strategy_wrapper
else:
split_strategy = split_strategy

stream = benchmark.streams["train"]
if isinstance(stream, EagerCLStream): # eager split
train_exps, valid_exps = [], []

exp: DatasetExperience
for exp in stream:
train_data, valid_data = split_strategy(exp.dataset)
train_exps.append(DatasetExperience(dataset=train_data))
valid_exps.append(DatasetExperience(dataset=valid_data))
else: # Lazy splitting (based on a generator)
split_generator = LazyTrainValSplitter(split_strategy, stream)
train_exps = (DatasetExperience(dataset=a) for a, _ in split_generator)
valid_exps = (DatasetExperience(dataset=b) for _, b in split_generator)

train_stream = make_stream(name="train", exps=train_exps)
valid_stream = make_stream(name="valid", exps=valid_exps)
other_streams = benchmark.streams

# don't drop classes-timeline for compatibility with old API
e0 = next(iter(train_stream))
if hasattr(e0, "dataset") and hasattr(e0.dataset, "targets"):
train_stream = with_classes_timeline(train_stream)
valid_stream = with_classes_timeline(valid_stream)

del other_streams["train"]
return CLScenario(
streams=[train_stream, valid_stream] + list(other_streams.values())
)


__all__ = ["benchmark_with_validation_stream"]
10 changes: 9 additions & 1 deletion tests/benchmarks/scenarios/test_dataset_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

from avalanche.benchmarks import (
benchmark_from_datasets,
benchmark_with_validation_stream,
CLScenario,
CLStream,
split_validation_random,
task_incremental_benchmark,
)
from avalanche.benchmarks.scenarios.validation_scenario import (
benchmark_with_validation_stream,
)
from avalanche.benchmarks.scenarios.dataset_scenario import (
DatasetExperience,
split_validation_class_balanced,
Expand Down Expand Up @@ -383,3 +385,9 @@ def test_gen():
mb = get_mbatch(dd, len(dd))
self.assertTrue(torch.equal(test_x, mb[0]))
self.assertTrue(torch.equal(test_y, mb[1]))

def test_regressioni1597(args):
# regression test for issue #1597
bm = get_fast_benchmark()
for exp in bm.train_stream:
assert hasattr(exp, "classes_in_this_experience")
2 changes: 2 additions & 0 deletions tests/training/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from avalanche.benchmarks import (
nc_benchmark,
GenericCLScenario,
)
from avalanche.benchmarks.scenarios.validation_scenario import (
benchmark_with_validation_stream,
)
from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader
Expand Down

0 comments on commit d075a44

Please sign in to comment.