Skip to content

chore: @requires for backend method constraints #2371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8d6e33d
chore(DRAFT): Plan out `requires` typing
dangotbanned Apr 11, 2025
bb1882c
feat: Support `@requires(min_version=...)`
dangotbanned Apr 11, 2025
a586721
test: Check error message composition
dangotbanned Apr 11, 2025
aeedd3f
typo
dangotbanned Apr 11, 2025
e0d6e99
refactor: `@requires` in `_polars`
dangotbanned Apr 11, 2025
a7ede1e
refactor: `@requires` in `_arrow`
dangotbanned Apr 11, 2025
c4fce51
refactor: `@requires` in `_duckdb`
dangotbanned Apr 11, 2025
6feeec5
Thanks `pre-commit`, thats useful πŸ™ƒ
dangotbanned Apr 11, 2025
ff8bce7
test: remove unused mock method
dangotbanned Apr 11, 2025
3c3599a
feat: Support `@requires(min_version=.., hint=...)`
dangotbanned Apr 11, 2025
c0db32f
fix: Use `(1, 0, 0)`
dangotbanned Apr 11, 2025
8eacf5d
docs(DRAFT): Add very basic doc
dangotbanned Apr 11, 2025
9468295
Merge branch 'main' into requires-context
dangotbanned Apr 12, 2025
4f02f06
chore(ruff): Add complexity ignore back 😞
dangotbanned Apr 12, 2025
adab2fd
Merge branch 'main' into requires-context
dangotbanned Apr 12, 2025
434ebf6
Merge branch 'main' into requires-context
dangotbanned Apr 12, 2025
651cf2f
feat: Add `@requires.backend_version` constructor
dangotbanned Apr 13, 2025
19adbba
Merge branch 'main' into requires-context
dangotbanned Apr 13, 2025
5f98191
Merge remote-tracking branch 'upstream/main' into requires-context
dangotbanned Apr 16, 2025
0ac9595
chore: Coverage for alternative spelling
dangotbanned Apr 16, 2025
0378291
Merge branch 'main' into requires-context
dangotbanned Apr 17, 2025
6e40e91
refactor: `@requires(min_version=` -> `@requires.backend_version(`
dangotbanned Apr 17, 2025
d3cedc5
docs: Add doctest, remove `__init__`
dangotbanned Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 13 additions & 42 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_list_of
from narwhals.utils import not_implemented
from narwhals.utils import requires
from narwhals.utils import validate_backend_version

if TYPE_CHECKING:
Expand Down Expand Up @@ -863,46 +864,34 @@ def cum_count(self: Self, *, reverse: bool) -> Self:
dtypes = import_dtypes_module(self._version)
return (~self.is_null()).cast(dtypes.UInt32()).cum_sum(reverse=reverse)

def cum_min(self: Self, *, reverse: bool) -> Self:
if self._backend_version < (13, 0, 0):
msg = "cum_min method is not supported for pyarrow < 13.0.0"
raise NotImplementedError(msg)
@requires.backend_version((13,))
def cum_min(self, *, reverse: bool) -> Self:
result = (
pc.cumulative_min(self.native, skip_nulls=True)
if not reverse
else pc.cumulative_min(self.native[::-1], skip_nulls=True)[::-1]
)
return self._with_native(result)

def cum_max(self: Self, *, reverse: bool) -> Self:
if self._backend_version < (13, 0, 0):
msg = "cum_max method is not supported for pyarrow < 13.0.0"
raise NotImplementedError(msg)
@requires.backend_version((13,))
def cum_max(self, *, reverse: bool) -> Self:
result = (
pc.cumulative_max(self.native, skip_nulls=True)
if not reverse
else pc.cumulative_max(self.native[::-1], skip_nulls=True)[::-1]
)
return self._with_native(result)

def cum_prod(self: Self, *, reverse: bool) -> Self:
if self._backend_version < (13, 0, 0):
msg = "cum_max method is not supported for pyarrow < 13.0.0"
raise NotImplementedError(msg)
@requires.backend_version((13,))
def cum_prod(self, *, reverse: bool) -> Self:
result = (
pc.cumulative_prod(self.native, skip_nulls=True)
if not reverse
else pc.cumulative_prod(self.native[::-1], skip_nulls=True)[::-1]
)
return self._with_native(result)

def rolling_sum(
self: Self,
window_size: int,
*,
min_samples: int,
center: bool,
) -> Self:
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
min_samples = min_samples if min_samples is not None else window_size
padded_series, offset = pad_series(self, window_size=window_size, center=center)

