Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ Deprecations
Bug Fixes
~~~~~~~~~

- The zarr backend now writes boolean arrays with native ``bool`` dtype instead
of converting them to ``int8``. Zarr supports ``bool`` natively, so the
``BooleanCoder`` (which was designed for NetCDF compatibility) is now skipped
for zarr writes. Existing zarr stores written with the old ``int8`` encoding
are still read correctly. (:issue:`2937`, :pull:`11318`)
By `Evan Lyall <https://github.com/elyall>`_.
- Fix a major performance regression in :py:meth:`Coordinates.to_index` (and
consequently :py:meth:`Dataset.to_dataframe`) caused by converting the cached
code ndarrays into Python lists (:issue:`11305`).
Expand Down
7 changes: 6 additions & 1 deletion xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,12 @@ def encode_zarr_variable(var, needs_copy=True, name=None):
A variable which has been encoded as described above.
"""

var = conventions.encode_cf_variable(var, name=name)
coders = [
c
for c in conventions._default_encode_cf_coders()
if not isinstance(c, coding.variables.BooleanCoder)
]
var = conventions.encode_cf_variable(var, name=name, coders=coders)
var = ensure_dtype_not_object(var, name=name)

# zarr allows unicode, but not variable-length strings, so it's both
Expand Down
32 changes: 21 additions & 11 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,22 @@ def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
)


def _default_encode_cf_coders():
"""Return the default list of coders used by encode_cf_variable."""
return [
CFDatetimeCoder(),
CFTimedeltaCoder(),
variables.CFScaleOffsetCoder(),
variables.CFMaskCoder(),
variables.NativeEnumCoder(),
variables.NonStringCoder(),
variables.DefaultFillvalueCoder(),
variables.BooleanCoder(),
]


def encode_cf_variable(
var: Variable, needs_copy: bool = True, name: T_Name = None
var: Variable, needs_copy: bool = True, name: T_Name = None, coders=None
) -> Variable:
"""
Converts a Variable into a Variable which follows some
Expand All @@ -81,6 +95,8 @@ def encode_cf_variable(
----------
var : Variable
A variable holding un-encoded data.
coders : list of VariableCoder, optional
List of coders to apply. If None, uses the default CF coder chain.

Returns
-------
Expand All @@ -89,16 +105,10 @@ def encode_cf_variable(
"""
ensure_not_multiindex(var, name=name)

for coder in [
CFDatetimeCoder(),
CFTimedeltaCoder(),
variables.CFScaleOffsetCoder(),
variables.CFMaskCoder(),
variables.NativeEnumCoder(),
variables.NonStringCoder(),
variables.DefaultFillvalueCoder(),
variables.BooleanCoder(),
]:
if coders is None:
coders = _default_encode_cf_coders()

for coder in coders:
var = coder.encode(var, name=name)

for attr_name in CF_RELATED_DATA:
Expand Down
54 changes: 54 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2709,6 +2709,60 @@ def roundtrip(
async def test_load_async(self) -> None:
await super().test_load_async()

def test_roundtrip_boolean_dtype(self) -> None:
original = create_boolean_data()
assert original["x"].dtype == "bool"
with self.create_zarr_target() as store_target:
self.save(original, store_target, consolidated=False)
# Verify on-disk zarr array uses native bool dtype (not int8)
zg = zarr.open_group(store_target, mode="r")
zarr_arr = zg["x"]
assert isinstance(zarr_arr, zarr.Array)
assert zarr_arr.dtype == np.dtype("bool")
assert "dtype" not in zarr_arr.attrs
with self.open(
store_target, backend_kwargs={"consolidated": False}
) as actual:
assert_identical(original, actual)
assert actual["x"].dtype == "bool"
# Verify second roundtrip also preserves bool
with self.roundtrip(actual) as actual2:
assert_identical(original, actual2)
assert actual2["x"].dtype == "bool"

def test_roundtrip_boolean_dtype_legacy_int8(self) -> None:
"""Verify backward compat: old-style int8 + attrs['dtype']='bool' decodes to bool."""
original = create_boolean_data()
with self.create_zarr_target() as store_target:
zg = zarr.open_group(store_target, mode="w")
data_int8 = original["x"].values.astype("i1")
is_v3_format = has_zarr_v3 and zg.metadata.zarr_format == 3
if is_v3_format:
arr = zg.create_array(
"x",
shape=data_int8.shape,
dtype=data_int8.dtype,
fill_value=-1,
dimension_names=("t", "x"),
)
else:
arr = zg.create_array(
"x",
shape=data_int8.shape,
dtype=data_int8.dtype,
fill_value=-1,
)
arr[:] = data_int8
arr.attrs["dtype"] = "bool"
arr.attrs["units"] = "-"
if not is_v3_format:
arr.attrs["_ARRAY_DIMENSIONS"] = ["t", "x"]
with self.open(
store_target, backend_kwargs={"consolidated": False}
) as actual:
assert actual["x"].dtype == "bool"
np.testing.assert_array_equal(actual["x"].values, original["x"].values)

def test_roundtrip_bytes_with_fill_value(self):
pytest.xfail("Broken by Zarr 3.0.7")

Expand Down
25 changes: 25 additions & 0 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,31 @@ def test_booltype_array(self) -> None:
assert_array_equal(bx.transpose((1, 0)), x.transpose((1, 0)))


class TestEncodeCfVariableCoders:
def test_empty_coders_is_identity(self) -> None:
var = Variable(["x"], np.array([True, False, True]), {"units": "test"})
result = conventions.encode_cf_variable(var, coders=[])
assert result.dtype == bool
assert_array_equal(result.values, var.values)

def test_custom_coders_excludes_boolean_coder(self) -> None:
var = Variable(["x"], np.array([True, False, True]))
coders = [
c
for c in conventions._default_encode_cf_coders()
if not isinstance(c, coding.variables.BooleanCoder)
]
result = conventions.encode_cf_variable(var, coders=coders)
assert result.dtype == bool
assert "dtype" not in result.attrs

def test_default_coders_encodes_bool_to_int8(self) -> None:
var = Variable(["x"], np.array([True, False, True]))
result = conventions.encode_cf_variable(var)
assert result.dtype == np.int8
assert result.attrs.get("dtype") == "bool"


class TestNativeEndiannessArray:
def test(self) -> None:
x = np.arange(5, dtype=">i8")
Expand Down
Loading