5
5
)
6
6
from pathlib import Path
7
7
from functools import lru_cache
8
+ import warnings
8
9
import textwrap
9
10
import inspect
10
11
import numpy as np
@@ -218,15 +219,15 @@ def split(
218
219
key_column : str ,
219
220
proportions : Dict [str , float ],
220
221
stratify_column : Optional [str ] = None ,
221
- filepath : Optional [Union [str , Path ]] = None ,
222
+ save_directory : Optional [Union [str , Path ]] = None ,
222
223
frozen : Optional [bool ] = False ,
223
224
seed : Optional [int ] = None ,
224
225
) -> Dict [str , Dataset [T ]]:
225
226
'''
226
227
Split dataset into multiple parts. Optionally you can chose to stratify
227
228
on a column in the source dataframe or save the split to a json file.
228
229
If you are sure that the split strategy will not change then you can
229
- safely use a seed instead of a filepath .
230
+ safely use a seed instead of a save_directory .
230
231
231
232
Saved splits can continue from the old split and handles:
232
233
@@ -252,14 +253,40 @@ def split(
252
253
>>> split_datasets['test'][0]
253
254
3
254
255
'''
255
- if filepath is not None :
256
- filepath = Path (filepath )
257
-
258
- if seed is None :
259
- split_dataframes = tools .split_dataframes
256
+ if save_directory is not None :
257
+ save_directory = Path (save_directory )
258
+ save_directory .mkdir (parents = True , exist_ok = True )
259
+
260
+ if stratify_column is not None :
261
+ return self ._stratified_split (
262
+ key_column = key_column ,
263
+ proportions = proportions ,
264
+ stratify_column = stratify_column ,
265
+ save_directory = save_directory ,
266
+ seed = seed ,
267
+ frozen = frozen ,
268
+ )
260
269
else :
261
- split_dataframes = tools .numpy_seed (seed )(tools .split_dataframes )
270
+ return self ._unstratified_split (
271
+ key_column = key_column ,
272
+ proportions = proportions ,
273
+ filepath = (
274
+ save_directory / 'split.json'
275
+ if save_directory is not None else None
276
+ ),
277
+ seed = seed ,
278
+ frozen = frozen ,
279
+ )
262
280
281
+ def _unstratified_split (
282
+ self ,
283
+ key_column : str ,
284
+ proportions : Dict [str , float ],
285
+ filepath : Optional [Path ] = None ,
286
+ seed : Optional [int ] = None ,
287
+ frozen : Optional [bool ] = False ,
288
+ ):
289
+ split_dataframes = tools .numpy_seed (seed )(tools .split_dataframes )
263
290
return {
264
291
split_name : Dataset (
265
292
dataframe = dataframe ,
@@ -270,63 +297,53 @@ def split(
270
297
self .dataframe ,
271
298
key_column ,
272
299
proportions ,
273
- stratify_column ,
274
- filepath ,
275
- frozen ,
300
+ filepath = filepath ,
301
+ frozen = frozen ,
276
302
).items ()
277
303
}
278
304
279
- def group_split (
305
+ def _stratified_split (
280
306
self ,
281
- split_column : str ,
307
+ key_column : str ,
282
308
proportions : Dict [str , float ],
283
- filepath : Optional [Union [ str , Path ] ] = None ,
284
- frozen : Optional [bool ] = False ,
309
+ stratify_column : Optional [str ] = None ,
310
+ save_directory : Optional [Path ] = None ,
285
311
seed : Optional [int ] = None ,
286
- ) -> Dict [str , Dataset [T ]]:
287
- '''
288
- Similar to :func:`Dataset.split`, but uses a non-unique split column
289
- instead of a unique key column. This is useful for example when you
290
- have a dataset with examples that come from separate sources and you
291
- don't want to have examples from the same source in different splits.
292
- Does not support stratification.
293
-
294
- >>> split_file = Path('doctest_split_dataset.json')
295
- >>> split_datasets = (
296
- ... Dataset.from_dataframe(pd.DataFrame(dict(
297
- ... source=np.arange(100) // 4,
298
- ... number=np.random.randn(100),
299
- ... )))
300
- ... .group_split(
301
- ... split_column='source',
302
- ... proportions=dict(train=0.8, test=0.2),
303
- ... filepath=split_file,
304
- ... )
305
- ... )
306
- >>> len(split_datasets['train'])
307
- 80
308
- >>> split_file.unlink() # clean up after doctest
309
- '''
310
- if filepath is not None :
311
- filepath = Path (filepath )
312
-
313
- split_dataframes = tools .group_split_dataframes
314
- if seed is not None :
315
- split_dataframes = tools .numpy_seed (seed )(split_dataframes )
316
-
312
+ frozen : Optional [bool ] = False ,
313
+ ):
314
+ if (
315
+ stratify_column is not None
316
+ and any (self .dataframe [key_column ].duplicated ())
317
+ ):
318
+ # mathematically impossible in the general case
319
+ warnings .warn (
320
+ 'Trying to do stratified split with non-unique key column'
321
+ ' - cannot guarantee correct splitting of key values.'
322
+ )
323
+ strata = {
324
+ stratum_value : self .subset (
325
+ lambda df : df [stratify_column ] == stratum_value
326
+ )
327
+ for stratum_value in self .dataframe [stratify_column ].unique ()
328
+ }
329
+ split_strata = [
330
+ stratum ._unstratified_split (
331
+ key_column = key_column ,
332
+ proportions = proportions ,
333
+ filepath = (
334
+ save_directory / f'{ hash (stratum_value )} .json'
335
+ if save_directory is not None else None
336
+ ),
337
+ seed = seed ,
338
+ frozen = frozen ,
339
+ )
340
+ for stratum_value , stratum in strata .items ()
341
+ ]
317
342
return {
318
- split_name : Dataset (
319
- dataframe = dataframe ,
320
- length = len (dataframe ),
321
- get_item = self .get_item ,
343
+ split_name : Dataset .concat (
344
+ [split_stratum [split_name ] for split_stratum in split_strata ]
322
345
)
323
- for split_name , dataframe in split_dataframes (
324
- self .dataframe ,
325
- split_column ,
326
- proportions ,
327
- filepath ,
328
- frozen ,
329
- ).items ()
346
+ for split_name in proportions .keys ()
330
347
}
331
348
332
349
def with_columns (
@@ -672,13 +689,14 @@ def test_combine_dataset():
672
689
673
690
674
691
def test_split_dataset ():
692
+ import shutil
675
693
dataset = Dataset .from_dataframe (pd .DataFrame (dict (
676
694
index = np .arange (100 ),
677
695
number = np .random .randn (100 ),
678
696
stratify = np .concatenate ([np .ones (50 ), np .zeros (50 )]),
679
697
))).map (tuple )
680
698
681
- split_file = Path ('test_split_dataset.json ' )
699
+ save_directory = Path ('test_split_dataset' )
682
700
proportions = dict (
683
701
gradient = 0.7 ,
684
702
early_stopping = 0.15 ,
@@ -688,7 +706,7 @@ def test_split_dataset():
688
706
kwargs = dict (
689
707
key_column = 'index' ,
690
708
proportions = proportions ,
691
- filepath = split_file ,
709
+ save_directory = save_directory ,
692
710
stratify_column = 'stratify' ,
693
711
)
694
712
@@ -712,8 +730,7 @@ def test_split_dataset():
712
730
stratify_column = 'stratify' ,
713
731
seed = 800 ,
714
732
)
715
-
716
- split_file .unlink ()
733
+ shutil .rmtree (save_directory )
717
734
718
735
assert split_datasets1 == split_datasets2
719
736
assert split_datasets1 != split_datasets3
@@ -722,45 +739,128 @@ def test_split_dataset():
722
739
723
740
724
741
def test_group_split_dataset ():
742
+ import shutil
725
743
dataset = Dataset .from_dataframe (pd .DataFrame (dict (
726
744
group = np .arange (100 ) // 4 ,
727
745
number = np .random .randn (100 ),
728
746
))).map (tuple )
729
747
730
- split_file = Path ('test_split_dataset.json ' )
748
+ save_directory = Path ('test_split_dataset' )
731
749
proportions = dict (
732
750
gradient = 0.7 ,
733
751
early_stopping = 0.15 ,
734
752
compare = 0.15 ,
735
753
)
736
754
737
755
kwargs = dict (
738
- split_column = 'group' ,
756
+ key_column = 'group' ,
739
757
proportions = proportions ,
740
- filepath = split_file ,
758
+ save_directory = save_directory ,
741
759
)
742
760
743
- split_datasets1 = dataset .group_split (** kwargs )
744
- split_datasets2 = dataset .group_split (** kwargs )
745
- split_datasets3 = dataset .group_split (
746
- split_column = 'group' ,
761
+ split_datasets1 = dataset .split (** kwargs )
762
+ split_datasets2 = dataset .split (** kwargs )
763
+ split_datasets3 = dataset .split (
764
+ key_column = 'group' ,
747
765
proportions = proportions ,
748
766
seed = 100 ,
749
767
)
750
- split_datasets4 = dataset .group_split (
751
- split_column = 'group' ,
768
+ split_datasets4 = dataset .split (
769
+ key_column = 'group' ,
752
770
proportions = proportions ,
753
771
seed = 100 ,
754
772
)
755
- split_datasets5 = dataset .group_split (
756
- split_column = 'group' ,
773
+ split_datasets5 = dataset .split (
774
+ key_column = 'group' ,
757
775
proportions = proportions ,
758
776
seed = 800 ,
759
777
)
760
778
761
- split_file . unlink ( )
779
+ shutil . rmtree ( save_directory )
762
780
763
781
assert split_datasets1 == split_datasets2
764
782
assert split_datasets1 != split_datasets3
765
783
assert split_datasets3 == split_datasets4
766
784
assert split_datasets3 != split_datasets5
785
+
786
+
787
+ def test_missing_stratify_column ():
788
+ from pytest import raises
789
+
790
+ dataset = Dataset .from_dataframe (pd .DataFrame (dict (
791
+ index = np .arange (100 ),
792
+ number = np .random .randn (100 ),
793
+ ))).map (tuple )
794
+
795
+ with raises (KeyError ):
796
+ dataset .split (
797
+ key_column = 'index' ,
798
+ proportions = dict (train = 0.8 , test = 0.2 ),
799
+ stratify_column = 'should_fail' ,
800
+ )
801
+
802
+
803
+ def test_split_proportions ():
804
+ dataset = Dataset .from_dataframe (pd .DataFrame (dict (
805
+ index = np .arange (100 ),
806
+ number = np .random .randn (100 ),
807
+ stratify = np .arange (100 ) // 10 ,
808
+ ))).map (tuple )
809
+
810
+ splits = dataset .split (
811
+ key_column = 'index' ,
812
+ proportions = dict (train = 0.8 , test = 0.2 ),
813
+ stratify_column = 'stratify' ,
814
+ )
815
+
816
+ assert len (splits ['train' ]) == 80
817
+ assert len (splits ['test' ]) == 20
818
+
819
+
820
+ def test_with_columns_split ():
821
+ dataset = (
822
+ Dataset .from_dataframe (pd .DataFrame (dict (
823
+ index = np .arange (100 ),
824
+ number = np .arange (100 ),
825
+ )))
826
+ .map (tuple )
827
+ .with_columns (split = lambda df : df ['index' ] * 2 )
828
+ )
829
+
830
+ splits = dataset .split (
831
+ key_column = 'index' ,
832
+ proportions = dict (train = 0.8 , test = 0.2 ),
833
+ )
834
+
835
+ assert splits ['train' ][0 ][0 ] * 2 == splits ['train' ][0 ][2 ]
836
+
837
+
838
+ def test_split_save_directory ():
839
+ import shutil
840
+
841
+ dataset = (
842
+ Dataset .from_dataframe (pd .DataFrame (dict (
843
+ index = np .arange (100 ),
844
+ number = np .random .randn (100 ),
845
+ stratify = np .arange (100 ) // 10 ,
846
+ )))
847
+ .map (tuple )
848
+ )
849
+
850
+ save_directory = Path ('tmp_test_directory' )
851
+ splits1 = dataset .split (
852
+ key_column = 'index' ,
853
+ proportions = dict (train = 0.8 , test = 0.2 ),
854
+ save_directory = save_directory ,
855
+ )
856
+
857
+ splits2 = dataset .split (
858
+ key_column = 'index' ,
859
+ proportions = dict (train = 0.8 , test = 0.2 ),
860
+ save_directory = save_directory ,
861
+ )
862
+
863
+ assert splits1 ['train' ][0 ] == splits2 ['train' ][0 ]
864
+ assert splits1 ['test' ][0 ] == splits2 ['test' ][0 ]
865
+
866
+ shutil .rmtree (save_directory )
0 commit comments