Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions src/ragged/_spec_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import numbers
from collections.abc import Iterable
from typing import Any, cast

import awkward as ak
Expand Down Expand Up @@ -257,11 +258,77 @@ def roll(

https://data-apis.org/array-api/latest/API_specification/generated/array_api.roll.html
"""
ak_x = x._impl # pylint: disable=W0212
if not isinstance(ak_x, (array, ak.Array)):
msg = f"x must be a ragged array or an Awkward Array, got {type(ak_x)}"
raise TypeError(msg)

x # noqa: B018, pylint: disable=W0104
shift # noqa: B018, pylint: disable=W0104
axis # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 121") # noqa: EM101
if axis is None:
flat = cast(ak.Array, ak.flatten(ak_x, axis=None))
n = len(flat)
if n == 0:
return array(ak_x)
if isinstance(shift, int):
s = shift % n
else:
msg = f"shift must be int or tuple of ints, got {type(shift)}"
raise TypeError(msg)
rolled_flat = cast(
ak.Array, ak.concatenate([flat[-s:], flat[:-s]]) if s else flat
)
lengths = ak.num(ak_x, axis=-1)
restored = ak.unflatten(rolled_flat, lengths)
return array(restored)

if isinstance(axis, int):
axis_tuple: tuple[int, ...] = (axis,)
elif isinstance(axis, tuple):
if not all(isinstance(a, int) for a in axis):
msg = f"axis must be int or tuple of ints, got {type(axis)}"
raise TypeError(msg)
axis_tuple = tuple(axis)
else:
msg = f"axis must be int, None, or tuple of ints, got {type(axis)}" # type: ignore[unreachable]
raise TypeError(msg)

if isinstance(shift, int):
shift_tuple: tuple[int, ...] = (shift,) * len(axis_tuple)
elif isinstance(shift, tuple):
if not all(isinstance(s, int) for s in shift):
msg = f"shift must be int or tuple of ints, got {type(shift)}"
raise TypeError(msg)
shift_tuple = tuple(shift)
if len(shift_tuple) != len(axis_tuple):
msg = f"shift and axis must have the same length, got shift={shift_tuple} and axis={axis_tuple}"
raise ValueError(msg)
else:
msg = f"shift must be int or tuple of ints, got {type(shift)}" # type: ignore[unreachable]
raise TypeError(msg)

ndim = ak_x.ndim
axis_tuple = tuple(a + ndim if a < 0 else a for a in axis_tuple)

def recursive_roll(
obj: Any, shift_val: int, target_axis: int, depth: int = 0
) -> Any:
if not isinstance(obj, Iterable) or isinstance(obj, (str, bytes)):
return obj

if depth == target_axis:
items = list(obj)
n = len(items)
if n == 0:
return []
s = shift_val % n
return items[-s:] + items[:-s] if s else items

return [recursive_roll(o, shift_val, target_axis, depth + 1) for o in obj]

result = ak_x
for ax, sh in zip(axis_tuple, shift_tuple):
result = recursive_roll(result, sh, ax)

return array(result)


def squeeze(x: array, /, axis: int | tuple[int, ...]) -> array:
Expand Down
104 changes: 104 additions & 0 deletions tests/test_spec_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,107 @@ def test_stack_invalid_axis_raises():
ragged.stack([x, x], axis=2)
with pytest.raises(ValueError, match="axis=-3 is out of bounds"):
ragged.stack([x, x], axis=-3)


def test_roll_basic_1d():
a = ragged.array([1, 2, 3, 4])
assert ragged.roll(a, 1, axis=0).tolist() == [4, 1, 2, 3]
assert ragged.roll(a, -1, axis=0).tolist() == [2, 3, 4, 1]
assert ragged.roll(a, 0, axis=0).tolist() == [1, 2, 3, 4]


def test_roll_axis_none_flatten_restore():
a = ragged.array([[1, 2], [3, 4, 5]])
rolled = ragged.roll(a, 2, axis=None)
assert ak.to_list(rolled) == [[4, 5], [1, 2, 3]]


def test_roll_multi_axis_tuple_shift():
a = ragged.array(
[
[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
]
)
rolled = ragged.roll(a, (1, -1), axis=(0, 1))
expected = [[[7, 8], [5, 6]], [[3, 4], [1, 2]]]
assert ak.to_list(rolled) == expected


def test_roll_scalar_shift_with_tuple_axis():
a = ragged.array(
[
[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
]
)
rolled = ragged.roll(a, 1, axis=(0, 1))
expected = [[[7, 8], [5, 6]], [[3, 4], [1, 2]]]
assert ak.to_list(rolled) == expected


def test_roll_negative_axis():
a = ragged.array([[1, 2], [3, 4]])
assert ak.to_list(ragged.roll(a, 1, axis=-1)) == [[2, 1], [4, 3]]


def test_roll_large_shift():
a = ragged.array([1, 2, 3])
assert ak.to_list(ragged.roll(a, 10, axis=0)) == [3, 1, 2] # 10 % 3 == 1


def test_roll_empty_inner_lists():
a = ragged.array([[], [1, 2], []])
rolled = ragged.roll(a, 1, axis=0)
assert all(isinstance(lst, list) for lst in ak.to_list(rolled))


def test_roll_preserves_dtype_and_shape():
a = ragged.array([[1.0, 2.0], [3.0, 4.0]])
out = ragged.roll(a, 1, axis=0)
assert out.dtype == a.dtype
assert out.shape == a.shape


def test_roll_axis_none_ragged():
a = ragged.array([[1], [2, 3], [4]])
rolled = ragged.roll(a, 2, axis=None)
assert ak.to_list(rolled) == [[3], [4, 1], [2]]


def test_roll_invalid_axis_type():
a = ragged.array([1, 2, 3])
with pytest.raises(TypeError):
ragged.roll(a, 1, axis=1.5) # type: ignore[arg-type]
with pytest.raises(TypeError):
ragged.roll(a, 1, axis=(0, "1")) # type: ignore[arg-type]


def test_roll_invalid_shift_type():
a = ragged.array([1, 2, 3])
with pytest.raises(TypeError):
ragged.roll(a, 2.0, axis=0) # type: ignore[arg-type]
with pytest.raises(TypeError):
ragged.roll(a, (1, "2"), axis=(0, 0)) # type: ignore[arg-type]


def test_roll_shift_axis_length_mismatch():
a = ragged.array([[1, 2], [3, 4]])
with pytest.raises(ValueError, match="shift and axis must have the same length"):
ragged.roll(a, (1, 2), axis=(0,))
with pytest.raises(ValueError, match="shift and axis must have the same length"):
ragged.roll(a, (1,), axis=(0, 1))


def test_roll_invalid_shift_type_message():
a = ragged.array([1, 2, 3])
# Passing a float shift
with pytest.raises(
TypeError, match=r"shift must be int or tuple of ints, got <class 'float'>"
):
ragged.roll(a, 2.0, axis=0) # type: ignore[arg-type]
# Passing a tuple with a non-int element
with pytest.raises(
TypeError, match=r"shift must be int or tuple of ints, got <class 'tuple'>"
):
ragged.roll(a, (1, "2"), axis=(0, 1)) # type: ignore[arg-type]
Loading