Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
kthyng committed Oct 25, 2024
1 parent a2c14a1 commit 1fd135a
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions extract_model/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,9 @@ def select(
* False: 2D array of points with 1 dimension the lons and the other dimension the lats.
* True: lons/lats as unstructured coordinate pairs (in xESMF language, LocStream).
locstreamT: boolean, optional
If False, interpolate in time dimension independently of horizontal points. If True, use advanced indexing/interpolation in xarray to interpolate times to each horizontal locstream point. If this is True, locstream must be True.
If False, interpolate in time dimension independently of horizontal points. If True, use advanced indexing/interpolation in xarray to interpolate times to each horizontal locstream point.
locstreamZ: boolean, optional
If False, interpolate in depth dimension independently of horizontal points. If True, use advanced indexing after depth interpolation select depths to match each horizontal locstream point. If this is True, locstream must be True and locstreamT must be True.
If False, interpolate in depth dimension independently of horizontal points. If True, use advanced indexing after depth interpolation select depths to match each horizontal locstream point.
new_dim : str
This is the name of the new dimension created if we are interpolating to a new set of points that are not a grid.
weights: xESMF netCDF file path, DataArray, optional
Expand Down Expand Up @@ -360,14 +360,15 @@ def select(
"Use extrap=True to extrapolate."
)

if locstreamT:
if not locstream:
raise ValueError("if `locstreamT` is True, `locstream` must also be True.")
if locstreamZ:
if not locstream or not locstreamT:
raise ValueError(
"if `locstreamZ` is True, `locstream` and `locstreamT` must also be True."
)
# these are only true if interpolating in those directions too — need to fix them
# if locstreamT:
# if not locstream:
# raise ValueError("if `locstreamT` is True, `locstream` must also be True.")
# if locstreamZ:
# if not locstream or not locstreamT:
# raise ValueError(
# "if `locstreamZ` is True, `locstream` and `locstreamT` must also be True."
# )

# Perform interpolation
if horizontal_interp:
Expand Down Expand Up @@ -443,13 +444,12 @@ def select(
xs, ys = proj(xs, ys)
x, y = proj(longitude, latitude)

# import pdb; pdb.set_trace()
# lam = calc_barycentric(x, y, xs.reshape((10,9,3)), ys.reshape((10,9,3)))
lam = calc_barycentric(x.flatten(), y.flatten(), xs, ys)
# lam = calc_barycentric(x, y, xs, ys)
# interp_coords are the coords and indices that went into the interpolation
da, interp_coords = interp_with_barycentric(da, ixs, iys, lam)
# import pdb; pdb.set_trace()

# if not locstream:
# FIGURE OUT HOW TO RECONSTITUTE INTO GRID HERE
kwargs_out["interp_coords"] = interp_coords
Expand Down Expand Up @@ -665,6 +665,7 @@ def pt_in_itriangle_proj(ix, iy):

# advanced indexing to select all assuming coherent time series
# make sure len of each dimension matches

if locstreamZ:

dims_to_index = [da.cf["T"].name]
Expand Down Expand Up @@ -809,7 +810,6 @@ def sel2d(
mask = mask.load()

# Assume mask is 2D — but not true for wetting/drying
# import pdb; pdb.set_trace()
# find indices representing mask
eta, xi = np.where(mask.values)

Expand Down Expand Up @@ -898,6 +898,10 @@ def sel2d(

else:

# make sure the mask matches
msg = f"Mask {mask.name} dimensions do not match horizontal var {var.name} dimensions. mask dims: {mask.dims}, var dims: {var.dims}"
assert len(set(mask.dims) - set(var.dims)) == 0, msg

# currently lons, lats 1D only

# if no mask, assume user just wants 1 nearest point to each input lons/lats pair
Expand All @@ -907,7 +911,7 @@ def sel2d(
# if user inputs mask, use it to only return the nearest point that is active
# so, find nearest 30 points to have options
else:
k = 30
k = 50

distances, (iys, ixs) = tree_query(var[lonname], var[latname], lons, lats, k=k)

Expand All @@ -916,7 +920,7 @@ def sel2d(
raise ValueError("all found values are masked!")

if mask is not None:
isorted_mask = np.argsort(-mask.values[iys, ixs], axis=-1)
isorted_mask = np.argsort(-mask.values[iys, ixs], axis=-1, kind="mergesort")
# sort the ixs and iys according to this sorting so that if there are unmasked indices,
# they are leftmost also, and we will use the leftmost values.
ixs_brought_along = np.take_along_axis(ixs, isorted_mask, axis=1)
Expand Down

0 comments on commit 1fd135a

Please sign in to comment.