Skip to content

Commit e65634b

Browse files
authored
Merge pull request #328 from PNNL-CompBio/327-fix-balancing-of-datasets-in-train_test_validate
added balancing flag to all splitting methods
2 parents adea7a5 + 93912fa commit e65634b

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

coderdata/dataset/dataset.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def split_train_other(
343343
]='mixed-set',
344344
ratio: tuple[int, int, int]=(8,2),
345345
stratify_by: Optional[str]=None,
346+
balance: bool=False,
346347
random_state: Optional[Union[int,RandomState]]=None,
347348
**kwargs: dict,
348349
) -> TwoWaySplit:
@@ -352,6 +353,7 @@ def split_train_other(
352353
split_type=split_type,
353354
ration=ratio,
354355
stratify_by=stratify_by,
356+
balance=balance,
355357
random_state=random_state,
356358
**kwargs
357359
)
@@ -366,15 +368,16 @@ def split_train_test_validate(
366368
]='mixed-set',
367369
ratio: tuple[int, int, int]=(8,1,1),
368370
stratify_by: Optional[str]=None,
371+
balance: bool=False,
369372
random_state: Optional[Union[int,RandomState]]=None,
370373
**kwargs: dict,
371374
) -> Split:
372-
373375
split = split_train_test_validate(
374376
data=self,
375377
split_type=split_type,
376378
ratio=ratio,
377379
stratify_by=stratify_by,
380+
balance=balance,
378381
random_state=random_state,
379382
**kwargs
380383
)
@@ -389,6 +392,7 @@ def train_test_validate(
389392
]='mixed-set',
390393
ratio: tuple[int, int, int]=(8,1,1),
391394
stratify_by: Optional[str]=None,
395+
balance: bool=False,
392396
random_state: Optional[Union[int,RandomState]]=None,
393397
**kwargs: dict,
394398
) -> Split:
@@ -398,6 +402,7 @@ def train_test_validate(
398402
split_type=split_type,
399403
ratio=ratio,
400404
stratify_by=stratify_by,
405+
balance=balance,
401406
random_state=random_state,
402407
**kwargs
403408
)
@@ -725,6 +730,7 @@ def split_train_other(
725730
]='mixed-set',
726731
ratio: tuple[int, int, int]=(8,2),
727732
stratify_by: Optional[str]=None,
733+
balance: bool=False,
728734
random_state: Optional[Union[int,RandomState]]=None,
729735
**kwargs: dict,
730736
):
@@ -733,8 +739,9 @@ def split_train_other(
733739
split_type,
734740
ratio,
735741
stratify_by,
742+
balance,
736743
random_state,
737-
kwargs=kwargs
744+
**kwargs
738745
)
739746
if stratify_by is not None:
740747
train.experiments = train.experiments[train.experiments['dose_response_metric'] != 'split_class']
@@ -749,10 +756,10 @@ def split_train_test_validate(
749756
]='mixed-set',
750757
ratio: tuple[int, int, int]=(8,1,1),
751758
stratify_by: Optional[str]=None,
759+
balance: bool=False,
752760
random_state: Optional[Union[int,RandomState]]=None,
753761
**kwargs: dict,
754762
) -> Split:
755-
756763
# Type checking split_type
757764
if split_type not in [
758765
'mixed-set', 'drug-blind', 'cancer-blind'
@@ -766,17 +773,19 @@ def split_train_test_validate(
766773
split_type=split_type,
767774
ratio=[ratio[0], ratio[1] + ratio[2]],
768775
stratify_by=stratify_by,
776+
balance=balance,
769777
random_state=random_state,
770-
kwargs=kwargs,
778+
**kwargs,
771779
)
772780

773781
test, val = _split_two_way(
774782
data=other,
775783
split_type=split_type,
776784
ratio=[ratio[1], ratio[2]],
777785
stratify_by=stratify_by,
786+
balance=balance,
778787
random_state=random_state,
779-
kwargs=kwargs,
788+
**kwargs,
780789
)
781790

782791
if stratify_by is not None:
@@ -794,6 +803,7 @@ def train_test_validate(
794803
]='mixed-set',
795804
ratio: tuple[int, int, int]=(8,1,1),
796805
stratify_by: Optional[str]=None,
806+
balance: bool=False,
797807
random_state: Optional[Union[int,RandomState]]=None,
798808
**kwargs: dict,
799809
) -> Split:
@@ -871,8 +881,9 @@ def train_test_validate(
871881
split_type=split_type,
872882
ratio=ratio,
873883
stratify_by=stratify_by,
884+
balance=balance,
874885
random_state=random_state,
875-
kwargs=kwargs,
886+
**kwargs,
876887
)
877888

878889

@@ -976,6 +987,20 @@ def _filter(data: Dataset, split: pd.DataFrame) -> Dataset:
976987

977988
return data_ret
978989

990+
def _balance_data(
991+
data: pd.Dataframe,
992+
random_state: Optional[Union[int,RandomState]]=None,
993+
# oversample: bool=False,
994+
) -> pd.Dataframe:
995+
tmp = deepcopy(data)
996+
counts = tmp.value_counts('split_class')
997+
ret_df = (
998+
tmp
999+
.groupby('split_class')
1000+
.sample(n=min(counts), random_state=random_state)
1001+
)
1002+
return ret_df
1003+
9791004

9801005
def _create_classes(
9811006
data: pd.DataFrame,
@@ -1072,6 +1097,7 @@ def _split_two_way(
10721097
'mixed-set', 'drug-blind', 'cancer-blind'
10731098
]='mixed-set',
10741099
ratio: tuple[int, int, int]=(8,2),
1100+
balance: bool=False,
10751101
stratify_by: Optional[str]=None,
10761102
random_state: Optional[Union[int,RandomState]]=None,
10771103
**kwargs: dict,
@@ -1150,7 +1176,6 @@ def _split_two_way(
11501176
thresh = kwargs.get('thresh', None)
11511177
num_classes = kwargs.get('num_classes', 2)
11521178
quantiles = kwargs.get('quantiles', True)
1153-
11541179
# Type checking split_type
11551180
if split_type not in [
11561181
'mixed-set', 'drug-blind', 'cancer-blind'
@@ -1290,6 +1315,11 @@ def _split_two_way(
12901315
thresh=thresh,
12911316
quantiles=quantiles,
12921317
)
1318+
if balance:
1319+
df_full = _balance_data(
1320+
data=df_full,
1321+
random_state=random_state
1322+
)
12931323
if split_type == 'mixed-set':
12941324
# Using StratifiedShuffleSplit to generate randomized train
12951325
# and 'other' set, since there is no need for grouping.

0 commit comments

Comments
 (0)