Skip to content

Commit

Permalink
Rename plot_transect with plot_cross_section (#65)
Browse files Browse the repository at this point in the history
* Rename plot_transect with plot_cross_section

* Define cross_section object and redefine transect

* Fix cross-section issues
  • Loading branch information
ghiggi authored Aug 28, 2024
1 parent 8ddf655 commit 1ddb9ec
Show file tree
Hide file tree
Showing 24 changed files with 527 additions and 381 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
- id: check-ast
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.7
rev: v0.6.2
hooks:
- id: ruff
args: [--fix]
Expand Down
43 changes: 33 additions & 10 deletions gpm/accessor/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,32 @@ def collocate(
decode_cf=decode_cf,
)

#### Transect utility
#### Transect/Trajectory utility

@auto_wrap_docstring
def extract_at_points(
self,
points,
method="nearest",
new_dim="points",
):
from gpm.utils.manipulations import extract_at_points

return extract_at_points(
self._obj,
points=points,
method=method,
new_dim=new_dim,
)

@auto_wrap_docstring
def extract_transect_between_points(
self,
start_point,
end_point,
steps=100,
method="linear",
new_dim="transect",
):
from gpm.utils.manipulations import extract_transect_between_points

Expand All @@ -256,6 +274,7 @@ def extract_transect_between_points(
end_point=end_point,
steps=steps,
method=method,
new_dim=new_dim,
)

@auto_wrap_docstring
Expand All @@ -266,6 +285,7 @@ def extract_transect_around_point(
distance,
steps=100,
method="linear",
new_dim="transect",
):
from gpm.utils.manipulations import extract_transect_around_point

Expand All @@ -276,20 +296,23 @@ def extract_transect_around_point(
distance=distance,
steps=steps,
method=method,
new_dim=new_dim,
)

@auto_wrap_docstring
def extract_transect_along_trajectory(
def extract_transect_at_points(
self,
points,
method="linear",
new_dim="transect",
):
from gpm.utils.manipulations import extract_transect_along_trajectory
from gpm.utils.manipulations import extract_transect_at_points

return extract_transect_along_trajectory(
return extract_transect_at_points(
self._obj,
points=points,
method=method,
new_dim=new_dim,
)

#### Range subset utility
Expand Down Expand Up @@ -841,7 +864,7 @@ def plot_image(
)

@auto_wrap_docstring
def plot_transect(
def plot_cross_section(
self,
variable,
ax=None,
Expand All @@ -854,9 +877,9 @@ def plot_transect(
cbar_kwargs=None,
**plot_kwargs,
):
from gpm.visualization.cross_section import plot_transect
from gpm.visualization.cross_section import plot_cross_section

return plot_transect(
return plot_cross_section(
self._obj[variable],
ax=ax,
x=x,
Expand Down Expand Up @@ -1053,7 +1076,7 @@ def plot_image(
)

@auto_wrap_docstring
def plot_transect(
def plot_cross_section(
self,
ax=None,
x=None,
Expand All @@ -1065,9 +1088,9 @@ def plot_transect(
cbar_kwargs=None,
**plot_kwargs,
):
from gpm.visualization.cross_section import plot_transect
from gpm.visualization.cross_section import plot_cross_section

return plot_transect(
return plot_cross_section(
self._obj,
ax=ax,
x=x,
Expand Down
101 changes: 81 additions & 20 deletions gpm/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,20 @@ def get_frequency_dimension(xr_obj):

def get_vertical_dimension(xr_obj):
"""Return the name of the available vertical dimension."""
return np.array(VERTICAL_DIMS)[np.isin(VERTICAL_DIMS, list(xr_obj.dims))].tolist()
vertical_dim = np.array(VERTICAL_DIMS)[np.isin(VERTICAL_DIMS, list(xr_obj.dims))].tolist()
if len(vertical_dim) > 1:
raise ValueError(f"Only one vertical dimension is allowed. Got {vertical_dim}.")
return vertical_dim


def get_spatial_dimensions(xr_obj):
"""Return the name of the available spatial dimensions."""
dims = list(xr_obj.dims)
flattened_spatial_dims = list(chain.from_iterable(SPATIAL_DIMS))
return np.array(flattened_spatial_dims)[np.isin(flattened_spatial_dims, dims)].tolist()
spatial_dimensions = np.array(flattened_spatial_dims)[np.isin(flattened_spatial_dims, dims)].tolist()
if len(spatial_dimensions) > 2:
raise ValueError(f"Only two horizontal spatial dimensions are allowed. Got {spatial_dimensions}.")
return spatial_dimensions


def _has_spatial_dim_dataarray(da, strict):
Expand Down Expand Up @@ -230,9 +236,8 @@ def _is_expected_spatial_dims(spatial_dims):
def is_orbit(xr_obj):
"""Check whether the xarray object is a GPM ORBIT.
An ORBIT transect or nadir view is considered ORBIT.
An ORBIT object must have the coordinates available !
An ORBIT cross-section (nadir view) or transect is considered ORBIT.
An ORBIT object must have the coordinates available.
"""
from gpm.dataset.crs import _get_swath_dim_coords

Expand Down Expand Up @@ -304,8 +309,8 @@ def _is_spatial_3d_dataarray(da, strict):
return True


def _is_transect_dataarray(da, strict):
"""Check if the xarray.DataArray is a spatial 3D array."""
def _is_cross_section_dataarray(da, strict):
"""Check if the xarray.DataArray is a cross-section array."""
spatial_dims = list(get_spatial_dimensions(da))
if len(spatial_dims) != 1:
return False
Expand All @@ -320,6 +325,16 @@ def _is_transect_dataarray(da, strict):
return True


def _is_transect_dataarray(da, strict):
"""Check if the xarray.DataArray is a transect array."""
spatial_dims = list(get_spatial_dimensions(da))
if len(spatial_dims) != 1:
return False
if strict and len(da.dims) != 1: # noqa
return False
return True


def _check_dataarrays_condition(condition, ds, strict):
if not ds: # Empty dataset (no variables)
return False
Expand All @@ -339,8 +354,13 @@ def _is_spatial_3d_dataset(ds, strict):
return _check_dataarrays_condition(_is_spatial_3d_dataarray, ds=ds, strict=strict)


def _is_cross_section_dataset(ds, strict):
"""Check if all xarray.DataArrays of a xarray.Dataset are cross-section objects."""
return _check_dataarrays_condition(_is_cross_section_dataarray, ds=ds, strict=strict)


def _is_transect_dataset(ds, strict):
"""Check if all xarray.DataArrays of a xarray.Dataset are transect objects."""
"""Check if all xarray.DataArrays of a xarray.Dataset are transects objects."""
return _check_dataarrays_condition(_is_transect_dataarray, ds=ds, strict=strict)


Expand Down Expand Up @@ -376,13 +396,31 @@ def is_spatial_3d(xr_obj, strict=True, squeeze=True):
)


def is_transect(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a transect object.
def is_cross_section(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a cross-section object.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just the
vertical dimension and a horizontal dimension.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions.
If ``strict=False`` , the xarray.DataArray can have additional dimensions but only
a single horizontal and vertical dimension.
"""
return _check_xarray_conditions(
_is_cross_section_dataarray,
_is_cross_section_dataset,
xr_obj=xr_obj,
strict=strict,
squeeze=squeeze,
)


def is_transect(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a transect object.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just an horizontal dimension.
If ``strict=False`` , the xarray.DataArray can have additional dimensions but only a single
horizontal dimension.
"""
return _check_xarray_conditions(
_is_transect_dataarray,
Expand Down Expand Up @@ -449,16 +487,29 @@ def check_is_spatial_3d(xr_obj, strict=True, squeeze=True):
raise ValueError("Expecting a 3D GPM field.")


def check_is_transect(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a transect.
def check_is_cross_section(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a cross-section.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just the
vertical dimension and a horizontal dimension.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions,
but only a single vertical and horizontal dimension.
"""
if not is_cross_section(xr_obj, strict=strict, squeeze=squeeze):
raise ValueError("Expecting a cross-section extracted from a 3D GPM field.")


def check_is_transect(xr_obj, strict=True, squeeze=True):
"""Check if the xarray.DataArray or xarray.Dataset is a transect.
If ``squeeze=True`` (default), dimensions of size=1 are removed prior testing.
If ``strict=True`` (default), the xarray.DataArray must have just an horizontal dimension.
If ``strict=False`` , the xarray.DataArray can also have additional dimensions,
but only an horizontal dimension.
"""
if not is_transect(xr_obj, strict=strict, squeeze=squeeze):
raise ValueError("Expecting a transect of a 3D GPM field.")
raise ValueError("Expecting a transect object.")


def check_has_vertical_dim(xr_obj, strict=False, squeeze=True):
Expand Down Expand Up @@ -523,16 +574,26 @@ def get_spatial_3d_variables(ds, strict=False, squeeze=True):
return sorted(variables)


def get_transect_variables(ds, strict=False, squeeze=True):
"""Get list of xarray.Dataset trasect variables.
def get_cross_section_variables(ds, strict=False, squeeze=True):
"""Get list of xarray.Dataset cross-section variables.
If ``strict=False`` (default), the potential variables for which a transect can be derived.
If ``strict=True``, the variables that are already provide a transect.
If ``strict=False`` (default), the potential variables for which a strict cross-section can be derived.
If ``strict=True``, the variables that are already a cross-section.
"""
variables = [var for var in get_dataset_variables(ds) if is_transect(ds[var], strict=strict, squeeze=squeeze)]
variables = [var for var in get_dataset_variables(ds) if is_cross_section(ds[var], strict=strict, squeeze=squeeze)]
return sorted(variables)


# def get_transect_variables(ds, strict=False, squeeze=True):
# """Get list of xarray.Dataset transect variables.

# If ``strict=False`` (default), the potential variables for which a strict transect can be derived.
# If ``strict=True``, the variables that are already a transect.
# """
# variables = [var for var in get_dataset_variables(ds) if is_transect(ds[var], strict=strict, squeeze=squeeze)]
# return sorted(variables)


def get_vertical_variables(ds):
"""Get list of xarray.Dataset variables with vertical dimension."""
variables = [var for var in get_dataset_variables(ds) if has_vertical_dim(ds[var], strict=False, squeeze=True)]
Expand Down
3 changes: 2 additions & 1 deletion gpm/dataset/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@
}

SPATIAL_DIMS = [
["transect"],
["along_track", "cross_track"],
["lat", "lon"],
["latitude", "longitude"],
["x", "y"], # compatibility with satpy/gpm_geo i.e.
["transect"],
["trajectory"],
["beam"], # when stacking 2D spatial dims
["pixel"], # when stacking 2D spatial dims
]
Expand Down
3 changes: 1 addition & 2 deletions gpm/io/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def run(commands, n_threads=10, progress_bar=True, verbose=True):
"""
from tqdm import tqdm

if n_threads < 1:
n_threads = 1
n_threads = max(n_threads, 1)
n_threads = min(n_threads, 10)
n_cmds = len(commands)

Expand Down
Loading

0 comments on commit 1ddb9ec

Please sign in to comment.