Skip to content

Commit

Permalink
Update precommit and fix for release (#62)
Browse files Browse the repository at this point in the history
* Update precommit

* Fix lint complaints

* Add fix for FillValue decoding in xr.decode_cf with numpy 2.0

* Update coverage settings
  • Loading branch information
ghiggi authored Aug 17, 2024
1 parent e87961e commit 88a1db7
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 63 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ omit =
gpm/visualization/animation.py
gpm/utils/pyresample.py
gpm/utils/collocation.py
gpm/utils/gv.py
gpm/utils/zonal_stats.py
gpm/xradar/*
gpm/_version.py

[report]
Expand Down
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -12,12 +12,12 @@ repos:
- id: check-ast
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.5
rev: v0.5.7
hooks:
- id: ruff
args: [--fix]
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 24.8.0
hooks:
- id: black
language_version: python3
Expand All @@ -27,18 +27,18 @@ repos:
- id: blackdoc
additional_dependencies: ["black[jupyter]"]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: "v3.1.0"
rev: "v4.0.0-alpha.8"
hooks:
- id: prettier
types_or: [yaml, html, css, scss, javascript, json] # markdown to avoid conflicts with mdformat
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
types_or: [python, markdown, rst]
additional_dependencies: [tomli]
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.2
rev: v3.17.0
hooks:
- id: pyupgrade
- repo: https://github.com/MarcoGorelli/madforhooks
Expand All @@ -58,7 +58,7 @@ repos:
- id: nbstripout
args: [--keep-output]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.5
rev: 1.8.7
hooks:
- id: nbqa-black
- id: nbqa-ruff
Expand Down
49 changes: 16 additions & 33 deletions gpm/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def _has_vertical_dim_dataarray(da, strict):
vertical_dims = list(get_vertical_dimension(da))
if not vertical_dims:
return False
if strict and len(da.dims) != 1: # only the vertical dim
only_vertical_dim = len(da.dims) == 1
if strict and not only_vertical_dim: # noqa
return False
return True

Expand All @@ -104,7 +105,8 @@ def _has_frequency_dim_dataarray(da, strict):
frequency_dims = list(get_frequency_dimension(da))
if not frequency_dims:
return False
if strict and len(da.dims) != 1: # only the frequency dimension
only_frequency_dim = len(da.dims) == 1
if strict and not only_frequency_dim: # noqa
return False
return True

Expand All @@ -114,29 +116,23 @@ def _has_vertical_dim_dataset(ds, strict):
has_vertical = np.any(
[_has_vertical_dim_dataarray(ds[var], strict=strict) for var in get_dataset_variables(ds)],
).item()
if has_vertical:
return True
return False
return bool(has_vertical)


def _has_spatial_dim_dataset(ds, strict):
"""Check if at least one xarray.DataArrays of a xarray.Dataset have at least one spatial dimension."""
has_spatial = np.any(
[_has_spatial_dim_dataarray(ds[var], strict=strict) for var in get_dataset_variables(ds)],
).item()
if has_spatial:
return True
return False
return bool(has_spatial)


def _has_frequency_dim_dataset(ds, strict):
"""Check if at least one xarray.DataArrays of a xarray.Dataset have a frequency dimension."""
has_spatial = np.any(
[_has_frequency_dim_dataarray(ds[var], strict=strict) for var in get_dataset_variables(ds)],
).item()
if has_spatial:
return True
return False
return bool(has_spatial)


def _check_xarray_conditions(da_condition, ds_condition, xr_obj, strict, squeeze):
Expand Down Expand Up @@ -207,9 +203,7 @@ def _is_grid_expected_spatial_dims(spatial_dims):
is_grid = set(spatial_dims) == set(GRID_SPATIAL_DIMS)
is_lonlat = set(spatial_dims) == {"latitude", "longitude"}
is_xy = set(spatial_dims) == {"y", "x"}
if is_grid or is_lonlat or is_xy:
return True
return False
return bool(is_grid or is_lonlat or is_xy)


def _is_orbit_expected_spatial_dims(spatial_dims):
Expand All @@ -223,19 +217,14 @@ def _is_orbit_expected_spatial_dims(spatial_dims):
# Check if spatial_dims is a non-empty subset of ORBIT_SPATIAL_DIMS
is_orbit = set(spatial_dims).issubset(ORBIT_SPATIAL_DIMS) and bool(spatial_dims)
is_xy = set(spatial_dims).issubset({"y", "x"}) and bool(spatial_dims)

if is_orbit or is_xy:
return True
return False
return bool(is_orbit or is_xy)


def _is_expected_spatial_dims(spatial_dims):
"""Check that the spatial_dims are the expected two."""
is_orbit = _is_orbit_expected_spatial_dims(spatial_dims)
is_grid = _is_grid_expected_spatial_dims(spatial_dims)
if is_orbit or is_grid:
return True
return False
return bool(is_orbit or is_grid)


def is_orbit(xr_obj):
Expand All @@ -255,9 +244,7 @@ def is_orbit(xr_obj):
# Check that swath coords exists
# - Swath objects are determined by 1D (nadir looking) and 2D coordinates
x_coord, y_coord = _get_swath_dim_coords(xr_obj)
if x_coord is not None and y_coord is not None:
return True
return False
return bool(x_coord is not None and y_coord is not None)


def is_grid(xr_obj):
Expand All @@ -278,9 +265,7 @@ def is_grid(xr_obj):
# - 1D coordinates: projection coordinates
# - 2D coordinates: lon/lat coordinates of each pixel
x_coord, y_coord = _get_proj_dim_coords(xr_obj)
if x_coord is not None and y_coord is not None:
return True
return False
return bool(x_coord is not None and y_coord is not None)


####-------------------------------------------------------------------------------------
Expand All @@ -298,7 +283,7 @@ def _is_spatial_2d_dataarray(da, strict):
vertical_dims = get_vertical_dimension(da)
if vertical_dims:
return False
if strict and len(da.dims) != 2:
if strict and len(da.dims) != 2: # noqa
return False

return True
Expand All @@ -313,7 +298,7 @@ def _is_spatial_3d_dataarray(da, strict):
vertical_dims = get_vertical_dimension(da)
if not vertical_dims:
return False
if strict and len(da.dims) != 3:
if strict and len(da.dims) != 3: # noqa
return False

return True
Expand All @@ -329,7 +314,7 @@ def _is_transect_dataarray(da, strict):
if not vertical_dims:
return False

if strict and len(da.dims) != 2:
if strict and len(da.dims) != 2: # noqa
return False

return True
Expand All @@ -341,9 +326,7 @@ def _check_dataarrays_condition(condition, ds, strict):
all_valid = np.all(
[condition(ds[var], strict=strict) for var in get_dataset_variables(ds)],
)
if all_valid.item():
return True
return False
return bool(all_valid)


def _is_spatial_2d_dataset(ds, strict):
Expand Down
1 change: 1 addition & 0 deletions gpm/dataset/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def finalize_dataset(ds, product, decode_cf, scan_mode, start_time=None, end_tim

##------------------------------------------------------------------------.
# Decode dataset
# - With numpy > 2.0, the _FillValue attribute must be a numpy scalar so that CF decoding is applied
# - _FillValue is moved from attrs to encoding !
if decode_cf:
ds = apply_cf_decoding(ds)
Expand Down
9 changes: 9 additions & 0 deletions gpm/dataset/decoding/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
"""This module contains functions to CF-decoding the GPM files."""
import warnings

import numpy as np
import xarray as xr
from packaging import version


def apply_cf_decoding(ds):
Expand All @@ -36,6 +38,13 @@ def apply_cf_decoding(ds):
For more information on CF-decoding, read:
https://docs.xarray.dev/en/stable/generated/xarray.decode_cf.html
"""
# Take care of numpy 2.0 FillValue CF Decoding issue
if version.parse(np.__version__) >= version.parse("2.0.0"):
vars_and_coords = list(ds.data_vars) + list(ds.coords)
for var in vars_and_coords:
if "_FillValue" in ds[var].attrs:
ds[var].attrs["_FillValue"] = ds[var].data.dtype.type(ds[var].attrs["_FillValue"])

# Decode with xr.decode_cf
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
Expand Down
2 changes: 1 addition & 1 deletion gpm/dataset/decoding/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def add_lh_height(ds):
# Fixed heights for 2HSLH and 2HCSH
# - FileSpec v7: p.2395, 2463
# NOTE: In SLH/CSH, the first row of the array correspond to the surface
# Instead, for the other GPM RADAR prodcuts, is the last row that correspond to the surface !!!
# Instead, for the other GPM RADAR products, is the last row that correspond to the surface !!!
height = np.linspace(0.25 / 2, 20 - 0.25 / 2, 80) * 1000 # in meters
ds = ds.assign_coords({"height": ("range", height)})
ds["height"].attrs["units"] = "m a.s.l"
Expand Down
4 changes: 2 additions & 2 deletions gpm/retrievals/retrieval_1b_radar.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def get_dielectric_constant(ds, dielectric_constant=None):
if value == -9999.9:
value = 0.9255
else:
ValueError("Expecting a radar product.")
raise ValueError("Expecting a radar product.")
return value


Expand All @@ -306,7 +306,7 @@ def get_radar_wavelength(ds):
elif "PR" in product:
eqvWavelength = default_dict["PR"]
else:
ValueError("Expecting a radar product.")
raise ValueError("Expecting a radar product.")
return eqvWavelength


Expand Down
2 changes: 1 addition & 1 deletion gpm/tests/test_bucket/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ def test_justified_label_xy(self):

def test_justified_labels_single_level(self):
"""Test labels justification of single-level 2D TilePartitioning."""
size = (10, 10) #
size = (10, 10)
extent = [-180, 180, -90, 90]
n_levels = 1
origin = "bottom"
Expand Down
12 changes: 6 additions & 6 deletions gpm/tests/test_io/test_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,18 @@ def test_get_sensor_satellite_names():
},
}

assert ["SSMIS"] == _get_sensor_satellite_names(info_dict, key="sensor", combine_with=None)
assert ["F18"] == _get_sensor_satellite_names(info_dict, key="satellite", combine_with=None)
assert ["SSMIS-F18"] == _get_sensor_satellite_names(
assert _get_sensor_satellite_names(info_dict, key="sensor", combine_with=None) == ["SSMIS"]
assert _get_sensor_satellite_names(info_dict, key="satellite", combine_with=None) == ["F18"]
assert _get_sensor_satellite_names(
info_dict,
key="satellite",
combine_with="sensor",
)
assert ["SSMIS-F18"] == _get_sensor_satellite_names(
) == ["SSMIS-F18"]
assert _get_sensor_satellite_names(
info_dict,
key="sensor",
combine_with="satellite",
)
) == ["SSMIS-F18"]


@pytest.mark.parametrize("full", [True, False])
Expand Down
12 changes: 3 additions & 9 deletions gpm/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,7 @@ def has_regular_time(xr_obj):
"""Return True if all timesteps are regular. False otherwise."""
list_discontinuous_slices = get_slices_non_regular_time(xr_obj)
n_discontinuous = len(list_discontinuous_slices)
if n_discontinuous > 0:
return False
return True
return n_discontinuous == 0


####--------------------------------------------------------------------------.
Expand Down Expand Up @@ -648,9 +646,7 @@ def has_contiguous_scans(
cross_track_dim=cross_track_dim,
)
n_discontinuous = len(list_discontinuous_slices)
if n_discontinuous > 0:
return False
return True
return n_discontinuous == 0


####--------------------------------------------------------------------------.
Expand Down Expand Up @@ -830,9 +826,7 @@ def has_valid_geolocation(
)
n_invalid_scan_slices = len(list_invalid_slices)
return n_invalid_scan_slices == 0
if is_grid(xr_obj):
return True
return False
return bool(is_grid(xr_obj))


def apply_on_valid_geolocation(function):
Expand Down
2 changes: 1 addition & 1 deletion gpm/utils/slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def ensure_is_slice(slc):
elif isinstance(slc, np.ndarray) and slc.size == 1:
slc = slice(slc.item(), slc.item() + 1)
else:
raise ValueError("Impossibile to convert to a slice object.")
raise ValueError("Impossible to convert to a slice object.")
return slc


Expand Down
5 changes: 2 additions & 3 deletions gpm/utils/subsetting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def is_1d_non_dimensional_coord(xr_obj, coord):
return False
if xr_obj[coord].ndim != 1:
return False
if xr_obj[coord].dims[0] == coord: # 1D dimension coordinate
return False
return True
is_1d_dim_coord = xr_obj[coord].dims[0] == coord
return not is_1d_dim_coord


def _get_dim_of_1d_non_dimensional_coord(xr_obj, coord):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ dependencies = [
"matplotlib>=3.8.3", # introduce pcolormesh rgb
"cartopy>=0.22.0",
"pyproj",
"numpy",
"pandas",
"scipy",
"pycolorbar",
]
Expand Down

0 comments on commit 88a1db7

Please sign in to comment.