Skip to content

Commit 676f2ef

Browse files
authored
Merge pull request #641 from xylar/fix-consistent-strlen
Set string length in `write_netcdf()`
2 parents ccf9d95 + 6da84fc commit 676f2ef

File tree

3 files changed

+268
-186
lines changed

3 files changed

+268
-186
lines changed

conda_package/mpas_tools/io.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
default_engine = None
1414
default_char_dim_name = 'StrLen'
1515
default_fills = netCDF4.default_fillvals
16+
default_nchar = 64
1617

1718

1819
def write_netcdf(
@@ -23,6 +24,7 @@ def write_netcdf(
2324
engine=None,
2425
char_dim_name=None,
2526
logger=None,
27+
nchar=None,
2628
):
2729
"""
2830
Write an xarray.Dataset to a file with NetCDF4 fill values and the given
@@ -31,9 +33,9 @@ def write_netcdf(
3133
3234
Note: the ``NETCDF3_64BIT_DATA`` format is handled as a special case
3335
because xarray output with this format is not performant. First, the file
34-
is written in `NETCDF4` format, which supports larger files and variables.
35-
Then, the `ncks` command is used to convert the file to the
36-
`NETCDF3_64BIT_DATA` format.
36+
is written in ``NETCDF4`` format, which supports larger files and
37+
variables. Then, the ``ncks`` command is used to convert the file to the
38+
``NETCDF3_64BIT_DATA`` format.
3739
3840
Note: All int64 variables are automatically converted to int32 for MPAS
3941
compatibility.
@@ -63,15 +65,19 @@ def write_netcdf(
6365
``mpas_tools.io.default_engine``
6466
6567
char_dim_name : str, optional
66-
The name of the dimension used for character strings, or None to let
67-
xarray figure this out. Default is
68+
The name of the dimension used for character strings. Default is
6869
``mpas_tools.io.default_char_dim_name``, which can be modified but
6970
which defaults to ``'StrLen'``
7071
72+
nchar : int, optional
73+
The number of characters to use for string variables. If None, the
74+
default is ``mpas_tools.io.default_nchar``, which can be modified but
75+
which defaults to 64.
76+
7177
logger : logging.Logger, optional
72-
A logger to write messages to write the output of `ncks` conversion
73-
calls to. If None, `ncks` output is suppressed. This is only
74-
relevant if `format` is 'NETCDF3_64BIT_DATA'
78+
A logger to write messages to write the output of ``ncks`` conversion
79+
calls to. If None, ``ncks`` output is suppressed. This is only
80+
relevant if ``format`` is 'NETCDF3_64BIT_DATA'
7581
""" # noqa: E501
7682
if format is None:
7783
format = default_format
@@ -85,31 +91,43 @@ def write_netcdf(
8591
if char_dim_name is None:
8692
char_dim_name = default_char_dim_name
8793

88-
# Convert int64 variables to int32 for MPAS compatibility
89-
for var in list(ds.data_vars.keys()) + list(ds.coords.keys()):
90-
if ds[var].dtype == numpy.int64:
91-
attrs = ds[var].attrs.copy()
92-
ds[var] = ds[var].astype(numpy.int32)
93-
ds[var].attrs = attrs
94+
if nchar is None:
95+
nchar = default_nchar
96+
97+
numpyFillValues = {}
98+
for fillType in fillValues:
99+
# drop string fill values
100+
if not fillType.startswith('S'):
101+
numpyFillValues[numpy.dtype(fillType)] = fillValues[fillType]
94102

95103
encodingDict = {}
96104
variableNames = list(ds.data_vars.keys()) + list(ds.coords.keys())
97105
for variableName in variableNames:
98-
isNumeric = numpy.issubdtype(ds[variableName].dtype, numpy.number)
99-
if isNumeric and numpy.any(numpy.isnan(ds[variableName])):
100-
dtype = ds[variableName].dtype
101-
for fillType in fillValues:
102-
if dtype == numpy.dtype(fillType):
103-
encodingDict[variableName] = {
104-
'_FillValue': fillValues[fillType]
105-
}
106-
break
107-
else:
108-
encodingDict[variableName] = {'_FillValue': None}
109-
110-
isString = numpy.issubdtype(ds[variableName].dtype, numpy.bytes_)
111-
if isString and char_dim_name is not None:
112-
encodingDict[variableName] = {'char_dim_name': char_dim_name}
106+
var = ds[variableName]
107+
encodingDict[variableName] = {}
108+
dtype = var.dtype
109+
110+
# Convert int64 variables to int32 for MPAS compatibility
111+
if dtype == numpy.int64:
112+
encodingDict[variableName]['dtype'] = 'int32'
113+
114+
# add fill values
115+
if dtype in numpyFillValues:
116+
if numpy.any(numpy.isnan(var)):
117+
# only add fill values if they're needed
118+
fill = numpyFillValues[dtype]
119+
else:
120+
fill = None
121+
encodingDict[variableName]['_FillValue'] = fill
122+
123+
isString = numpy.issubdtype(dtype, numpy.bytes_) or numpy.issubdtype(
124+
dtype, numpy.str_
125+
)
126+
if isString:
127+
# set the encoding for string variables
128+
encodingDict[variableName].update(
129+
{'dtype': f'|S{nchar}', 'char_dim_name': char_dim_name}
130+
)
113131

114132
update_history(ds)
115133

conda_package/mpas_tools/mesh/mask.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from shapely.strtree import STRtree
1414

1515
from mpas_tools.cime.constants import constants
16-
from mpas_tools.io import write_netcdf
16+
from mpas_tools.io import default_nchar, write_netcdf
1717
from mpas_tools.logging import LoggingContext
1818
from mpas_tools.parallel import create_pool
1919
from mpas_tools.transects import (
@@ -100,7 +100,7 @@ def compute_mpas_region_masks(
100100

101101
# create shapely geometry for lon and lat
102102
points = [shapely.geometry.Point(x, y) for x, y in zip(lon, lat)]
103-
regionNames, masks, properties, nchar = _compute_region_masks(
103+
regionNames, masks, properties = _compute_region_masks(
104104
fcMask,
105105
points,
106106
logger,
@@ -133,7 +133,6 @@ def compute_mpas_region_masks(
133133
ds=dsMasks,
134134
properties=properties,
135135
dim='nRegions',
136-
nchar=nchar,
137136
)
138137

139138
if logger is not None:
@@ -339,17 +338,15 @@ def compute_mpas_transect_masks(
339338
polygons, nPolygons, duplicatePolygons = _get_polygons(
340339
dsMesh, maskType
341340
)
342-
transectNames, masks, properties, nchar, shapes = (
343-
_compute_transect_masks(
344-
fcMask,
345-
polygons,
346-
logger,
347-
pool,
348-
chunkSize,
349-
showProgress,
350-
subdivisionResolution,
351-
earthRadius,
352-
)
341+
transectNames, masks, properties, shapes = _compute_transect_masks(
342+
fcMask,
343+
polygons,
344+
logger,
345+
pool,
346+
chunkSize,
347+
showProgress,
348+
subdivisionResolution,
349+
earthRadius,
353350
)
354351

355352
if logger is not None:
@@ -393,7 +390,6 @@ def compute_mpas_transect_masks(
393390
ds=dsMasks,
394391
properties=properties,
395392
dim='nTransects',
396-
nchar=nchar,
397393
)
398394

399395
if logger is not None:
@@ -723,7 +719,7 @@ def compute_lon_lat_region_masks(
723719

724720
# create shapely geometry for lon and lat
725721
points = [shapely.geometry.Point(x, y) for x, y in zip(Lon, Lat)]
726-
regionNames, masks, properties, nchar = _compute_region_masks(
722+
regionNames, masks, properties = _compute_region_masks(
727723
fcMask,
728724
points,
729725
logger,
@@ -757,7 +753,6 @@ def compute_lon_lat_region_masks(
757753
ds=dsMasks,
758754
properties=properties,
759755
dim='nRegions',
760-
nchar=nchar,
761756
)
762757

763758
if logger is not None:
@@ -959,7 +954,7 @@ def compute_projection_grid_region_masks(
959954
points = [
960955
shapely.geometry.Point(x, y) for x, y in zip(lon.ravel(), lat.ravel())
961956
]
962-
regionNames, masks, properties, nchar = _compute_region_masks(
957+
regionNames, masks, properties = _compute_region_masks(
963958
fcMask,
964959
points,
965960
logger,
@@ -990,7 +985,6 @@ def compute_projection_grid_region_masks(
990985
ds=dsMasks,
991986
properties=properties,
992987
dim='nRegions',
993-
nchar=nchar,
994988
)
995989

996990
if logger is not None:
@@ -1171,10 +1165,11 @@ def _compute_mask_from_shapes(
11711165
return mask
11721166

11731167

1174-
def _add_properties(ds, properties, dim, nchar):
1168+
def _add_properties(ds, properties, dim):
11751169
"""
11761170
Add properties to the dataset from a dictionary of properties
11771171
"""
1172+
nchar = default_nchar
11781173
for name, prop_list in properties.items():
11791174
if name not in ds:
11801175
if isinstance(prop_list[0], str):
@@ -1186,7 +1181,7 @@ def _add_properties(ds, properties, dim, nchar):
11861181
for index, value in enumerate(prop_list):
11871182
ds[name][index] = value
11881183
else:
1189-
ds[name] = ((dim,), properties[prop_list])
1184+
ds[name] = ((dim,), prop_list)
11901185

11911186

11921187
def _get_region_names_and_properties(fc):
@@ -1208,19 +1203,16 @@ def _get_region_names_and_properties(fc):
12081203
propertyNames.add(propertyName)
12091204

12101205
properties = {}
1211-
nchar = 0
12121206
for propertyName in propertyNames:
12131207
properties[propertyName] = []
12141208
for feature in fc.features:
12151209
if propertyName in feature['properties']:
12161210
propertyVal = feature['properties'][propertyName]
12171211
properties[propertyName].append(propertyVal)
1218-
if isinstance(propertyVal, str):
1219-
nchar = max(nchar, len(propertyVal))
12201212
else:
12211213
properties[propertyName].append('')
12221214

1223-
return regionNames, properties, nchar
1215+
return regionNames, properties
12241216

12251217

12261218
def _compute_region_masks(
@@ -1231,7 +1223,7 @@ def _compute_region_masks(
12311223
a set of regions.
12321224
"""
12331225

1234-
regionNames, properties, nchar = _get_region_names_and_properties(fcMask)
1226+
regionNames, properties = _get_region_names_and_properties(fcMask)
12351227

12361228
masks = []
12371229

@@ -1253,11 +1245,9 @@ def _compute_region_masks(
12531245
showProgress=showProgress,
12541246
)
12551247

1256-
nchar = max(nchar, len(name))
1257-
12581248
masks.append(mask)
12591249

1260-
return regionNames, masks, properties, nchar
1250+
return regionNames, masks, properties
12611251

12621252

12631253
def _contains(shapes, points):
@@ -1355,7 +1345,7 @@ def _compute_transect_masks(
13551345
a set of transects.
13561346
"""
13571347

1358-
transectNames, properties, nchar = _get_region_names_and_properties(fcMask)
1348+
transectNames, properties = _get_region_names_and_properties(fcMask)
13591349

13601350
masks = []
13611351
shapes = []
@@ -1405,12 +1395,10 @@ def _compute_transect_masks(
14051395
showProgress=showProgress,
14061396
)
14071397

1408-
nchar = max(nchar, len(name))
1409-
14101398
masks.append(mask)
14111399
shapes.append(shape)
14121400

1413-
return transectNames, masks, properties, nchar, shapes
1401+
return transectNames, masks, properties, shapes
14141402

14151403

14161404
def _intersects(shape, polygons):

0 commit comments

Comments
 (0)