Skip to content

Commit cf7f4ce

Browse files
authored
Merge pull request #221 from lincc-frameworks/result_nesting
Enable Inference of Nested Structures to Reduce outputs
2 parents 6b7a45f + a6b134a commit cf7f4ce

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

src/nested_pandas/nestedframe/core.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def sort_values(
845845
return None
846846
return new_df
847847

848-
def reduce(self, func, *args, **kwargs) -> NestedFrame: # type: ignore[override]
848+
def reduce(self, func, *args, infer_nesting=True, **kwargs) -> NestedFrame: # type: ignore[override]
849849
"""
850850
Takes a function and applies it to each top-level row of the NestedFrame.
851851
@@ -862,6 +862,12 @@ def reduce(self, func, *args, **kwargs) -> NestedFrame: # type: ignore[override
862862
args : positional arguments
863863
Positional arguments to pass to the function, the first *args should be the names of the
864864
columns to apply the function to.
865+
infer_nesting : bool, default True
866+
If True, the function will pack output columns into nested
867+
structures based on column names adhering to a nested naming
868+
scheme. E.g. "nested.b" and "nested.c" will be packed into a column
869+
called "nested" with columns "b" and "c". If False, all outputs
870+
will be returned as base columns.
865871
kwargs : keyword arguments, optional
866872
Keyword arguments to pass to the function.
867873
@@ -915,7 +921,30 @@ def reduce(self, func, *args, **kwargs) -> NestedFrame: # type: ignore[override
915921
iterators.append(self[layer].array.iter_field_lists(col))
916922

917923
results = [func(*cols, *extra_args, **kwargs) for cols in zip(*iterators)]
918-
return NestedFrame(results, index=self.index)
924+
results_nf = NestedFrame(results, index=self.index)
925+
926+
if infer_nesting:
927+
# find potential nested structures from columns
928+
nested_cols = list(
929+
np.unique(
930+
[
931+
column.split(".", 1)[0]
932+
for column in results_nf.columns
933+
if isinstance(column, str) and "." in column
934+
]
935+
)
936+
)
937+
938+
# pack results into nested structures
939+
for layer in nested_cols:
940+
layer_cols = [col for col in results_nf.columns if col.startswith(f"{layer}.")]
941+
rename_df = results_nf[layer_cols].rename(columns=lambda x: x.split(".", 1)[1])
942+
nested_col = pack_lists(rename_df, name=layer)
943+
results_nf = results_nf[
944+
[col for col in results_nf.columns if not col.startswith(f"{layer}.")]
945+
].join(nested_col)
946+
947+
return results_nf
919948

920949
def to_parquet(self, path, by_layer=False, **kwargs) -> None:
921950
"""Creates parquet file(s) with the data of a NestedFrame, either

tests/nested_pandas/nestedframe/test_nestedframe.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,67 @@ def cols_allclose(col1, col2):
10221022
)
10231023

10241024

1025+
def test_reduce_infer_nesting():
1026+
"""Test that nesting inference works in reduce"""
1027+
1028+
ndf = generate_data(3, 20, seed=1)
1029+
1030+
# Test simple case
1031+
def complex_output(flux):
1032+
return {
1033+
"max_flux": np.max(flux),
1034+
"lc.flux_quantiles": np.quantile(flux, [0.1, 0.2, 0.3, 0.4, 0.5]),
1035+
}
1036+
1037+
result = ndf.reduce(complex_output, "nested.flux")
1038+
assert list(result.columns) == ["max_flux", "lc"]
1039+
assert list(result.lc.nest.fields) == ["flux_quantiles"]
1040+
1041+
# Test multi-column nested output
1042+
def complex_output(flux):
1043+
return {
1044+
"max_flux": np.max(flux),
1045+
"lc.flux_quantiles": np.quantile(flux, [0.1, 0.2, 0.3, 0.4, 0.5]),
1046+
"lc.labels": [0.1, 0.2, 0.3, 0.4, 0.5],
1047+
}
1048+
1049+
result = ndf.reduce(complex_output, "nested.flux")
1050+
assert list(result.columns) == ["max_flux", "lc"]
1051+
assert list(result.lc.nest.fields) == ["flux_quantiles", "labels"]
1052+
1053+
# Test integer names
1054+
def complex_output(flux):
1055+
return np.max(flux), np.quantile(flux, [0.1, 0.2, 0.3, 0.4, 0.5]), [0.1, 0.2, 0.3, 0.4, 0.5]
1056+
1057+
result = ndf.reduce(complex_output, "nested.flux")
1058+
assert list(result.columns) == [0, 1, 2]
1059+
1060+
# Test multiple nested structures output
1061+
def complex_output(flux):
1062+
return {
1063+
"max_flux": np.max(flux),
1064+
"lc.flux_quantiles": np.quantile(flux, [0.1, 0.2, 0.3, 0.4, 0.5]),
1065+
"lc.labels": [0.1, 0.2, 0.3, 0.4, 0.5],
1066+
"meta.colors": ["green", "red", "blue"],
1067+
}
1068+
1069+
result = ndf.reduce(complex_output, "nested.flux")
1070+
assert list(result.columns) == ["max_flux", "lc", "meta"]
1071+
assert list(result.lc.nest.fields) == ["flux_quantiles", "labels"]
1072+
assert list(result.meta.nest.fields) == ["colors"]
1073+
1074+
# Test only nested structure output
1075+
def complex_output(flux):
1076+
return {
1077+
"lc.flux_quantiles": np.quantile(flux, [0.1, 0.2, 0.3, 0.4, 0.5]),
1078+
"lc.labels": [0.1, 0.2, 0.3, 0.4, 0.5],
1079+
}
1080+
1081+
result = ndf.reduce(complex_output, "nested.flux")
1082+
assert list(result.columns) == ["lc"]
1083+
assert list(result.lc.nest.fields) == ["flux_quantiles", "labels"]
1084+
1085+
10251086
def test_scientific_notation():
10261087
"""
10271088
Test that NestedFrame.query handles constants that are written in scientific notation.

0 commit comments

Comments
 (0)