@@ -343,6 +343,7 @@ def split_train_other(
343
343
]= 'mixed-set' ,
344
344
ratio : tuple [int , int , int ]= (8 ,2 ),
345
345
stratify_by : Optional [str ]= None ,
346
+ balance : bool = False ,
346
347
random_state : Optional [Union [int ,RandomState ]]= None ,
347
348
** kwargs : dict ,
348
349
) -> TwoWaySplit :
@@ -352,6 +353,7 @@ def split_train_other(
352
353
split_type = split_type ,
353
354
ration = ratio ,
354
355
stratify_by = stratify_by ,
356
+ balance = balance ,
355
357
random_state = random_state ,
356
358
** kwargs
357
359
)
@@ -366,15 +368,16 @@ def split_train_test_validate(
366
368
]= 'mixed-set' ,
367
369
ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
368
370
stratify_by : Optional [str ]= None ,
371
+ balance : bool = False ,
369
372
random_state : Optional [Union [int ,RandomState ]]= None ,
370
373
** kwargs : dict ,
371
374
) -> Split :
372
-
373
375
split = split_train_test_validate (
374
376
data = self ,
375
377
split_type = split_type ,
376
378
ratio = ratio ,
377
379
stratify_by = stratify_by ,
380
+ balance = balance ,
378
381
random_state = random_state ,
379
382
** kwargs
380
383
)
@@ -389,6 +392,7 @@ def train_test_validate(
389
392
]= 'mixed-set' ,
390
393
ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
391
394
stratify_by : Optional [str ]= None ,
395
+ balance : bool = False ,
392
396
random_state : Optional [Union [int ,RandomState ]]= None ,
393
397
** kwargs : dict ,
394
398
) -> Split :
@@ -398,6 +402,7 @@ def train_test_validate(
398
402
split_type = split_type ,
399
403
ratio = ratio ,
400
404
stratify_by = stratify_by ,
405
+ balance = balance ,
401
406
random_state = random_state ,
402
407
** kwargs
403
408
)
@@ -725,6 +730,7 @@ def split_train_other(
725
730
]= 'mixed-set' ,
726
731
ratio : tuple [int , int , int ]= (8 ,2 ),
727
732
stratify_by : Optional [str ]= None ,
733
+ balance : bool = False ,
728
734
random_state : Optional [Union [int ,RandomState ]]= None ,
729
735
** kwargs : dict ,
730
736
):
@@ -733,8 +739,9 @@ def split_train_other(
733
739
split_type ,
734
740
ratio ,
735
741
stratify_by ,
742
+ balance ,
736
743
random_state ,
737
- kwargs = kwargs
744
+ ** kwargs
738
745
)
739
746
if stratify_by is not None :
740
747
train .experiments = train .experiments [train .experiments ['dose_response_metric' ] != 'split_class' ]
@@ -749,10 +756,10 @@ def split_train_test_validate(
749
756
]= 'mixed-set' ,
750
757
ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
751
758
stratify_by : Optional [str ]= None ,
759
+ balance : bool = False ,
752
760
random_state : Optional [Union [int ,RandomState ]]= None ,
753
761
** kwargs : dict ,
754
762
) -> Split :
755
-
756
763
# Type checking split_type
757
764
if split_type not in [
758
765
'mixed-set' , 'drug-blind' , 'cancer-blind'
@@ -766,17 +773,19 @@ def split_train_test_validate(
766
773
split_type = split_type ,
767
774
ratio = [ratio [0 ], ratio [1 ] + ratio [2 ]],
768
775
stratify_by = stratify_by ,
776
+ balance = balance ,
769
777
random_state = random_state ,
770
- kwargs = kwargs ,
778
+ ** kwargs ,
771
779
)
772
780
773
781
test , val = _split_two_way (
774
782
data = other ,
775
783
split_type = split_type ,
776
784
ratio = [ratio [1 ], ratio [2 ]],
777
785
stratify_by = stratify_by ,
786
+ balance = balance ,
778
787
random_state = random_state ,
779
- kwargs = kwargs ,
788
+ ** kwargs ,
780
789
)
781
790
782
791
if stratify_by is not None :
@@ -794,6 +803,7 @@ def train_test_validate(
794
803
]= 'mixed-set' ,
795
804
ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
796
805
stratify_by : Optional [str ]= None ,
806
+ balance : bool = False ,
797
807
random_state : Optional [Union [int ,RandomState ]]= None ,
798
808
** kwargs : dict ,
799
809
) -> Split :
@@ -871,8 +881,9 @@ def train_test_validate(
871
881
split_type = split_type ,
872
882
ratio = ratio ,
873
883
stratify_by = stratify_by ,
884
+ balance = balance ,
874
885
random_state = random_state ,
875
- kwargs = kwargs ,
886
+ ** kwargs ,
876
887
)
877
888
878
889
@@ -976,6 +987,20 @@ def _filter(data: Dataset, split: pd.DataFrame) -> Dataset:
976
987
977
988
return data_ret
978
989
990
+ def _balance_data (
991
+ data : pd .Dataframe ,
992
+ random_state : Optional [Union [int ,RandomState ]]= None ,
993
+ # oversample: bool=False,
994
+ ) -> pd .Dataframe :
995
+ tmp = deepcopy (data )
996
+ counts = tmp .value_counts ('split_class' )
997
+ ret_df = (
998
+ tmp
999
+ .groupby ('split_class' )
1000
+ .sample (n = min (counts ), random_state = random_state )
1001
+ )
1002
+ return ret_df
1003
+
979
1004
980
1005
def _create_classes (
981
1006
data : pd .DataFrame ,
@@ -1072,6 +1097,7 @@ def _split_two_way(
1072
1097
'mixed-set' , 'drug-blind' , 'cancer-blind'
1073
1098
]= 'mixed-set' ,
1074
1099
ratio : tuple [int , int , int ]= (8 ,2 ),
1100
+ balance : bool = False ,
1075
1101
stratify_by : Optional [str ]= None ,
1076
1102
random_state : Optional [Union [int ,RandomState ]]= None ,
1077
1103
** kwargs : dict ,
@@ -1150,7 +1176,6 @@ def _split_two_way(
1150
1176
thresh = kwargs .get ('thresh' , None )
1151
1177
num_classes = kwargs .get ('num_classes' , 2 )
1152
1178
quantiles = kwargs .get ('quantiles' , True )
1153
-
1154
1179
# Type checking split_type
1155
1180
if split_type not in [
1156
1181
'mixed-set' , 'drug-blind' , 'cancer-blind'
@@ -1290,6 +1315,11 @@ def _split_two_way(
1290
1315
thresh = thresh ,
1291
1316
quantiles = quantiles ,
1292
1317
)
1318
+ if balance :
1319
+ df_full = _balance_data (
1320
+ data = df_full ,
1321
+ random_state = random_state
1322
+ )
1293
1323
if split_type == 'mixed-set' :
1294
1324
# Using StratifiedShuffleSplit to generate randomized train
1295
1325
# and 'other' set, since there is no need for grouping.
0 commit comments