Skip to content

Commit aa5a819

Browse files
committed
precommit
1 parent 3147116 commit aa5a819

File tree

8 files changed

+217
-130
lines changed

8 files changed

+217
-130
lines changed

docs/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
:mod:`What's New`
22
-----------------
33

4+
v1.2.0 (September 13, 2023)
5+
===========================
6+
* Improvements to interpolation
7+
8+
49
v1.1.4 (January 27, 2023)
510
=========================
611
* fixed docs to run fully

extract_model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import extract_model.accessor # noqa: F401
1010

1111
from .extract_model import sel2d, sel2dcf, select, selZ # noqa: F401
12-
from .utils import filter, order, sub_bbox, sub_grid, guess_model_type # noqa: F401
1312
from .preprocessing import preprocess
13+
from .utils import filter, guess_model_type, order, sub_bbox, sub_grid # noqa: F401
1414

1515

1616
try:

extract_model/extract_model.py

Lines changed: 134 additions & 85 deletions
Large diffs are not rendered by default.

extract_model/preprocessing.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
"""Preprocessing-related functions for model output."""
22

33

4+
from typing import Optional
5+
46
import numpy as np
57
import xarray as xr
6-
from typing import Optional
7-
from .utils import order, guess_model_type
8+
89
from extract_model.model_type import ModelType
910

11+
from .utils import guess_model_type, order
1012

11-
def preprocess_roms(ds, grid=None,):
13+
14+
def preprocess_roms(
15+
ds,
16+
grid=None,
17+
):
1218
"""Preprocess ROMS model output for use with cf-xarray.
1319
1420
Also fixes any other known issues with model output.
@@ -24,7 +30,7 @@ def preprocess_roms(ds, grid=None,):
2430
-------
2531
Same Dataset but with some metadata added and/or altered.
2632
"""
27-
33+
2834
rename = {}
2935
if "eta_u" in ds.dims:
3036
rename["eta_u"] = "eta_rho"
@@ -36,7 +42,6 @@ def preprocess_roms(ds, grid=None,):
3642
rename["eta_psi"] = "eta_v"
3743
ds = ds.rename(rename)
3844

39-
4045
# add axes attributes for dimensions
4146
dims = [dim for dim in ds.dims if dim.startswith("s_")]
4247
for dim in dims:
@@ -125,44 +130,48 @@ def preprocess_roms(ds, grid=None,):
125130
# }
126131

127132
ds.coords["z_rho"] = order(ds["z_rho"])
128-
ds.coords["z_rho_u"] = grid.interp(ds.z_rho.chunk({ds.z_rho.cf["X"].name: -1}), "X")
133+
ds.coords["z_rho_u"] = grid.interp(
134+
ds.z_rho.chunk({ds.z_rho.cf["X"].name: -1}), "X"
135+
)
129136
ds.coords["z_rho_u"].attrs = {
130137
"long_name": "depth of U-points on vertical RHO grid",
131138
"time": "ocean_time",
132139
"field": "z_rho_u, scalar, series",
133140
"units": "m",
134141
}
135142

136-
ds.coords["z_rho_v"] = grid.interp(ds.z_rho.chunk({ds.z_rho.cf["Y"].name: -1}), "Y")
143+
ds.coords["z_rho_v"] = grid.interp(
144+
ds.z_rho.chunk({ds.z_rho.cf["Y"].name: -1}), "Y"
145+
)
137146
ds.coords["z_rho_v"].attrs = {
138147
"long_name": "depth of V-points on vertical RHO grid",
139148
"time": "ocean_time",
140149
"field": "z_rho_v, scalar, series",
141150
"units": "m",
142151
}
143152

144-
ds.coords["z_rho_psi"] = grid.interp(ds.z_rho_u.chunk({ds.z_rho_u.cf["Y"].name: -1}), "Y")
153+
ds.coords["z_rho_psi"] = grid.interp(
154+
ds.z_rho_u.chunk({ds.z_rho_u.cf["Y"].name: -1}), "Y"
155+
)
145156
ds.coords["z_rho_psi"].attrs = {
146157
"long_name": "depth of PSI-points on vertical RHO grid",
147158
"time": "ocean_time",
148159
"field": "z_rho_psi, scalar, series",
149160
"units": "m",
150161
}
151-
152-
# will use this to update coordinate encoding
153-
name_dict.update({"filler1": "z_rho_u", "filler2": "z_rho_v", "filler3": "z_rho_psi"})#, "None": "z_w_u", "None": "z_w_v", "None": "z_w_psi"})
154-
155162

163+
# will use this to update coordinate encoding
164+
name_dict.update(
165+
{"filler1": "z_rho_u", "filler2": "z_rho_v", "filler3": "z_rho_psi"}
166+
) # , "None": "z_w_u", "None": "z_w_v", "None": "z_w_psi"})
156167

