@@ -387,7 +387,8 @@ def entry_point_compute_mpas_transect_masks():
387
387
engine = args .engine )
388
388
389
389
390
- def compute_mpas_flood_fill_mask (dsMesh , fcSeed , logger = None , workers = - 1 ):
390
+ def compute_mpas_flood_fill_mask (dsMesh , fcSeed , daGrow = None , logger = None ,
391
+ workers = - 1 ):
391
392
"""
392
393
Flood fill from the given set of seed points to create a contiguous mask.
393
394
The flood fill operates using cellsOnCell, starting from the cells
@@ -401,6 +402,11 @@ def compute_mpas_flood_fill_mask(dsMesh, fcSeed, logger=None, workers=-1):
401
402
fcSeed : geometric_features.FeatureCollection
402
403
A feature collection containing points at which to start the flood fill
403
404
405
+ daGrow : xarray.DataArray, optional
406
+ A data array of size ``nCells`` with a mask that is 1 anywhere the
407
+ flood fill is allowed to grow. The default is that the mask is all
408
+ ones.
409
+
404
410
logger : logging.Logger, optional
405
411
A logger for the output if not stdout
406
412
@@ -426,17 +432,22 @@ def compute_mpas_flood_fill_mask(dsMesh, fcSeed, logger=None, workers=-1):
426
432
if logger is not None :
427
433
logger .info (' Computing flood fill mask on cells:' )
428
434
429
- mask = _compute_seed_mask (fcSeed , lon , lat , workers )
435
+ seedMask = _compute_seed_mask (fcSeed , lon , lat , workers )
430
436
431
437
cellsOnCell = dsMesh .cellsOnCell .values - 1
432
438
433
- mask = _flood_fill_mask (mask , cellsOnCell )
439
+ if daGrow is not None :
440
+ growMask = daGrow .values
441
+ else :
442
+ growMask = numpy .ones (dsMesh .sizes ['nCells' ])
443
+
444
+ seedMask = _flood_fill_mask (seedMask , growMask , cellsOnCell )
434
445
435
446
if logger is not None :
436
447
logger .info (' Adding masks to dataset...' )
437
448
# create a new data array for the mask
438
449
masksVarName = 'cellSeedMask'
439
- dsMasks [masksVarName ] = (('nCells' ,), numpy .array (mask , dtype = int ))
450
+ dsMasks [masksVarName ] = (('nCells' ,), numpy .array (seedMask , dtype = int ))
440
451
441
452
if logger is not None :
442
453
logger .info (' Done.' )
@@ -1183,30 +1194,31 @@ def _compute_seed_mask(fcSeed, lon, lat, workers):
1183
1194
return mask
1184
1195
1185
1196
1186
- def _flood_fill_mask (mask , cellsOnCell ):
1197
+ def _flood_fill_mask (seedMask , growMask , cellsOnCell ):
1187
1198
"""
1188
1199
Flood fill starting with a mask of seed points
1189
1200
"""
1190
1201
1191
1202
maxNeighbors = cellsOnCell .shape [1 ]
1192
1203
1193
1204
while True :
1194
- neighbors = cellsOnCell [mask == 1 , :]
1205
+ neighbors = cellsOnCell [seedMask == 1 , :]
1195
1206
maskCount = 0
1196
1207
for iNeighbor in range (maxNeighbors ):
1197
1208
indices = neighbors [:, iNeighbor ]
1198
- # we only want to mask valid neighbors and locations that aren't
1199
- # already masked
1209
+ # we only want to mask valid neighbors, locations that aren't
1210
+ # already masked, and locations that we're allowed to flood
1200
1211
indices = indices [indices >= 0 ]
1201
- localMask = mask [indices ] == 0
1212
+ localMask = numpy .logical_and (seedMask [indices ] == 0 ,
1213
+ growMask [indices ] == 1 )
1202
1214
maskCount += numpy .count_nonzero (localMask )
1203
1215
indices = indices [localMask ]
1204
- mask [indices ] = 1
1216
+ seedMask [indices ] = 1
1205
1217
1206
1218
if maskCount == 0 :
1207
1219
break
1208
1220
1209
- return mask
1221
+ return seedMask
1210
1222
1211
1223
1212
1224
def _compute_edge_sign (dsMesh , edgeMask , shape ):
0 commit comments