@@ -223,15 +223,15 @@ def split(
223
223
key_column : str ,
224
224
proportions : Dict [str , float ],
225
225
stratify_column : Optional [str ] = None ,
226
- save_directory : Optional [Union [str , Path ]] = None ,
226
+ filepath : Optional [Union [str , Path ]] = None ,
227
227
frozen : Optional [bool ] = False ,
228
228
seed : Optional [int ] = None ,
229
229
) -> Dict [str , Dataset [T ]]:
230
230
'''
231
231
Split dataset into multiple parts. Optionally you can chose to stratify
232
232
on a column in the source dataframe or save the split to a json file.
233
233
If you are sure that the split strategy will not change then you can
234
- safely use a seed instead of a save_directory .
234
+ safely use a seed instead of a filepath .
235
235
236
236
Saved splits can continue from the old split and handles:
237
237
@@ -257,17 +257,17 @@ def split(
257
257
>>> split_datasets['test'][0]
258
258
3
259
259
'''
260
- if save_directory is not None :
261
- save_directory = Path (save_directory )
262
- save_directory .mkdir (parents = True , exist_ok = True )
260
+ if filepath is not None :
261
+ filepath = Path (filepath )
262
+ filepath . parent .mkdir (parents = True , exist_ok = True )
263
263
264
264
if stratify_column is not None :
265
265
return tools .stratified_split (
266
266
self ,
267
267
key_column = key_column ,
268
268
proportions = proportions ,
269
269
stratify_column = stratify_column ,
270
- save_directory = save_directory ,
270
+ filepath = filepath ,
271
271
seed = seed ,
272
272
frozen = frozen ,
273
273
)
@@ -276,10 +276,7 @@ def split(
276
276
self ,
277
277
key_column = key_column ,
278
278
proportions = proportions ,
279
- filepath = (
280
- save_directory / 'split.json'
281
- if save_directory is not None else None
282
- ),
279
+ filepath = filepath ,
283
280
seed = seed ,
284
281
frozen = frozen ,
285
282
)
@@ -627,14 +624,13 @@ def test_combine_dataset():
627
624
628
625
629
626
def test_split_dataset ():
630
- import shutil
631
627
dataset = Dataset .from_dataframe (pd .DataFrame (dict (
632
628
index = np .arange (100 ),
633
629
number = np .random .randn (100 ),
634
630
stratify = np .concatenate ([np .ones (50 ), np .zeros (50 )]),
635
631
))).map (tuple )
636
632
637
- save_directory = Path ('test_split_dataset' )
633
+ filepath = Path ('test_split_dataset.json ' )
638
634
proportions = dict (
639
635
gradient = 0.7 ,
640
636
early_stopping = 0.15 ,
@@ -644,7 +640,7 @@ def test_split_dataset():
644
640
kwargs = dict (
645
641
key_column = 'index' ,
646
642
proportions = proportions ,
647
- save_directory = save_directory ,
643
+ filepath = filepath ,
648
644
stratify_column = 'stratify' ,
649
645
)
650
646
@@ -668,7 +664,7 @@ def test_split_dataset():
668
664
stratify_column = 'stratify' ,
669
665
seed = 800 ,
670
666
)
671
- shutil . rmtree ( save_directory )
667
+ filepath . unlink ( )
672
668
673
669
assert split_datasets1 == split_datasets2
674
670
assert split_datasets1 != split_datasets3
@@ -677,13 +673,12 @@ def test_split_dataset():
677
673
678
674
679
675
def test_group_split_dataset ():
680
- import shutil
681
676
dataset = Dataset .from_dataframe (pd .DataFrame (dict (
682
677
group = np .arange (100 ) // 4 ,
683
678
number = np .random .randn (100 ),
684
679
))).map (tuple )
685
680
686
- save_directory = Path ('test_split_dataset' )
681
+ filepath = Path ('test_split_dataset.json ' )
687
682
proportions = dict (
688
683
gradient = 0.7 ,
689
684
early_stopping = 0.15 ,
@@ -693,7 +688,7 @@ def test_group_split_dataset():
693
688
kwargs = dict (
694
689
key_column = 'group' ,
695
690
proportions = proportions ,
696
- save_directory = save_directory ,
691
+ filepath = filepath ,
697
692
)
698
693
699
694
split_datasets1 = dataset .split (** kwargs )
@@ -714,7 +709,7 @@ def test_group_split_dataset():
714
709
seed = 800 ,
715
710
)
716
711
717
- shutil . rmtree ( save_directory )
712
+ filepath . unlink ( )
718
713
719
714
assert split_datasets1 == split_datasets2
720
715
assert split_datasets1 != split_datasets3
@@ -773,8 +768,7 @@ def test_with_columns_split():
773
768
assert splits ['train' ][0 ][0 ] * 2 == splits ['train' ][0 ][2 ]
774
769
775
770
776
- def test_split_save_directory ():
777
- import shutil
771
+ def test_split_filepath ():
778
772
779
773
dataset = (
780
774
Dataset .from_dataframe (pd .DataFrame (dict (
@@ -785,20 +779,70 @@ def test_split_save_directory():
785
779
.map (tuple )
786
780
)
787
781
788
- save_directory = Path ('tmp_test_directory ' )
782
+ filepath = Path ('tmp_test_split.json ' )
789
783
splits1 = dataset .split (
790
784
key_column = 'index' ,
791
785
proportions = dict (train = 0.8 , test = 0.2 ),
792
- save_directory = save_directory ,
786
+ filepath = filepath ,
793
787
)
794
788
795
789
splits2 = dataset .split (
796
790
key_column = 'index' ,
797
791
proportions = dict (train = 0.8 , test = 0.2 ),
798
- save_directory = save_directory ,
792
+ filepath = filepath ,
799
793
)
800
794
801
795
assert splits1 ['train' ][0 ] == splits2 ['train' ][0 ]
802
796
assert splits1 ['test' ][0 ] == splits2 ['test' ][0 ]
803
797
804
- shutil .rmtree (save_directory )
798
+ filepath .unlink ()
799
+
800
+
801
+ def test_update_stratified_split ():
802
+
803
+ dataset = (
804
+ Dataset .from_dataframe (pd .DataFrame (dict (
805
+ index = np .arange (100 ),
806
+ number = np .random .randn (100 ),
807
+ stratify1 = np .random .randint (0 , 10 , 100 ),
808
+ stratify2 = np .random .randint (0 , 10 , 100 ),
809
+ )))
810
+ .map (tuple )
811
+ )
812
+
813
+ filepath = Path ('tmp_test_split.json' )
814
+
815
+ splits1 = (
816
+ dataset
817
+ .subset (lambda df : df ['index' ] < 50 )
818
+ .split (
819
+ key_column = 'index' ,
820
+ proportions = dict (train = 0.8 , test = 0.2 ),
821
+ filepath = filepath ,
822
+ stratify_column = 'stratify1' ,
823
+ )
824
+ )
825
+
826
+ splits2 = (
827
+ dataset
828
+ .split (
829
+ key_column = 'index' ,
830
+ proportions = dict (train = 0.8 , test = 0.2 ),
831
+ filepath = filepath ,
832
+ stratify_column = 'stratify2' ,
833
+ )
834
+ )
835
+
836
+ assert (
837
+ splits1 ['train' ].dataframe ['index' ]
838
+ .isin (splits2 ['train' ].dataframe ['index' ])
839
+ .all ()
840
+ )
841
+
842
+ assert (
843
+ splits1 ['compare' ].dataframe ['index' ]
844
+ .isin (splits2 ['compare' ].dataframe ['index' ])
845
+ .all ()
846
+ )
847
+
848
+ filepath .unlink ()
0 commit comments