Skip to content

Commit a08e7e9

Browse files
authored
fix: concat with union categories (#3127)
* fix: concat with union categories * formatting
1 parent 8f925fe commit a08e7e9

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

awswrangler/s3/_read.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def _extract_partitions_dtypes_from_table_details(response: "GetTableResponseTyp
116116
return dtypes
117117

118118

119-
def _union(dfs: list[pd.DataFrame], ignore_index: bool) -> pd.DataFrame:
119+
def _concat_union_categoricals(dfs: list[pd.DataFrame], ignore_index: bool) -> pd.DataFrame:
120+
"""Concatenate dataframes with union of categorical columns."""
120121
cats: tuple[set[str], ...] = tuple(set(df.select_dtypes(include="category").columns) for df in dfs)
121122
for col in set.intersection(*cats):
122123
cat = union_categoricals([df[col] for df in dfs])

awswrangler/s3/_read_parquet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from awswrangler.s3._read import (
3434
_apply_partition_filter,
3535
_check_version_id,
36+
_concat_union_categoricals,
3637
_extract_partitions_dtypes_from_table_details,
3738
_get_num_output_blocks,
3839
_get_path_ignore_suffix,
@@ -264,7 +265,7 @@ def _read_parquet_chunked(
264265
yield df
265266
else:
266267
if next_slice is not None:
267-
df = pd.concat(objs=[next_slice, df], sort=False, copy=False)
268+
df = _concat_union_categoricals(dfs=[next_slice, df], ignore_index=False)
268269
while len(df.index) >= chunked:
269270
yield df.iloc[:chunked, :].copy()
270271
df = df.iloc[chunked:, :]

awswrangler/s3/_read_text.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from awswrangler.s3._read import (
2020
_apply_partition_filter,
2121
_check_version_id,
22+
_concat_union_categoricals,
2223
_get_num_output_blocks,
2324
_get_path_ignore_suffix,
2425
_get_path_root,
25-
_union,
2626
)
2727
from awswrangler.s3._read_text_core import _read_text_file, _read_text_files_chunked
2828
from awswrangler.typing import RaySettings
@@ -70,7 +70,7 @@ def _read_text(
7070
itertools.repeat(s3_additional_kwargs),
7171
itertools.repeat(dataset),
7272
)
73-
return _union(dfs=tables, ignore_index=ignore_index)
73+
return _concat_union_categoricals(dfs=tables, ignore_index=ignore_index)
7474

7575

7676
def _read_text_format(

0 commit comments

Comments
 (0)