Skip to content

Commit ce281cf

Browse files
ueshinhaoyangeng-db
authored andcommitted
[SPARK-52811][PYTHON] Optimize ArrowTableToRowsConversion.convert to improve its performance
### What changes were proposed in this pull request? Optimizes `ArrowTableToRowsConversion.convert` to improve its performance, similar to apache#51482. - Calculate `fields` in advance - Move conversions to `columnar_data` creation - Make creation of `rows` for-comprehension to avoid expensive `list.append` calls ### Why are the changes needed? `ArrowTableToRowsConversion.convert` has several performance overhead. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The existing tests, and manual benchmarks. ```py def profile(f, *args, _n=10, **kwargs): import cProfile import pstats import gc st = None for _ in range(5): f(*args, **kwargs) for _ in range(_n): gc.collect() with cProfile.Profile() as pr: ret = f(*args, **kwargs) if st is None: st = pstats.Stats(pr) else: st.add(pstats.Stats(pr)) st.sort_stats("time", "cumulative").print_stats() return ret from pyspark.sql.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion from pyspark.sql.types import * data = [ (i if i % 1000 else None, str(i), i) for i in range(1000000) ] schema = ( StructType() .add("i", IntegerType(), nullable=True) .add("s", StringType(), nullable=True) .add("ii", IntegerType(), nullable=False) ) def to_arrow(): return LocalDataToArrowConversion.convert(data, schema, use_large_var_types=False) def from_arrow(tbl): return ArrowTableToRowsConversion.convert(tbl, schema) tbl = to_arrow() profile(from_arrow, tbl) ``` - before ``` 100983380 function calls in 24.509 seconds ``` - after ``` 70655910 function calls in 16.947 seconds ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51508 from ueshin/issues/SPARK-52811/convert. Authored-by: Takuya Ueshin <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent b627f61 commit ce281cf

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

python/pyspark/sql/conversion.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -531,14 +531,20 @@ def convert(table: "pa.Table", schema: StructType) -> List[Row]:
531531

532532
assert schema is not None and isinstance(schema, StructType)
533533

534-
field_converters = [
535-
ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields
536-
]
534+
fields = schema.fieldNames()
537535

538-
columnar_data = [column.to_pylist() for column in table.columns]
536+
if len(fields) > 0:
537+
field_converters = [
538+
ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields
539+
]
539540

540-
rows: List[Row] = []
541-
for i in range(0, table.num_rows):
542-
values = [field_converters[j](columnar_data[j][i]) for j in range(table.num_columns)]
543-
rows.append(_create_row(fields=schema.fieldNames(), values=values))
544-
return rows
541+
columnar_data = [
542+
[conv(v) for v in column.to_pylist()]
543+
for column, conv in zip(table.columns, field_converters)
544+
]
545+
546+
rows = [_create_row(fields, tuple(cols)) for cols in zip(*columnar_data)]
547+
assert len(rows) == table.num_rows, f"{len(rows)}, {table.num_rows}"
548+
return rows
549+
else:
550+
return [_create_row(fields, tuple())] * table.num_rows

0 commit comments

Comments
 (0)