diff --git a/gpm/retrievals/retrieval_1b_c_pmw.py b/gpm/retrievals/retrieval_1b_c_pmw.py index 000e1d5..1acfbf7 100644 --- a/gpm/retrievals/retrieval_1b_c_pmw.py +++ b/gpm/retrievals/retrieval_1b_c_pmw.py @@ -26,8 +26,11 @@ # -----------------------------------------------------------------------------. """This module contains GPM PMW 1B and 1C products community-based retrievals.""" import numpy as np +import pandas as pd import xarray as xr +from gpm.checks import check_is_spatial_2d +from gpm.utils.decorators import check_software_availability from gpm.utils.pmw import ( PMWFrequency, create_rgb_composite, @@ -316,4 +319,136 @@ def retrieve_PESCA(ds, t2m="t2m"): return da_pesca +@check_software_availability(software="sklearn", conda_package="scikit-learn") +@check_software_availability(software="umap", conda_package="umap-learn") +def retrieve_UMAP_RGB(ds, scaler=None, n_neighbors=10, min_dist=0.01, random_state=None, **kwargs): + """Create a UMAP RGB composite.""" + import umap + from sklearn.preprocessing import MinMaxScaler + + # Check dataset has only spatial 2D variables + check_is_spatial_2d(ds) + + # Define variables + variables = list(ds.data_vars) + + # Convert to dataframe + df = ds.gpm.to_pandas_dataframe(drop_index=False) + + # Retrieve dataset coordinates which are present in the dataframe + coordinates = [column for column in list(ds.coords) if column in df] + + # Remove rows with non finite values + df_valid = df[np.isfinite(df[variables]).all(axis=1) & (~np.isnan(df[variables])).all(axis=1)] + + # Retrieve dataframe with only variables + df_data = df_valid[variables] + + # Define scaler + if scaler is not None: + scaler.fit(df_data) + scaled_data = scaler.transform(df_data) + else: + scaled_data = df_data + + # Compute 3D embedding + reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=3, random_state=random_state, **kwargs) + embedding = reducer.fit_transform(scaled_data) + + # Define RGB scaler + rgb_scaler = MinMaxScaler() + rgb_scaler = rgb_scaler.fit(embedding) + + # Scale UMAP embedding between 0 and 1 + rgb_data = rgb_scaler.transform(embedding) + rgb_data = np.clip(rgb_data, a_min=0, a_max=1) + + # Create RGB dataframe of valid pixels + df_rgb_valid = pd.DataFrame(rgb_data, index=df_data.index, columns=["R", "G", "B"]) + + # Create original RGB dataframe + df_rgb = df[coordinates] + df_rgb = df_rgb.merge(df_rgb_valid, how="outer", left_index=True, right_index=True) + + # Convert back to xarray + ds_rgb = df_rgb.to_xarray() + ds_rgb = ds_rgb.set_coords(coordinates) + + # Define RGB DataArray + da_rgb = ds_rgb[["R", "G", "B"]].to_array(dim="rgb") + + # Add missing coordinates + missing_coords = {coord: ds[coord] for coord in set(ds.coords) - set(da_rgb.coords)} + da_rgb = da_rgb.assign_coords(missing_coords) + + # Return RGB DataArray + return da_rgb + + +@check_software_availability(software="sklearn", conda_package="scikit-learn") +def retrieve_PCA_RGB(ds, scaler=None): + """Create a PCA RGB composite.""" + from sklearn.decomposition import PCA + from sklearn.preprocessing import MinMaxScaler + + # Check dataset has only spatial 2D variables + check_is_spatial_2d(ds) + + # Define variables + variables = list(ds.data_vars) + + # Convert to dataframe + df = ds.gpm.to_pandas_dataframe(drop_index=False) + + # Retrieve dataset coordinates which are present in the dataframe + coordinates = [column for column in list(ds.coords) if column in df] + + # Remove rows with non finite values + df_valid = df[np.isfinite(df[variables]).all(axis=1) & (~np.isnan(df[variables])).all(axis=1)] + + # Retrieve dataframe with only variables + df_data = df_valid[variables] + + # Define scaler + if scaler is not None: + scaler.fit(df_data) + scaled_data = scaler.transform(df_data) + else: + scaled_data = df_data + + # Compute 3D embedding + pca = PCA(n_components=3) + pca.fit(scaled_data) + embedding = pca.transform(scaled_data) + + # Define RGB scaler + rgb_scaler = MinMaxScaler() + rgb_scaler = rgb_scaler.fit(embedding) + + # Scale UMAP embedding between 0 and 1 + rgb_data = rgb_scaler.transform(embedding) + rgb_data = np.clip(rgb_data, a_min=0, a_max=1) + + # Create RGB dataframe of valid pixels + df_rgb_valid = pd.DataFrame(rgb_data, index=df_data.index, columns=["R", "G", "B"]) + + # Create original RGB dataframe + df_rgb = df[coordinates] + df_rgb = df_rgb.merge(df_rgb_valid, how="outer", left_index=True, right_index=True) + + # Convert back to xarray + ds_rgb = df_rgb.to_xarray() + ds_rgb = ds_rgb.set_coords(coordinates) + + # Define RGB DataArray + da_rgb = ds_rgb[["R", "G", "B"]].to_array(dim="rgb") + + # Add missing coordinates + missing_coords = {coord: ds[coord] for coord in set(ds.coords) - set(da_rgb.coords)} + da_rgb = da_rgb.assign_coords(missing_coords) + + # Return RGB DataArray + return da_rgb + + ####----------------------------------------------------------------------------------------. diff --git a/gpm/tests/test_utils/test_decorators.py b/gpm/tests/test_utils/test_decorators.py index 8d8d615..83981ff 100644 --- a/gpm/tests/test_utils/test_decorators.py +++ b/gpm/tests/test_utils/test_decorators.py @@ -30,13 +30,17 @@ import pytest import xarray as xr -from gpm.utils import decorators +from gpm.utils.decorators import ( + check_has_along_track_dimension, + check_has_cross_track_dimension, + check_software_availability, +) def test_check_has_cross_track_dimension() -> None: """Test check_has_cross_track_dimension decorator.""" - @decorators.check_has_cross_track_dimension + @check_has_cross_track_dimension def identity(xr_obj: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray: return xr_obj @@ -53,7 +57,7 @@ def identity(xr_obj: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray: def test_check_has_along_track_dimension() -> None: """Test check_has_along_track_dimension decorator.""" - @decorators.check_has_along_track_dimension + @check_has_along_track_dimension def identity(xr_obj: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray: return xr_obj @@ -65,3 +69,20 @@ def identity(xr_obj: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray: da = xr.DataArray(np.arange(10)) with pytest.raises(ValueError): identity(da) + + +def test_check_software_availability_decorator(): + """Test check_software_availability_decorator raise ImportError.""" + + @check_software_availability(software="dummy_package", conda_package="dummy_package") + def dummy_function(a, b=1): + return a, b + + with pytest.raises(ImportError): + dummy_function() + + @check_software_availability(software="numpy", conda_package="numpy") + def dummy_function(a, b=1): + return a, b + + assert dummy_function(2, b=3) == (2, 3) diff --git a/gpm/utils/decorators.py b/gpm/utils/decorators.py index 7a817e5..af62e0b 100644 --- a/gpm/utils/decorators.py +++ b/gpm/utils/decorators.py @@ -26,6 +26,8 @@ # -----------------------------------------------------------------------------. """This module contains functions decorators checking GPM-API object type.""" import functools +import importlib +from functools import wraps from gpm.checks import check_has_along_track_dim as _check_has_along_track_dim from gpm.checks import check_has_cross_track_dim as _check_has_cross_track_dim @@ -104,3 +106,30 @@ def wrapper(*args, **kwargs): return function(*args, **kwargs) return wrapper + + +def check_software_availability(software, conda_package): + """A decorator to ensure that a software package is installed. + + Parameters + ---------- + software : str + The package name as recognized by Python's import system. + conda_package : str + The package name as recognized by conda-forge. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not importlib.util.find_spec(software): + raise ImportError( + f"The '{software}' package is required but not found.\n" + "Please install it using conda:\n" + f" conda install -c conda-forge {conda_package}", + ) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/gpm/utils/manipulations.py b/gpm/utils/manipulations.py index 5293142..b1d430d 100644 --- a/gpm/utils/manipulations.py +++ b/gpm/utils/manipulations.py @@ -25,7 +25,6 @@ # -----------------------------------------------------------------------------. """This module contains functions for manipulating GPM-API Datasets.""" -import importlib import numpy as np import xarray as xr @@ -39,7 +38,7 @@ has_vertical_dim, is_grid, ) -from gpm.utils.decorators import check_is_gpm_object +from gpm.utils.decorators import check_is_gpm_object, check_software_availability from gpm.utils.geospatial import get_geodesic_line, get_great_circle_arc_endpoints from gpm.utils.xarray import ( check_variable_availabilty, @@ -1221,6 +1220,7 @@ def extract_transect_at_points(xr_obj, points, method="linear", new_dim="transec @check_is_gpm_object +@check_software_availability(software="sklearn", conda_package="scikit-learn") def extract_transect_between_points(xr_obj, start_point, end_point, steps=100, method="linear", new_dim="transect"): """Extract an interpolated transect between two points on a sphere. @@ -1256,11 +1256,6 @@ def extract_transect_between_points(xr_obj, start_point, end_point, steps=100, m :py:class:`gpm.utils.manipulations.extract_transect_around_point`. """ - if importlib.util.find_spec("sklearn") is None: - raise ImportError( - "The 'sklearn' package required to extract cross-sections is not installed. \n" - "Please install it using the following command: conda install -c conda-forge scikit-learn", - ) # Get the points along the geodesic line points = get_geodesic_line(start_point=start_point, end_point=end_point, steps=steps) diff --git a/gpm/utils/pyresample.py b/gpm/utils/pyresample.py index 7246752..8d71ef8 100644 --- a/gpm/utils/pyresample.py +++ b/gpm/utils/pyresample.py @@ -31,25 +31,21 @@ import numpy as np import xarray as xr +from gpm.utils.decorators import check_software_availability + +@check_software_availability(software="pyresample", conda_package="pyresample") def remap(src_ds, dst_ds, radius_of_influence=20000, fill_value=np.nan): """Remap dataset to another one using nearest-neighbour. The spatial non-dimensional coordinates of the source dataset are not remapped. ! The output dataset has the spatial coordinates of the destination dataset ! """ + from pyresample.future.resamplers.nearest import KDTreeNearestXarrayResampler + from gpm.checks import get_spatial_dimensions from gpm.dataset.crs import _get_crs_coordinates, _get_proj_dim_coords, _get_swath_dim_coords, set_dataset_crs - try: - from pyresample.future.resamplers.nearest import KDTreeNearestXarrayResampler - except ImportError: - raise ImportError( - "The 'pyresample' package is required but not found. " - "Please install it using the following command: " - "conda install -c conda-forge pyresample", - ) - # Retrieve source and destination area src_area = src_ds.gpm.pyresample_area dst_area = dst_ds.gpm.pyresample_area @@ -143,17 +139,11 @@ def remap(src_ds, dst_ds, radius_of_influence=20000, fill_value=np.nan): return ds +@check_software_availability(software="pyresample", conda_package="pyresample") def get_pyresample_area(xr_obj): """It returns the corresponding pyresample area.""" - try: - import pyresample # noqa - from gpm.dataset.crs import get_pyresample_area as _get_pyresample_area - except ImportError: - raise ImportError( - "The 'pyresample' package is required but not found. " - "Please install it using the following command: " - "conda install -c conda-forge pyresample", - ) + import pyresample # noqa + from gpm.dataset.crs import get_pyresample_area as _get_pyresample_area # Ensure correct dimension order for Swath if "cross_track" in xr_obj.dims: