Skip to content

Commit acdfd8b

Browse files
author
FelixAbrahamsson
committed
refactor!: simplify dataset.split and split_dataframes
- split_dataframes no longer accepts stratify_column argument - dataset.split handles group_split automatically - add _stratified_split and _unstratified_split methods to dataset - strata splits are stored in separate files - add some tests related to splitting and move tests from split_dataframe to dataset BREAKING CHANGE: dataset.split now takes a save_directory argument and dataset.group_split no longer exists (use dataset.split)
1 parent 2385c2c commit acdfd8b

File tree

3 files changed

+185
-147
lines changed

3 files changed

+185
-147
lines changed

datastream/dataset.py

Lines changed: 173 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
)
66
from pathlib import Path
77
from functools import lru_cache
8+
import warnings
89
import textwrap
910
import inspect
1011
import numpy as np
@@ -218,15 +219,15 @@ def split(
218219
key_column: str,
219220
proportions: Dict[str, float],
220221
stratify_column: Optional[str] = None,
221-
filepath: Optional[Union[str, Path]] = None,
222+
save_directory: Optional[Union[str, Path]] = None,
222223
frozen: Optional[bool] = False,
223224
seed: Optional[int] = None,
224225
) -> Dict[str, Dataset[T]]:
225226
'''
226227
Split dataset into multiple parts. Optionally you can chose to stratify
227228
on a column in the source dataframe or save the split to a json file.
228229
If you are sure that the split strategy will not change then you can
229-
safely use a seed instead of a filepath.
230+
safely use a seed instead of a save_directory.
230231
231232
Saved splits can continue from the old split and handles:
232233
@@ -252,14 +253,40 @@ def split(
252253
>>> split_datasets['test'][0]
253254
3
254255
'''
255-
if filepath is not None:
256-
filepath = Path(filepath)
257-
258-
if seed is None:
259-
split_dataframes = tools.split_dataframes
256+
if save_directory is not None:
257+
save_directory = Path(save_directory)
258+
save_directory.mkdir(parents=True, exist_ok=True)
259+
260+
if stratify_column is not None:
261+
return self._stratified_split(
262+
key_column=key_column,
263+
proportions=proportions,
264+
stratify_column=stratify_column,
265+
save_directory=save_directory,
266+
seed=seed,
267+
frozen=frozen,
268+
)
260269
else:
261-
split_dataframes = tools.numpy_seed(seed)(tools.split_dataframes)
270+
return self._unstratified_split(
271+
key_column=key_column,
272+
proportions=proportions,
273+
filepath=(
274+
save_directory / 'split.json'
275+
if save_directory is not None else None
276+
),
277+
seed=seed,
278+
frozen=frozen,
279+
)
262280

