Skip to content

Commit 4f4fcb9

Browse files
shoyerXarray-Beam authors
authored and
Xarray-Beam authors
committed
Fix split_variables with different dimensions per variable
PiperOrigin-RevId: 536400546
1 parent 71f848d commit 4f4fcb9

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
setuptools.setup(
4444
name='xarray-beam',
45-
version='0.6.1',
45+
version='0.6.2',
4646
license='Apache 2.0',
4747
author='Google LLC',
4848
author_email='[email protected]',

xarray_beam/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@
4545
DatasetToZarr,
4646
)
4747

48-
__version__ = '0.6.1'
48+
__version__ = '0.6.2'

xarray_beam/_src/rechunk.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,10 @@ def split_variables(
460460
# TODO(shoyer): add support for partial splitting, into explicitly provided
461461
# sets of variables
462462
for var_name in dataset:
463-
yield key.replace(vars={var_name}), dataset[[var_name]]
463+
new_dataset = dataset[[var_name]]
464+
offsets = {k: v for k, v in key.offsets.items() if k in new_dataset.dims}
465+
new_key = core.Key(offsets, vars={var_name})
466+
yield new_key, new_dataset
464467

465468

466469
@dataclasses.dataclass

xarray_beam/_src/rechunk_test.py

+27
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,33 @@ def test_consolidate_and_split_variables_only(self):
257257
actual = consolidated | xbeam.SplitVariables()
258258
self.assertIdenticalChunks(actual, split)
259259

260+
def test_split_variables_with_different_dims(self):
261+
inputs = [
262+
(
263+
xbeam.Key({'x': 0, 'y': 0}, vars=None),
264+
xarray.Dataset({
265+
'foo': ('x', np.array([1, 2])),
266+
'bar': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])),
267+
}),
268+
),
269+
]
270+
expected = [
271+
(
272+
xbeam.Key({'x': 0}, vars={'foo'}),
273+
xarray.Dataset({
274+
'foo': ('x', np.array([1, 2])),
275+
}),
276+
),
277+
(
278+
xbeam.Key({'x': 0, 'y': 0}, vars={'bar'}),
279+
xarray.Dataset({
280+
'bar': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])),
281+
}),
282+
),
283+
]
284+
actual = inputs | xbeam.SplitVariables()
285+
self.assertIdenticalChunks(actual, expected)
286+
260287
def test_consolidate_chunks_not_fully_shared_dims(self):
261288
inputs = [
262289
(

0 commit comments

Comments
 (0)