|
| 1 | +# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# ========================================================================================= |
| 5 | +# Generate an remapping weights between two ESMF mesh files for remapping a runoff field |
| 6 | +# an unmasked mesh to a masked mesh without losing any water volume. Each field on the |
| 7 | +# unmasked mesh is mapped to the nearest ocean cell in the resulting weights. |
| 8 | +# |
| 9 | +# To run: |
| 10 | +# python generate_rof_weights.py --mesh_filename=<input_file> --weights_filename=<output_file> |
| 11 | +# |
| 12 | +# This script currently supports mesh files in the ESMF unstructed mesh format. |
| 13 | +# |
| 14 | +# There is not enough memory on the gadi login node to run this, its simplest to run in |
| 15 | +# a terminal through are.nci.org.au |
| 16 | +# |
| 17 | +# The run command and full github url of the current version of this script is added to the |
| 18 | +# metadata of the generated weights file. This is to uniquely identify the script and inputs used |
| 19 | +# to generate the mesh file. To produce weights files for sharing, ensure you are using a version |
| 20 | +# of this script which is committed and pushed to github. For mesh files intended for released |
| 21 | +# configurations, use the latest version checked in to the main branch of the github repository. |
| 22 | +# |
| 23 | +# Contact: |
| 24 | +# Anton Steketee <[email protected]> |
| 25 | +# |
| 26 | +# Dependencies: |
| 27 | +# esmpy, xarray and scipy |
| 28 | +# ========================================================================================= |
| 29 | + |
| 30 | + |
| 31 | +import xarray as xr |
| 32 | +import esmpy |
| 33 | +from sklearn.neighbors import BallTree |
| 34 | +from numpy import deg2rad |
| 35 | +from copy import copy |
| 36 | + |
| 37 | +from pathlib import Path |
| 38 | +import sys |
| 39 | +import os |
| 40 | +from datetime import datetime |
| 41 | + |
| 42 | +path_root = Path(__file__).parents[1] |
| 43 | +sys.path.append(str(path_root)) |
| 44 | +from scripts_common import get_provenance_metadata, md5sum |
| 45 | + |
| 46 | +TEMP_WEIGHTS_F = "temp_weights.nc" |
| 47 | +COMP_ENCODING = {"complevel": 1, "compression": "zlib"} # compression settings to use |
| 48 | + |
| 49 | + |
| 50 | +def drof_remapping_weights(mesh_filename, weights_filename, global_attrs=None): |
| 51 | + # We need to generate remapping weights for use in the mediator, such that the overall volume of runoff is conserved and no runoff is mapped onto land cells. Inside the mediator, the grid doesn't change as we run the mediator with the ocean grid (the DROF component does the remapping from JRA grid to mediator grid). There we use the same _mesh_file for the input and output mesh, however this same routine would work for differing input and output meshes |
| 52 | + |
| 53 | + model_mesh = esmpy.Mesh( |
| 54 | + filename=mesh_filename, |
| 55 | + filetype=esmpy.FileFormat.ESMFMESH, |
| 56 | + ) |
| 57 | + |
| 58 | + med_in_fld = esmpy.Field(model_mesh, meshloc=esmpy.MeshLoc.ELEMENT) |
| 59 | + |
| 60 | + med_out_fld = esmpy.Field(model_mesh, meshloc=esmpy.MeshLoc.ELEMENT) |
| 61 | + |
| 62 | + try: |
| 63 | + os.remove(TEMP_WEIGHTS_F) # rm old temp file |
| 64 | + except OSError: |
| 65 | + pass |
| 66 | + |
| 67 | + # Generate remapping weights and write to file. |
| 68 | + esmpy.Regrid( |
| 69 | + med_in_fld, |
| 70 | + med_out_fld, |
| 71 | + filename=TEMP_WEIGHTS_F, |
| 72 | + regrid_method=esmpy.RegridMethod.CONSERVE, |
| 73 | + # unmapped_action=esmpy.UnmappedAction.ERROR, #ignore errors about some destination cells not having source cells, |
| 74 | + ) |
| 75 | + |
| 76 | + """ |
| 77 | + From https://earthsystemmodeling.org/docs/release/ESMF_5_2_0rp3/ESMF_refdoc/node3.html : |
| 78 | +
|
| 79 | + " The indices and weights generated by ESMF_FieldRegridStore() are stored in the output file as variables col, row and S. Where col and row are the indices to the source and the destination grid cells. These are a one-dimension array with length defined by dimension n_s. S is the weight which is multiplied by the source value indicated by col and then summed with the destination value indicated by row to build the final interpolated value of the destination. |
| 80 | +
|
| 81 | + Per the above note, we want to adjust all row values, so they are ocean cells. When we do this, we want to adjust S, the weight to account for the difference in area. |
| 82 | + """ |
| 83 | + |
| 84 | + weights_ds = xr.open_dataset(TEMP_WEIGHTS_F) |
| 85 | + |
| 86 | + mod_mesh_ds = xr.open_dataset(mesh_filename) |
| 87 | + |
| 88 | + # Find index for all ocean cells |
| 89 | + mask_i = mod_mesh_ds.elementCount.where(mod_mesh_ds.elementMask, drop=True).astype( |
| 90 | + "int" |
| 91 | + ) |
| 92 | + |
| 93 | + # Haversine distances expect lat first, lon second, so index coordDim backwards |
| 94 | + center_coords_rad = deg2rad(mod_mesh_ds.centerCoords.isel(coordDim=[1, 0])) |
| 95 | + |
| 96 | + # Make a BallTree from the ocean cells |
| 97 | + mask_tree = BallTree( |
| 98 | + center_coords_rad.isel(elementCount=mask_i), metric="haversine" |
| 99 | + ) |
| 100 | + |
| 101 | + # Using the Tree, look up the nearest ocean cell to every destination grid cell in our weights file. Note our weights are indexed from 1 (i.e. Fortran style) but xarray starts from 0 (i.e. python style), so subract one from our destination grid cell indices. |
| 102 | + |
| 103 | + ii = mask_tree.query( |
| 104 | + center_coords_rad.isel(elementCount=(weights_ds.row - 1)), return_distance=False |
| 105 | + ) |
| 106 | + |
| 107 | + new_row = mask_i[ii[:, 0]] + 1 |
| 108 | + |
| 109 | + # Get the mesh element areas and adjust: |
| 110 | + # n.b. per CMEPS we are using the internally calculated areas, not the user provided ones. |
| 111 | + med_out_fld.get_area() |
| 112 | + area = copy(med_out_fld.data) |
| 113 | + old_area = area[weights_ds.row - 1] |
| 114 | + new_area = area[new_row - 1] |
| 115 | + |
| 116 | + weights_ds["row"] = xr.DataArray(data=new_row, dims="n_s") |
| 117 | + |
| 118 | + weights_ds["S"] = weights_ds.S * old_area / new_area |
| 119 | + |
| 120 | + # add global attributes |
| 121 | + weights_ds.attrs = { |
| 122 | + "gridType": "unstructured mesh", |
| 123 | + "inputFile": f"{mesh_filename} (md5 hash: {md5sum(mesh_filename)})", |
| 124 | + } |
| 125 | + |
| 126 | + # add git info to history |
| 127 | + if global_attrs: |
| 128 | + weights_ds.attrs |= global_attrs |
| 129 | + |
| 130 | + # save (compressed) |
| 131 | + encoding = {} |
| 132 | + for iVar in weights_ds.data_vars: |
| 133 | + encoding[iVar] = COMP_ENCODING |
| 134 | + weights_ds.to_netcdf(weights_filename, encoding=encoding) |
| 135 | + |
| 136 | + os.remove(TEMP_WEIGHTS_F) |
| 137 | + |
| 138 | + return True |
| 139 | + |
| 140 | + |
| 141 | +def main(): |
| 142 | + parser = argparse.ArgumentParser( |
| 143 | + description="Create an remapping weights to transfer runoff from unmasked mesh to masked mesh using ESMF mesh file." |
| 144 | + ) |
| 145 | + |
| 146 | + parser.add_argument( |
| 147 | + "--mesh_filename", |
| 148 | + type=str, |
| 149 | + required=True, |
| 150 | + help="The path to the mesh file specifying the model grid and land mask.", |
| 151 | + ) |
| 152 | + parser.add_argument( |
| 153 | + "--weights_filename", |
| 154 | + type=str, |
| 155 | + required=True, |
| 156 | + help="The path to the weights file to output (netcdf).", |
| 157 | + ) |
| 158 | + |
| 159 | + args = parser.parse_args() |
| 160 | + mesh_filename = os.path.abspath(args.mesh_filename) |
| 161 | + weights_filename = os.path.abspath(args.weights_filename) |
| 162 | + |
| 163 | + this_file = os.path.normpath(__file__) |
| 164 | + |
| 165 | + # Add some info about how the file was generated |
| 166 | + runcmd = f"python3 {os.path.basename(this_file)} --mesh-filename={mesh_filename} --weights_filename={weights_filename} " |
| 167 | + |
| 168 | + global_attrs = {"history": get_provenance_metadata(this_file, runcmd)} |
| 169 | + |
| 170 | + drof_remapping_weights(mesh_filename, weights_filename, global_attrs) |
| 171 | + |
| 172 | + return True |
| 173 | + |
| 174 | + |
| 175 | +if __name__ == "__main__": |
| 176 | + import argparse |
| 177 | + |
| 178 | + main() |
0 commit comments