-
-
Notifications
You must be signed in to change notification settings - Fork 298
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1598 from AntonioCarta/master
fix issue #1597
- Loading branch information
Showing
7 changed files
with
130 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from typing import ( | ||
Callable, | ||
Generator, | ||
Generic, | ||
List, | ||
Sequence, | ||
TypeVar, | ||
Union, | ||
Tuple, | ||
Optional, | ||
Iterable, | ||
Dict, | ||
) | ||
|
||
import random | ||
from avalanche.benchmarks.utils.data import AvalancheDataset | ||
from .generic_scenario import EagerCLStream, CLScenario, CLExperience, make_stream | ||
from .dataset_scenario import ( | ||
LazyTrainValSplitter, | ||
DatasetExperience, | ||
split_validation_random, | ||
) | ||
from .supervised import with_classes_timeline | ||
|
||
|
||
def benchmark_with_validation_stream( | ||
benchmark: CLScenario, | ||
validation_size: Union[int, float] = 0.5, | ||
shuffle: bool = False, | ||
seed: Optional[int] = None, | ||
split_strategy: Optional[ | ||
Callable[[AvalancheDataset], Tuple[AvalancheDataset, AvalancheDataset]] | ||
] = None, | ||
) -> CLScenario: | ||
"""Helper to obtain a benchmark with a validation stream. | ||
This generator accepts an existing benchmark instance and returns a version | ||
of it in which the train stream has been split into training and validation | ||
streams. | ||
Each train/validation experience will be by splitting the original training | ||
experiences. Patterns selected for the validation experience will be removed | ||
from the training experiences. | ||
The default splitting strategy is a random split as implemented by `split_validation_random`. | ||
If you want to use class balancing you can use `split_validation_class_balanced`, or | ||
use a custom `split_strategy`, as shown in the following example:: | ||
validation_size = 0.2 | ||
foo = lambda exp: split_dataset_class_balanced(validation_size, exp) | ||
bm = benchmark_with_validation_stream(bm, custom_split_strategy=foo) | ||
:param benchmark: The benchmark to split. | ||
:param validation_size: The size of the validation experience, as an int | ||
or a float between 0 and 1. Ignored if `custom_split_strategy` is used. | ||
:param shuffle: If True, patterns will be allocated to the validation | ||
stream randomly. This will use the default PyTorch random number | ||
generator at its current state. Defaults to False. Ignored if | ||
`custom_split_strategy` is used. If False, the first instances will be | ||
allocated to the training dataset by leaving the last ones to the | ||
validation dataset. | ||
:param split_strategy: A function that implements a custom splitting | ||
strategy. The function must accept an AvalancheDataset and return a tuple | ||
containing the new train and validation dataset. By default, the splitting | ||
strategy will split the data according to `validation_size` and `shuffle`). | ||
A good starting to understand the mechanism is to look at the | ||
implementation of the standard splitting function | ||
:func:`random_validation_split_strategy`. | ||
:return: A benchmark instance in which the validation stream has been added. | ||
""" | ||
|
||
if split_strategy is None: | ||
if seed is None: | ||
seed = random.randint(0, 1000000) | ||
|
||
# functools.partial is a more compact option | ||
# However, MyPy does not understand what a partial is -_- | ||
def random_validation_split_strategy_wrapper(data): | ||
return split_validation_random(validation_size, shuffle, seed, data) | ||
|
||
split_strategy = random_validation_split_strategy_wrapper | ||
else: | ||
split_strategy = split_strategy | ||
|
||
stream = benchmark.streams["train"] | ||
if isinstance(stream, EagerCLStream): # eager split | ||
train_exps, valid_exps = [], [] | ||
|
||
exp: DatasetExperience | ||
for exp in stream: | ||
train_data, valid_data = split_strategy(exp.dataset) | ||
train_exps.append(DatasetExperience(dataset=train_data)) | ||
valid_exps.append(DatasetExperience(dataset=valid_data)) | ||
else: # Lazy splitting (based on a generator) | ||
split_generator = LazyTrainValSplitter(split_strategy, stream) | ||
train_exps = (DatasetExperience(dataset=a) for a, _ in split_generator) | ||
valid_exps = (DatasetExperience(dataset=b) for _, b in split_generator) | ||
|
||
train_stream = make_stream(name="train", exps=train_exps) | ||
valid_stream = make_stream(name="valid", exps=valid_exps) | ||
other_streams = benchmark.streams | ||
|
||
# don't drop classes-timeline for compatibility with old API | ||
e0 = next(iter(train_stream)) | ||
if hasattr(e0, "dataset") and hasattr(e0.dataset, "targets"): | ||
train_stream = with_classes_timeline(train_stream) | ||
valid_stream = with_classes_timeline(valid_stream) | ||
|
||
del other_streams["train"] | ||
return CLScenario( | ||
streams=[train_stream, valid_stream] + list(other_streams.values()) | ||
) | ||
|
||
|
||
__all__ = ["benchmark_with_validation_stream"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters