-
-
Notifications
You must be signed in to change notification settings - Fork 18.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Enable pytables to round-trip with StringDtype #60663
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
writers as libwriters, | ||
) | ||
from pandas._libs.lib import is_string_array | ||
from pandas._libs.missing import NA | ||
from pandas._libs.tslibs import timezones | ||
from pandas.compat._optional import import_optional_dependency | ||
from pandas.compat.pickle_compat import patch_pickle | ||
|
@@ -86,12 +87,16 @@ | |
PeriodArray, | ||
) | ||
from pandas.core.arrays.datetimes import tz_to_dtype | ||
from pandas.core.arrays.string_ import BaseStringArray | ||
import pandas.core.common as com | ||
from pandas.core.computation.pytables import ( | ||
PyTablesExpr, | ||
maybe_expression, | ||
) | ||
from pandas.core.construction import extract_array | ||
from pandas.core.construction import ( | ||
array as pd_array, | ||
extract_array, | ||
) | ||
from pandas.core.indexes.api import ensure_index | ||
|
||
from pandas.io.common import stringify_path | ||
|
@@ -3023,6 +3028,18 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None | |
|
||
if isinstance(node, tables.VLArray): | ||
ret = node[0][start:stop] | ||
dtype = getattr(attrs, "value_type", None) | ||
if dtype is not None: | ||
if dtype == "str[python]": | ||
dtype = StringDtype("python", np.nan) | ||
elif dtype == "string[python]": | ||
dtype = StringDtype("python", NA) | ||
elif dtype == "str[pyarrow]": | ||
dtype = StringDtype("pyarrow", np.nan) | ||
else: | ||
assert dtype == "string[pyarrow]" | ||
dtype = StringDtype("pyarrow", NA) | ||
ret = pd_array(ret, dtype=dtype) | ||
else: | ||
dtype = getattr(attrs, "value_type", None) | ||
shape = getattr(attrs, "shape", None) | ||
|
@@ -3262,6 +3279,19 @@ def write_array( | |
elif lib.is_np_dtype(value.dtype, "m"): | ||
self._handle.create_array(self.group, key, value.view("i8")) | ||
getattr(self.group, key)._v_attrs.value_type = "timedelta64" | ||
elif isinstance(value, BaseStringArray): | ||
vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) | ||
vlarr.append(value.to_numpy()) | ||
node = getattr(self.group, key) | ||
if value.dtype == StringDtype("python", np.nan): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need all the branches here or can you just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as #60663 (comment) |
||
node._v_attrs.value_type = "str[python]" | ||
elif value.dtype == StringDtype("python", NA): | ||
node._v_attrs.value_type = "string[python]" | ||
elif value.dtype == StringDtype("pyarrow", np.nan): | ||
node._v_attrs.value_type = "str[pyarrow]" | ||
else: | ||
assert value.dtype == StringDtype("pyarrow", NA) | ||
node._v_attrs.value_type = "string[pyarrow]" | ||
elif empty_array: | ||
self.write_array_empty(key, value) | ||
else: | ||
|
@@ -3294,7 +3324,11 @@ def read( | |
index = self.read_index("index", start=start, stop=stop) | ||
values = self.read_array("values", start=start, stop=stop) | ||
result = Series(values, index=index, name=self.name, copy=False) | ||
if using_string_dtype() and is_string_array(values, skipna=True): | ||
if ( | ||
using_string_dtype() | ||
and isinstance(values, np.ndarray) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm assuming There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct - datetime and string currently. |
||
and is_string_array(values, skipna=True) | ||
): | ||
result = result.astype(StringDtype(na_value=np.nan)) | ||
return result | ||
|
||
|
@@ -3363,7 +3397,11 @@ def read( | |
|
||
columns = items[items.get_indexer(blk_items)] | ||
df = DataFrame(values.T, columns=columns, index=axes[1], copy=False) | ||
if using_string_dtype() and is_string_array(values, skipna=True): | ||
if ( | ||
using_string_dtype() | ||
and isinstance(values, np.ndarray) | ||
and is_string_array(values, skipna=True) | ||
): | ||
df = df.astype(StringDtype(na_value=np.nan)) | ||
dfs.append(df) | ||
|
||
|
@@ -4737,9 +4775,13 @@ def read( | |
df = DataFrame._from_arrays([values], columns=cols_, index=index_) | ||
if not (using_string_dtype() and values.dtype.kind == "O"): | ||
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype) | ||
if using_string_dtype() and is_string_array( | ||
values, # type: ignore[arg-type] | ||
skipna=True, | ||
if ( | ||
using_string_dtype() | ||
and isinstance(values, np.ndarray) | ||
and is_string_array( | ||
values, # type: ignore[arg-type] | ||
skipna=True, | ||
) | ||
): | ||
df = df.astype(StringDtype(na_value=np.nan)) | ||
frames.append(df) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if there is a better approach here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StringDtype.construct_from_string
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wouldn't allow round-tripping if you e.g. write out a Python-backed string with NaN-semantics, and read it in an environment with PyArrow installed.