Skip to content

Commit

Permalink
Merge pull request #883 from xylar/smooth-salinity-restoring
Browse files Browse the repository at this point in the history
Add smoothing to remapping of SSS restoring
  • Loading branch information
xylar authored Jan 14, 2025
2 parents 13959a3 + a677010 commit 9accc96
Show file tree
Hide file tree
Showing 11 changed files with 486 additions and 124 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import os
import pathlib

import numpy as np
import xarray as xr
from mpas_tools.io import write_netcdf
from pyremap import LatLonGridDescriptor, MpasCellMeshDescriptor, Remapper
from mpas_tools.logging import check_call
from pyremap import MpasCellMeshDescriptor

from compass.io import symlink
from compass.ocean.tests.global_ocean.files_for_e3sm.files_for_e3sm_step import ( # noqa: E501
FilesForE3SMStep,
)
from compass.parallel import run_command


class RemapSeaSurfaceSalinityRestoring(FilesForE3SMStep):
"""
A step for for remapping sea surface salinity (SSS) from the Polar science
center Hydrographic Climatology (PHC) to the current MPAS mesh
A step for for remapping sea surface salinity (SSS) from WOA23 to the
current MPAS mesh
"""
def __init__(self, test_case):
"""
Expand All @@ -26,11 +30,14 @@ def __init__(self, test_case):
"""
super().__init__(test_case,
name='remap_sea_surface_salinity_restoring',
ntasks=512, min_tasks=1)
ntasks=512, min_tasks=128)

self.add_input_file(
filename='woa23_decav_0.25_sss_monthly_extrap.20241101.nc',
target='woa23_decav_0.25_sss_monthly_extrap.20241101.nc',
target='woa23_decav_ne300_sss_monthly_extrap.20250114.nc',
database='initial_condition_database')

self.add_input_file(
target='ne300_20250114.scrip.nc',
database='initial_condition_database')

self.add_output_file(filename='sss.WOA23_monthlyClimatology.nc')
Expand All @@ -40,104 +47,167 @@ def run(self):
Run this step of the test case
"""
super().run()
logger = self.logger
config = self.config
ntasks = self.ntasks

in_filename = 'woa23_decav_0.25_sss_monthly_extrap.20241101.nc'
in_filename = self.inputs[0]
src_scrip_filename = self.inputs[1]

prefix = 'sss.WOA23_monthlyClimatology'
suffix = f'{self.mesh_short_name}.{self.creation_date}'

remapped_filename = f'{prefix}.nc'
out_filename = f'{prefix}.nc'
dest_filename = f'{prefix}.{suffix}.nc'

parallel_executable = config.get('parallel', 'parallel_executable')

mesh_filename = 'restart.nc'
mesh_short_name = self.mesh_short_name
mesh_name = self.mesh_short_name

remap_sss(in_filename, mesh_filename, mesh_short_name,
remapped_filename, logger=logger, mpi_tasks=ntasks,
parallel_executable=parallel_executable)
target_scrip_filename = self._create_target_scrip_file(
mesh_filename, mesh_name)

symlink(
os.path.abspath(remapped_filename),
f'{self.ocean_inputdata_dir}/{dest_filename}')
mapping_filename = \
f'map_ne300_to_{mesh_name}_mbtraave.nc'

stem = pathlib.Path(out_filename).stem
remap_filename = f'{stem}_after_remap.nc'

def remap_sss(in_filename, mesh_filename, mesh_name, out_filename, logger,
mapping_directory='.', method='bilinear', mpi_tasks=1,
parallel_executable=None):
"""
Remap sea surface salinity (SSS) from the Polar science center
Hydrographic Climatology (PHC) to the current MPAS mesh
src_partition_filename = self._partition_scrip_file(
src_scrip_filename)
target_partition_filename = self._partition_scrip_file(
target_scrip_filename)
self._create_weights_tempest(src_partition_filename,
target_partition_filename,
mapping_filename)
self._remap_to_target(in_filename, remap_filename, mapping_filename)

Parameters
----------
in_filename : str
The original PHC sea surface salinity file
self._modify_remapped_sss(remap_filename, out_filename)

