Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
0f29529
update to nanvar to use more stable algorithm if engine is flox
jemmajeffree Jul 18, 2025
1fbf5f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2025
322f511
[revert] only nanvar test
dcherian Jul 18, 2025
adab8e6
Some mods
dcherian Jul 18, 2025
93cd9b3
Update flox/aggregations.py to neater tuple unpacking
jemmajeffree Jul 21, 2025
2be4f74
Change np.all to all in flox/aggregate_flox.py
jemmajeffree Jul 21, 2025
edb655d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2025
dd2e4b6
delete some resolved comments
jemmajeffree Jul 21, 2025
936ed1d
Remove answered questions in comments
jemmajeffree Jul 21, 2025
1968870
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2025
d036ebc
Merge branch 'main' into var_algorithm
jemmajeffree Jul 21, 2025
12bcb0f
Remove more unnecessary comments
jemmajeffree Jul 21, 2025
6f5bece
Merge branch 'var_algorithm' of github.com:jemmajeffree/flox into var…
jemmajeffree Jul 21, 2025
b1f7b5d
Remove _version.py
jemmajeffree Jul 21, 2025
cd9a8b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2025
27448e4
Add preliminary test for std/var precision
jemmajeffree Jul 31, 2025
10214cc
Merge branch 'var_algorithm' of github.com:jemmajeffree/flox into var…
jemmajeffree Jul 31, 2025
a81b1a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2025
004fddc
Correct comment
jemmajeffree Jul 31, 2025
4491ce9
fix merge conflicts
jemmajeffree Jul 31, 2025
c3a6d88
Update flox/aggregate_flox.py
jemmajeffree Aug 5, 2025
4dcd7c2
Replace some list comprehension with tuple
jemmajeffree Aug 5, 2025
c101a2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2025
98e1b4e
Fixes
dcherian Aug 5, 2025
d0d09df
minor edit for neater test reports.
dcherian Aug 5, 2025
1139a9c
Fix another list/tuple comprehension
jemmajeffree Aug 5, 2025
569629c
implement np.full
jemmajeffree Aug 5, 2025
50ad095
Implement np.full and empty chunks in var_chunk
jemmajeffree Aug 6, 2025
f88e231
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2025
77526fd
update comment
jemmajeffree Aug 6, 2025
0f5d587
Fix merge conflict
jemmajeffree Aug 6, 2025
31f30c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2025
3b3369f
edits
dcherian Aug 18, 2025
24fb532
support var, std
dcherian Aug 18, 2025
177b8de
enable property tests
dcherian Aug 18, 2025
7deb84a
cleanup
dcherian Aug 18, 2025
120fbf3
xfail some
dcherian Aug 18, 2025
4541c46
revert some
dcherian Aug 18, 2025
aa4b9b3
fix types
dcherian Aug 18, 2025
d5c59e3
adjust tolerance
dcherian Aug 18, 2025
b721433
disable for cubed
dcherian Aug 20, 2025
4f26ed8
handle nans in check
dcherian Aug 20, 2025
d77c132
Promote var/std to float64 from the beginning (dodgy)
jemmajeffree Aug 20, 2025
3cbe54c
Revert "Promote var/std to float64 from the beginning (dodgy)"
jemmajeffree Aug 26, 2025
d7d772c
Handle combine along multiple dimensions
jemmajeffree Aug 27, 2025
9a51095
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2025
1373318
Comments
jemmajeffree Aug 27, 2025
4f15495
Technicalities regarding multiple dimensions in var combine
jemmajeffree Aug 27, 2025
591997c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2025
bbc0be2
more explicit NaNs in empty groups
jemmajeffree Aug 27, 2025
63d7e96
Better "more explicit NaNs in empty groups"
dcherian Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 149 additions & 26 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from . import aggregate_flox, aggregate_npg, xrutils
from . import xrdtypes as dtypes
from .lib import dask_array_type, sparse_array_type
from .multiarray import MultiArray
from .xrutils import notnull

if TYPE_CHECKING:
FuncTuple = tuple[Callable | str, ...]
Expand Down Expand Up @@ -161,8 +163,8 @@ def __init__(
self,
name: str,
*,
numpy: str | None = None,
chunk: str | FuncTuple | None,
numpy: partial | str | None = None,
chunk: partial | str | FuncTuple | None,
combine: str | FuncTuple | None,
preprocess: Callable | None = None,
finalize: Callable | None = None,
Expand Down Expand Up @@ -343,57 +345,178 @@ def _mean_finalize(sum_, count):
)


