Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remapping weights script #45

Merged
merged 7 commits into from
Jan 23, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions mesh_generation/generate_rof_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
# SPDX-License-Identifier: Apache-2.0

# =========================================================================================
# Generate an remapping weights between two ESMF mesh files for remapping a runoff field
# an unmasked mesh to a masked mesh without loosing any water volume. Each field on the
# unmasked mesh is mapped to the nearest ocean cell in the resulting weights.
#
# To run:
# python generate_rof_weights.py --mesh_filename=<input_file> --weights_filename=<output_file>
#
# This script currently supports mesh files in the ESMF unstructed mesh format.
#
# There is not enough memory on the gadi login node to run this, its simplest to run in
# a terminal through are.nci.org.au
#
# The run command and full github url of the current version of this script is added to the
# metadata of the generated weights file. This is to uniquely identify the script and inputs used
# to generate the mesh file. To produce weights files for sharing, ensure you are using a version
# of this script which is committed and pushed to github. For mesh files intended for released
# configurations, use the latest version checked in to the main branch of the github repository.
#
# Contact:
# Anton Steketee <[email protected]>
#
# Dependencies:
# esmpy, xarray and scipy
# =========================================================================================


import xarray as xr
import esmpy
from scipy.spatial import KDTree
from copy import copy

from pathlib import Path
import sys
import os
from datetime import datetime

path_root = Path(__file__).parents[1]
sys.path.append(str(path_root))
from scripts_common import get_provenance_metadata, md5sum

TEMP_WEIGHTS_F = "temp_weights.nc"
COMP_ENCODING = {"complevel": 1, "compression": "zlib"} # compression settings to use


def drof_remapping_weights(mesh_filename, weights_filename, global_attrs=None):
# We need to generate remapping weights for use in the mediator, such that the overall volume of runoff is conserved and no run-off is mapped onto land cells. Inside the mediator, the grid doesn't change as we run the mediator with the ocean grid (the DROF component does the remapping from JRA grid to mediator grid). There we use the same _mesh_file for the input and output mesh, however this same routine would work for differeing input and output meshes

model_mesh = esmpy.Mesh(
filename=mesh_filename,
filetype=esmpy.FileFormat.ESMFMESH,
)

med_in_fld = esmpy.Field(model_mesh, meshloc=esmpy.MeshLoc.ELEMENT)

med_out_fld = esmpy.Field(model_mesh, meshloc=esmpy.MeshLoc.ELEMENT)

try:
os.remove(TEMP_WEIGHTS_F) # rm old temp file
except:
None

# Generate remapping weights and write to file.
esmpy.Regrid(
med_in_fld,
med_out_fld,
filename=TEMP_WEIGHTS_F,
regrid_method=esmpy.RegridMethod.CONSERVE,
# unmapped_action=esmpy.UnmappedAction.ERROR, #ignore errors about some destination cells not having source cells,
)

"""
From https://earthsystemmodeling.org/docs/release/ESMF_5_2_0rp3/ESMF_refdoc/node3.html :

" The indices and weights generated by ESMF_FieldRegridStore() are stored in the output file as variables col, row and S. Where col and row are the indices to the source and the destination grid cells. These are a one-dimension array with length defined by dimension n_s. S is the weight which is multiplied by the source value indicated by col and then summed with the destination value indicated by row to build the final interpolated value of the destination.

Per the above note, we want to adjust all row values, so they are ocean cells. When we do this, we want to adjust S, the weight to account for the difference in area.
"""

weights_ds = xr.open_dataset(TEMP_WEIGHTS_F)

mod_mesh_ds = xr.open_dataset(mesh_filename)

# Find index for all ocean cells
mask_i = mod_mesh_ds.elementCount.where(mod_mesh_ds.elementMask, drop=True).astype(
"int"
)

# Make a KDTree from the ocean cells
mask_tree = KDTree(mod_mesh_ds.centerCoords.isel(elementCount=mask_i))

# Using the KDTree, look up the nearest ocean cell to every destination grid cell in our weights file. Note our weights are indexed from 1 (i.e. Fortran style) but xarray starts from 0 (i.e. python style), so subract one from our destination grid cell indices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, does "nearest" mean "smallest difference in grid index"? This may not correspond to smallest distance on the sphere, especially at the join between the eastern and western edges, and in the tripolar region.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

centerCoords is the geographic coordinate of the center of the cell. This function returns the index of the cell with the smallest difference in the geographic coordinate (by euclidean distance, not surface of a sphere). It won't wrap from east to western edge but should be ok for everywhere else.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, might need to take a look at whether anything fishy happens at the join - I hope there are no major rivers near 80E in the Indian or Arctic ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the haversine distance (with a small bug fix) successfully maps across the join - see #45 (review)


dd, ii = mask_tree.query(
mod_mesh_ds.centerCoords.isel(elementCount=(weights_ds.row - 1)), workers=-1
)

new_row = mask_i[ii] + 1

# Get the mesh element areas and adjust:
# n.b. per CMEPS we are using the internally calculated areas, not the user provided ones.
med_out_fld.get_area()
area = copy(med_out_fld.data)
old_area = area[weights_ds.row - 1]
new_area = area[new_row - 1]

weights_ds["row"] = xr.DataArray(data=new_row, dims="n_s")

weights_ds["S"] = weights_ds.S * old_area / new_area

# add global attributes
weights_ds.attrs = {
"gridType": "unstructured mesh",
"timeGenerated": f"{{str(datetime.now())[0:16]}}",
"created_by": f"{os.environ.get('USER')}",
"inputFile": f"{mesh_filename} (md5 hash: {md5sum(mesh_filename)})",
}

# add git info to history
if global_attrs:
weights_ds.attrs |= global_attrs

# save (compressed)
encoding = {}
for iVar in weights_ds.data_vars:
encoding[iVar] = COMP_ENCODING
weights_ds.to_netcdf(weights_filename, encoding=encoding)

os.remove(TEMP_WEIGHTS_F)

return True


def main():
parser = argparse.ArgumentParser(
description="Create an remapping weights to transfer runoff from the unmasked to masked cells in an ESMF mesh file."
)

parser.add_argument(
"--mesh_filename",
type=str,
required=True,
help="The path to the mesh file specifying the model grid.",
)
parser.add_argument(
"--weights_filename",
type=str,
required=True,
help="The path to the weights file to output (netcdf).",
)

args = parser.parse_args()
mesh_filename = os.path.abspath(args.mesh_filename)
weights_filename = os.path.abspath(args.weights_filename)

this_file = os.path.normpath(__file__)

# Add some info about how the file was generated
runcmd = f"python3 {os.path.basename(this_file)} --mesh-filename={mesh_filename} --weights_filename={weights_filename} "

global_attrs = {"history": get_provenance_metadata(this_file, runcmd)}

drof_remapping_weights(mesh_filename, weights_filename, global_attrs)


if __name__ == "__main__":
import argparse

main()
Loading