@@ -804,8 +804,7 @@ def test_update_stratified_split():
804
804
Dataset .from_dataframe (pd .DataFrame (dict (
805
805
index = np .arange (100 ),
806
806
number = np .random .randn (100 ),
807
- stratify1 = np .random .randint (0 , 10 , 100 ),
808
- stratify2 = np .random .randint (0 , 10 , 100 ),
807
+ stratify = np .random .randint (0 , 10 , 100 ),
809
808
)))
810
809
.map (tuple )
811
810
)
@@ -819,7 +818,7 @@ def test_update_stratified_split():
819
818
key_column = 'index' ,
820
819
proportions = dict (train = 0.8 , test = 0.2 ),
821
820
filepath = filepath ,
822
- stratify_column = 'stratify1 ' ,
821
+ stratify_column = 'stratify ' ,
823
822
)
824
823
)
825
824
@@ -829,7 +828,7 @@ def test_update_stratified_split():
829
828
key_column = 'index' ,
830
829
proportions = dict (train = 0.8 , test = 0.2 ),
831
830
filepath = filepath ,
832
- stratify_column = 'stratify2 ' ,
831
+ stratify_column = 'stratify ' ,
833
832
)
834
833
)
835
834
@@ -840,8 +839,8 @@ def test_update_stratified_split():
840
839
)
841
840
842
841
assert (
843
- splits1 ['compare ' ].dataframe ['index' ]
844
- .isin (splits2 ['compare ' ].dataframe ['index' ])
842
+ splits1 ['test ' ].dataframe ['index' ]
843
+ .isin (splits2 ['test ' ].dataframe ['index' ])
845
844
.all ()
846
845
)
847
846
0 commit comments