157168
# fix attrs
158169
# for zname in ["z_rho", "z_w"]:
159170
for zname in [var for var in ds.coords if "z_rho" in var or "z_w" in var]:
160171
if zname in ds:
161172
ds[
162173
zname
163-
].attrs = (
164-
{}
165-
) # coord inherits from one of the vars going into calculation
174+
].attrs = {} # coord inherits from one of the vars going into calculation
166175
ds[zname].attrs["positive"] = "up"
167176
ds[zname].attrs["units"] = "m"
168177
ds[zname] = order(ds[zname])
@@ -224,6 +233,7 @@ def preprocess_roms(ds, grid=None,):
224233
def preprocess_roms_grid(ds):
225234
# use xgcm
226235
from xgcm import Grid
236+
227237
coords = {
228238
"X": {"center": "xi_rho", "inner": "xi_u"},
229239
"Y": {"center": "eta_rho", "inner": "eta_v"},

extract_model/utils.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import dask
1010
import numpy as np
1111
import xarray as xr
12+
1213
from sklearn.neighbors import BallTree
14+
1315
from extract_model.grids.triangular_mesh import UnstructuredGridSubset
1416
from extract_model.model_type import ModelType
1517

@@ -482,10 +484,15 @@ def order(da):
482484
)
483485

484486

485-
def tree_query(lon_coords: xr.DataArray, lat_coords: xr.DataArray,
486-
lons_to_find: np.array, lats_to_find: np.array, k: int = 3) -> Tuple[np.array]:
487+
def tree_query(
488+
lon_coords: xr.DataArray,
489+
lat_coords: xr.DataArray,
490+
lons_to_find: np.array,
491+
lats_to_find: np.array,
492+
k: int = 3,
493+
) -> Tuple[np.array]:
487494
"""Set up and query BallTree for k nearest points
488-
495+
489496
Uses haversine for the metric because we are dealing with lon/lat coordinates.
490497
491498
Parameters
@@ -505,38 +512,40 @@ def tree_query(lon_coords: xr.DataArray, lat_coords: xr.DataArray,
505512
-------
506513
Tuple[np.array]
507514
distances, (iys, ixs) 2D indices for coordinates
508-
515+
509516
Notes
510517
-----
511518
Reference: https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.BallTree.html
512519
"""
513-
520+
514521
# create tree
515522
coords = [lon_coords, lat_coords]
516523
X = np.stack([np.ravel(c) for c in coords]).T
517-
tree = BallTree(np.deg2rad(X), metric='haversine')
524+
tree = BallTree(np.deg2rad(X), metric="haversine")
518525

519526
# set up coordinates we want to search for
520527
coords_to_find = [lons_to_find, lats_to_find]
521528
X_to_find = np.stack([np.ravel(c) for c in coords_to_find]).T
522-
529+
523530
# query tree
524531
distances, inds = tree.query(np.deg2rad(X_to_find), k=k)
525-
532+
526533
# convert flat indies to 2D indices
527534
iys, ixs = np.unravel_index(inds, lon_coords.shape)
528-
535+
529536
return distances, (iys, ixs)
530537

531538

532-
def calc_barycentric(x: np.array, y: np.array, xs: np.array, ys: np.array) -> xr.DataArray:
539+
def calc_barycentric(
540+
x: np.array, y: np.array, xs: np.array, ys: np.array
541+
) -> xr.DataArray:
533542
"""Calculate barycentric weights for npts
534-
543+
535544
Parameters
536545
----------
537-
x
546+
x
538547
npts x 1 vector of x locations, can be in lon or projection coordinates.
539-
y
548+
y
540549
npts x 1 vector of y locations, can be in lat or projection coordinates.
541550
xs
542551
npts x 3 array of triangle x vertices with which to calculate the barycentric weights for each of npts
@@ -550,19 +559,32 @@ def calc_barycentric(x: np.array, y: np.array, xs: np.array, ys: np.array) -> xr
550559
"""
551560
# barycentric weights
552561
# npts x 1 (vectors)
553-
L1 = ( (ys[:,1] - ys[:,2])*(x[:] - xs[:,2]) + (xs[:,2] - xs[:,1])*(y[:] - ys[:,2]) )/ \
554-
( (ys[:,1] - ys[:,2])*(xs[:,0] - xs[:,2]) + (xs[:,2] - xs[:,1])*(ys[:,0] - ys[:,2]) )
555-
L2 = ( (ys[:,2] - ys[:,0])*(x[:] - xs[:,2]) + (xs[:,0] - xs[:,2])*(y[:] - ys[:,2]) )/ \
556-
( (ys[:,1] - ys[:,2])*(xs[:,0] - xs[:,2]) + (xs[:,2] - xs[:,1])*(ys[:,0] - ys[:,2]))
562+
L1 = (
563+
(ys[:, 1] - ys[:, 2]) * (x[:] - xs[:, 2])
564+
+ (xs[:, 2] - xs[:, 1]) * (y[:] - ys[:, 2])
565+
) / (
566+
(ys[:, 1] - ys[:, 2]) * (xs[:, 0] - xs[:, 2])
567+
+ (xs[:, 2] - xs[:, 1]) * (ys[:, 0] - ys[:, 2])
568+
)
569+
L2 = (
570+
(ys[:, 2] - ys[:, 0]) * (x[:] - xs[:, 2])
571+
+ (xs[:, 0] - xs[:, 2]) * (y[:] - ys[:, 2])
572+
) / (
573+
(ys[:, 1] - ys[:, 2]) * (xs[:, 0] - xs[:, 2])
574+
+ (xs[:, 2] - xs[:, 1]) * (ys[:, 0] - ys[:, 2])
575+
)
557576
L3 = 1 - L1 - L2
558577

