Skip to content

Commit 383fad1

Browse files
authored
[ENH] Add support for extension arrays to expand_grid (pyjanitor-devs#1122)
1 parent 8327ec8 commit 383fad1

File tree

5 files changed

+33
-38
lines changed

5 files changed

+33
-38
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- [ENH] `min_max_scale` drop `old_min` and `old_max` to fit sklearn's method API. Issue #1068 @Zeroto521
1111
- [ENH] Add `jointly` option for `min_max_scale` support to transform each column values or entire values. Default transform each column, similar behavior to `sklearn.preprocessing.MinMaxScaler`. (Issue #1067, PR #1112, PR #1123) @Zeroto521
1212
- [INF] Require pyspark minimal version is v3.2.0 to cut duplicates codes. Issue #1110 @Zeroto521
13+
- [ENH] Added support for extension arrays in `expand_grid`. Issue #1121 @samukweku
1314

1415

1516
## [v0.23.1] - 2022-05-03

examples/notebooks/expand_grid.ipynb

+3-5
Original file line numberDiff line numberDiff line change
@@ -1108,11 +1108,9 @@
11081108
}
11091109
],
11101110
"metadata": {
1111-
"interpreter": {
1112-
"hash": "98b0a9b7b4eaaa670588a142fd0a9b87eaafe866f1db4228be72b4211d12040f"
1113-
},
11141111
"kernelspec": {
1115-
"display_name": "Python 3.8.10 64-bit ('base': conda)",
1112+
"display_name": "Python 3.9.10 ('base')",
1113+
"language": "python",
11161114
"name": "python3"
11171115
},
11181116
"language_info": {
@@ -1125,7 +1123,7 @@
11251123
"name": "python",
11261124
"nbconvert_exporter": "python",
11271125
"pygments_lexer": "ipython3",
1128-
"version": "3.9.9"
1126+
"version": "3.9.10"
11291127
},
11301128
"orig_nbformat": 2
11311129
},

janitor/functions/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _computations_expand_grid(others: dict) -> pd.DataFrame:
188188
grid = ((*left, right) for left, right in grid)
189189
contents = {}
190190
for key, value, grid_index in grid:
191-
contents = {**contents, **_expand_grid(value, grid_index, key)}
191+
contents.update(_expand_grid(value, grid_index, key))
192192
return pd.DataFrame(contents, copy=False)
193193

194194

janitor/utils.py

+9-32
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import numpy as np
1111
import pandas as pd
12-
from pandas.core.construction import extract_array
1312

1413

1514
def check(varname: str, value, expected_types: list):
@@ -57,7 +56,6 @@ def _expand_grid(value, grid_index, key):
5756
def _sub_expand_grid(value, grid_index, key): # noqa: F811
5857
"""
5958
Expands the numpy array based on `grid_index`.
60-
6159
Returns a dictionary.
6260
"""
6361

@@ -73,43 +71,37 @@ def _sub_expand_grid(value, grid_index, key): # noqa: F811
7371
if value.ndim == 1:
7472
return {(key, 0): value}
7573

76-
return {(key, num): arr for num, arr in enumerate(value.T)}
74+
return {(key, num): value[:, num] for num in range(value.shape[-1])}
7775

7876

77+
@_expand_grid.register(pd.api.extensions.ExtensionArray)
7978
@_expand_grid.register(pd.arrays.PandasArray)
8079
def _sub_expand_grid(value, grid_index, key): # noqa: F811
8180
"""
8281
Expands the pandas array based on `grid_index`.
83-
8482
Returns a dictionary.
8583
"""
8684

87-
value = value[grid_index]
88-
89-
return {(key, 0): value}
85+
return {(key, 0): value[grid_index]}
9086

9187

88+
@_expand_grid.register(pd.Index)
9289
@_expand_grid.register(pd.Series)
9390
def _sub_expand_grid(value, grid_index, key): # noqa: F811
9491
"""
95-
Expands the Series based on `grid_index`.
96-
92+
Expands the pd.Series/pd.Index based on `grid_index`.
9793
Returns a dictionary.
9894
"""
9995

100-
name = value.name
101-
if not name:
102-
name = 0
103-
value = extract_array(value, extract_numpy=True)[grid_index]
96+
name = value.name or 0
10497

105-
return {(key, name): value}
98+
return {(key, name): value._values[grid_index]}
10699

107100

108101
@_expand_grid.register(pd.DataFrame)
109102
def _sub_expand_grid(value, grid_index, key): # noqa: F811
110103
"""
111104
Expands the DataFrame based on `grid_index`.
112-
113105
Returns a dictionary.
114106
"""
115107

@@ -120,16 +112,14 @@ def _sub_expand_grid(value, grid_index, key): # noqa: F811
120112
value = value.set_axis(columns, axis="columns")
121113

122114
return {
123-
(key, name): extract_array(val, extract_numpy=True)[grid_index]
124-
for name, val in value.items()
115+
(key, name): val._values[grid_index] for name, val in value.items()
125116
}
126117

127118

128119
@_expand_grid.register(pd.MultiIndex)
129120
def _sub_expand_grid(value, grid_index, key): # noqa: F811
130121
"""
131122
Expands the MultiIndex based on `grid_index`.
132-
133123
Returns a dictionary.
134124
"""
135125

@@ -138,27 +128,14 @@ def _sub_expand_grid(value, grid_index, key): # noqa: F811
138128
for n in range(value.nlevels):
139129
arr = value.get_level_values(n)
140130
name = arr.name
141-
arr = extract_array(arr, extract_numpy=True)[grid_index]
131+
arr = arr._values[grid_index]
142132
if not name:
143133
name = num
144134
num += 1
145135
contents[(key, name)] = arr
146136
return contents
147137

148138

149-
@_expand_grid.register(pd.Index)
150-
def _sub_expand_grid(value, grid_index, key): # noqa: F811
151-
"""
152-
Expands the Index based on `grid_index`.
153-
154-
Returns a dictionary.
155-
"""
156-
name = value.name
157-
if not name:
158-
name = 0
159-
return {(key, name): extract_array(value, extract_numpy=True)[grid_index]}
160-
161-
162139
def import_message(
163140
submodule: str,
164141
package: str,

tests/functions/test_expand_grid.py

+19
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
categoricaldf_strategy,
99
)
1010
from janitor.functions import expand_grid
11+
from functools import reduce
1112

1213

1314
@given(df=df_strategy())
@@ -324,3 +325,21 @@ def test_series_name(df):
324325
[["city", "A"], ["cities", 0]]
325326
)
326327
assert_frame_equal(result, expected)
328+
329+
330+
def test_extension_array():
331+
"""Test output on an extension array"""
332+
others = dict(
333+
id=pd.Categorical(
334+
values=(2, 1, 1, 2, 1), categories=(1, 2, 3), ordered=True
335+
),
336+
year=(2018, 2018, 2019, 2020, 2020),
337+
gender=pd.Categorical(("female", "male", "male", "female", "male")),
338+
)
339+
340+
expected = expand_grid(others=others).droplevel(axis=1, level=-1)
341+
others = [pd.Series(val).rename(key) for key, val in others.items()]
342+
343+
func = lambda x, y: pd.merge(x, y, how="cross") # noqa: E731
344+
actual = reduce(func, others)
345+
assert_frame_equal(expected, actual)

0 commit comments

Comments
 (0)