# TODO: fix this for complex numbers
def _var_finalize(sumsq, sum_, count, ddof=0):
with np.errstate(invalid="ignore", divide="ignore"):
result = (sumsq - (sum_**2 / count)) / (count - ddof)
result[count <= ddof] = np.nan
return result
def var_chunk(
group_idx, array, *, skipna: bool, engine: str, axis=-1, size=None, fill_value=None, dtype=None
):
# Calculate length and sum - important for the adjustment terms to sum squared deviations
array_lens = generic_aggregate(
group_idx,
array,
func="nanlen",
engine=engine,
axis=axis,
size=size,
fill_value=0, # Unpack fill value bc it's currently defined for multiarray
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding correctly that this would overwrite whatever is passed through in fill_value when the aggregation is defined? And we're assuming that in no instance would a different value of fill_value be wanted?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the concern is None[2] isn't a thing wouldn't it make more sense to have (None, None, None) be the default and keep the unpacking?

Copy link
Collaborator

@dcherian dcherian Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think the hardcoding is fine here. It's probably fine to just set fill_value=(np.nan,)

dtype=dtype,
)

array_sums = generic_aggregate(
group_idx,
array,
func="nansum" if skipna else "sum",
engine=engine,
axis=axis,
size=size,
fill_value=0, # Unpack fill value bc it's currently defined for multiarray
dtype=dtype,
)

# Calculate sum squared deviations - the main part of variance sum
array_means = array_sums / array_lens

sum_squared_deviations = generic_aggregate(
group_idx,
(array - array_means[..., group_idx]) ** 2,
func="nansum" if skipna else "sum",
engine=engine,
axis=axis,
size=size,
fill_value=0, # Unpack fill value bc it's currently defined for multiarray
dtype=dtype,
)

return MultiArray((sum_squared_deviations, array_sums, array_lens))


def _var_combine(array, axis, keepdims=True):
def clip_last(array, ax, n=1):
"""Return array except the last element along axis
Purely included to tidy up the adj_terms line
"""
assert n > 0, "Clipping nothing off the end isn't implemented"
not_last = [slice(None, None) for i in range(array.ndim)]
not_last[ax] = slice(None, -n)
return array[*not_last]

def clip_first(array, ax, n=1):
"""Return array except the first element along axis
Purely included to tidy up the adj_terms line
"""
not_first = [slice(None, None) for i in range(array.ndim)]
not_first[ax] = slice(n, None)
return array[*not_first]

for ax in axis:
if array.shape[ax] == 1:
continue

sum_deviations, sum_X, sum_len = array.arrays

# Calculate parts needed for cascading combination
cumsum_X = np.cumsum(sum_X, axis=ax)
cumsum_len = np.cumsum(sum_len, axis=ax)

# There will be instances in which one or both chunks being merged are empty
# In which case, the adjustment term should be zero, but will throw a divide-by-zero error
# We're going to add a constant to the bottom of the adjustment term equation on those instances
# and count on the zeros on the top making our adjustment term still zero
zero_denominator = (clip_last(cumsum_len, ax) == 0) | (clip_first(sum_len, ax) == 0)

# Adjustment terms to tweak the sum of squared deviations because not every chunk has the same mean
adj_terms = (
clip_last(cumsum_len, ax) * clip_first(sum_X, ax)
- clip_first(sum_len, ax) * clip_last(cumsum_X, ax)
) ** 2 / (
clip_last(cumsum_len, ax)
* clip_first(sum_len, ax)
* (clip_last(cumsum_len, ax) + clip_first(sum_len, ax))
+ zero_denominator.astype(int)
)

check = adj_terms * zero_denominator
assert np.all(check[notnull(check)] == 0), (
"Instances where we add something to the denominator must come out to zero"
)

array = MultiArray(
(
np.sum(sum_deviations, axis=ax, keepdims=keepdims)
+ np.sum(adj_terms, axis=ax, keepdims=keepdims), # sum of squared deviations
np.sum(sum_X, axis=ax, keepdims=keepdims), # sum of array items
np.sum(sum_len, axis=ax, keepdims=keepdims), # sum of array lengths
)
)
return array


