@@ -413,15 +413,28 @@ def concat(datasets: List[Dataset]) -> Dataset[R]:
413
413
'''
414
414
from_concat_mapping = Dataset .create_from_concat_mapping (datasets )
415
415
416
- def get_item (dataframe , index ):
417
- dataset_index , inner_index = from_concat_mapping (index )
418
- return datasets [dataset_index ][inner_index ]
416
+ if any ([dataset .dataframe is None for dataset in datasets ]):
419
417
420
- return Dataset (
421
- dataframe = None , # TODO: concat dataframes?
422
- length = sum (map (len , datasets )),
423
- get_item = get_item ,
424
- )
418
+ def get_item (dataframe , index ):
419
+ dataset_index , inner_index = from_concat_mapping (index )
420
+ return datasets [dataset_index ][inner_index ]
421
+
422
+ return Dataset (
423
+ dataframe = None ,
424
+ length = sum (map (len , datasets )),
425
+ get_item = get_item ,
426
+ )
427
+ else :
428
+
429
+ def get_item (dataframe , index ):
430
+ dataset_index , _ = from_concat_mapping (index )
431
+ return datasets [dataset_index ].get_item (dataframe , index )
432
+
433
+ return Dataset (
434
+ dataframe = pd .concat ([dataset .dataframe for dataset in datasets ]),
435
+ length = sum (map (len , datasets )),
436
+ get_item = get_item ,
437
+ )
425
438
426
439
@staticmethod
427
440
def create_from_combine_mapping (datasets ):
@@ -600,6 +613,28 @@ def test_concat_dataset():
600
613
assert dataset [6 ] == 1
601
614
602
615
616
+ def test_concat_heterogenous_datasets ():
617
+ dataset1 = Dataset .from_dataframe (
618
+ pd .DataFrame (dict (a = [1 ], b = ['a' ])).set_index ('a' ),
619
+ )
620
+ dataset2 = Dataset .from_dataframe (
621
+ pd .DataFrame (dict (a = [1 ], b = [1 ], c = [2 ])).set_index ('a' ),
622
+ )
623
+ dataset = (
624
+ Dataset .concat ([dataset1 , dataset2 ])
625
+ .map (lambda row : row ['b' ])
626
+ )
627
+
628
+ assert list (dataset ) == ['a' , 1 ]
629
+
630
+ dataset_other_functions = Dataset .concat ([
631
+ dataset1 .map (lambda row : row ['b' ]),
632
+ dataset2 .map (lambda row : row ['c' ]),
633
+ ])
634
+
635
+ assert list (dataset_other_functions ) == ['a' , 2 ]
636
+
637
+
603
638
def test_zip_dataset ():
604
639
dataset = Dataset .zip ([
605
640
Dataset .from_subscriptable (list (range (5 ))),
0 commit comments