Skip to content

Commit 84f9d25

Browse files
author
samuel.oranyeli
committed
update for polars row to names
1 parent 83bb7c0 commit 84f9d25

File tree

4 files changed

+93
-90
lines changed

4 files changed

+93
-90
lines changed

janitor/functions/row_to_names.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _row_to_names_dispatch( # noqa: F811
165165
len_df = len(df_)
166166
arrays = [arr._values for _, arr in df_.items()]
167167
if remove_rows_above and remove_rows:
168-
indexer = np.arange(row_numbers.stop + 1, len_df)
168+
indexer = np.arange(row_numbers.stop, len_df)
169169
elif remove_rows_above:
170170
indexer = np.arange(row_numbers.start, len_df)
171171
elif remove_rows:

janitor/polars/row_to_names.py

+76-69
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from functools import singledispatch
6+
57
from janitor.utils import check, import_message
68

79
from .polars_flavor import register_dataframe_method
@@ -28,8 +30,6 @@ def row_to_names(
2830
"""
2931
Elevates a row, or rows, to be the column names of a DataFrame.
3032
31-
For a LazyFrame, the user should materialize into a DataFrame before using `row_to_names`..
32-
3333
Examples:
3434
Replace column names with the first row.
3535
@@ -103,8 +103,7 @@ def row_to_names(
103103
104104
Args:
105105
row_numbers: Position of the row(s) containing the variable names.
106-
Note that indexing starts from 0. It can also be a list/slice.
107-
Defaults to 0 (first row).
106+
It can be an integer, list or a slice.
108107
remove_rows: Whether the row(s) should be removed from the DataFrame.
109108
remove_rows_above: Whether the row(s) above the selected row should
110109
be removed from the DataFrame.
@@ -115,85 +114,93 @@ def row_to_names(
115114
A polars DataFrame.
116115
""" # noqa: E501
117116
return _row_to_names(
117+
row_numbers,
118118
df=df,
119-
row_numbers=row_numbers,
120119
remove_rows=remove_rows,
121120
remove_rows_above=remove_rows_above,
122121
separator=separator,
123122
)
124123

125124

125+
@singledispatch
126126
def _row_to_names(
127-
df: pl.DataFrame,
128-
row_numbers: int | list | slice,
129-
remove_rows: bool,
130-
remove_rows_above: bool,
131-
separator: str,
127+
row_numbers, df, remove_rows, remove_rows_above, separator
132128
) -> pl.DataFrame:
133129
"""
134-
Function to convert rows in the DataFrame to column names.
130+
Base function for row_to_names.
135131
"""
136-
check("separator", separator, [str])
137-
if isinstance(row_numbers, int):
138-
row_numbers = slice(row_numbers, row_numbers + 1)
139-
elif isinstance(row_numbers, slice):
140-
if row_numbers.step is not None:
141-
raise ValueError(
142-
"The step argument for slice is not supported in row_to_names."
143-
)
144-
elif isinstance(row_numbers, list):
145-
for entry in row_numbers:
146-
check("entry in the row_numbers argument", entry, [int])
147-
else:
148-
raise TypeError(
149-
"row_numbers should be either an integer, "
150-
"a slice or a list; "
151-
f"instead got type {type(row_numbers).__name__}"
132+
raise TypeError(
133+
"row_numbers should be either an integer, "
134+
"a slice or a list; "
135+
f"instead got type {type(row_numbers).__name__}"
136+
)
137+
138+
139+
@_row_to_names.register(int) # noqa: F811
140+
def _row_to_names_dispatch( # noqa: F811
141+
row_numbers, df, remove_rows, remove_rows_above, separator
142+
):
143+
expression = pl.col("*").cast(pl.String).gather(row_numbers)
144+
expression = pl.struct(expression)
145+
headers = df.select(expression).to_series(0).to_list()[0]
146+
df = df.rename(mapping=headers)
147+
if remove_rows_above and remove_rows:
148+
return df.slice(row_numbers + 1)
149+
elif remove_rows_above:
150+
return df.slice(row_numbers)
151+
elif remove_rows:
152+
expression = pl.int_range(pl.len()).ne(row_numbers)
153+
return df.filter(expression)
154+
return df
155+
156+
157+
@_row_to_names.register(slice) # noqa: F811
158+
def _row_to_names_dispatch( # noqa: F811
159+
row_numbers, df, remove_rows, remove_rows_above, separator
160+
):
161+
if row_numbers.step is not None:
162+
raise ValueError(
163+
"The step argument for slice is not supported in row_to_names."
152164
)
153-
is_a_slice = isinstance(row_numbers, slice)
154-
if is_a_slice:
155-
expression = pl.all().str.concat(delimiter=separator)
156-
expression = pl.struct(expression)
157-
offset = row_numbers.start
158-
length = row_numbers.stop - row_numbers.start
159-
mapping = df.slice(
160-
offset=offset,
161-
length=length,
165+
headers = df.slice(row_numbers.start, row_numbers.stop - row_numbers.start)
166+
headers = headers.cast(pl.String)
167+
expression = pl.all().str.concat(delimiter=separator)
168+
expression = pl.struct(expression)
169+
headers = headers.select(expression).to_series(0).to_list()[0]
170+
df = df.rename(mapping=headers)
171+
if remove_rows_above and remove_rows:
172+
return df.slice(row_numbers.stop)
173+
elif remove_rows_above:
174+
return df.slice(row_numbers.start)
175+
elif remove_rows:
176+
expression = pl.int_range(pl.len()).is_between(
177+
row_numbers.start, row_numbers.stop, closed="left"
162178
)
163-
mapping = mapping.select(expression)
164-
else:
165-
expression = pl.all().gather(row_numbers)
166-
expression = expression.str.concat(delimiter=separator)
167-
expression = pl.struct(expression)
168-
mapping = df.select(expression)
169-
170-
mapping = mapping.to_series(0)[0]
171-
df = df.rename(mapping=mapping)
172-
if remove_rows_above:
173-
if not is_a_slice:
174-
raise ValueError(
175-
"The remove_rows_above argument is applicable "
176-
"only if the row_numbers argument is an integer "
177-
"or a slice."
178-
)
179-
if remove_rows:
180-
return df.slice(offset=row_numbers.stop)
181-
return df.slice(offset=row_numbers.start)
179+
return df.filter(~expression)
180+
return df
182181

183-
if remove_rows:
184-
if is_a_slice:
185-
df = [
186-
df.slice(offset=0, length=row_numbers.start),
187-
df.slice(offset=row_numbers.stop),
188-
]
189-
return pl.concat(df, rechunk=True)
190-
name = "".join(df.columns)
191-
name = f"{name}_"
192-
df = (
193-
df.with_row_index(name=name)
194-
.filter(pl.col(name=name).is_in(row_numbers).not_())
195-
.select(pl.exclude(name))
182+
183+
@_row_to_names.register(list) # noqa: F811
184+
def _row_to_names_dispatch( # noqa: F811
185+
row_numbers, df, remove_rows, remove_rows_above, separator
186+
):
187+
if remove_rows_above:
188+
raise ValueError(
189+
"The remove_rows_above argument is applicable "
190+
"only if the row_numbers argument is an integer "
191+
"or a slice."
196192
)
197-
return df
198193

194+
for entry in row_numbers:
195+
check("entry in the row_numbers argument", entry, [int])
196+
197+
expression = pl.col("*").gather(row_numbers)
198+
headers = df.select(expression).cast(pl.String)
199+
expression = pl.all().str.concat(delimiter=separator)
200+
expression = pl.struct(expression)
201+
headers = headers.select(expression).to_series(0).to_list()[0]
202+
df = df.rename(mapping=headers)
203+
if remove_rows:
204+
expression = pl.int_range(pl.len()).is_in(row_numbers)
205+
return df.filter(~expression)
199206
return df

tests/functions/test_row_to_names.py

+16
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ def test_row_to_names_delete_above_slice(dataframe):
7474
assert df.iloc[0, 4] == "Basel"
7575

7676

77+
@pytest.mark.functions
78+
def test_row_to_names_delete_above_delete_rows(dataframe):
79+
"""
80+
Test output for remove_rows=True
81+
and remove_rows_above=True
82+
"""
83+
df = dataframe.row_to_names(
84+
slice(2, 4), remove_rows=True, remove_rows_above=True
85+
)
86+
assert df.iloc[0, 0] == 2
87+
assert df.iloc[0, 1] == 2.456234
88+
assert df.iloc[0, 2] == 2
89+
assert df.iloc[0, 3] == "leopard"
90+
assert df.iloc[0, 4] == "Shanghai"
91+
92+
7793
@pytest.mark.functions
7894
def test_row_to_names_delete_above_is_a_list(dataframe):
7995
"Raise if row_numbers is a list"

tests/polars/functions/test_row_to_names_polars.py

-20
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@ def df():
1717
)
1818

1919

20-
def test_separator_type(df):
21-
"""
22-
Raise if separator is not a string
23-
"""
24-
with pytest.raises(TypeError, match="separator should be.+"):
25-
df.row_to_names([1, 2], separator=1)
26-
27-
2820
def test_row_numbers_type(df):
2921
"""
3022
Raise if row_numbers is not an int/slice/list
@@ -88,8 +80,6 @@ def test_row_to_names_list(df):
8880

8981
def test_row_to_names_delete_this_row(df):
9082
df = df.row_to_names(2, remove_rows=True)
91-
if isinstance(df, pl.LazyFrame):
92-
df = df.collect()
9383
assert df.to_series(0)[0] == 1.234_523_45
9484
assert df.to_series(1)[0] == 1
9585
assert df.to_series(2)[0] == "rabbit"
@@ -98,8 +88,6 @@ def test_row_to_names_delete_this_row(df):
9888

9989
def test_row_to_names_list_delete_this_row(df):
10090
df = df.row_to_names([2], remove_rows=True)
101-
if isinstance(df, pl.LazyFrame):
102-
df = df.collect()
10391
assert df.to_series(0)[0] == 1.234_523_45
10492
assert df.to_series(1)[0] == 1
10593
assert df.to_series(2)[0] == "rabbit"
@@ -108,8 +96,6 @@ def test_row_to_names_list_delete_this_row(df):
10896

10997
def test_row_to_names_delete_above(df):
11098
df = df.row_to_names(2, remove_rows_above=True)
111-
if isinstance(df, pl.LazyFrame):
112-
df = df.collect()
11399
assert df.to_series(0)[0] == 3.234_612_5
114100
assert df.to_series(1)[0] == 3
115101
assert df.to_series(2)[0] == "lion"
@@ -119,8 +105,6 @@ def test_row_to_names_delete_above(df):
119105
def test_row_to_names_delete_above_list(df):
120106
"Test output if row_numbers is a list"
121107
df = df.row_to_names(slice(2, 4), remove_rows_above=True)
122-
if isinstance(df, pl.LazyFrame):
123-
df = df.collect()
124108
assert df.to_series(0)[0] == 3.234_612_5
125109
assert df.to_series(1)[0] == 3
126110
assert df.to_series(2)[0] == "lion"
@@ -133,8 +117,6 @@ def test_row_to_names_delete_above_delete_rows(df):
133117
and remove_rows_above=True
134118
"""
135119
df = df.row_to_names(slice(2, 4), remove_rows=True, remove_rows_above=True)
136-
if isinstance(df, pl.LazyFrame):
137-
df = df.collect()
138120
assert df.to_series(0)[0] == 2.456234
139121
assert df.to_series(1)[0] == 2
140122
assert df.to_series(2)[0] == "leopard"
@@ -147,8 +129,6 @@ def test_row_to_names_delete_above_delete_rows_scalar(df):
147129
and remove_rows_above=True
148130
"""
149131
df = df.row_to_names(2, remove_rows=True, remove_rows_above=True)
150-
if isinstance(df, pl.LazyFrame):
151-
df = df.collect()
152132
assert df.to_series(0)[0] == 1.23452345
153133
assert df.to_series(1)[0] == 1
154134
assert df.to_series(2)[0] == "rabbit"

0 commit comments

Comments
 (0)