def is_var_chunk_reduction(agg: Callable) -> bool:
if isinstance(agg, partial):
agg = agg.func
return agg is blockwise_or_numpy_var or agg is var_chunk


def _var_finalize(multiarray, ddof=0):
den = multiarray.arrays[2] - ddof
# preserve nans for groups with 0 obs; so these values are -ddof
ret = multiarray.arrays[0] / den
ret[den < 0] = np.nan
return ret


def _std_finalize(sumsq, sum_, count, ddof=0):
return np.sqrt(_var_finalize(sumsq, sum_, count, ddof))
def _std_finalize(multiarray, ddof=0):
return np.sqrt(_var_finalize(multiarray, ddof))


def blockwise_or_numpy_var(*args, skipna: bool, ddof=0, std=False, **kwargs):
res = _var_finalize(var_chunk(*args, skipna=skipna, **kwargs), ddof)
return np.sqrt(res) if std else res


# var, std always promote to float, so we set nan
var = Aggregation(
"var",
chunk=("sum_of_squares", "sum", "nanlen"),
combine=("sum", "sum", "sum"),
chunk=partial(var_chunk, skipna=False),
numpy=partial(blockwise_or_numpy_var, skipna=False),
combine=(_var_combine,),
finalize=_var_finalize,
fill_value=0,
fill_value=((0, 0, 0),),
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
dtypes=(None,),
final_dtype=np.floating,
)

nanvar = Aggregation(
"nanvar",
chunk=("nansum_of_squares", "nansum", "nanlen"),
combine=("sum", "sum", "sum"),
chunk=partial(var_chunk, skipna=True),
numpy=partial(blockwise_or_numpy_var, skipna=True),
combine=(_var_combine,),
finalize=_var_finalize,
fill_value=0,
fill_value=((0, 0, 0),),
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
dtypes=(None,),
final_dtype=np.floating,
)

std = Aggregation(
"std",
chunk=("sum_of_squares", "sum", "nanlen"),
combine=("sum", "sum", "sum"),
chunk=partial(var_chunk, skipna=False),
numpy=partial(blockwise_or_numpy_var, skipna=False, std=True),
combine=(_var_combine,),
finalize=_std_finalize,
fill_value=0,
fill_value=((0, 0, 0),),
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
dtypes=(None,),
final_dtype=np.floating,
)
nanstd = Aggregation(
"nanstd",
chunk=("nansum_of_squares", "nansum", "nanlen"),
combine=("sum", "sum", "sum"),
chunk=partial(var_chunk, skipna=True),
numpy=partial(blockwise_or_numpy_var, skipna=True, std=True),
combine=(_var_combine,),
finalize=_std_finalize,
fill_value=0,
fill_value=((0, 0, 0),),
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
dtypes=(None,),
final_dtype=np.floating,
)

Expand Down
8 changes: 7 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
_atleast_1d,
_initialize_aggregation,
generic_aggregate,
is_var_chunk_reduction,
quantile_new_dims_func,
)
from .cache import memoize
Expand Down Expand Up @@ -1288,7 +1289,8 @@ def chunk_reduce(
# optimize that out.
previous_reduction: T_Func = ""
for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes):
if empty:
# UGLY! but this is because the `var` breaks our design assumptions
if empty and not is_var_chunk_reduction(reduction):
result = np.full(shape=final_array_shape, fill_value=fv, like=array)
elif is_nanlen(reduction) and is_nanlen(previous_reduction):
result = results["intermediates"][-1]
Expand All @@ -1297,6 +1299,10 @@ def chunk_reduce(
kw_func = dict(size=size, dtype=dt, fill_value=fv)
kw_func.update(kw)

# UGLY! but this is because the `var` breaks our design assumptions
if is_var_chunk_reduction(reduction):
kw_func.update(engine=engine)

if callable(reduction):
# passing a custom reduction for npg to apply per-group is really slow!
# So this `reduction` has to do the groupby-aggregation
Expand Down
97 changes: 97 additions & 0 deletions flox/multiarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from collections.abc import Callable
from typing import Self

import numpy as np

MULTIARRAY_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}


class MultiArray:
arrays: tuple[np.ndarray, ...]

