Skip to content
7 changes: 4 additions & 3 deletions asv_bench/benchmarks/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DataFrame,
Index,
Series,
StringDtype,
)
from pandas.arrays import StringArray

Expand Down Expand Up @@ -290,10 +291,10 @@ def setup(self):
self.series_arr_nan = np.concatenate([self.series_arr, np.array([NA] * 1000)])

def time_string_array_construction(self):
StringArray(self.series_arr)
StringArray(self.series_arr, dtype=StringDtype())

def time_string_array_with_nan_construction(self):
StringArray(self.series_arr_nan)
StringArray(self.series_arr_nan, dtype=StringDtype())

def peakmem_stringarray_construction(self):
StringArray(self.series_arr)
StringArray(self.series_arr, dtype=StringDtype())
102 changes: 50 additions & 52 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def construct_array_type(self) -> type_t[BaseStringArray]:
elif self.storage == "pyarrow" and self._na_value is libmissing.NA:
return ArrowStringArray
elif self.storage == "python":
return StringArrayNumpySemantics
return StringArray
else:
return ArrowStringArray

Expand Down Expand Up @@ -487,8 +487,10 @@ def _str_map_str_or_object(
)
# error: "BaseStringArray" has no attribute "_from_pyarrow_array"
return self._from_pyarrow_array(result) # type: ignore[attr-defined]
# error: Too many arguments for "BaseStringArray"
return type(self)(result) # type: ignore[call-arg]
else:
# StringArray
# error: Too many arguments for "BaseStringArray"
return type(self)(result, dtype=self.dtype) # type: ignore[call-arg]

else:
# This is when the result type is object. We reach this when
Expand Down Expand Up @@ -578,6 +580,8 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
nan-likes(``None``, ``np.nan``) for the ``values`` parameter
in addition to strings and :attr:`pandas.NA`

dtype : StringDtype
Dtype for the array.
copy : bool, default False
Whether to copy the array of data.

Expand Down Expand Up @@ -632,36 +636,56 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]

# undo the NumpyExtensionArray hack
_typ = "extension"
_storage = "python"
_na_value: libmissing.NAType | float = libmissing.NA

def __init__(self, values, copy: bool = False) -> None:
def __init__(
self, values, *, dtype: StringDtype | None = None, copy: bool = False
) -> None:
if dtype is None:
dtype = StringDtype()
values = extract_array(values)

super().__init__(values, copy=copy)
if not isinstance(values, type(self)):
self._validate()
self._validate(dtype)
NDArrayBacked.__init__(
self,
self._ndarray,
StringDtype(storage=self._storage, na_value=self._na_value),
dtype,
)

def _validate(self) -> None:
def _validate(self, dtype: StringDtype) -> None:
"""Validate that we only store NA or strings."""
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True):
raise ValueError("StringArray requires a sequence of strings or pandas.NA")
if self._ndarray.dtype != "object":
raise ValueError(
"StringArray requires a sequence of strings or pandas.NA. Got "
f"'{self._ndarray.dtype}' dtype instead."
)
# Check to see if need to convert Na values to pd.NA
if self._ndarray.ndim > 2:
# Ravel if ndims > 2 b/c no cythonized version available
lib.convert_nans_to_NA(self._ndarray.ravel("K"))

if dtype._na_value is libmissing.NA:
if len(self._ndarray) and not lib.is_string_array(
self._ndarray, skipna=True
):
raise ValueError(
"StringArray requires a sequence of strings or pandas.NA"
)
if self._ndarray.dtype != "object":
raise ValueError(
"StringArray requires a sequence of strings or pandas.NA. Got "
f"'{self._ndarray.dtype}' dtype instead."
)
# Check to see if need to convert Na values to pd.NA
if self._ndarray.ndim > 2:
# Ravel if ndims > 2 b/c no cythonized version available
lib.convert_nans_to_NA(self._ndarray.ravel("K"))
else:
lib.convert_nans_to_NA(self._ndarray)
else:
lib.convert_nans_to_NA(self._ndarray)
# Validate that we only store NaN or strings.
if len(self._ndarray) and not lib.is_string_array(
self._ndarray, skipna=True
):
raise ValueError("StringArray requires a sequence of strings or NaN")
if self._ndarray.dtype != "object":
raise ValueError(
"StringArray requires a sequence of strings "
"or NaN. Got '{self._ndarray.dtype}' dtype instead."
)
# TODO validate or force NA/None to NaN