Expand All @@ -926,13 +915,7 @@ def rolling_sum(
)
return result[offset:]

def rolling_mean(
self: Self,
window_size: int,
*,
min_samples: int,
center: bool,
) -> Self:
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
min_samples = min_samples if min_samples is not None else window_size
padded_series, offset = pad_series(self, window_size=window_size, center=center)

Expand Down Expand Up @@ -962,12 +945,7 @@ def rolling_mean(
return result[offset:]

def rolling_var(
self: Self,
window_size: int,
*,
min_samples: int,
center: bool,
ddof: int,
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
min_samples = min_samples if min_samples is not None else window_size
padded_series, offset = pad_series(self, window_size=window_size, center=center)
Expand Down Expand Up @@ -1010,12 +988,7 @@ def rolling_var(
return result[offset:]

def rolling_std(
self: Self,
window_size: int,
*,
min_samples: int,
center: bool,
ddof: int,
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
return (
self.rolling_var(
Expand Down Expand Up @@ -1048,16 +1021,14 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
result = pc.if_else(null_mask, lit(None, native_series.type), rank)
return self._with_native(result)

@requires.backend_version((13,))
def hist( # noqa: PLR0915
self: Self,
self,
bins: list[float | int] | None,
*,
bin_count: int | None,
include_breakpoint: bool,
) -> ArrowDataFrame:
if self._backend_version < (13,):
msg = f"`Series.hist` requires PyArrow>=13.0.0, found PyArrow version: {self._backend_version}"
raise NotImplementedError(msg)
import numpy as np # ignore-banned-import

from narwhals._arrow.dataframe import ArrowDataFrame
Expand Down
43 changes: 24 additions & 19 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from narwhals._expression_parsing import ExprKind
from narwhals.utils import Implementation
from narwhals.utils import not_implemented
from narwhals.utils import requires

if TYPE_CHECKING:
import duckdb
Expand Down Expand Up @@ -59,7 +60,7 @@ class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]):
_implementation = Implementation.DUCKDB

def __init__(
self: Self,
self,
call: EvalSeries[DuckDBLazyFrame, duckdb.Expression],
*,
evaluate_output_names: EvalNames[DuckDBLazyFrame],
Expand All @@ -75,7 +76,7 @@ def __init__(
self._window_function: WindowFunction | None = None
self._metadata: ExprMetadata | None = None

def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
return self._call(df)

def __narwhals_expr__(self) -> None: ...
Expand Down Expand Up @@ -177,7 +178,7 @@ def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:

@classmethod
def from_column_names(
cls: type[Self],
cls,
evaluate_column_names: EvalNames[DuckDBLazyFrame],
/,
*,
Expand All @@ -195,9 +196,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
)

@classmethod
def from_column_indices(
cls: type[Self], *column_indices: int, context: _FullContext
) -> Self:
def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self:
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
columns = df.columns
return [col(columns[i]) for i in column_indices]
Expand All @@ -211,10 +210,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
)

def _with_callable(
self: Self,
call: Callable[..., duckdb.Expression],
/,
**expressifiable_args: Self | Any,
self, call: Callable[..., duckdb.Expression], /, **expressifiable_args: Self | Any
) -> Self:
"""Create expression from callable.

Expand Down Expand Up @@ -497,14 +493,8 @@ def null_count(self: Self) -> Self:
lambda _input: FunctionExpression("sum", _input.isnull().cast("int")),
)

def over(
self: Self,
partition_by: Sequence[str],
order_by: Sequence[str] | None,
) -> Self:
if self._backend_version < (1, 3):
msg = "At least version 1.3 of DuckDB is required for `over` operation."
raise NotImplementedError(msg)
@requires.backend_version((1, 3))
def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> Self:
if (window_function := self._window_function) is not None:
assert order_by is not None # noqa: S101

Expand Down Expand Up @@ -549,6 +539,7 @@ def round(self: Self, decimals: int) -> Self:
lambda _input: FunctionExpression("round", _input, lit(decimals))
)

@requires.backend_version((1, 3))
def shift(self, n: int) -> Self:
ensure_type(n, int)

Expand All @@ -562,6 +553,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:

return self._with_window_function(func)

@requires.backend_version((1, 3))
def is_first_distinct(self) -> Self:
def func(window_inputs: WindowInputs) -> duckdb.Expression:
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
Expand All @@ -577,6 +569,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:

return self._with_window_function(func)

@requires.backend_version((1, 3))
def is_last_distinct(self) -> Self:
def func(window_inputs: WindowInputs) -> duckdb.Expression:
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=False)
Expand All @@ -592,6 +585,7 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:

return self._with_window_function(func)

@requires.backend_version((1, 3))
def diff(self) -> Self:
def func(window_inputs: WindowInputs) -> duckdb.Expression:
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
Expand All @@ -601,31 +595,37 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:

return self._with_window_function(func)

@requires.backend_version((1, 3))
def cum_sum(self, *, reverse: bool) -> Self:
return self._with_window_function(
self._cum_window_func(reverse=reverse, func_name="sum")
)

@requires.backend_version((1, 3))
def cum_max(self, *, reverse: bool) -> Self:
return self._with_window_function(
self._cum_window_func(reverse=reverse, func_name="max")
)

@requires.backend_version((1, 3))
def cum_min(self, *, reverse: bool) -> Self:
return self._with_window_function(
self._cum_window_func(reverse=reverse, func_name="min")
)

@requires.backend_version((1, 3))
def cum_count(self, *, reverse: bool) -> Self:
return self._with_window_function(
self._cum_window_func(reverse=reverse, func_name="count")
)

@requires.backend_version((1, 3))
def cum_prod(self, *, reverse: bool) -> Self:
return self._with_window_function(
self._cum_window_func(reverse=reverse, func_name="product")
)

@requires.backend_version((1, 3))
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_window_function(
self._rolling_window_func(
Expand All @@ -636,6 +636,7 @@ def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Se
)
)

@requires.backend_version((1, 3))
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_window_function(
self._rolling_window_func(
Expand All @@ -646,6 +647,7 @@ def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> S
)
)

@requires.backend_version((1, 3))
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
Expand All @@ -659,6 +661,7 @@ def rolling_var(
)
)

@requires.backend_version((1, 3))
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
Expand Down Expand Up @@ -691,13 +694,15 @@ def func(_input: duckdb.Expression) -> duckdb.Expression:

return self._with_callable(func)

def is_unique(self: Self) -> Self:
@requires.backend_version((1, 3))
def is_unique(self) -> Self:
def func(_input: duckdb.Expression) -> duckdb.Expression:
sql = f"count(*) over (partition by {_input})"
return SQLExpression(sql) == lit(1) # type: ignore[no-any-return, unused-ignore]

return self._with_callable(func)

@requires.backend_version((1, 3))
def rank(self, method: RankMethod, *, descending: bool) -> Self:
if self._backend_version < (1, 3):
msg = "At least version 1.3 of DuckDB is required for `rank`."
Expand Down
5 changes: 2 additions & 3 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import requires
from narwhals.utils import validate_backend_version

if TYPE_CHECKING:
Expand Down Expand Up @@ -414,6 +415,7 @@ def unpivot(
)
)

@requires.backend_version((1,))
def pivot(
self,
on: Sequence[str],
Expand All @@ -424,9 +426,6 @@ def pivot(
sort_columns: bool,
separator: str,
) -> Self:
if self._backend_version < (1, 0, 0): # pragma: no cover
msg = "`pivot` is only supported for Polars>=1.0.0"
raise NotImplementedError(msg)
try:
result = self.native.pivot(
on,
Expand Down
13 changes: 4 additions & 9 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals.utils import Implementation
from narwhals.utils import requires

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -109,24 +110,20 @@ def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> S
native = self.native.over(partition_by or pl.lit(1), order_by=order_by)
return self._with_native(native)

@requires.backend_version((1,))
def rolling_var(
self: Self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if self._backend_version < (1,): # pragma: no cover
msg = "`rolling_var` not implemented for polars older than 1.0"
raise NotImplementedError(msg)
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_var(
window_size=window_size, center=center, ddof=ddof, **kwds
)
return self._with_native(native)

@requires.backend_version((1,))
def rolling_std(
self: Self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if self._backend_version < (1,): # pragma: no cover
msg = "`rolling_std` not implemented for polars older than 1.0"
raise NotImplementedError(msg)
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_std(
window_size=window_size, center=center, ddof=ddof, **kwds
Expand Down Expand Up @@ -158,16 +155,14 @@ def map_batches(
native = self.native.map_batches(function, return_dtype_pl)
return self._with_native(native)

@requires.backend_version((1,))
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: DType | type[DType] | None,
) -> Self:
if self._backend_version < (1,):
msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}"
raise NotImplementedError(msg)
return_dtype_pl = (
narwhals_to_native_dtype(return_dtype, self._version, self._backend_version)
if return_dtype
Expand Down
Loading
Loading