Skip to content

Commit

Permalink
Add BAMS pictures
Browse files Browse the repository at this point in the history
  • Loading branch information
ghiggi committed Feb 6, 2025
1 parent 064f05d commit 76fa09d
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 28 deletions.
135 changes: 135 additions & 0 deletions gpm/retrievals/retrieval_1b_c_pmw.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
# -----------------------------------------------------------------------------.
"""This module contains GPM PMW 1B and 1C products community-based retrievals."""
import numpy as np
import pandas as pd
import xarray as xr

from gpm.checks import check_is_spatial_2d
from gpm.utils.decorators import check_software_availability
from gpm.utils.pmw import (
PMWFrequency,
create_rgb_composite,
Expand Down Expand Up @@ -316,4 +319,136 @@ def retrieve_PESCA(ds, t2m="t2m"):
return da_pesca


@check_software_availability(software="sklearn", conda_package="scikit-learn")
@check_software_availability(software="umap", conda_package="umap-learn")
def retrieve_UMAP_RGB(ds, scaler=None, n_neighbors=10, min_dist=0.01, random_state=None, **kwargs):
"""Create a UMAP RGB composite."""
import umap
from sklearn.preprocessing import MinMaxScaler

# Check dataset has only spatial 2D variables
check_is_spatial_2d(ds)

# Define variables
variables = list(ds.data_vars)

# Convert to dataframe
df = ds.gpm.to_pandas_dataframe(drop_index=False)

# Retrieve dataset coordinates which are present in the dataframe
coordinates = [column for column in list(ds.coords) if column in df]

# Remove rows with non finite values
df_valid = df[np.isfinite(df[variables]).all(axis=1) & (~np.isnan(df[variables])).all(axis=1)]

# Retrieve dataframe with only variables
df_data = df_valid[variables]

# Define scaler
if scaler is not None:
scaler.fit(df_data)
scaled_data = scaler.transform(df_data)
else:
scaled_data = df_data

# Compute 3D embedding
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=3, random_state=random_state, **kwargs)
embedding = reducer.fit_transform(scaled_data)

# Define RGB scaler
rgb_scaler = MinMaxScaler()
rgb_scaler = rgb_scaler.fit(embedding)

# Scale UMAP embedding between 0 and 1
rgb_data = rgb_scaler.transform(embedding)
rgb_data = np.clip(rgb_data, a_min=0, a_max=1)

# Create RGB dataframe of valid pixels
df_rgb_valid = pd.DataFrame(rgb_data, index=df_data.index, columns=["R", "G", "B"])

# Create original RGB dataframe
df_rgb = df[coordinates]
df_rgb = df_rgb.merge(df_rgb_valid, how="outer", left_index=True, right_index=True)

# Convert back to xarray
ds_rgb = df_rgb.to_xarray()
ds_rgb = ds_rgb.set_coords(coordinates)

# Define RGB DataArray
da_rgb = ds_rgb[["R", "G", "B"]].to_array(dim="rgb")

# Add missing coordinates
missing_coords = {coord: ds[coord] for coord in set(ds.coords) - set(da_rgb.coords)}
da_rgb = da_rgb.assign_coords(missing_coords)

# Return RGB DataArray
return da_rgb

Check warning on line 385 in gpm/retrievals/retrieval_1b_c_pmw.py

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

❌ New issue: Excess Number of Function Arguments

retrieve_UMAP_RGB has 6 arguments, threshold = 4. This function has too many arguments, indicating a lack of encapsulation. Avoid adding more arguments.


@check_software_availability(software="sklearn", conda_package="scikit-learn")
def retrieve_PCA_RGB(ds, scaler=None):
"""Create a PCA RGB composite."""
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

# Check dataset has only spatial 2D variables
check_is_spatial_2d(ds)

# Define variables
variables = list(ds.data_vars)

# Convert to dataframe
df = ds.gpm.to_pandas_dataframe(drop_index=False)

# Retrieve dataset coordinates which are present in the dataframe
coordinates = [column for column in list(ds.coords) if column in df]