mesh_filename : str
The MPAS mesh
mesh_name : str
The name of the mesh (e.g. oEC60to30wISC), used in the name of the
mapping file
out_filename : str
An output file to write the remapped climatology of SSS to
logger : logging.Logger
A logger for output from the step
mapping_directory : str
The directory where the mapping file should be stored (if it is to be
computed) or where it already exists (if not)
method : {'bilinear', 'neareststod', 'conserve'}, optional
The method of interpolation used, see documentation for
`ESMF_RegridWeightGen` for details.
mpi_tasks : int, optional
The number of MPI tasks to use to compute the mapping file
symlink(
os.path.abspath(out_filename),
f'{self.ocean_inputdata_dir}/{dest_filename}')

parallel_executable : {'srun', 'mpirun'}, optional
The name of the parallel executable to use to launch ESMF tools.
But default, 'mpirun' from the conda environment is used
"""
def _create_target_scrip_file(self, target_mesh_filename, mesh_name):
"""
Create target SCRIP file from MPAS mesh file
"""
logger = self.logger
logger.info('Create target SCRIP file')

logger.info('Creating the source grid descriptor...')
src_descriptor = LatLonGridDescriptor.read(fileName=in_filename)
src_mesh_name = src_descriptor.meshName
config = self.config
section = config['files_for_e3sm']
min_lat = np.deg2rad(section.getfloat('sss_smoothing_min_lat'))
max_dist = section.getfloat('sss_smoothing_max_dist')

scrip_filename = f'{mesh_name}.scrip.nc'

ds_mesh = xr.open_dataset(target_mesh_filename)
lat_cell = ds_mesh.latCell

expand_dist = xr.zeros_like(lat_cell)
mask = lat_cell >= min_lat
# goes from 1 at the North pole to zero at min_lat
alpha = (lat_cell - min_lat) / (0.5 * np.pi - min_lat)
expand_dist[mask] = alpha[mask] * max_dist

ds_out = xr.Dataset()
ds_out['expandDist'] = expand_dist
write_netcdf(ds_out, 'expandDist.nc')

descriptor = MpasCellMeshDescriptor(
fileName=target_mesh_filename,
meshName=mesh_name,
)
descriptor.to_scrip(
scrip_filename,
expandDist=expand_dist
)

logger.info(' Done.')
return scrip_filename

def _partition_scrip_file(self, in_filename):
"""
Partition SCRIP file for parallel mbtempest use
"""
logger = self.logger
logger.info('Partition SCRIP file')

stem = pathlib.Path(in_filename).stem
h5m_filename = f'{stem}.h5m'
part_filename = f'{stem}.p{self.ntasks}.h5m'

# Convert source SCRIP to mbtempest
args = [
'mbconvert', '-B',
in_filename,
h5m_filename,
]
check_call(args, logger)

# Partition source SCRIP
args = [
'mbpart', f'{self.ntasks}',
'-z', 'RCB',
h5m_filename,
part_filename,
]
check_call(args, logger)

logger.info(' Done.')
return part_filename

def _create_weights_tempest(self, src_partition_filename,
target_partition_filename,
mapping_filename):
"""
Create mapping weights file using TempestRemap
"""
logger = self.logger
logger.info('Create weights file')

dst_descriptor = MpasCellMeshDescriptor(mesh_filename, mesh_name)
args = [
'mbtempest', '--type', '5',
'--load', src_partition_filename,
'--load', target_partition_filename,
'--file', mapping_filename,
'--weights', '--gnomonic',
'--boxeps', '1e-9',
]

mapping_filename = \
f'{mapping_directory}/map_{src_mesh_name}_to_{mesh_name}_{method}.nc'
run_command(
args, self.min_cpus_per_task, self.ntasks,
self.openmp_threads, self.config, self.logger
)

logger.info(f'Creating the mapping file {mapping_filename}...')
remapper = Remapper(src_descriptor, dst_descriptor, mapping_filename)
logger.info(' Done.')

remapper.build_mapping_file(method=method, mpiTasks=mpi_tasks,
tempdir=mapping_directory, logger=logger,
esmf_parallel_exec=parallel_executable)
logger.info('done.')
def _remap_to_target(self, in_filename, remap_filename, mapping_filename):
"""
Remap SSS onto MPAS target mesh
"""
logger = self.logger
logger.info('Remap to target')

logger.info('Remapping...')
name, ext = os.path.splitext(out_filename)
remap_filename = f'{name}_after_remap{ext}'
remapper.remap_file(inFileName=in_filename, outFileName=remap_filename,
logger=logger)
# Build command args
args = [
'ncremap',
'-m', mapping_filename,
'--vrb=1',
in_filename, remap_filename,
]
check_call(args, logger)

ds = xr.open_dataset(remap_filename)
logger.info('Removing lat/lon bounds variables...')
drop = [var for var in ds if 'nv' in ds[var].dims]
ds = ds.drop_vars(drop)
logger.info('Renaming dimensions and variables...')
rename = dict(ncol='nCells',
SALT='surfaceSalinityMonthlyClimatologyValue')
ds = ds.rename(rename)
write_netcdf(ds, out_filename)
logger.info(' Done.')

logger.info('done.')
def _modify_remapped_sss(self, remap_filename, out_filename):
"""
Modify remapped SSS
"""
logger = self.logger
ds = xr.open_dataset(remap_filename)
logger.info('Removing lat/lon bounds variables...')
drop = [var for var in ds if 'nv' in ds[var].dims]
ds = ds.drop_vars(drop)
logger.info('Renaming dimensions and variables...')
rename = dict(ncol='nCells',
SALT='surfaceSalinityMonthlyClimatologyValue')
ds = ds.rename(rename)
write_netcdf(ds, out_filename)
10 changes: 10 additions & 0 deletions compass/ocean/tests/global_ocean/global_ocean.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,13 @@ with_ice_shelf_cavities = autodetect

# whether to write out sea-ice partition info for plotting in paraview
plot_seaice_partitions = False

# Config options related to smoothing of sea-surface salinity during remapping.
# The smoothing distance increases linearly from zero at sss_smoothing_min_lat
# to its maximum value at the north pole.
#
# the minimum latitude (degrees) for smoothing
sss_smoothing_min_lat = 70

# the maximum smoothing distance (meters)
sss_smoothing_max_dist = 1000e3
14 changes: 6 additions & 8 deletions compass/ocean/tests/utility/create_salin_restoring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from compass.ocean.tests.utility.create_salin_restoring.extrap_salin import (
ExtrapSalin,
)
from compass.ocean.tests.utility.create_salin_restoring.salinity_restoring import ( # noqa: E501
Salinity,
)
from compass.ocean.tests.utility.create_salin_restoring.combine import Combine
from compass.ocean.tests.utility.create_salin_restoring.extrap import Extrap
from compass.ocean.tests.utility.create_salin_restoring.remap import Remap
from compass.testcase import TestCase


Expand All @@ -25,5 +22,6 @@ def __init__(self, test_group):
"""
super().__init__(test_group=test_group, name='create_salin_restoring')