def _validate_scalar(self, value):
# used by NDArrayBackedExtensionIndex.insert
Expand Down Expand Up @@ -729,8 +753,8 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
@classmethod
def _empty(cls, shape, dtype) -> StringArray:
values = np.empty(shape, dtype=object)
values[:] = libmissing.NA
return cls(values).astype(dtype, copy=False)
values[:] = dtype.na_value
return cls(values, dtype=dtype).astype(dtype, copy=False)

def __arrow_array__(self, type=None):
"""
Expand Down Expand Up @@ -930,7 +954,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
if self._hasna:
na_mask = cast("npt.NDArray[np.bool_]", isna(ndarray))
if np.all(na_mask):
return type(self)(ndarray)
return type(self)(ndarray, dtype=self.dtype)
if skipna:
if name == "cumsum":
ndarray = np.where(na_mask, "", ndarray)
Expand Down Expand Up @@ -964,7 +988,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
# Argument 2 to "where" has incompatible type "NAType | float"
np_result = np.where(na_mask, self.dtype.na_value, np_result) # type: ignore[arg-type]

result = type(self)(np_result)
result = type(self)(np_result, dtype=self.dtype)
return result

def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:
Expand Down Expand Up @@ -1043,7 +1067,7 @@ def _cmp_method(self, other, op):
and other.dtype.na_value is libmissing.NA
):
# NA has priority of NaN semantics
return NotImplemented
return op(self.astype(other.dtype, copy=False), other)

if isinstance(other, ArrowExtensionArray):
if isinstance(other, BaseStringArray):
Expand Down Expand Up @@ -1093,29 +1117,3 @@ def _cmp_method(self, other, op):
return res_arr

_arith_method = _cmp_method


class StringArrayNumpySemantics(StringArray):
_storage = "python"
_na_value = np.nan

def _validate(self) -> None:
"""Validate that we only store NaN or strings."""
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True):
raise ValueError(
"StringArrayNumpySemantics requires a sequence of strings or NaN"
)
if self._ndarray.dtype != "object":
raise ValueError(
"StringArrayNumpySemantics requires a sequence of strings or NaN. Got "
f"'{self._ndarray.dtype}' dtype instead."
)
# TODO validate or force NA/None to NaN

@classmethod
def _from_sequence(
cls, scalars, *, dtype: Dtype | None = None, copy: bool = False
) -> Self:
if dtype is None:
dtype = StringDtype(storage="python", na_value=np.nan)
return super()._from_sequence(scalars, dtype=dtype, copy=copy)
40 changes: 20 additions & 20 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays.string_ import StringArrayNumpySemantics
from pandas.core.arrays.string_arrow import (
ArrowStringArray,
)
Expand Down Expand Up @@ -115,7 +114,7 @@ def test_repr(dtype):
arr_name = "ArrowStringArray"
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
elif dtype.storage == "python" and dtype.na_value is np.nan:
arr_name = "StringArrayNumpySemantics"
arr_name = "StringArray"
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
else:
arr_name = "StringArray"
Expand Down Expand Up @@ -433,44 +432,45 @@ def test_comparison_methods_list(comparison_op, dtype):
def test_constructor_raises(cls):
if cls is pd.arrays.StringArray:
msg = "StringArray requires a sequence of strings or pandas.NA"
elif cls is StringArrayNumpySemantics:
msg = "StringArrayNumpySemantics requires a sequence of strings or NaN"
kwargs = {"dtype": pd.StringDtype()}
else:
msg = "Unsupported type '<class 'numpy.ndarray'>' for ArrowExtensionArray"
kwargs = {}

with pytest.raises(ValueError, match=msg):
cls(np.array(["a", "b"], dtype="S1"))
cls(np.array(["a", "b"], dtype="S1"), **kwargs)

with pytest.raises(ValueError, match=msg):
cls(np.array([]))
cls(np.array([]), **kwargs)

if cls is pd.arrays.StringArray or cls is StringArrayNumpySemantics:
if cls is pd.arrays.StringArray:
# GH#45057 np.nan and None do NOT raise, as they are considered valid NAs
# for string dtype
cls(np.array(["a", np.nan], dtype=object))
cls(np.array(["a", None], dtype=object))
cls(np.array(["a", np.nan], dtype=object), **kwargs)
cls(np.array(["a", None], dtype=object), **kwargs)
else:
with pytest.raises(ValueError, match=msg):
cls(np.array(["a", np.nan], dtype=object))
cls(np.array(["a", np.nan], dtype=object), **kwargs)
with pytest.raises(ValueError, match=msg):
cls(np.array(["a", None], dtype=object))
cls(np.array(["a", None], dtype=object), **kwargs)

with pytest.raises(ValueError, match=msg):
cls(np.array(["a", pd.NaT], dtype=object))
cls(np.array(["a", pd.NaT], dtype=object), **kwargs)

with pytest.raises(ValueError, match=msg):
cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object))
cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object), **kwargs)

with pytest.raises(ValueError, match=msg):
cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object))
cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object), **kwargs)


@pytest.mark.parametrize("na", [np.nan, np.float64("nan"), float("nan"), None, pd.NA])
def test_constructor_nan_like(na):
expected = pd.arrays.StringArray(np.array(["a", pd.NA]))
tm.assert_extension_array_equal(
pd.arrays.StringArray(np.array(["a", na], dtype="object")), expected
expected = pd.arrays.StringArray(np.array(["a", pd.NA]), dtype=pd.StringDtype())
result = pd.arrays.StringArray(
np.array(["a", na], dtype="object"), dtype=pd.StringDtype()
)
tm.assert_extension_array_equal(result, expected)


@pytest.mark.parametrize("copy", [True, False])
Expand All @@ -487,10 +487,10 @@ def test_from_sequence_no_mutate(copy, cls, dtype):
expected = cls(
pa.array(na_arr, type=pa.string(), from_pandas=True), dtype=dtype
)
elif cls is StringArrayNumpySemantics:
expected = cls(nan_arr)
elif dtype.na_value is np.nan:
expected = cls(nan_arr, dtype=dtype)
else:
expected = cls(na_arr)
expected = cls(na_arr, dtype=dtype)

tm.assert_extension_array_equal(result, expected)
tm.assert_numpy_array_equal(nan_arr, expected_input)
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/base/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
NumpyExtensionArray,
PeriodArray,
SparseArray,
StringArray,
TimedeltaArray,
)
from pandas.core.arrays.string_ import StringArrayNumpySemantics
from pandas.core.arrays.string_arrow import ArrowStringArray


Expand Down Expand Up @@ -222,7 +222,7 @@ def test_iter_box_period(self):
)
def test_values_consistent(arr, expected_type, dtype, using_infer_string):
if using_infer_string and dtype == "object":
expected_type = ArrowStringArray if HAS_PYARROW else StringArrayNumpySemantics
expected_type = ArrowStringArray if HAS_PYARROW else StringArray
l_values = Series(arr)._values
r_values = pd.Index(arr)._values
assert type(l_values) is expected_type
Expand Down
7 changes: 6 additions & 1 deletion pandas/tests/extension/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,13 @@ def __getitem__(self, item):
def test_ellipsis_index():
# GH#42430 1D slices over extension types turn into N-dimensional slices
# over ExtensionArrays
dtype = pd.StringDtype()
df = pd.DataFrame(
{"col1": CapturingStringArray(np.array(["hello", "world"], dtype=object))}
{
"col1": CapturingStringArray(
np.array(["hello", "world"], dtype=object), dtype=dtype
)
}
)
_ = df.iloc[:1]

Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/io/parser/test_upcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
BooleanArray,
FloatingArray,
IntegerArray,
StringArray,
)


Expand Down Expand Up @@ -95,7 +94,7 @@ def test_maybe_upcast_object(val, string_storage):

if string_storage == "python":
exp_val = "c" if val == "c" else NA
expected = StringArray(np.array(["a", "b", exp_val], dtype=np.object_))
expected = pd.array(["a", "b", exp_val], dtype=pd.StringDtype())
else:
exp_val = "c" if val == "c" else None
expected = ArrowStringArray(pa.array(["a", "b", exp_val]))
Expand Down
11 changes: 3 additions & 8 deletions pandas/tests/io/test_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pandas as pd
from pandas import read_orc
import pandas._testing as tm
from pandas.core.arrays import StringArray

pytest.importorskip("pyarrow.orc")

Expand Down Expand Up @@ -368,13 +367,9 @@ def test_orc_dtype_backend_numpy_nullable():

expected = pd.DataFrame(
{
"string": StringArray(np.array(["a", "b", "c"], dtype=np.object_)),
"string_with_nan": StringArray(
np.array(["a", pd.NA, "c"], dtype=np.object_)
),
"string_with_none": StringArray(
np.array(["a", pd.NA, "c"], dtype=np.object_)
),
"string": pd.array(["a", "b", "c"], dtype=pd.StringDtype()),
"string_with_nan": pd.array(["a", pd.NA, "c"], dtype=pd.StringDtype()),
"string_with_none": pd.array(["a", pd.NA, "c"], dtype=pd.StringDtype()),
"int": pd.Series([1, 2, 3], dtype="Int64"),
"int_with_nan": pd.Series([1, pd.NA, 3], dtype="Int64"),
"na_only": pd.Series([pd.NA, pd.NA, pd.NA], dtype="Int64"),
Expand Down
Loading