Skip to content

Commit

Permalink
ENH(string dtype): Make str.decode return str dtype (#60709)
Browse files Browse the repository at this point in the history
* TST(string dtype): Make str.decode return str dtype

* Test fixups

* pytables fixup

* Simplify

* whatsnew

* fix implementation
  • Loading branch information
rhshadrach authored Jan 29, 2025
1 parent c430c61 commit c36da3f
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 18 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Other enhancements
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
updated to work correctly with NumPy >= 2 (:issue:`57739`)
- :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`)
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
Expand Down
10 changes: 7 additions & 3 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import numpy as np

from pandas._config import get_option

from pandas._libs import lib
from pandas._typing import (
AlignJoin,
Expand Down Expand Up @@ -400,7 +402,9 @@ def cons_row(x):
# This is a mess.
_dtype: DtypeObj | str | None = dtype
vdtype = getattr(result, "dtype", None)
if self._is_string:
if _dtype is not None:
pass
elif self._is_string:
if is_bool_dtype(vdtype):
_dtype = result.dtype
elif returns_string:
Expand Down Expand Up @@ -2141,9 +2145,9 @@ def decode(self, encoding, errors: str = "strict"):
decoder = codecs.getdecoder(encoding)
f = lambda x: decoder(x, errors)[0]
arr = self._data.array
# assert isinstance(arr, (StringArray,))
result = arr._str_map(f)
return self._wrap_result(result)
dtype = "str" if get_option("future.infer_string") else None
return self._wrap_result(result, dtype=dtype)

@forbid_nonstring_types(["bytes"])
def encode(self, encoding, errors: str = "strict"):
Expand Down
4 changes: 3 additions & 1 deletion pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5233,7 +5233,9 @@ def _unconvert_string_array(
dtype = f"U{itemsize}"

if isinstance(data[0], bytes):
data = Series(data, copy=False).str.decode(encoding, errors=errors)._values
ser = Series(data, copy=False).str.decode(encoding, errors=errors)
data = ser.to_numpy()
data.flags.writeable = True
else:
data = data.astype(dtype, copy=False).astype(object, copy=False)

Expand Down
6 changes: 6 additions & 0 deletions pandas/io/sas/sas7bdat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import numpy as np

from pandas._config import get_option

from pandas._libs.byteswap import (
read_double_with_byteswap,
read_float_with_byteswap,
Expand Down Expand Up @@ -699,6 +701,7 @@ def _chunk_to_dataframe(self) -> DataFrame:
rslt = {}

js, jb = 0, 0
infer_string = get_option("future.infer_string")
for j in range(self.column_count):
name = self.column_names[j]

Expand All @@ -715,6 +718,9 @@ def _chunk_to_dataframe(self) -> DataFrame:
rslt[name] = pd.Series(self._string_chunk[js, :], index=ix, copy=False)
if self.convert_text and (self.encoding is not None):
rslt[name] = self._decode_string(rslt[name].str)
if infer_string:
rslt[name] = rslt[name].astype("str")

js += 1
else:
self.close()
Expand Down
16 changes: 6 additions & 10 deletions pandas/tests/io/sas/test_sas7bdat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat._constants import (
IS64,
WASM,
Expand All @@ -20,10 +18,6 @@

from pandas.io.sas.sas7bdat import SAS7BDATReader

pytestmark = pytest.mark.xfail(
using_string_dtype(), reason="TODO(infer_string)", strict=False
)


@pytest.fixture
def dirpath(datapath):
Expand Down Expand Up @@ -246,11 +240,13 @@ def test_zero_variables(datapath):
pd.read_sas(fname)


def test_zero_rows(datapath):
@pytest.mark.parametrize("encoding", [None, "utf8"])
def test_zero_rows(datapath, encoding):
# GH 18198
fname = datapath("io", "sas", "data", "zero_rows.sas7bdat")
result = pd.read_sas(fname)
expected = pd.DataFrame([{"char_field": "a", "num_field": 1.0}]).iloc[:0]
result = pd.read_sas(fname, encoding=encoding)
str_value = b"a" if encoding is None else "a"
expected = pd.DataFrame([{"char_field": str_value, "num_field": 1.0}]).iloc[:0]
tm.assert_frame_equal(result, expected)


Expand Down Expand Up @@ -409,7 +405,7 @@ def test_0x40_control_byte(datapath):
fname = datapath("io", "sas", "data", "0x40controlbyte.sas7bdat")
df = pd.read_sas(fname, encoding="ascii")
fname = datapath("io", "sas", "data", "0x40controlbyte.csv")
df0 = pd.read_csv(fname, dtype="object")
df0 = pd.read_csv(fname, dtype="str")
tm.assert_frame_equal(df, df0)


Expand Down
9 changes: 5 additions & 4 deletions pandas/tests/strings/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def test_repeat_with_null(any_string_dtype, arg, repeat):

def test_empty_str_methods(any_string_dtype):
empty_str = empty = Series(dtype=any_string_dtype)
empty_inferred_str = Series(dtype="str")
if is_object_or_nan_string_dtype(any_string_dtype):
empty_int = Series(dtype="int64")
empty_bool = Series(dtype=bool)
Expand Down Expand Up @@ -154,7 +155,7 @@ def test_empty_str_methods(any_string_dtype):
tm.assert_series_equal(empty_str, empty.str.rstrip())
tm.assert_series_equal(empty_str, empty.str.wrap(42))
tm.assert_series_equal(empty_str, empty.str.get(0))
tm.assert_series_equal(empty_object, empty_bytes.str.decode("ascii"))
tm.assert_series_equal(empty_inferred_str, empty_bytes.str.decode("ascii"))
tm.assert_series_equal(empty_bytes, empty.str.encode("ascii"))
# ismethods should always return boolean (GH 29624)
tm.assert_series_equal(empty_bool, empty.str.isalnum())
Expand Down Expand Up @@ -566,7 +567,7 @@ def test_string_slice_out_of_bounds(any_string_dtype):
def test_encode_decode(any_string_dtype):
ser = Series(["a", "b", "a\xe4"], dtype=any_string_dtype).str.encode("utf-8")
result = ser.str.decode("utf-8")
expected = ser.map(lambda x: x.decode("utf-8")).astype(object)
expected = Series(["a", "b", "a\xe4"], dtype="str")
tm.assert_series_equal(result, expected)


Expand Down Expand Up @@ -596,7 +597,7 @@ def test_decode_errors_kwarg():
ser.str.decode("cp1252")

result = ser.str.decode("cp1252", "ignore")
expected = ser.map(lambda x: x.decode("cp1252", "ignore")).astype(object)
expected = ser.map(lambda x: x.decode("cp1252", "ignore")).astype("str")
tm.assert_series_equal(result, expected)


Expand Down Expand Up @@ -751,5 +752,5 @@ def test_get_with_dict_label():
def test_series_str_decode():
# GH 22613
result = Series([b"x", b"y"]).str.decode(encoding="UTF-8", errors="strict")
expected = Series(["x", "y"], dtype="object")
expected = Series(["x", "y"], dtype="str")
tm.assert_series_equal(result, expected)

0 comments on commit c36da3f

Please sign in to comment.