@@ -384,7 +384,16 @@ def get_item(dataframe, index):
384
384
+ '' .join ([random .choice (string .ascii_lowercase ) for _ in range (8 )])
385
385
)
386
386
387
- new_dataframe = pd .concat ([dataset .dataframe for dataset in datasets ])
387
+ dataframes = [dataset .dataframe for dataset in datasets ]
388
+ for dataframe in dataframes :
389
+ for col in dataframe .columns :
390
+ if (
391
+ dataframe [col ].dtype == int
392
+ and any ([col not in other .columns for other in dataframes ])
393
+ ):
394
+ dataframe [col ] = dataframe [col ].astype (object )
395
+
396
+ new_dataframe = pd .concat (dataframes )
388
397
new_dataframe [dataset_column ] = [
389
398
from_concat_mapping (index )[0 ]
390
399
for index in range (len (new_dataframe ))
@@ -860,3 +869,19 @@ def test_update_stratified_split():
860
869
)
861
870
862
871
filepath .unlink ()
872
+
873
+
874
+ def test_concat_missing_columns ():
875
+ dataset1 = Dataset .from_dataframe (
876
+ pd .DataFrame (dict (a = [1 , 2 , 3 ], b = ['a' , 'b' , 'c' ]))
877
+ )
878
+ dataset2 = Dataset .from_dataframe (
879
+ pd .DataFrame (dict (c = [True , False ], d = [[1 , 2 ], [3 , 4 ]]))
880
+ )
881
+ concatenated = Dataset .concat ([dataset1 , dataset2 ])
882
+
883
+ assert type (concatenated [0 ]['a' ]) == int
884
+ assert type (concatenated [- 1 ]['a' ]) == float
885
+ assert type (concatenated [0 ]['b' ]) == str
886
+ assert type (concatenated [- 1 ]['c' ]) == bool
887
+ assert type (concatenated [- 1 ]['d' ]) == list
0 commit comments