Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] improve performance for polars' pivot_longer #1377

Merged
merged 31 commits into from
Jul 4, 2024
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ed3ff67
faster pivot_longer for non dot value
Jun 17, 2024
1568143
fix docs and tests
Jun 18, 2024
b5a89a9
fix docs and tests
Jun 18, 2024
ddfc230
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jun 18, 2024
58b2912
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jun 18, 2024
c20838b
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jun 18, 2024
f2e761c
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jun 19, 2024
5278936
fix doc
Jun 20, 2024
e8c3057
fix doc pivot_longer_spec
Jun 20, 2024
7c497cd
fix doc pivot_longer_spec
Jun 20, 2024
9fecc2b
Merge remote-tracking branch 'origin/dev' into samukweku/polars_pivot…
Jun 20, 2024
513fe73
updates
Jun 20, 2024
2399484
updates
Jun 20, 2024
49fc638
updates
Jun 20, 2024
6107948
fix docs
Jun 20, 2024
f2b956b
fix tests
Jun 20, 2024
d849cff
change sort logic for `complete`
Jun 22, 2024
aee2b09
updates to complete
Jun 22, 2024
6a5f66e
restore inital setup for complete
Jun 22, 2024
8ea3f56
remove dead code
Jun 22, 2024
cf350a3
use left join
Jun 23, 2024
8fe093c
update docs for pivot_longer
Jun 26, 2024
8dd1d82
WIP - expand
Jun 27, 2024
83296d1
Delete janitor/polars/expand.py
samukweku Jun 27, 2024
f1fab2e
remove expand
Jun 27, 2024
2b98614
remove expand
Jun 27, 2024
aecc4c2
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jun 28, 2024
1fc553e
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jun 28, 2024
a028079
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jul 3, 2024
a9e344c
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jul 3, 2024
08fe245
Merge dev into samukweku/polars_pivot_longer_improve
ericmjl Jul 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions janitor/polars/complete.py
Original file line number Diff line number Diff line change
@@ -385,14 +385,14 @@ def _complete(

no_columns_to_fill = set(df.columns) == set(uniques.columns)
if fill_value is None or no_columns_to_fill:
return uniques.join(df, on=uniques.columns, how="full", coalesce=True)
return uniques.join(df, on=uniques.columns, how="left", coalesce=True)
idx = None
columns_to_select = df.columns
if not explicit:
idx = "".join(df.columns)
idx = f"{idx}_"
df = df.with_row_index(name=idx)
df = uniques.join(df, on=uniques.columns, how="full", coalesce=True)
df = uniques.join(df, on=uniques.columns, how="left", coalesce=True)
# exclude columns that were not used
# to generate the combinations
exclude_columns = uniques.columns
444 changes: 235 additions & 209 deletions janitor/polars/pivot_longer.py
Original file line number Diff line number Diff line change
@@ -2,16 +2,12 @@

from __future__ import annotations

from collections import defaultdict
from typing import Any, Iterable

from janitor.utils import check, import_message

from .polars_flavor import register_dataframe_method, register_lazyframe_method

try:
import polars as pl
import polars.selectors as cs
from polars.type_aliases import ColumnNameOrSelector
except ImportError:
import_message(
@@ -37,14 +33,14 @@ def pivot_longer_spec(
becomes variables.
It can come in handy for situations where
`janitor.polars.pivot_longer`
[`pivot_longer`][janitor.polars.pivot_longer.pivot_longer]
seems inadequate for the transformation.
!!! info "New in version 0.28.0"
Examples:
>>> import pandas as pd
>>> import janitor.polars
>>> from janitor.polars import pivot_longer_spec
>>> df = pl.DataFrame(
... {
... "Sepal.Length": [5.1, 5.9],
@@ -81,18 +77,18 @@ def pivot_longer_spec(
│ Sepal.Width ┆ Width ┆ Sepal │
│ Petal.Width ┆ Width ┆ Petal │
└──────────────┴────────┴───────┘
>>> df.pipe(pivot_longer_spec,spec=spec)
>>> df.pipe(pivot_longer_spec,spec=spec).sort(by=pl.all())
shape: (4, 4)
┌───────────┬────────┬───────┬───────┐
│ Species ┆ Length ┆ Widthpart
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ str
╞═══════════╪════════╪═══════╪═══════╡
│ setosa ┆ 5.13.5 ┆ Sepal
virginica ┆ 5.9 ┆ 3.0 ┆ Sepal
setosa ┆ 1.40.2 ┆ Petal
│ virginica ┆ 5.11.8 ┆ Petal
└───────────┴────────┴───────┴───────┘
┌───────────┬───────┬────────┬───────┐
│ Species ┆ part ┆ LengthWidth
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ f64 ┆ f64
╞═══════════╪═══════╪════════╪═══════╡
│ setosa ┆ Petal ┆ 1.40.2
setosa ┆ Sepal ┆ 5.1 ┆ 3.5
virginica ┆ Petal ┆ 5.11.8
│ virginica ┆ Sepal ┆ 5.93.0
└───────────┴───────┴────────┴───────┘
Args:
df: The source DataFrame to unpivot.
@@ -140,17 +136,29 @@ def pivot_longer_spec(
"Kindly ensure the spec DataFrame's columns "
"are not present in the source DataFrame."
)

if spec.columns[:2] != [".name", ".value"]:
raise ValueError(
"The first two columns of the spec DataFrame "
"should be '.name' and '.value', "
"with '.name' coming before '.value'."
)

index = [
label for label in df.columns if label not in spec.get_column(".name")
]
others = [
label for label in spec.columns if label not in {".name", ".value"}
]
variable_name = "".join(df.columns + spec.columns)
variable_name = f"{variable_name}_"
if others:
dot_value_only = False
expression = pl.struct(others).alias(variable_name)
spec = spec.select(".name", ".value", expression)
else:
dot_value_only = True
expression = pl.cum_count(".value").over(".value").alias(variable_name)
spec = spec.with_columns(expression)
return _pivot_longer_dot_value(
df=df,
index=index,
spec=spec,
variable_name=variable_name,
dot_value_only=dot_value_only,
names_transform=None,
)


@@ -179,8 +187,11 @@ def pivot_longer(
All measured variables are *unpivoted* (and typically duplicated) along the
row axis.
If `names_pattern`, use a valid regular expression pattern containing at least
one capture group, compatible with the [regex crate](https://docs.rs/regex/latest/regex/).
For more granular control on the unpivoting, have a look at
`pivot_longer_spec`.
[`pivot_longer_spec`][janitor.polars.pivot_longer.pivot_longer_spec].
`pivot_longer` can also be applied to a LazyFrame.
@@ -209,61 +220,61 @@ def pivot_longer(
└──────────────┴─────────────┴──────────────┴─────────────┴───────────┘
Replicate polars' [melt](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.melt.html#polars-dataframe-melt):
>>> df.pivot_longer(index = 'Species')
>>> df.pivot_longer(index = 'Species').sort(by=pl.all())
shape: (8, 3)
┌───────────┬──────────────┬───────┐
│ Species ┆ variable ┆ value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ f64 │
╞═══════════╪══════════════╪═══════╡
│ setosa ┆ Petal.Length ┆ 1.4 │
│ setosa ┆ Petal.Width ┆ 0.2 │
│ setosa ┆ Sepal.Length ┆ 5.1 │
│ virginica ┆ Sepal.Length ┆ 5.9 │
│ setosa ┆ Sepal.Width ┆ 3.5 │
│ virginica ┆ Sepal.Width ┆ 3.0 │
│ setosa ┆ Petal.Length ┆ 1.4 │
│ virginica ┆ Petal.Length ┆ 5.1 │
│ setosa ┆ Petal.Width ┆ 0.2 │
│ virginica ┆ Petal.Width ┆ 1.8 │
│ virginica ┆ Sepal.Length ┆ 5.9 │
│ virginica ┆ Sepal.Width ┆ 3.0 │
└───────────┴──────────────┴───────┘
Split the column labels into individual columns:
>>> df.pivot_longer(
... index = 'Species',
... names_to = ('part', 'dimension'),
... names_sep = '.',
... ).select('Species','part','dimension','value')
... ).select('Species','part','dimension','value').sort(by=pl.all())
shape: (8, 4)
┌───────────┬───────┬───────────┬───────┐
│ Species ┆ part ┆ dimension ┆ value │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 │
╞═══════════╪═══════╪═══════════╪═══════╡
│ setosa ┆ Petal ┆ Length ┆ 1.4 │
│ setosa ┆ Petal ┆ Width ┆ 0.2 │
│ setosa ┆ Sepal ┆ Length ┆ 5.1 │
│ virginica ┆ Sepal ┆ Length ┆ 5.9 │
│ setosa ┆ Sepal ┆ Width ┆ 3.5 │
│ virginica ┆ Sepal ┆ Width ┆ 3.0 │
│ setosa ┆ Petal ┆ Length ┆ 1.4 │
│ virginica ┆ Petal ┆ Length ┆ 5.1 │
│ setosa ┆ Petal ┆ Width ┆ 0.2 │
│ virginica ┆ Petal ┆ Width ┆ 1.8 │
│ virginica ┆ Sepal ┆ Length ┆ 5.9 │
│ virginica ┆ Sepal ┆ Width ┆ 3.0 │
└───────────┴───────┴───────────┴───────┘
Retain parts of the column names as headers:
>>> df.pivot_longer(
... index = 'Species',
... names_to = ('part', '.value'),
... names_sep = '.',
... ).select('Species','part','Length','Width')
... ).select('Species','part','Length','Width').sort(by=pl.all())
shape: (4, 4)
┌───────────┬───────┬────────┬───────┐
│ Species ┆ part ┆ Length ┆ Width │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ f64 ┆ f64 │
╞═══════════╪═══════╪════════╪═══════╡
│ setosa ┆ Sepal ┆ 5.1 ┆ 3.5 │
│ virginica ┆ Sepal ┆ 5.9 ┆ 3.0 │
│ setosa ┆ Petal ┆ 1.4 ┆ 0.2 │
│ setosa ┆ Sepal ┆ 5.1 ┆ 3.5 │
│ virginica ┆ Petal ┆ 5.1 ┆ 1.8 │
│ virginica ┆ Sepal ┆ 5.9 ┆ 3.0 │
└───────────┴───────┴────────┴───────┘
Split the column labels based on regex:
@@ -393,7 +404,7 @@ def _pivot_longer(
df: pl.DataFrame | pl.LazyFrame,
index: ColumnNameOrSelector,
column_names: ColumnNameOrSelector,
names_to: list | tuple | str,
names_to: list | tuple | str | None,
values_to: str,
names_sep: str,
names_pattern: str,
@@ -403,6 +414,14 @@ def _pivot_longer(
Unpivots a DataFrame/LazyFrame from wide to long form.
"""

if all((names_pattern is None, names_sep is None)):
return df.melt(
id_vars=index,
value_vars=column_names,
variable_name=names_to,
value_name=values_to,
)

(
df,
index,
@@ -411,7 +430,6 @@ def _pivot_longer(
values_to,
names_sep,
names_pattern,
names_transform,
) = _data_checks_pivot_longer(
df=df,
index=index,
@@ -420,43 +438,53 @@ def _pivot_longer(
values_to=values_to,
names_sep=names_sep,
names_pattern=names_pattern,
names_transform=names_transform,
)

if not column_names:
return df

if all((names_pattern is None, names_sep is None)):
return df.melt(
id_vars=index,
value_vars=column_names,
variable_name=names_to,
value_name=values_to,
)

df = df.select(pl.col(index), pl.col(column_names))
if isinstance(names_to, str):
names_to = [names_to]

variable_name = "".join(df.columns)
variable_name = f"{variable_name}_"
spec = _pivot_longer_create_spec(
column_names=column_names,
names_to=names_to,
names_sep=names_sep,
names_pattern=names_pattern,
values_to=values_to,
names_transform=names_transform,
variable_name=variable_name,
)

return _pivot_longer_dot_value(df=df, spec=spec)
if ".value" not in names_to:
return _pivot_longer_no_dot_value(
df=df,
index=index,
spec=spec,
column_names=column_names,
names_to=names_to,
values_to=values_to,
variable_name=variable_name,
names_transform=names_transform,
)

if {".name", ".value"}.symmetric_difference(spec.columns):
dot_value_only = False
else:
dot_value_only = True
expression = pl.cum_count(".value").over(".value").alias(variable_name)
spec = spec.with_columns(expression)

return _pivot_longer_dot_value(
df=df,
index=index,
spec=spec,
variable_name=variable_name,
dot_value_only=dot_value_only,
names_transform=names_transform,
)


def _pivot_longer_create_spec(
column_names: Iterable,
names_to: Iterable,
column_names: list,
names_to: list,
names_sep: str | None,
names_pattern: str | None,
values_to: str,
names_transform: pl.Expr,
variable_name: str,
) -> pl.DataFrame:
"""
This is where the spec DataFrame is created,
@@ -468,16 +496,16 @@ def _pivot_longer_create_spec(
pl.col(".name")
.str.split(by=names_sep)
.list.to_struct(n_field_strategy="max_width")
.alias("extract")
.alias(variable_name)
)
else:
expression = (
pl.col(".name")
.str.extract_groups(pattern=names_pattern)
.alias("extract")
.alias(variable_name)
)
spec = spec.with_columns(expression)
len_fields = len(spec.get_column("extract").struct.fields)
len_fields = len(spec.get_column(variable_name).struct.fields)
len_names_to = len(names_to)

if len_names_to != len_fields:
@@ -492,7 +520,7 @@ def _pivot_longer_create_spec(
expression = pl.exclude(".name").is_null().any()
expression = pl.any_horizontal(expression)
null_check = (
spec.unnest(columns="extract")
spec.unnest(columns=variable_name)
.filter(expression)
.get_column(".name")
)
@@ -504,112 +532,132 @@ def _pivot_longer_create_spec(
"in the provided regex. Kindly provide a regular expression "
"(with the correct groups) that matches all labels in the columns."
)
if names_to.count(".value") < 2:
expression = pl.col("extract").struct.rename_fields(names=names_to)
spec = spec.with_columns(expression).unnest(columns="extract")
else:
spec = _squash_multiple_dot_value(spec=spec, names_to=names_to)

if ".value" not in names_to:
expression = pl.lit(value=values_to).alias(".value")
spec = spec.with_columns(expression)
spec = spec.get_column(variable_name)
spec = spec.struct.rename_fields(names=names_to)
return spec
if names_to.count(".value") == 1:
spec = spec.with_columns(
pl.col(variable_name).struct.rename_fields(names=names_to)
)
not_dot_value = [name for name in names_to if name != ".value"]
spec = spec.unnest(variable_name)
if not_dot_value:
return spec.select(
".name",
".value",
pl.struct(not_dot_value).alias(variable_name),
)
return spec.select(".name", ".value")
_spec = spec.get_column(variable_name)
_spec = _spec.struct.unnest()
fields = _spec.columns

if len(set(names_to)) == 1:
expression = pl.concat_str(fields).alias(".value")
dot_value = _spec.select(expression)
dot_value = dot_value.to_series(0)
return spec.select(".name", dot_value)
dot_value = [
field for field, label in zip(fields, names_to) if label == ".value"
]
dot_value = pl.concat_str(dot_value).alias(".value")
not_dot_value = [
pl.col(field).alias(label)
for field, label in zip(fields, names_to)
if label != ".value"
]
not_dot_value = pl.struct(not_dot_value).alias(variable_name)
return _spec.select(spec.get_column(".name"), not_dot_value, dot_value)


spec = spec.select(
pl.col([".name", ".value"]), pl.exclude([".name", ".value"])
def _pivot_longer_no_dot_value(
df: pl.DataFrame | pl.LazyFrame,
spec: pl.DataFrame,
index: ColumnNameOrSelector,
column_names: ColumnNameOrSelector,
names_to: list | tuple,
values_to: str,
variable_name: str,
names_transform: pl.Expr,
) -> pl.DataFrame | pl.LazyFrame:
"""
flip polars Frame to long form,
if no .value in names_to.
"""
# the implode/explode approach is used here
# for efficiency
# do the operation on a smaller size
# and then blow it up after
# it is usually much faster
# than running on the actual data
outcome = (
df.select(pl.all().implode())
.melt(
id_vars=index,
value_vars=column_names,
variable_name=variable_name,
value_name=values_to,
)
.with_columns(spec)
)

outcome = outcome.unnest(variable_name)
if names_transform is not None:
spec = spec.with_columns(names_transform)
return spec
outcome = outcome.with_columns(names_transform)
columns = [name for name in outcome.columns if name not in names_to]
outcome = outcome.explode(columns=columns)
return outcome


def _pivot_longer_dot_value(
df: pl.DataFrame | pl.LazyFrame, spec: pl.DataFrame
df: pl.DataFrame | pl.LazyFrame,
spec: pl.DataFrame,
index: ColumnNameOrSelector,
variable_name: str,
dot_value_only: bool,
names_transform: pl.Expr,
) -> pl.DataFrame | pl.LazyFrame:
"""
Reshape DataFrame to long form based on metadata in `spec`.
flip polars Frame to long form,
if names_sep and .value in names_to.
"""
index = [column for column in df.columns if column not in spec[".name"]]
not_dot_value = [
column for column in spec.columns if column not in {".name", ".value"}
]
idx = "".join(spec.columns)
if not_dot_value:
# assign a number to each group (grouped by not_dot_value)
expression = pl.first(idx).over(not_dot_value).rank("dense").sub(1)
spec = spec.with_row_index(name=idx).with_columns(expression)
else:
# use a cumulative count to properly pair the columns
# grouped by .value
expression = pl.cum_count(".value").over(".value").alias(idx)
spec = spec.with_columns(expression)
mapping = defaultdict(list)
for position, column_name, replacement_name in zip(
spec.get_column(name=idx),
spec.get_column(name=".name"),
spec.get_column(name=".value"),
spec = spec.group_by(variable_name)
spec = spec.agg(pl.all())
expressions = []
for names, fields in zip(
spec.get_column(".name").to_list(),
spec.get_column(".value").to_list(),
):
expression = pl.col(column_name).alias(replacement_name)
mapping[position].append(expression)

mapping = (
(
[
*index,
*columns_to_select,
],
pl.lit(position, dtype=pl.UInt32).alias(idx),
)
for position, columns_to_select in mapping.items()
expression = pl.struct(names).struct.rename_fields(names=fields)
expressions.append(expression)
expressions = [*index, *expressions]
spec = spec.get_column(variable_name)
outcome = (
df.select(expressions)
.select(pl.all().implode())
.melt(id_vars=index, variable_name=variable_name, value_name=".value")
.with_columns(spec)
)
df = [
df.select(columns_to_select).with_columns(position)
for columns_to_select, position in mapping
]
# rechunking can be expensive;
# however subsequent operations are faster
# since data is contiguous in memory
df = pl.concat(df, how="diagonal_relaxed", rechunk=True)
expression = pl.cum_count(".value").over(".value").eq(1)
dot_value = spec.filter(expression).select(".value")
columns_to_select = [*index, *dot_value.to_series(0)]
if not_dot_value:
if isinstance(df, pl.LazyFrame):
ranges = df.select(idx).collect().get_column(idx)
else:
ranges = df.get_column(idx)
spec = spec.select(pl.struct(not_dot_value))
_value = spec.columns[0]
expression = pl.cum_count(_value).over(_value).eq(1)
# using a gather approach, instead of a join
# offers more performance - not sure why
# maybe in the join there is another rechunking?
spec = spec.filter(expression).select(pl.col(_value).gather(ranges))
df = df.with_columns(spec).unnest(_value)
columns_to_select.extend(not_dot_value)
return df.select(columns_to_select)


def _squash_multiple_dot_value(
spec: pl.DataFrame, names_to: Iterable
) -> pl.DataFrame:
"""
Combine multiple .values into a single .value column
"""
extract = spec.get_column("extract")
fields = extract.struct.fields
dot_value = [
field for field, label in zip(fields, names_to) if label == ".value"
]
dot_value = pl.concat_str(dot_value).alias(".value")
not_dot_value = [
pl.col(field).alias(label)
for field, label in zip(fields, names_to)
if label != ".value"

if dot_value_only:
columns = [
label for label in outcome.columns if label != variable_name
]
outcome = outcome.explode(columns).unnest(".value")
outcome = outcome.select(pl.exclude(variable_name))
return outcome
outcome = outcome.unnest(variable_name)
if names_transform is not None:
outcome = outcome.with_columns(names_transform)
columns = [
label for label in outcome.columns if label not in spec.struct.fields
]
select_expr = [".name", dot_value]
if not_dot_value:
select_expr.extend(not_dot_value)
outcome = outcome.explode(columns)
outcome = outcome.unnest(".value")

return spec.unnest("extract").select(select_expr)
return outcome


def _data_checks_pivot_longer(
@@ -620,7 +668,6 @@ def _data_checks_pivot_longer(
values_to,
names_sep,
names_pattern,
names_transform,
) -> tuple:
"""
This function majorly does type checks on the passed arguments.
@@ -630,57 +677,24 @@ def _data_checks_pivot_longer(
Type annotations are not provided because this function is where type
checking happens.
"""

def _check_type(arg_name: str, arg_value: Any):
"""
Raise if argument is not a valid type
"""

def _check_type_single(entry):
if (
not isinstance(entry, str)
and not cs.is_selector(entry)
and not isinstance(entry, pl.Expr)
):
raise TypeError(
f"The argument passed to the {arg_name} parameter "
"should be a type that is supported in the polars' "
"select function."
)

if isinstance(arg_value, (list, tuple)):
for entry in arg_value:
_check_type_single(entry=entry)
else:
_check_type_single(entry=arg_value)

if (index is None) and (column_names is None):
column_names = df.columns
index = []
elif (index is not None) and (column_names is not None):
_check_type(arg_name="index", arg_value=index)
index = df.select(index).columns
_check_type(arg_name="column_names", arg_value=column_names)
column_names = df.select(column_names).columns

elif (index is None) and (column_names is not None):
_check_type(arg_name="column_names", arg_value=column_names)
column_names = df.select(column_names).columns
index = df.select(pl.exclude(column_names)).columns

elif (index is not None) and (column_names is None):
_check_type(arg_name="index", arg_value=index)
index = df.select(index).columns
column_names = df.select(pl.exclude(index)).columns

check("names_to", names_to, [list, tuple, str])
if isinstance(names_to, (list, tuple)):
if isinstance(names_to, str):
names_to = [names_to]
elif isinstance(names_to, (list, tuple)):
uniques = set()
for word in names_to:
check(f"'{word}' in names_to", word, [str])
if not isinstance(word, str):
raise TypeError(
f"'{word}' in names_to should be a string type; "
f"instead got type {type(word).__name__}"
)
if (word in uniques) and (word != ".value"):
raise ValueError(f"'{word}' is duplicated in names_to.")
uniques.add(word)
else:
raise TypeError(
"names_to should be a string, list, or tuple; "
f"instead got type {type(names_to).__name__}"
)

if names_sep and names_pattern:
raise ValueError(
@@ -690,11 +704,24 @@ def _check_type_single(entry):
if names_sep is not None:
check("names_sep", names_sep, [str])

if names_pattern is not None:
else:
check("names_pattern", names_pattern, [str])

check("values_to", values_to, [str])

if (index is None) and (column_names is None):
column_names = df.columns
index = []
elif (index is None) and (column_names is not None):
column_names = df.select(column_names).columns
index = df.select(pl.exclude(column_names)).columns
elif (index is not None) and (column_names is None):
index = df.select(index).columns
column_names = df.select(pl.exclude(index)).columns
else:
index = df.select(index).columns
column_names = df.select(column_names).columns

return (
df,
index,
@@ -703,5 +730,4 @@ def _check_type_single(entry):
values_to,
names_sep,
names_pattern,
names_transform,
)
133 changes: 42 additions & 91 deletions tests/polars/functions/test_pivot_longer_polars.py
Original file line number Diff line number Diff line change
@@ -19,25 +19,9 @@ def df_checks():
)


def test_type_index(df_checks):
"""Raise TypeError if wrong type is provided for the index."""
msg = "The argument passed to the index parameter "
msg += "should be a type that is supported in the.+"
with pytest.raises(TypeError, match=msg):
df_checks.pivot_longer(index=2007, names_sep="_")


def test_type_column_names(df_checks):
"""Raise TypeError if wrong type is provided for column_names."""
msg = "The argument passed to the column_names parameter "
msg += "should be a type that is supported in the.+"
with pytest.raises(TypeError, match=msg):
df_checks.pivot_longer(column_names=2007, names_sep="_")


def test_type_names_to(df_checks):
"""Raise TypeError if wrong type is provided for names_to."""
msg = "names_to should be one of .+"
msg = "names_to should be a string, list, or tuple.+"
with pytest.raises(TypeError, match=msg):
df_checks.pivot_longer(names_to=2007, names_sep="_")

@@ -90,38 +74,6 @@ def test_values_to_wrong_type(df_checks):
df_checks.pivot_longer(values_to={"salvo"}, names_sep="_")


def test_pivot_index_only(df_checks):
"""Test output if only index is passed."""
result = df_checks.pivot_longer(
index=["famid", "birth"],
names_to="dim",
values_to="num",
)

actual = df_checks.melt(
id_vars=["famid", "birth"], variable_name="dim", value_name="num"
)

assert_frame_equal(result, actual, check_column_order=False)


def test_pivot_column_only(df_checks):
"""Test output if only column_names is passed."""
result = df_checks.pivot_longer(
column_names=["ht1", "ht2"],
names_to="dim",
values_to="num",
)

actual = df_checks.melt(
id_vars=["famid", "birth"],
variable_name="dim",
value_name="num",
)

assert_frame_equal(result, actual, check_column_order=False)


def test_names_to_names_pattern_len(df_checks):
""" "
Raise ValueError
@@ -161,12 +113,16 @@ def test_names_pat_str(df_checks):
Test output when names_pattern is a string,
and .value is present.
"""
result = df_checks.pivot_longer(
column_names=cs.starts_with("ht"),
names_to=(".value", "age"),
names_pattern="(.+)(.)",
names_transform=pl.col("age").cast(pl.Int64),
).sort(by=pl.all())
result = (
df_checks.pivot_longer(
index=["famid", "birth"],
names_to=(".value", "age"),
names_pattern="(.+)(.)",
names_transform=pl.col("age").cast(pl.Int64),
)
.select("famid", "birth", "age", "ht")
.sort(by=pl.all())
)

actual = [
{"famid": 1, "birth": 1, "age": 1, "ht": 2.8},
@@ -190,20 +146,7 @@ def test_names_pat_str(df_checks):
]
actual = pl.DataFrame(actual).sort(by=pl.all())

assert_frame_equal(
result, actual, check_dtype=False, check_column_order=False
)


def test_no_column_names(df_checks):
"""
Test output if all the columns
are assigned to the index parameter.
"""
assert_frame_equal(
df_checks.pivot_longer(index=pl.all()),
df_checks,
)
assert_frame_equal(result, actual)


@pytest.fixture
@@ -310,23 +253,31 @@ def test_df():
def test_names_pattern_dot_value(test_df):
"""Test output for names_pattern and .value."""

result = test_df.pivot_longer(
column_names=pl.all(),
names_to=["set", ".value"],
names_pattern="(.+)_(.+)",
).sort(by=["loc", "lat", "long"])
assert_frame_equal(result, actual, check_column_order=False)
result = (
test_df.pivot_longer(
column_names=cs.all(),
names_to=["set", ".value"],
names_pattern="(.+)_(.+)",
)
.sort(by=["loc", "lat", "long"])
.select("set", "loc", "lat", "long")
)
assert_frame_equal(result, actual)


def test_names_sep_dot_value(test_df):
"""Test output for names_pattern and .value."""

result = test_df.pivot_longer(
column_names=pl.all(),
names_to=["set", ".value"],
names_sep="_",
).sort(by=["loc", "lat", "long"])
assert_frame_equal(result, actual, check_column_order=False)
result = (
test_df.pivot_longer(
column_names=cs.all(),
names_to=["set", ".value"],
names_sep="_",
)
.sort(by=["loc", "lat", "long"])
.select("set", "loc", "lat", "long")
)
assert_frame_equal(result, actual)


@pytest.fixture
@@ -388,7 +339,7 @@ def test_not_dot_value_sep2(not_dot_value):
"country", variable_name="event", value_name="score"
)

assert_frame_equal(result, actual, check_column_order=False)
assert_frame_equal(result, actual)


def test_not_dot_value_pattern(not_dot_value):
@@ -460,7 +411,7 @@ def test_multiple_dot_value():

actual = pl.DataFrame(actual).sort(by=pl.all())

assert_frame_equal(result, actual, check_column_order=False)
assert_frame_equal(result, actual)


@pytest.fixture
@@ -482,7 +433,7 @@ def test_multiple_dot_value2(single_val):
index="id", names_to=(".value", ".value"), names_pattern="(.)(.)"
)

assert_frame_equal(result, single_val, check_column_order=False)
assert_frame_equal(result, single_val)


actual3 = [
@@ -506,7 +457,7 @@ def test_names_pattern_single_column(single_val):
"id", names_to=".value", names_pattern="(.)."
)

assert_frame_equal(result, actual3, check_column_order=False)
assert_frame_equal(result.sort(by=pl.all()), actual3.sort(by=pl.all()))


def test_names_pattern_single_column_not_dot_value(single_val):
@@ -515,27 +466,27 @@ def test_names_pattern_single_column_not_dot_value(single_val):
"""
result = single_val.pivot_longer(
index="id", column_names="x1", names_to="yA", names_pattern="(.+)"
)
).select("id", "yA", "value")

assert_frame_equal(
result,
single_val.melt(id_vars="id", value_vars="x1", variable_name="yA"),
check_column_order=False,
)


def test_names_pattern_single_column_not_dot_value1(single_val):
"""
Test output if names_to is not '.value'.
"""
result = single_val.select("x1").pivot_longer(
names_to="yA", names_pattern="(.+)"
result = (
single_val.select("x1")
.pivot_longer(names_to="yA", names_pattern="(.+)")
.select("yA", "value")
)

assert_frame_equal(
result,
single_val.select("x1").melt(variable_name="yA"),
check_column_order=False,
)


@@ -592,4 +543,4 @@ def test_names_pattern_nulls_in_data(df_null):

actual = pl.DataFrame(actual).sort(by=pl.all())

assert_frame_equal(result, actual, check_column_order=False)
assert_frame_equal(result, actual)