# Remove rows with non finite values
df_valid = df[np.isfinite(df[variables]).all(axis=1) & (~np.isnan(df[variables])).all(axis=1)]

# Retrieve dataframe with only variables
df_data = df_valid[variables]

# Define scaler
if scaler is not None:
scaler.fit(df_data)
scaled_data = scaler.transform(df_data)
else:
scaled_data = df_data

# Compute 3D embedding
pca = PCA(n_components=3)
pca.fit(scaled_data)
embedding = pca.transform(scaled_data)

# Define RGB scaler
rgb_scaler = MinMaxScaler()
rgb_scaler = rgb_scaler.fit(embedding)

# Scale UMAP embedding between 0 and 1
rgb_data = rgb_scaler.transform(embedding)
rgb_data = np.clip(rgb_data, a_min=0, a_max=1)

# Create RGB dataframe of valid pixels
df_rgb_valid = pd.DataFrame(rgb_data, index=df_data.index, columns=["R", "G", "B"])

# Create original RGB dataframe
df_rgb = df[coordinates]
df_rgb = df_rgb.merge(df_rgb_valid, how="outer", left_index=True, right_index=True)

# Convert back to xarray
ds_rgb = df_rgb.to_xarray()
ds_rgb = ds_rgb.set_coords(coordinates)

# Define RGB DataArray
da_rgb = ds_rgb[["R", "G", "B"]].to_array(dim="rgb")

# Add missing coordinates
missing_coords = {coord: ds[coord] for coord in set(ds.coords) - set(da_rgb.coords)}
da_rgb = da_rgb.assign_coords(missing_coords)

# Return RGB DataArray
return da_rgb


####----------------------------------------------------------------------------------------.
27 changes: 24 additions & 3 deletions gpm/tests/test_utils/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@
import pytest
import xarray as xr

from gpm.utils import decorators
from gpm.utils.decorators import (
check_has_along_track_dimension,
check_has_cross_track_dimension,
check_software_availability,
)


def test_check_has_cross_track_dimension() -> None:
"""Test check_has_cross_track_dimension decorator."""