281+
def _unstratified_split(
282+
self,
283+
key_column: str,
284+
proportions: Dict[str, float],
285+
filepath: Optional[Path] = None,
286+
seed: Optional[int] = None,
287+
frozen: Optional[bool] = False,
288+
):
289+
split_dataframes = tools.numpy_seed(seed)(tools.split_dataframes)
263290
return {
264291
split_name: Dataset(
265292
dataframe=dataframe,
@@ -270,63 +297,53 @@ def split(
270297
self.dataframe,
271298
key_column,
272299
proportions,
273-
stratify_column,
274-
filepath,
275-
frozen,
300+
filepath=filepath,
301+
frozen=frozen,
276302
).items()
277303
}
278304

279-
def group_split(
305+
def _stratified_split(
280306
self,
281-
split_column: str,
307+
key_column: str,
282308
proportions: Dict[str, float],
283-
filepath: Optional[Union[str, Path]] = None,
284-
frozen: Optional[bool] = False,
309+
stratify_column: Optional[str] = None,
310+
save_directory: Optional[Path] = None,
285311
seed: Optional[int] = None,
286-
) -> Dict[str, Dataset[T]]:
287-
'''
288-
Similar to :func:`Dataset.split`, but uses a non-unique split column
289-
instead of a unique key column. This is useful for example when you
290-
have a dataset with examples that come from separate sources and you
291-
don't want to have examples from the same source in different splits.
292-
Does not support stratification.
293-
294-
>>> split_file = Path('doctest_split_dataset.json')
295-
>>> split_datasets = (
296-
... Dataset.from_dataframe(pd.DataFrame(dict(
297-
... source=np.arange(100) // 4,
298-
... number=np.random.randn(100),
299-
... )))
300-
... .group_split(
301-
... split_column='source',
302-
... proportions=dict(train=0.8, test=0.2),
303-
... filepath=split_file,
304-
... )
305-
... )
306-
>>> len(split_datasets['train'])
307-
80
308-
>>> split_file.unlink() # clean up after doctest
309-
'''
310-
if filepath is not None:
311-
filepath = Path(filepath)
312-
313-
split_dataframes = tools.group_split_dataframes
314-
if seed is not None:
315-
split_dataframes = tools.numpy_seed(seed)(split_dataframes)
316-
312+
frozen: Optional[bool] = False,
313+
):
314+
if (
315+
stratify_column is not None
316+
and any(self.dataframe[key_column].duplicated())
317+
):
318+
# mathematically impossible in the general case
319+
warnings.warn(
320+
'Trying to do stratified split with non-unique key column'
321+
' - cannot guarantee correct splitting of key values.'
322+
)
323+
strata = {
324+
stratum_value: self.subset(
325+
lambda df: df[stratify_column] == stratum_value
326+
)
327+
for stratum_value in self.dataframe[stratify_column].unique()
328+
}
329+
split_strata = [
330+
stratum._unstratified_split(
331+
key_column=key_column,
332+
proportions=proportions,
333+
filepath=(
334+
save_directory / f'{hash(stratum_value)}.json'
335+
if save_directory is not None else None
336+
),
337+
seed=seed,
338+
frozen=frozen,
339+
)
340+
for stratum_value, stratum in strata.items()
341+
]
317342
return {
318-
split_name: Dataset(
319-
dataframe=dataframe,
320-
length=len(dataframe),
321-
get_item=self.get_item,
343+
split_name: Dataset.concat(
344+
[split_stratum[split_name] for split_stratum in split_strata]
322345
)
323-
for split_name, dataframe in split_dataframes(
324-
self.dataframe,
325-
split_column,
326-
proportions,
327-
filepath,
328-
frozen,
329-
).items()
346+
for split_name in proportions.keys()
330347
}
331348

332349
def with_columns(
@@ -672,13 +689,14 @@ def test_combine_dataset():
672689

673690

674691
def test_split_dataset():
692+
import shutil
675693
dataset = Dataset.from_dataframe(pd.DataFrame(dict(
676694
index=np.arange(100),
677695
number=np.random.randn(100),
678696
stratify=np.concatenate([np.ones(50), np.zeros(50)]),
679697
))).map(tuple)
680698

681-
split_file = Path('test_split_dataset.json')
699+
save_directory = Path('test_split_dataset')
682700
proportions = dict(
683701
gradient=0.7,
684702
early_stopping=0.15,
@@ -688,7 +706,7 @@ def test_split_dataset():
688706
kwargs = dict(
689707
key_column='index',
690708
proportions=proportions,
691-
filepath=split_file,
709+
save_directory=save_directory,
692710
stratify_column='stratify',
693711
)
694712

@@ -712,8 +730,7 @@ def test_split_dataset():
712730
stratify_column='stratify',
713731
seed=800,
714732
)
715-
716-
split_file.unlink()
733+
shutil.rmtree(save_directory)
717734

718735
assert split_datasets1 == split_datasets2
719736
assert split_datasets1 != split_datasets3
@@ -722,45 +739,128 @@ def test_split_dataset():
722739

723740

724741
def test_group_split_dataset():
742+
import shutil
725743
dataset = Dataset.from_dataframe(pd.DataFrame(dict(
726744
group=np.arange(100) // 4,
727745
number=np.random.randn(100),
728746
))).map(tuple)
729747

730-
split_file = Path('test_split_dataset.json')
748+
save_directory = Path('test_split_dataset')
731749
proportions = dict(
732750
gradient=0.7,
733751
early_stopping=0.15,
734752
compare=0.15,
735753
)
736754

737755
kwargs = dict(
738-
split_column='group',
756+
key_column='group',
739757
proportions=proportions,
740-
filepath=split_file,
758+
save_directory=save_directory,
741759
)
742760

743-
split_datasets1 = dataset.group_split(**kwargs)
744-
split_datasets2 = dataset.group_split(**kwargs)
745-
split_datasets3 = dataset.group_split(
746-
split_column='group',
761+
split_datasets1 = dataset.split(**kwargs)
762+
split_datasets2 = dataset.split(**kwargs)
763+
split_datasets3 = dataset.split(
764+
key_column='group',
747765
proportions=proportions,
748766
seed=100,
749767
)
750-
split_datasets4 = dataset.group_split(
751-
split_column='group',
768+
split_datasets4 = dataset.split(
769+
key_column='group',
752770
proportions=proportions,
753771
seed=100,
754772
)
755-
split_datasets5 = dataset.group_split(
756-
split_column='group',
773+
split_datasets5 = dataset.split(
774+
key_column='group',
757775
proportions=proportions,
758776
seed=800,
759777
)
760778

761-
split_file.unlink()
779+
shutil.rmtree(save_directory)
762780

763781
assert split_datasets1 == split_datasets2
764782
assert split_datasets1 != split_datasets3
765783
assert split_datasets3 == split_datasets4
766784
assert split_datasets3 != split_datasets5
785+
786+
787+
def test_missing_stratify_column():
788+
from pytest import raises
789+
790+
dataset = Dataset.from_dataframe(pd.DataFrame(dict(
791+
index=np.arange(100),
792+
number=np.random.randn(100),
793+
))).map(tuple)
794+
795+
with raises(KeyError):
796+
dataset.split(
797+
key_column='index',
798+
proportions=dict(train=0.8, test=0.2),
799+
stratify_column='should_fail',
800+
)
801+
802+
803+
def test_split_proportions():
804+
dataset = Dataset.from_dataframe(pd.DataFrame(dict(
805+
index=np.arange(100),
806+
number=np.random.randn(100),
807+
stratify=np.arange(100) // 10,
808+
))).map(tuple)
809+
810+
splits = dataset.split(
811+
key_column='index',
812+
proportions=dict(train=0.8, test=0.2),
813+
stratify_column='stratify',
814+
)
815+
816+
assert len(splits['train']) == 80
817+
assert len(splits['test']) == 20
818+
819+
820+
def test_with_columns_split():
821+
dataset = (
822+
Dataset.from_dataframe(pd.DataFrame(dict(
823+
index=np.arange(100),
824+
number=np.arange(100),
825+
)))
826+
.map(tuple)
827+
.with_columns(split=lambda df: df['index'] * 2)
828+
)
829+
830+
splits = dataset.split(
831+
key_column='index',
832+
proportions=dict(train=0.8, test=0.2),
833+
)
834+
835+
assert splits['train'][0][0] * 2 == splits['train'][0][2]
836+
837+
838+
def test_split_save_directory():
839+
import shutil
840+
841+
dataset = (
842+
Dataset.from_dataframe(pd.DataFrame(dict(
843+
index=np.arange(100),
844+
number=np.random.randn(100),
845+
stratify=np.arange(100) // 10,
846+
)))
847+
.map(tuple)
848+
)
849+
850+
save_directory = Path('tmp_test_directory')
851+
splits1 = dataset.split(
852+
key_column='index',
853+
proportions=dict(train=0.8, test=0.2),
854+
save_directory=save_directory,
855+
)
856+
857+
splits2 = dataset.split(
858+
key_column='index',
859+
proportions=dict(train=0.8, test=0.2),
860+
save_directory=save_directory,
861+
)
862+
863+
assert splits1['train'][0] == splits2['train'][0]
864+
assert splits1['test'][0] == splits2['test'][0]
865+
866+
shutil.rmtree(save_directory)

datastream/tools/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,4 @@
22
from datastream.tools.starcompose import starcompose
33
from datastream.tools.repeat_map_chain import repeat_map_chain
44
from datastream.tools.numpy_seed import numpy_seed
5-
from datastream.tools.split_dataframes import (
6-
split_dataframes, group_split_dataframes
7-
)
5+
from datastream.tools.split_dataframes import split_dataframes

0 commit comments

Comments
 (0)