Skip to content

Commit 43473b4

Browse files
author
FelixAbrahamsson
committed
improve!: filepath instead of save_directory
BREAKING CHANGE: save_directory is no longer an argument to dataset.split
1 parent cc1ca05 commit 43473b4

File tree

2 files changed

+70
-29
lines changed

2 files changed

+70
-29
lines changed

datastream/dataset.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,15 @@ def split(
223223
key_column: str,
224224
proportions: Dict[str, float],
225225
stratify_column: Optional[str] = None,
226-
save_directory: Optional[Union[str, Path]] = None,
226+
filepath: Optional[Union[str, Path]] = None,
227227
frozen: Optional[bool] = False,
228228
seed: Optional[int] = None,
229229
) -> Dict[str, Dataset[T]]:
230230
'''
231231
Split dataset into multiple parts. Optionally you can chose to stratify
232232
on a column in the source dataframe or save the split to a json file.
233233
If you are sure that the split strategy will not change then you can
234-
safely use a seed instead of a save_directory.
234+
safely use a seed instead of a filepath.
235235
236236
Saved splits can continue from the old split and handles:
237237
@@ -257,17 +257,17 @@ def split(
257257
>>> split_datasets['test'][0]
258258
3
259259
'''
260-
if save_directory is not None:
261-
save_directory = Path(save_directory)
262-
save_directory.mkdir(parents=True, exist_ok=True)
260+
if filepath is not None:
261+
filepath = Path(filepath)
262+
filepath.parent.mkdir(parents=True, exist_ok=True)
263263

264264
if stratify_column is not None:
265265
return tools.stratified_split(
266266
self,
267267
key_column=key_column,
268268
proportions=proportions,
269269
stratify_column=stratify_column,
270-
save_directory=save_directory,
270+
filepath=filepath,
271271
seed=seed,
272272
frozen=frozen,
273273
)
@@ -276,10 +276,7 @@ def split(
276276
self,
277277
key_column=key_column,
278278
proportions=proportions,
279-
filepath=(
280-
save_directory / 'split.json'
281-
if save_directory is not None else None
282-
),
279+
filepath=filepath,
283280
seed=seed,
284281
frozen=frozen,
285282
)
@@ -627,14 +624,13 @@ def test_combine_dataset():
627624

628625

629626
def test_split_dataset():
630-
import shutil
631627
dataset = Dataset.from_dataframe(pd.DataFrame(dict(
632628
index=np.arange(100),
633629
number=np.random.randn(100),
634630
stratify=np.concatenate([np.ones(50), np.zeros(50)]),
635631
))).map(tuple)
636632

637-
save_directory = Path('test_split_dataset')
633+
filepath = Path('test_split_dataset.json')
638634
proportions = dict(
639635
gradient=0.7,
640636
early_stopping=0.15,
@@ -644,7 +640,7 @@ def test_split_dataset():
644640
kwargs = dict(
645641
key_column='index',
646642
proportions=proportions,
647-
save_directory=save_directory,
643+
filepath=filepath,
648644
stratify_column='stratify',
649645
)
650646

@@ -668,7 +664,7 @@ def test_split_dataset():
668664
stratify_column='stratify',
669665
seed=800,
670666
)
671-
shutil.rmtree(save_directory)
667+
filepath.unlink()
672668

673669
assert split_datasets1 == split_datasets2
674670
assert split_datasets1 != split_datasets3
@@ -677,13 +673,12 @@ def test_split_dataset():
677673

678674

679675
def test_group_split_dataset():
680-
import shutil
681676
dataset = Dataset.from_dataframe(pd.DataFrame(dict(
682677
group=np.arange(100) // 4,
683678
number=np.random.randn(100),
684679
))).map(tuple)
685680

