Skip to content

Commit 4a5da0c

Browse files
committed
Implement GetItem tests
1 parent 47c7af8 commit 4a5da0c

File tree

4 files changed

+140
-30
lines changed

4 files changed

+140
-30
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
428428
"""
429429
if isinstance(value, pa.Scalar):
430430
pa_scalar = value
431-
elif isna(value):
431+
elif not is_list_like(value) and isna(value):
432432
pa_scalar = pa.scalar(None, type=pa_type)
433433
else:
434434
# Workaround https://github.com/apache/arrow/issues/37291
@@ -1350,7 +1350,16 @@ def take(
13501350
# TODO(ARROW-9433): Treat negative indices as NULL
13511351
indices_array = pa.array(indices_array, mask=fill_mask)
13521352
result = self._pa_array.take(indices_array)
1353-
if isna(fill_value):
1353+
if is_list_like(fill_value):
1354+
# TODO: this should be hit by ListArray. Ideally we do:
1355+
# pc.replace_with_mask(result, fill_mask, pa.scalar(fill_value))
1356+
# but pyarrow does not yet implement that for list types
1357+
new_values = [
1358+
fill_value if should_fill else x.as_py()
1359+
for x, should_fill in zip(result, fill_mask)
1360+
]
1361+
return type(self)(new_values)
1362+
elif isna(fill_value):
13541363
return type(self)(result)
13551364
# TODO: ArrowNotImplementedError: Function fill_null has no
13561365
# kernel matching input types (array[string], scalar[string])

pandas/core/arrays/list_.py

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@
1111
ExtensionDtype,
1212
register_extension_dtype,
1313
)
14-
from pandas.core.dtypes.common import is_string_dtype
14+
from pandas.core.dtypes.common import (
15+
is_bool_dtype,
16+
is_integer_dtype,
17+
is_string_dtype,
18+
)
1519
from pandas.core.dtypes.dtypes import ArrowDtype
1620

1721
from pandas.core.arrays.arrow.array import ArrowExtensionArray
22+
from pandas.core.arrays.base import ExtensionArray
1823

1924
if TYPE_CHECKING:
2025
from collections.abc import Sequence
@@ -146,6 +151,15 @@ def __init__(
146151
else:
147152
if value_type is None:
148153
if isinstance(values, (pa.Array, pa.ChunkedArray)):
154+
parent_type = values.type
155+
if not isinstance(parent_type, (pa.ListType, pa.LargeListType)):
156+
# Ideally could cast here, but I don't think pyarrow implements
157+
# many list casts
158+
new_values = [
159+
[x.as_py()] if x.is_valid else None for x in values
160+
]
161+
values = pa.array(new_values, type=pa.large_list(parent_type))
162+
149163
value_type = values.type.value_type
150164
else:
151165
value_type = pa.array(values).type.value_type
@@ -193,19 +207,89 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False):
193207

194208
return cls(values)
195209

210+
@classmethod
211+
def _box_pa(
212+
cls, value, pa_type: pa.DataType | None = None
213+
) -> pa.Array | pa.ChunkedArray | pa.Scalar:
214+
"""
215+
Box value into a pyarrow Array, ChunkedArray or Scalar.
216+
217+
Parameters
218+
----------
219+
value : any
220+
pa_type : pa.DataType | None
221+
222+
Returns
223+
-------
224+
pa.Array or pa.ChunkedArray or pa.Scalar
225+
"""
226+
if (
227+
isinstance(value, (pa.ListScalar, pa.LargeListScalar))
228+
or isinstance(value, list)
229+
or value is None
230+
):
231+
return cls._box_pa_scalar(value, pa_type)
232+
return cls._box_pa_array(value, pa_type)
233+
196234
def __getitem__(self, item):
197235
# PyArrow does not support NumPy's selection with an equal length
198236
# mask, so let's convert those to integral positions if needed
199-
if isinstance(item, np.ndarray) and item.dtype == bool:
200-
pos = np.array(range(len(item)))
201-
mask = pos[item]
202-
return type(self)(self._pa_array.take(mask))
237+
if isinstance(item, (np.ndarray, ExtensionArray)):
238+
if is_bool_dtype(item.dtype):
239+
mask_len = len(item)
240+
if mask_len != len(self):
241+
raise IndexError(
242+
f"Boolean index has wrong length: {mask_len} "
243+
f"instead of {len(self)}"
244+
)
245+
pos = np.array(range(len(item)))
246+
247+
if isinstance(item, ExtensionArray):
248+
mask = pos[item.fillna(False)]
249+
else:
250+
mask = pos[item]
251+
return type(self)(self._pa_array.take(mask))
252+
elif is_integer_dtype(item.dtype):
253+
if isinstance(item, ExtensionArray) and item.isna().any():
254+
msg = "Cannot index with an integer indexer containing NA values"
255+
raise ValueError(msg)
256+
257+
indexer = pa.array(item)
258+
return type(self)(self._pa_array.take(indexer))
203259
elif isinstance(item, int):
204-
return self._pa_array[item]
260+
value = self._pa_array[item]
261+
if value.is_valid:
262+
return value.as_py()
263+
else:
264+
return self.dtype.na_value
205265
elif isinstance(item, list):
206-
return type(self)(self._pa_array.take(item))
266+
# pyarrow does not support taking yet from an empty list
267+
# https://github.com/apache/arrow/issues/39917
268+
if item:
269+
try:
270+
result = self._pa_array.take(item)
271+
except pa.lib.ArrowInvalid as e:
272+
if "Could not convert <NA>" in str(e):
273+
msg = (
274+
"Cannot index with an integer indexer containing NA values"
275+
)
276+
raise ValueError(msg) from e
277+
raise e
278+
else:
279+
result = pa.array([], type=self._pa_array.type)
280+
281+
return type(self)(result)
282+
283+
try:
284+
result = type(self)(self._pa_array[item])
285+
except TypeError as e:
286+
msg = (
287+
"only integers, slices (`:`), ellipsis (`...`), numpy.newaxis "
288+
"(`None`) and integer or boolean arrays are valid indices"
289+
)
290+
raise IndexError(msg) from e
207291

208-
return type(self)(self._pa_array[item])
292+
return result
209293

210294
def __setitem__(self, key, value) -> None:
211295
msg = "ListArray does not support item assignment via setitem"
@@ -241,7 +325,13 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
241325
return super().astype(dtype, copy)
242326

243327
def __eq__(self, other):
244-
if isinstance(other, (pa.ListScalar, pa.LargeListScalar)):
328+
if isinstance(other, list):
329+
from pandas.arrays import BooleanArray
330+
331+
mask = np.array([False] * len(self))
332+
values = np.array([x.as_py() == other for x in self._pa_array])
333+
return BooleanArray(values, mask)
334+
elif isinstance(other, (pa.ListScalar, pa.LargeListScalar)):
245335
from pandas.arrays import BooleanArray
246336

247337
# TODO: pyarrow.compute does not implement broadcasting equality

pandas/core/generic.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import warnings
2424

2525
import numpy as np
26-
import pyarrow as pa
2726

2827
from pandas._config import config
2928

@@ -150,6 +149,7 @@
150149
)
151150
from pandas.core.array_algos.replace import should_use_regex
152151
from pandas.core.arrays import ExtensionArray
152+
from pandas.core.arrays.list_ import ListDtype
153153
from pandas.core.base import PandasObject
154154
from pandas.core.construction import extract_array
155155
from pandas.core.flags import Flags
@@ -7013,11 +7013,20 @@ def fillna(
70137013
stacklevel=2,
70147014
)
70157015

7016+
holds_list_array = False
7017+
if isinstance(self, ABCSeries) and isinstance(self.dtype, ListDtype):
7018+
holds_list_array = True
7019+
elif isinstance(self, ABCDataFrame) and any(
7020+
isinstance(x, ListDtype) for x in self.dtypes
7021+
):
7022+
holds_list_array = True
7023+
70167024
if isinstance(value, (list, tuple)):
7017-
raise TypeError(
7018-
'"value" parameter must be a scalar or dict, but '
7019-
f'you passed a "{type(value).__name__}"'
7020-
)
7025+
if not holds_list_array:
7026+
raise TypeError(
7027+
'"value" parameter must be a scalar or dict, but '
7028+
f'you passed a "{type(value).__name__}"'
7029+
)
70217030

70227031
# set the default here, so functions examining the signature
70237032
# can detect if something was set (e.g. in groupby) (GH9221)
@@ -7037,8 +7046,9 @@ def fillna(
70377046
value = Series(value)
70387047
value = value.reindex(self.index)
70397048
value = value._values
7040-
elif isinstance(value, pa.ListScalar) or not is_list_like(value):
7041-
# TODO(wayd): maybe is_list_like should return false for ListScalar?
7049+
elif (
7050+
isinstance(value, list) and isinstance(self.dtype, ListDtype)
7051+
) or not is_list_like(value):
70427052
pass
70437053
else:
70447054
raise TypeError(
@@ -7102,7 +7112,7 @@ def fillna(
71027112
else:
71037113
return result
71047114

7105-
elif isinstance(value, pa.ListScalar) or not is_list_like(value):
7115+
elif holds_list_array or not is_list_like(value):
71067116
if axis == 1:
71077117
result = self.T.fillna(value=value, limit=limit).T
71087118
new_data = result._mgr

pandas/tests/extension/list/test_list.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
NDArrayBacked2DTests,
1919
)
2020
from pandas.tests.extension.base.dtype import BaseDtypeTests
21+
from pandas.tests.extension.base.getitem import BaseGetitemTests
2122
from pandas.tests.extension.base.groupby import BaseGroupbyTests
2223
from pandas.tests.extension.base.index import BaseIndexTests
2324
from pandas.tests.extension.base.interface import BaseInterfaceTests
@@ -49,7 +50,7 @@ def dtype():
4950
def data():
5051
"""Length-100 ListArray for semantics test."""
5152
# TODO: make better random data
52-
data = [list("a"), list("ab"), list("abc")] * 33 + [None]
53+
data = [list("a"), list("ab"), list("abc")] * 33 + [list("a")]
5354
return ListArray(data)
5455

5556

@@ -74,7 +75,7 @@ class TestListArray(
7475
BaseCastingTests,
7576
BaseConstructorsTests,
7677
BaseDtypeTests,
77-
# BaseGetitemTests,
78+
BaseGetitemTests,
7879
BaseGroupbyTests,
7980
BaseIndexTests,
8081
BaseInterfaceTests,
@@ -90,12 +91,12 @@ class TestListArray(
9091
BaseSetitemTests,
9192
Dim2CompatTests,
9293
):
93-
# TODO(wayd): The tests here are copied from test_arrow.py
94-
# It appears the TestArrowArray class has different expectations around
95-
# when copies should be made then the base.ExtensionTests
96-
# Assuming intentional, maybe in the long term this should just
97-
# inherit from TestArrowArray
9894
def test_fillna_no_op_returns_copy(self, data):
95+
# TODO(wayd): This test is copied from test_arrow.py
96+
# It appears the TestArrowArray class has different expectations around
97+
# when copies should be made then the base.ExtensionTests
98+
# Assuming intentional, maybe in the long term this should just
99+
# inherit from TestArrowArray
99100
data = data[~data.isna()]
100101

101102
valid = data[0]
@@ -154,10 +155,7 @@ def test_compare_scalar(self, data, comparison_op):
154155
super().test_compare_scalar(data, comparison_op)
155156

156157
def test_compare_array(self, data, comparison_op):
157-
if comparison_op in (operator.eq, operator.ne):
158-
pytest.skip("Series.combine does not properly handle missing values")
159-
160-
super().test_compare_array(data, comparison_op)
158+
pytest.skip("ListArray comparison ops are not implemented")
161159

162160
def test_invert(self, data):
163161
pytest.skip("ListArray does not implement invert")
@@ -229,6 +227,9 @@ def test_unstack(self, data, index, obj):
229227
# result = result.astype(object)
230228
tm.assert_frame_equal(result, expected)
231229

230+
def test_getitem_ellipsis_and_slice(self, data):
231+
pytest.skip("ListArray does not support NumPy style ellipsis slicing nor 2-D")
232+
232233

233234
def test_to_csv(data):
234235
# https://github.com/pandas-dev/pandas/issues/28840

0 commit comments

Comments
 (0)