Skip to content

Commit 70e7a7b

Browse files
samukwekusamuel.oranyeli
and
samuel.oranyeli
authored
[fix] remove restrictions for mutate/summarise (#1452)
remove restrictions for mutate/summarise Co-authored-by: samuel.oranyeli <[email protected]>
1 parent 02a4599 commit 70e7a7b

File tree

4 files changed

+18
-28
lines changed

4 files changed

+18
-28
lines changed

janitor/functions/mutate.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def mutate(
3333
mutate creates new columns that are functions of existing columns.
3434
It can also modify columns (if the name is the same as an existing column).
3535
36-
The argument provided to *args* should be either a dictionary, a tuple or a callable.
36+
The argument provided to *args* should be either
37+
a dictionary, a callable or a tuple; however,
38+
anything can be passed, as long as it can
39+
be aligned with the original DataFrame.
40+
3741
3842
- **dictionary argument**:
3943
If the argument is a dictionary,
@@ -193,10 +197,6 @@ def mutate(
193197

194198
@singledispatch
195199
def _mutator(arg, df, by):
196-
if not callable(arg):
197-
raise NotImplementedError(
198-
f"janitor.mutate is not supported for {type(arg)}"
199-
)
200200
if by is None:
201201
val = df
202202
else:
@@ -212,7 +212,7 @@ def _mutator(arg, df, by):
212212
df[column] = outcome[column]
213213
return df
214214
raise TypeError(
215-
"The output from a callable should be a named Series or a DataFrame"
215+
"The output from the mutation should be a named Series or a DataFrame"
216216
)
217217

218218

janitor/functions/summarise.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
import pandas_flavor as pf
1010
from pandas.api.types import is_scalar
11+
from pandas.core.common import apply_if_callable
1112
from pandas.core.groupby.generic import DataFrameGroupBy
1213

1314
from janitor.functions.select import get_index_labels
@@ -33,7 +34,10 @@ def summarise(
3334
the output will have a single row
3435
summarising all observations in the input.
3536
36-
The argument provided to *args* should be either a dictionary or a tuple.
37+
The argument provided to *args* should be either
38+
a dictionary, a callable or a tuple; however,
39+
anything can be passed, as long as it fits
40+
within pandas' aggregation semantics.
3741
3842
- **dictionary argument**:
3943
If the argument is a dictionary,
@@ -187,15 +191,11 @@ def summarise(
187191
values = map(is_scalar, dictionary.values())
188192
if all(values):
189193
return pd.Series(dictionary)
190-
return pd.concat(dictionary, axis=1, sort=False, copy=False)
194+
return pd.concat(dictionary, axis="columns", sort=False, copy=False)
191195

192196

193197
@singledispatch
194198
def _mutator(arg, df, by):
195-
if not callable(arg):
196-
raise NotImplementedError(
197-
f"janitor.summarise is not supported for {type(arg)}"
198-
)
199199
if by is None:
200200
val = df
201201
else:
@@ -205,9 +205,8 @@ def _mutator(arg, df, by):
205205
if not outcome.name:
206206
raise ValueError("Ensure the pandas Series object has a name")
207207
return {outcome.name: outcome}
208-
# assumption: should return a DataFrame
209-
outcome = {key: outcome[key] for key in outcome}
210-
return outcome
208+
# assumption: a mapping - DataFrame/dictionary/...
209+
return {**outcome}
211210

212211

213212
@_mutator.register(dict)
@@ -247,7 +246,7 @@ def _process_maybe_callable(func: callable, obj):
247246
try:
248247
column = obj.agg(func)
249248
except: # noqa: E722
250-
column = func(obj)
249+
column = apply_if_callable(maybe_callable=func, obj=obj)
251250
return column
252251

253252

tests/functions/test_mutate.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_mutate_callable(df_mutate):
4848
"Raise if output of callable is not a pandas Series/DataFrame"
4949
with pytest.raises(
5050
TypeError,
51-
match="The output from a callable should be a named Series or a DataFrame",
51+
match="The output from the mutation should be a named Series or a DataFrame",
5252
):
5353
df_mutate.mutate(lambda df: np.sum(df["avg_run"]))
5454

@@ -64,7 +64,8 @@ def test_mutate_wrong_arg(df_mutate):
6464
Raise if wrong arg is provided
6565
"""
6666
with pytest.raises(
67-
NotImplementedError, match="janitor.mutate is not supported for.+"
67+
TypeError,
68+
match="The output from the mutation should be a named Series or a DataFrame",
6869
):
6970
df_mutate.mutate(1)
7071

tests/functions/test_summarise.py

-10
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,6 @@ def df_summarise():
1313
return pd.DataFrame(data)
1414

1515

16-
def test_summarise_wrong_arg(df_summarise):
17-
"""
18-
Raise if wrong arg is provided
19-
"""
20-
with pytest.raises(
21-
NotImplementedError, match="janitor.summarise is not supported for.+"
22-
):
23-
df_summarise.summarise(1)
24-
25-
2616
def test_mutate_callable_series_unnamed(df_summarise):
2717
"""Test output for callable"""
2818
with pytest.raises(

0 commit comments

Comments
 (0)