Skip to content

TST: enable 2D tests for MaskedArrays, fix+test shift #61826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
masked_reductions,
)
from pandas.core.array_algos.quantile import quantile_with_mask
from pandas.core.array_algos.transforms import shift
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._utils import to_numpy_dtype_inference
from pandas.core.arrays.base import ExtensionArray
Expand Down Expand Up @@ -361,6 +362,17 @@ def ravel(self, *args, **kwargs) -> Self:
mask = self._mask.ravel(*args, **kwargs)
return type(self)(data, mask)

def shift(self, periods: int = 1, fill_value=None) -> Self:
# NB: shift is always along axis=0
axis = 0
if fill_value is None:
new_data = shift(self._data, periods, axis, 0)
new_mask = shift(self._mask, periods, axis, True)
else:
new_data = shift(self._data, periods, axis, fill_value)
new_mask = shift(self._mask, periods, axis, False)
return type(self)(new_data, new_mask)

@property
def T(self) -> Self:
return self._simple_new(self._data.T, self._mask.T)
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/extension/base/dim2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def skip_if_doesnt_support_2d(self, dtype, request):
# TODO: is there a less hacky way of checking this?
pytest.skip(f"{dtype} does not support 2D.")

def test_shift_2d(self, data):
arr2d = data.repeat(2).reshape(-1, 2)

for n in [1, -2]:
for fill_value in [None, data[0]]:
result = arr2d.shift(n, fill_value=fill_value)
expected_col = data.shift(n, fill_value=fill_value)
tm.assert_extension_array_equal(result[:, 0], expected_col)
tm.assert_extension_array_equal(result[:, 1], expected_col)

def test_transpose(self, data):
arr2d = data.repeat(2).reshape(-1, 2)
shape = arr2d.shape
Expand Down
11 changes: 7 additions & 4 deletions pandas/tests/extension/test_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ def data_for_grouping(dtype):


class TestMaskedArrays(base.ExtensionTests):
@pytest.fixture(autouse=True)
def skip_if_doesnt_support_2d(self, dtype, request):
# Override the fixture so that we run these tests.
assert not dtype._supports_2d
# If dtype._supports_2d is ever changed to True, then this fixture
# override becomes unnecessary.

@pytest.mark.parametrize("na_action", [None, "ignore"])
def test_map(self, data_missing, na_action):
result = data_missing.map(lambda x: x, na_action=na_action)
Expand Down Expand Up @@ -402,7 +409,3 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):

else:
raise NotImplementedError(f"{op_name} not supported")


class Test2DCompat(base.Dim2CompatTests):
pass
Loading