@decorators.check_has_cross_track_dimension
@check_has_cross_track_dimension
def identity(xr_obj: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray:
return xr_obj

Expand All @@ -53,7 +57,7 @@ def identity(xr_obj: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray:
def test_check_has_along_track_dimension() -> None:
"""Test check_has_along_track_dimension decorator."""

@decorators.check_has_along_track_dimension
@check_has_along_track_dimension
def identity(xr_obj: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray:
return xr_obj

Expand All @@ -65,3 +69,20 @@ def identity(xr_obj: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray:
da = xr.DataArray(np.arange(10))
with pytest.raises(ValueError):
identity(da)


def test_check_software_availability_decorator():
"""Test check_software_availability_decorator raise ImportError."""

@check_software_availability(software="dummy_package", conda_package="dummy_package")
def dummy_function(a, b=1):
return a, b

with pytest.raises(ImportError):
dummy_function()

@check_software_availability(software="numpy", conda_package="numpy")
def dummy_function(a, b=1):
return a, b

assert dummy_function(2, b=3) == (2, 3)
29 changes: 29 additions & 0 deletions gpm/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
# -----------------------------------------------------------------------------.
"""This module contains functions decorators checking GPM-API object type."""
import functools
import importlib
from functools import wraps

from gpm.checks import check_has_along_track_dim as _check_has_along_track_dim
from gpm.checks import check_has_cross_track_dim as _check_has_cross_track_dim
Expand Down Expand Up @@ -104,3 +106,30 @@ def wrapper(*args, **kwargs):
return function(*args, **kwargs)

return wrapper


def check_software_availability(software, conda_package):
"""A decorator to ensure that a software package is installed.
Parameters
----------
software : str
The package name as recognized by Python's import system.
conda_package : str
The package name as recognized by conda-forge.
"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not importlib.util.find_spec(software):
raise ImportError(
f"The '{software}' package is required but not found.\n"
"Please install it using conda:\n"
f" conda install -c conda-forge {conda_package}",
)
return func(*args, **kwargs)

return wrapper

return decorator
9 changes: 2 additions & 7 deletions gpm/utils/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

# -----------------------------------------------------------------------------.
"""This module contains functions for manipulating GPM-API Datasets."""
import importlib

import numpy as np
import xarray as xr
Expand All @@ -39,7 +38,7 @@
has_vertical_dim,
is_grid,
)
from gpm.utils.decorators import check_is_gpm_object
from gpm.utils.decorators import check_is_gpm_object, check_software_availability
from gpm.utils.geospatial import get_geodesic_line, get_great_circle_arc_endpoints
from gpm.utils.xarray import (
check_variable_availabilty,
Expand Down Expand Up @@ -1221,6 +1220,7 @@ def extract_transect_at_points(xr_obj, points, method="linear", new_dim="transec


@check_is_gpm_object
@check_software_availability(software="sklearn", conda_package="scikit-learn")
def extract_transect_between_points(xr_obj, start_point, end_point, steps=100, method="linear", new_dim="transect"):
"""Extract an interpolated transect between two points on a sphere.
Expand Down Expand Up @@ -1256,11 +1256,6 @@ def extract_transect_between_points(xr_obj, start_point, end_point, steps=100, m
:py:class:`gpm.utils.manipulations.extract_transect_around_point`.
"""
if importlib.util.find_spec("sklearn") is None:
raise ImportError(
"The 'sklearn' package required to extract cross-sections is not installed. \n"
"Please install it using the following command: conda install -c conda-forge scikit-learn",
)
# Get the points along the geodesic line
points = get_geodesic_line(start_point=start_point, end_point=end_point, steps=steps)

Expand Down
26 changes: 8 additions & 18 deletions gpm/utils/pyresample.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,21 @@
import numpy as np
import xarray as xr

from gpm.utils.decorators import check_software_availability


@check_software_availability(software="pyresample", conda_package="pyresample")
def remap(src_ds, dst_ds, radius_of_influence=20000, fill_value=np.nan):
"""Remap dataset to another one using nearest-neighbour.
The spatial non-dimensional coordinates of the source dataset are not remapped. !
The output dataset has the spatial coordinates of the destination dataset !
"""
from pyresample.future.resamplers.nearest import KDTreeNearestXarrayResampler

from gpm.checks import get_spatial_dimensions
from gpm.dataset.crs import _get_crs_coordinates, _get_proj_dim_coords, _get_swath_dim_coords, set_dataset_crs

try:
from pyresample.future.resamplers.nearest import KDTreeNearestXarrayResampler
except ImportError:
raise ImportError(
"The 'pyresample' package is required but not found. "
"Please install it using the following command: "
"conda install -c conda-forge pyresample",
)

# Retrieve source and destination area
src_area = src_ds.gpm.pyresample_area
dst_area = dst_ds.gpm.pyresample_area
Expand Down Expand Up @@ -143,17 +139,11 @@ def remap(src_ds, dst_ds, radius_of_influence=20000, fill_value=np.nan):
return ds

Check notice on line 139 in gpm/utils/pyresample.py

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

✅ Getting better: Complex Method

remap decreases in cyclomatic complexity from 11 to 9, threshold = 9. This function has many conditional statements (e.g. if, for, while), leading to lower code health. Avoid adding more conditionals and code to it without refactoring.


@check_software_availability(software="pyresample", conda_package="pyresample")
def get_pyresample_area(xr_obj):
"""It returns the corresponding pyresample area."""
try:
import pyresample # noqa
from gpm.dataset.crs import get_pyresample_area as _get_pyresample_area
except ImportError:
raise ImportError(
"The 'pyresample' package is required but not found. "
"Please install it using the following command: "
"conda install -c conda-forge pyresample",
)
import pyresample # noqa
from gpm.dataset.crs import get_pyresample_area as _get_pyresample_area

# Ensure correct dimension order for Swath
if "cross_track" in xr_obj.dims:
Expand Down

0 comments on commit 76fa09d

Please sign in to comment.