Skip to content

Commit ee973b3

Browse files
committed
fixed bug where 'split_class' wouldn't be removed from dataset.experiments if mix-set splits were generated with stratification
1 parent 8fadc34 commit ee973b3

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

coderdata/dataset/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,13 +936,16 @@ def train_test_validate(
936936
sss_1.split(X=df_full, y=df_full['split_class'])
937937
)
938938
df_train = df_full.iloc[idx_train]
939+
df_train = df_train.drop(labels=['split_class'], axis=1)
939940
df_other = df_full.iloc[idx_other]
940941
# Splitting 'other' further into test and validate
941942
idx_test, idx_val = next(
942943
sss_2.split(X=df_other, y=df_other['split_class'])
943944
)
944945
df_test = df_other.iloc[idx_test]
946+
df_test = df_test.drop(labels=['split_class'], axis=1)
945947
df_val = df_other.iloc[idx_val]
948+
df_val = df_val.drop(labels=['split_class'], axis=1)
946949

947950
# using StratifiedGroupKSplit for the stratified drug-/sample-
948951
# blind splits.

0 commit comments

Comments
 (0)