Skip to content

Commit 294c962

Browse files
author
FelixAbrahamsson
committed
test: updating stratified split preserves old split
1 parent 43473b4 commit 294c962

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

datastream/dataset.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,7 @@ def test_update_stratified_split():
804804
Dataset.from_dataframe(pd.DataFrame(dict(
805805
index=np.arange(100),
806806
number=np.random.randn(100),
807-
stratify1=np.random.randint(0, 10, 100),
808-
stratify2=np.random.randint(0, 10, 100),
807+
stratify=np.random.randint(0, 10, 100),
809808
)))
810809
.map(tuple)
811810
)
@@ -819,7 +818,7 @@ def test_update_stratified_split():
819818
key_column='index',
820819
proportions=dict(train=0.8, test=0.2),
821820
filepath=filepath,
822-
stratify_column='stratify1',
821+
stratify_column='stratify',
823822
)
824823
)
825824

@@ -829,7 +828,7 @@ def test_update_stratified_split():
829828
key_column='index',
830829
proportions=dict(train=0.8, test=0.2),
831830
filepath=filepath,
832-
stratify_column='stratify2',
831+
stratify_column='stratify',
833832
)
834833
)
835834

@@ -840,8 +839,8 @@ def test_update_stratified_split():
840839
)
841840

842841
assert (
843-
splits1['compare'].dataframe['index']
844-
.isin(splits2['compare'].dataframe['index'])
842+
splits1['test'].dataframe['index']
843+
.isin(splits2['test'].dataframe['index'])
845844
.all()
846845
)
847846

0 commit comments

Comments
 (0)