Skip to content

Commit 772b7dc

Browse files
samukwekusamuel.oranyeli
and
samuel.oranyeli
authored
[ENH] Improve performance for polars' pivot_longer (#1402)
* shortcut for .value only * fastpath if others is just a single column and a string dtype * fix parameters for unpivot --------- Co-authored-by: samuel.oranyeli <[email protected]>
1 parent dabccdb commit 772b7dc

File tree

1 file changed

+110
-35
lines changed

1 file changed

+110
-35
lines changed

janitor/polars/pivot_longer.py

+110-35
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def pivot_longer_spec(
2626
A declarative interface to pivot a Polars Frame
2727
from wide to long form,
2828
where you describe how the data will be unpivoted,
29-
using a DataFrame. This gives you, the user,
29+
using a DataFrame.
30+
31+
It is modeled after tidyr's `pivot_longer_spec`.
32+
33+
This gives you, the user,
3034
more control over the transformation to long form,
3135
using a *spec* DataFrame that describes exactly
3236
how data stored in the column names
@@ -108,41 +112,56 @@ def pivot_longer_spec(
108112
corresponding to columns pivoted from the wide format.
109113
Note that these additional columns should not already exist
110114
in the source DataFrame.
115+
If there are additional columns, the combination of these columns
116+
and the `.value` column must be unique.
111117
112118
Raises:
113119
KeyError: If `.name` or `.value` is missing from the spec's columns.
114-
ValueError: If the labels in `spec['.name']` is not unique.
120+
ValueError: If the labels in spec's `.name` column is not unique.
115121
116122
Returns:
117123
A polars DataFrame/LazyFrame.
118124
"""
119125
check("spec", spec, [pl.DataFrame])
120-
if ".name" not in spec.columns:
126+
spec_columns = spec.collect_schema().names()
127+
if ".name" not in spec_columns:
121128
raise KeyError(
122129
"Kindly ensure the spec DataFrame has a `.name` column."
123130
)
124-
if ".value" not in spec.columns:
131+
if ".value" not in spec_columns:
125132
raise KeyError(
126133
"Kindly ensure the spec DataFrame has a `.value` column."
127134
)
128-
if spec.select(pl.col(".name").is_duplicated().any()).item():
135+
if spec.get_column(".name").is_duplicated().any():
129136
raise ValueError("The labels in the `.name` column should be unique.")
130-
131-
exclude = set(df.columns).intersection(spec.columns)
137+
df_columns = df.collect_schema().names()
138+
exclude = set(df_columns).intersection(spec_columns)
132139
if exclude:
133140
raise ValueError(
134141
f"Labels {*exclude, } in the spec dataframe already exist "
135142
"as column labels in the source dataframe. "
136143
"Kindly ensure the spec DataFrame's columns "
137144
"are not present in the source DataFrame."
138145
)
146+
139147
index = [
140-
label for label in df.columns if label not in spec.get_column(".name")
148+
label for label in df_columns if label not in spec.get_column(".name")
141149
]
142150
others = [
143-
label for label in spec.columns if label not in {".name", ".value"}
151+
label for label in spec_columns if label not in {".name", ".value"}
144152
]
145-
variable_name = "".join(df.columns + spec.columns)
153+
154+
if (len(others) == 1) & (spec.get_column(others[0]).dtype == pl.String):
155+
# shortcut that avoids the implode/explode approach - and is faster
156+
# if the requirements are met
157+
# inspired by https://github.com/pola-rs/polars/pull/18519#issue-2500860927
158+
return _pivot_longer_dot_value_string(
159+
df=df,
160+
index=index,
161+
spec=spec,
162+
variable_name=others[0],
163+
)
164+
variable_name = "".join(df_columns + spec_columns)
146165
variable_name = f"{variable_name}_"
147166
if others:
148167
dot_value_only = False
@@ -219,7 +238,7 @@ def pivot_longer(
219238
│ 5.9 ┆ 3.0 ┆ 5.1 ┆ 1.8 ┆ virginica │
220239
└──────────────┴─────────────┴──────────────┴─────────────┴───────────┘
221240
222-
Replicate polars' [melt](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.melt.html#polars-dataframe-melt):
241+
Replicate polars' [melt](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.unpivot.html#polars-dataframe-melt):
223242
>>> df.pivot_longer(index = 'Species').sort(by=pl.all())
224243
shape: (8, 3)
225244
┌───────────┬──────────────┬───────┐
@@ -375,8 +394,8 @@ def pivot_longer(
375394
specification as polars' `str.split` method.
376395
names_pattern: Determines how the column name is broken up.
377396
It can be a regular expression containing matching groups.
378-
It takes the same
379-
specification as polars' `str.extract_groups` method.
397+
It takes the same specification as
398+
polars' `str.extract_groups` method.
380399
names_transform: Use this option to change the types of columns that
381400
have been transformed to rows.
382401
This does not applies to the values' columns.
@@ -440,7 +459,7 @@ def _pivot_longer(
440459
names_pattern=names_pattern,
441460
)
442461

443-
variable_name = "".join(df.columns)
462+
variable_name = "".join(df.collect_schema().names())
444463
variable_name = f"{variable_name}_"
445464
spec = _pivot_longer_create_spec(
446465
column_names=column_names,
@@ -461,8 +480,25 @@ def _pivot_longer(
461480
variable_name=variable_name,
462481
names_transform=names_transform,
463482
)
464-
465-
if {".name", ".value"}.symmetric_difference(spec.columns):
483+
if {".name", ".value"}.symmetric_difference(spec.collect_schema().names()):
484+
# shortcut that avoids the implode/explode approach - and is faster
485+
# if the requirements are met
486+
# inspired by https://github.com/pola-rs/polars/pull/18519#issue-2500860927
487+
data = spec.get_column(variable_name)
488+
others = data.struct.fields
489+
data = data.struct[others[0]]
490+
if (
491+
(len(others) == 1)
492+
& (data.dtype == pl.String)
493+
& (names_transform is None)
494+
):
495+
spec = spec.unnest(variable_name)
496+
return _pivot_longer_dot_value_string(
497+
df=df,
498+
index=index,
499+
spec=spec,
500+
variable_name=others[0],
501+
)
466502
dot_value_only = False
467503
else:
468504
dot_value_only = True
@@ -552,7 +588,7 @@ def _pivot_longer_create_spec(
552588
return spec.select(".name", ".value")
553589
_spec = spec.get_column(variable_name)
554590
_spec = _spec.struct.unnest()
555-
fields = _spec.columns
591+
fields = _spec.collect_schema().names()
556592

557593
if len(set(names_to)) == 1:
558594
expression = pl.concat_str(fields).alias(".value")
@@ -591,7 +627,7 @@ def _pivot_longer_no_dot_value(
591627
# do the operation on a smaller size
592628
# and then blow it up after
593629
# it is usually much faster
594-
# than running on the actual data
630+
# than unpivoting and running the string operations after
595631
outcome = (
596632
df.select(pl.all().implode())
597633
.unpivot(
@@ -606,11 +642,44 @@ def _pivot_longer_no_dot_value(
606642
outcome = outcome.unnest(variable_name)
607643
if names_transform is not None:
608644
outcome = outcome.with_columns(names_transform)
609-
columns = [name for name in outcome.columns if name not in names_to]
645+
columns = [
646+
name
647+
for name in outcome.collect_schema().names()
648+
if name not in names_to
649+
]
610650
outcome = outcome.explode(columns=columns)
611651
return outcome
612652

613653

654+
def _pivot_longer_dot_value_string(
655+
df: pl.DataFrame | pl.LazyFrame,
656+
spec: pl.DataFrame,
657+
index: ColumnNameOrSelector,
658+
variable_name: str,
659+
) -> pl.DataFrame | pl.LazyFrame:
660+
"""
661+
fastpath for .value - does not require implode/explode approach.
662+
"""
663+
spec = spec.group_by(variable_name)
664+
spec = spec.agg(pl.all())
665+
expressions = []
666+
for names, fields, header in zip(
667+
spec.get_column(".name").to_list(),
668+
spec.get_column(".value").to_list(),
669+
spec.get_column(variable_name).to_list(),
670+
):
671+
expression = pl.struct(names).struct.rename_fields(names=fields)
672+
expression = expression.alias(header)
673+
expressions.append(expression)
674+
expressions = [*index, *expressions]
675+
df = (
676+
df.select(expressions)
677+
.unpivot(index=index, variable_name=variable_name, value_name=".value")
678+
.unnest(".value")
679+
)
680+
return df
681+
682+
614683
def _pivot_longer_dot_value(
615684
df: pl.DataFrame | pl.LazyFrame,
616685
spec: pl.DataFrame,
@@ -621,7 +690,7 @@ def _pivot_longer_dot_value(
621690
) -> pl.DataFrame | pl.LazyFrame:
622691
"""
623692
flip polars Frame to long form,
624-
if names_sep and .value in names_to.
693+
if .value in names_to.
625694
"""
626695
spec = spec.group_by(variable_name)
627696
spec = spec.agg(pl.all())
@@ -634,25 +703,31 @@ def _pivot_longer_dot_value(
634703
expressions.append(expression)
635704
expressions = [*index, *expressions]
636705
spec = spec.get_column(variable_name)
706+
if dot_value_only:
707+
outcome = (
708+
df.select(expressions)
709+
.unpivot(
710+
index=index, variable_name=variable_name, value_name=".value"
711+
)
712+
.select(pl.exclude(variable_name))
713+
.unnest(".value")
714+
)
715+
return outcome
716+
637717
outcome = (
638718
df.select(expressions)
639719
.select(pl.all().implode())
640720
.unpivot(index=index, variable_name=variable_name, value_name=".value")
641721
.with_columns(spec)
642722
)
643723

644-
if dot_value_only:
645-
columns = [
646-
label for label in outcome.columns if label != variable_name
647-
]
648-
outcome = outcome.explode(columns).unnest(".value")
649-
outcome = outcome.select(pl.exclude(variable_name))
650-
return outcome
651724
outcome = outcome.unnest(variable_name)
652725
if names_transform is not None:
653726
outcome = outcome.with_columns(names_transform)
654727
columns = [
655-
label for label in outcome.columns if label not in spec.struct.fields
728+
label
729+
for label in outcome.collect_schema().names()
730+
if label not in spec.struct.fields
656731
]
657732
outcome = outcome.explode(columns)
658733
outcome = outcome.unnest(".value")
@@ -710,17 +785,17 @@ def _data_checks_pivot_longer(
710785
check("values_to", values_to, [str])
711786

712787
if (index is None) and (column_names is None):
713-
column_names = df.columns
788+
column_names = df.collect_schema().names()
714789
index = []
715790
elif (index is None) and (column_names is not None):
716-
column_names = df.select(column_names).columns
717-
index = df.select(pl.exclude(column_names)).columns
791+
column_names = df.select(column_names).collect_schema().names()
792+
index = df.select(pl.exclude(column_names)).collect_schema().names()
718793
elif (index is not None) and (column_names is None):
719-
index = df.select(index).columns
720-
column_names = df.select(pl.exclude(index)).columns
794+
index = df.select(index).collect_schema().names()
795+
column_names = df.select(pl.exclude(index)).collect_schema().names()
721796
else:
722-
index = df.select(index).columns
723-
column_names = df.select(column_names).columns
797+
index = df.select(index).collect_schema().names()
798+
column_names = df.select(column_names).collect_schema().names()
724799

725800
return (
726801
df,

0 commit comments

Comments
 (0)