686-
save_directory = Path('test_split_dataset')
681+
filepath = Path('test_split_dataset.json')
687682
proportions = dict(
688683
gradient=0.7,
689684
early_stopping=0.15,
@@ -693,7 +688,7 @@ def test_group_split_dataset():
693688
kwargs = dict(
694689
key_column='group',
695690
proportions=proportions,
696-
save_directory=save_directory,
691+
filepath=filepath,
697692
)
698693

699694
split_datasets1 = dataset.split(**kwargs)
@@ -714,7 +709,7 @@ def test_group_split_dataset():
714709
seed=800,
715710
)
716711

717-
shutil.rmtree(save_directory)
712+
filepath.unlink()
718713

719714
assert split_datasets1 == split_datasets2
720715
assert split_datasets1 != split_datasets3
@@ -773,8 +768,7 @@ def test_with_columns_split():
773768
assert splits['train'][0][0] * 2 == splits['train'][0][2]
774769

775770

776-
def test_split_save_directory():
777-
import shutil
771+
def test_split_filepath():
778772

779773
dataset = (
780774
Dataset.from_dataframe(pd.DataFrame(dict(
@@ -785,20 +779,70 @@ def test_split_save_directory():
785779
.map(tuple)
786780
)
787781

788-
save_directory = Path('tmp_test_directory')
782+
filepath = Path('tmp_test_split.json')
789783
splits1 = dataset.split(
790784
key_column='index',
791785
proportions=dict(train=0.8, test=0.2),
792-
save_directory=save_directory,
786+
filepath=filepath,
793787
)
794788

795789
splits2 = dataset.split(
796790
key_column='index',
797791
proportions=dict(train=0.8, test=0.2),
798-
save_directory=save_directory,
792+
filepath=filepath,
799793
)
800794

801795
assert splits1['train'][0] == splits2['train'][0]
802796
assert splits1['test'][0] == splits2['test'][0]
803797

804-
shutil.rmtree(save_directory)
798+
filepath.unlink()
799+
800+
801+
def test_update_stratified_split():
802+
803+
dataset = (
804+
Dataset.from_dataframe(pd.DataFrame(dict(
805+
index=np.arange(100),
806+
number=np.random.randn(100),
807+
stratify1=np.random.randint(0, 10, 100),
808+
stratify2=np.random.randint(0, 10, 100),
809+
)))
810+
.map(tuple)
811+
)
812+
813+
filepath = Path('tmp_test_split.json')
814+
815+
splits1 = (
816+
dataset
817+
.subset(lambda df: df['index'] < 50)
818+
.split(
819+
key_column='index',
820+
proportions=dict(train=0.8, test=0.2),
821+
filepath=filepath,
822+
stratify_column='stratify1',
823+
)
824+
)
825+
826+
splits2 = (
827+
dataset
828+
.split(
829+
key_column='index',
830+
proportions=dict(train=0.8, test=0.2),
831+
filepath=filepath,
832+
stratify_column='stratify2',
833+
)
834+
)
835+
836+
assert (
837+
splits1['train'].dataframe['index']
838+
.isin(splits2['train'].dataframe['index'])
839+
.all()
840+
)
841+
842+
assert (
843+
splits1['compare'].dataframe['index']
844+
.isin(splits2['compare'].dataframe['index'])
845+
.all()
846+
)
847+
848+
filepath.unlink()

datastream/tools/stratified_split.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def stratified_split(
99
key_column: str,
1010
proportions: Dict[str, float],
1111
stratify_column: Optional[str] = None,
12-
save_directory: Optional[Path] = None,
12+
filepath: Optional[Path] = None,
1313
seed: Optional[int] = None,
1414
frozen: Optional[bool] = False,
1515
):
@@ -33,10 +33,7 @@ def stratified_split(
3333
stratum,
3434
key_column=key_column,
3535
proportions=proportions,
36-
filepath=(
37-
save_directory / f'{hash(stratum_value)}.json'
38-
if save_directory is not None else None
39-
),
36+
filepath=filepath,
4037
seed=seed,
4138
frozen=frozen,
4239
)

0 commit comments

Comments
 (0)