def __init__(self, arrays):
self.arrays = arrays
assert all(arrays[0].shape == a.shape for a in arrays), "Expect all arrays to have the same shape"

def astype(self, dt, **kwargs) -> Self:
return type(self)(tuple(array.astype(dt, **kwargs) for array in self.arrays))

def reshape(self, shape, **kwargs) -> Self:
return type(self)(tuple(array.reshape(shape, **kwargs) for array in self.arrays))

def squeeze(self, axis=None) -> Self:
return type(self)(tuple(array.squeeze(axis) for array in self.arrays))

def __setitem__(self, key, value) -> None:
assert len(value) == len(self.arrays)
for array, val in zip(self.arrays, value):
array[key] = val

def __array_function__(self, func, types, args, kwargs):
if func not in MULTIARRAY_HANDLED_FUNCTIONS:
return NotImplemented
# Note: this allows subclasses that don't override
# __array_function__ to handle MyArray objects
# if not all(issubclass(t, MyArray) for t in types): # I can't see this being relevant at all for this code, but maybe it's safer to leave it in?
# return NotImplemented
return MULTIARRAY_HANDLED_FUNCTIONS[func](*args, **kwargs)

# Shape is needed, seems likely that the other two might be
# Making some strong assumptions here that all the arrays are the same shape, and I don't really like this
@property
def dtype(self) -> np.dtype:
return self.arrays[0].dtype

@property
def shape(self) -> tuple[int, ...]:
return self.arrays[0].shape

@property
def ndim(self) -> int:
return self.arrays[0].ndim

def __getitem__(self, key) -> Self:
return type(self)([array[key] for array in self.arrays])


def implements(numpy_function):
"""Register an __array_function__ implementation for MyArray objects."""

def decorator(func):
MULTIARRAY_HANDLED_FUNCTIONS[numpy_function] = func
return func

return decorator


@implements(np.expand_dims)
def expand_dims(multiarray, axis) -> MultiArray:
return MultiArray(tuple(np.expand_dims(a, axis) for a in multiarray.arrays))


@implements(np.concatenate)
def concatenate(multiarrays, axis) -> MultiArray:
n_arrays = len(multiarrays[0].arrays)
for ma in multiarrays[1:]:
assert len(ma.arrays) == n_arrays
return MultiArray(
tuple(np.concatenate(tuple(ma.arrays[i] for ma in multiarrays), axis) for i in range(n_arrays))
)


@implements(np.transpose)
def transpose(multiarray, axes) -> MultiArray:
return MultiArray(tuple(np.transpose(a, axes) for a in multiarray.arrays))


@implements(np.squeeze)
def squeeze(multiarray, axis) -> MultiArray:
return MultiArray(tuple(np.squeeze(a, axis) for a in multiarray.arrays))


@implements(np.full)
def full(shape, fill_values, *args, **kwargs) -> MultiArray:
"""All arguments except fill_value are shared by each array in the MultiArray.
Iterate over fill_values to create arrays
"""
return MultiArray(tuple(np.full(shape, fv, *args, **kwargs) for fv in fill_values))
6 changes: 6 additions & 0 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool:


def notnull(data):
if isinstance(data, tuple) and len(data) == 3 and data == (0, 0, 0):
# boo: another special case for Var
return True
if not is_duck_array(data):
data = np.asarray(data)

Expand All @@ -163,6 +166,9 @@ def notnull(data):


def isnull(data: Any):
if isinstance(data, tuple) and len(data) == 3 and data == (0, 0, 0):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what are these lines (and associated lines above) doing?
I guess pd.isnull can't cope with three numbers as a single entity? In which case, why is only the fill_value something that needs checking?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it didn't like the tuple of 0s, so it's a hack

# boo: another special case for Var
return False
if data is None:
return False
if not is_duck_array(data):
Expand Down
3 changes: 1 addition & 2 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ def insert_nans(draw: st.DrawFn, array: np.ndarray) -> np.ndarray:
"any",
"all",
] + list(SCIPY_STATS_FUNCS)
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
Copy link
Collaborator

@dcherian dcherian Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a real win! We now add var and std to our property test suite (

def test_groupby_reduce(data, array, func: str) -> None:
)


func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS])


@st.composite
Expand Down
Loading
Loading