559-
lam = xr.DataArray(dims=("npts","triangle"), data=np.vstack((L1, L2, L3)).T)
560-
578+
lam = xr.DataArray(dims=("npts", "triangle"), data=np.vstack((L1, L2, L3)).T)
579+
561580
return lam
562581

563582

564583
def interp_with_barycentric(da, ixs, iys, lam):
565-
vector = da.cf.isel(X=xr.DataArray(ixs, dims=("npts","triangle")), Y=xr.DataArray(iys, dims=("npts","triangle")))
584+
vector = da.cf.isel(
585+
X=xr.DataArray(ixs, dims=("npts", "triangle")),
586+
Y=xr.DataArray(iys, dims=("npts", "triangle")),
587+
)
566588
with xr.set_options(keep_attrs=True):
567589
da = xr.dot(vector, lam, dims=("triangle"))
568590

@@ -576,7 +598,7 @@ def interp_with_barycentric(da, ixs, iys, lam):
576598

577599
# add vertical coords into da
578600
da = da.assign_coords({zkey: da_vert})
579-
601+
580602
# add "X" axis to npts
581603
da["npts"] = ("npts", da.npts.values, {"axis": "X"})
582604

tests/grids/test_triangular_mesh.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
import pytest
88
import xarray as xr
99

10-
from extract_model import utils
11-
from extract_model import preprocessing
10+
from extract_model import preprocessing, utils
1211
from extract_model.grids.triangular_mesh import UnstructuredGridSubset
1312

1413

tests/test_accessor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def test_2dsel():
4646
assert np.allclose(da_sel2d_check.squeeze(), da_check)
4747

4848
da_test, kwargs_out = da.em.sel2dcf(
49-
longitude=lon_comp, latitude=lat_comp, return_info=True, #distances_name="distance"
49+
longitude=lon_comp,
50+
latitude=lat_comp,
51+
return_info=True, # distances_name="distance"
5052
)
5153
assert np.allclose(da_sel2d[varname], da_test[varname])
5254
assert np.allclose(kwargs_out_sel2d_acc_check["distances"], kwargs_out["distances"])

tests/test_em.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def test_hor_interp_no_xesmf():
6565
XESMF_AVAILABLE = em.extract_model.XESMF_AVAILABLE
6666
em.extract_model.XESMF_AVAILABLE = False
6767
with pytest.raises(ModuleNotFoundError):
68-
em.select(da, longitude=longitude, latitude=latitude, T=0.5, horizontal_interp=True)
68+
em.select(
69+
da, longitude=longitude, latitude=latitude, T=0.5, horizontal_interp=True
70+
)
6971
em.extract_model.XESMF_AVAILABLE = XESMF_AVAILABLE
7072

7173

@@ -95,7 +97,7 @@ def test_sel2d(model):
9597
inputs = {
9698
da.cf["longitude"].name: lon_comp,
9799
da.cf["latitude"].name: lat_comp,
98-
"return_info": True
100+
"return_info": True,
99101
}
100102
da_sel2d, kwargs_out = em.sel2d(da, **inputs)
101103
da_check = da.cf.isel(X=i, Y=j)
@@ -369,9 +371,7 @@ def test_sel2d_simple_2D():
369371
assert ds_out.lon == 3
370372
assert ds_out.lat == 7
371373

372-
ds_outcf = em.sel2dcf(
373-
ds, longitude=0, latitude=4, mask=mask
374-
)
374+
ds_outcf = em.sel2dcf(ds, longitude=0, latitude=4, mask=mask)
375375
assert ds_out.coords == ds_outcf.coords
376376

377377
# if distance_name=None, no distance returned

0 commit comments

Comments
 (0)