self.add_step(Salinity(test_case=self))
self.add_step(ExtrapSalin(test_case=self))
self.add_step(Combine(test_case=self))
self.add_step(Extrap(test_case=self))
self.add_step(Remap(test_case=self))
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from compass.step import Step


class Salinity(Step):
class Combine(Step):
"""
A step for combining January through December sea surface salinity
data into a single file for salinity restoring in G-cases.
Expand All @@ -14,15 +14,14 @@ class Salinity(Step):

def __init__(self, test_case):
"""
Create a new step
Create the step
Parameters
----------
test_case : compass.ocean.tests.utility.create_salin_restoring.
CreateSalinRestoring
The test case this step belongs to
"""
super().__init__(test_case, name='salinity_restoring', ntasks=1,
test_case : compass.ocean.tests.utility.create_salin_restoring.CreateSalinRestoring
The test case this step belongs to
""" # noqa: E501
super().__init__(test_case, name='combine', ntasks=1,
min_tasks=1)
self.add_output_file(filename='woa_surface_salinity_monthly.nc')

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# config options related creating salinity restoring
[salinity_restoring]

# target resolution (NExxx)
resolution = 300
method = bilinear

# threshold for masks below which interpolated variables are not renormalized
renorm_thresh = 1e-3

# the target and minimum number of MPI tasks to use in remapping
ntasks = 1280
min_tasks = 256
Loading

0 comments on commit 9accc96

Please sign in to comment.