Skip to content

Commit eda7e9c

Browse files
author
FelixAbrahamsson
committed
fix: concat broke dataset if int columns were missing
1 parent 60b9c42 commit eda7e9c

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

datastream/dataset.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,16 @@ def get_item(dataframe, index):
384384
+ ''.join([random.choice(string.ascii_lowercase) for _ in range(8)])
385385
)
386386

387-
new_dataframe = pd.concat([dataset.dataframe for dataset in datasets])
387+
dataframes = [dataset.dataframe for dataset in datasets]
388+
for dataframe in dataframes:
389+
for col in dataframe.columns:
390+
if (
391+
dataframe[col].dtype == int
392+
and any([col not in other.columns for other in dataframes])
393+
):
394+
dataframe[col] = dataframe[col].astype(object)
395+
396+
new_dataframe = pd.concat(dataframes)
388397
new_dataframe[dataset_column] = [
389398
from_concat_mapping(index)[0]
390399
for index in range(len(new_dataframe))
@@ -860,3 +869,19 @@ def test_update_stratified_split():
860869
)
861870

862871
filepath.unlink()
872+
873+
874+
def test_concat_missing_columns():
875+
dataset1 = Dataset.from_dataframe(
876+
pd.DataFrame(dict(a=[1, 2, 3], b=['a', 'b', 'c']))
877+
)
878+
dataset2 = Dataset.from_dataframe(
879+
pd.DataFrame(dict(c=[True, False], d=[[1, 2], [3, 4]]))
880+
)
881+
concatenated = Dataset.concat([dataset1, dataset2])
882+
883+
assert type(concatenated[0]['a']) == int
884+
assert type(concatenated[-1]['a']) == float
885+
assert type(concatenated[0]['b']) == str
886+
assert type(concatenated[-1]['c']) == bool
887+
assert type(concatenated[-1]['d']) == list

0 commit comments

Comments
 (0)