Skip to content

Commit 7ea9de8

Browse files
committed
Merge branch 'endian_convert'
2 parents 0e920f0 + 465b459 commit 7ea9de8

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

xray/backends/netCDF4_.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from .common import AbstractWritableDataStore
1212
from .netcdf3 import encode_nc3_variable, maybe_convert_to_char_array
1313

14+
# This lookup table maps from dtype.byteorder to a readable endian
15+
# string used by netCDF4.
16+
_endian_lookup = {'=': 'native',
17+
'>': 'big',
18+
'<': 'little',
19+
'|': 'native'}
20+
1421

1522
class NetCDF4ArrayWrapper(NDArrayMixin):
1623
def __init__(self, array):
@@ -83,6 +90,27 @@ def _ensure_fill_value_valid(data, attributes):
8390
attributes['_FillValue'] = np.string_(attributes['_FillValue'])
8491

8592

93+
def _force_native_endianness(var):
94+
# possible values for byteorder are:
95+
# = native
96+
# < little-endian
97+
# > big-endian
98+
# | not applicable
99+
# Below we check if the data type is not native or NA
100+
if var.dtype.byteorder not in ['=', '|']:
101+
# if endianness is specified explicitly, convert to the native type
102+
data = var.values.astype(var.dtype.newbyteorder('='))
103+
var = Variable(var.dims, data, var.attrs, var.encoding)
104+
# if endian exists, remove it from the encoding.
105+
var.encoding.pop('endian', None)
106+
# check to see if encoding has a value for endian its 'native'
107+
if not var.encoding.get('endian', 'native') is 'native':
108+
raise NotImplementedError("Attempt to write non-native endian type, "
109+
"this is not supported by the netCDF4 python "
110+
"library.")
111+
return var
112+
113+
86114
class NetCDF4DataStore(AbstractWritableDataStore):
87115
"""Store for reading and writing data via the Python-NetCDF4 library.
88116
@@ -152,6 +180,9 @@ def set_attribute(self, key, value):
152180

153181
def set_variable(self, name, variable):
154182
attrs = variable.attrs.copy()
183+
184+
variable = _force_native_endianness(variable)
185+
155186
if self.format == 'NETCDF4':
156187
variable, datatype = _nc4_values_and_dtype(variable)
157188
else:
@@ -167,6 +198,8 @@ def set_variable(self, name, variable):
167198
fill_value = None
168199

169200
encoding = variable.encoding
201+
data = variable.values
202+
170203
nc4_var = self.ds.createVariable(
171204
varname=name,
172205
datatype=datatype,
@@ -177,11 +210,11 @@ def set_variable(self, name, variable):
177210
fletcher32=encoding.get('fletcher32', False),
178211
contiguous=encoding.get('contiguous', False),
179212
chunksizes=encoding.get('chunksizes'),
180-
endian=encoding.get('endian', 'native'),
213+
endian='native',
181214
least_significant_digit=encoding.get('least_significant_digit'),
182215
fill_value=fill_value)
183216
nc4_var.set_auto_maskandscale(False)
184-
nc4_var[:] = variable.values
217+
nc4_var[:] = data
185218
for k, v in iteritems(attrs):
186219
# set attributes one-by-one since netCDF4<1.0.10 can't handle
187220
# OrderedDict as the input to setncatts

xray/test/test_backends.py

+18
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,24 @@ def test_variable_len_strings(self):
444444
with open_dataset(tmp_file, **kwargs) as actual:
445445
self.assertDatasetIdentical(expected, actual)
446446

447+
def test_roundtrip_endian(self):
448+
ds = Dataset({'x': np.arange(3, 10, dtype='>i2'),
449+
'y': np.arange(3, 20, dtype='<i4'),
450+
'z': np.arange(3, 30, dtype='=i8'),
451+
'w': ('x', np.arange(3, 10, dtype=np.float))})
452+
453+
with self.roundtrip(ds) as actual:
454+
# technically these datasets are slightly different,
455+
# one hold mixed endian data (ds) the other should be
456+
# all big endian (actual). assertDatasetIdentical
457+
# should still pass though.
458+
self.assertDatasetIdentical(ds, actual)
459+
460+
ds['z'].encoding['endian'] = 'big'
461+
with self.assertRaises(NotImplementedError):
462+
with self.roundtrip(ds) as actual:
463+
pass
464+
447465
def test_roundtrip_character_array(self):
448466
with create_tmp_file() as tmp_file:
449467
values = np.array([['a', 'b', 'c'], ['d', 'e', 'f']], dtype='S')

0 commit comments

Comments
 (0)