@@ -343,6 +343,7 @@ def split_train_other(
343343 ]= 'mixed-set' ,
344344 ratio : tuple [int , int , int ]= (8 ,2 ),
345345 stratify_by : Optional [str ]= None ,
346+ balance : bool = False ,
346347 random_state : Optional [Union [int ,RandomState ]]= None ,
347348 ** kwargs : dict ,
348349 ) -> TwoWaySplit :
@@ -352,6 +353,7 @@ def split_train_other(
352353 split_type = split_type ,
353354 ration = ratio ,
354355 stratify_by = stratify_by ,
356+ balance = balance ,
355357 random_state = random_state ,
356358 ** kwargs
357359 )
@@ -366,15 +368,16 @@ def split_train_test_validate(
366368 ]= 'mixed-set' ,
367369 ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
368370 stratify_by : Optional [str ]= None ,
371+ balance : bool = False ,
369372 random_state : Optional [Union [int ,RandomState ]]= None ,
370373 ** kwargs : dict ,
371374 ) -> Split :
372-
373375 split = split_train_test_validate (
374376 data = self ,
375377 split_type = split_type ,
376378 ratio = ratio ,
377379 stratify_by = stratify_by ,
380+ balance = balance ,
378381 random_state = random_state ,
379382 ** kwargs
380383 )
@@ -389,6 +392,7 @@ def train_test_validate(
389392 ]= 'mixed-set' ,
390393 ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
391394 stratify_by : Optional [str ]= None ,
395+ balance : bool = False ,
392396 random_state : Optional [Union [int ,RandomState ]]= None ,
393397 ** kwargs : dict ,
394398 ) -> Split :
@@ -398,6 +402,7 @@ def train_test_validate(
398402 split_type = split_type ,
399403 ratio = ratio ,
400404 stratify_by = stratify_by ,
405+ balance = balance ,
401406 random_state = random_state ,
402407 ** kwargs
403408 )
@@ -725,6 +730,7 @@ def split_train_other(
725730 ]= 'mixed-set' ,
726731 ratio : tuple [int , int , int ]= (8 ,2 ),
727732 stratify_by : Optional [str ]= None ,
733+ balance : bool = False ,
728734 random_state : Optional [Union [int ,RandomState ]]= None ,
729735 ** kwargs : dict ,
730736 ):
@@ -733,8 +739,9 @@ def split_train_other(
733739 split_type ,
734740 ratio ,
735741 stratify_by ,
742+ balance ,
736743 random_state ,
737- kwargs = kwargs
744+ ** kwargs
738745 )
739746 if stratify_by is not None :
740747 train .experiments = train .experiments [train .experiments ['dose_response_metric' ] != 'split_class' ]
@@ -749,10 +756,10 @@ def split_train_test_validate(
749756 ]= 'mixed-set' ,
750757 ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
751758 stratify_by : Optional [str ]= None ,
759+ balance : bool = False ,
752760 random_state : Optional [Union [int ,RandomState ]]= None ,
753761 ** kwargs : dict ,
754762 ) -> Split :
755-
756763 # Type checking split_type
757764 if split_type not in [
758765 'mixed-set' , 'drug-blind' , 'cancer-blind'
@@ -766,17 +773,19 @@ def split_train_test_validate(
766773 split_type = split_type ,
767774 ratio = [ratio [0 ], ratio [1 ] + ratio [2 ]],
768775 stratify_by = stratify_by ,
776+ balance = balance ,
769777 random_state = random_state ,
770- kwargs = kwargs ,
778+ ** kwargs ,
771779 )
772780
773781 test , val = _split_two_way (
774782 data = other ,
775783 split_type = split_type ,
776784 ratio = [ratio [1 ], ratio [2 ]],
777785 stratify_by = stratify_by ,
786+ balance = balance ,
778787 random_state = random_state ,
779- kwargs = kwargs ,
788+ ** kwargs ,
780789 )
781790
782791 if stratify_by is not None :
@@ -794,6 +803,7 @@ def train_test_validate(
794803 ]= 'mixed-set' ,
795804 ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
796805 stratify_by : Optional [str ]= None ,
806+ balance : bool = False ,
797807 random_state : Optional [Union [int ,RandomState ]]= None ,
798808 ** kwargs : dict ,
799809 ) -> Split :
@@ -871,8 +881,9 @@ def train_test_validate(
871881 split_type = split_type ,
872882 ratio = ratio ,
873883 stratify_by = stratify_by ,
884+ balance = balance ,
874885 random_state = random_state ,
875- kwargs = kwargs ,
886+ ** kwargs ,
876887 )
877888
878889
@@ -976,6 +987,20 @@ def _filter(data: Dataset, split: pd.DataFrame) -> Dataset:
976987
977988 return data_ret
978989
990+ def _balance_data (
991+ data : pd .Dataframe ,
992+ random_state : Optional [Union [int ,RandomState ]]= None ,
993+ # oversample: bool=False,
994+ ) -> pd .Dataframe :
995+ tmp = deepcopy (data )
996+ counts = tmp .value_counts ('split_class' )
997+ ret_df = (
998+ tmp
999+ .groupby ('split_class' )
1000+ .sample (n = min (counts ), random_state = random_state )
1001+ )
1002+ return ret_df
1003+
9791004
9801005def _create_classes (
9811006 data : pd .DataFrame ,
@@ -1072,6 +1097,7 @@ def _split_two_way(
10721097 'mixed-set' , 'drug-blind' , 'cancer-blind'
10731098 ]= 'mixed-set' ,
10741099 ratio : tuple [int , int , int ]= (8 ,2 ),
1100+ balance : bool = False ,
10751101 stratify_by : Optional [str ]= None ,
10761102 random_state : Optional [Union [int ,RandomState ]]= None ,
10771103 ** kwargs : dict ,
@@ -1150,7 +1176,6 @@ def _split_two_way(
11501176 thresh = kwargs .get ('thresh' , None )
11511177 num_classes = kwargs .get ('num_classes' , 2 )
11521178 quantiles = kwargs .get ('quantiles' , True )
1153-
11541179 # Type checking split_type
11551180 if split_type not in [
11561181 'mixed-set' , 'drug-blind' , 'cancer-blind'
@@ -1290,6 +1315,11 @@ def _split_two_way(
12901315 thresh = thresh ,
12911316 quantiles = quantiles ,
12921317 )
1318+ if balance :
1319+ df_full = _balance_data (
1320+ data = df_full ,
1321+ random_state = random_state
1322+ )
12931323 if split_type == 'mixed-set' :
12941324 # Using StratifiedShuffleSplit to generate randomized train
12951325 # and 'other' set, since there is no need for grouping.
0 commit comments