Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions pyxlma/plot/xlma_plot_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,68 @@
from matplotlib.collections import PatchCollection


def subset(lon_data, lat_data, alt_data, time_data, chi_data,station_data,
xlim, ylim, zlim, tlim, xchi, stationmin):
def subset(lon_data=None, lat_data=None, alt_data=None, time_data=None, chi_data=None, station_data=None,
xlim=None, ylim=None, zlim=None, tlim=None, xchi=None, stationmin=None):
"""
Generate a subset of x,y,z,t of sources based on maximum
reduced chi squared and given x,y,z,t bounds

Returns: longitude, latitude, altitude, time and boolean arrays
"""
selection = ((alt_data>zlim[0])&(alt_data<zlim[1])&
(lon_data>xlim[0])&(lon_data<xlim[1])&
(lat_data>ylim[0])&(lat_data<ylim[1])&
(time_data>tlim[0])&(time_data<tlim[1])&
(chi_data<=xchi)&(station_data>=stationmin)
)

alt_data = alt_data[selection]
lon_data = lon_data[selection]
lat_data = lat_data[selection]
time_data = time_data[selection]
return lon_data, lat_data, alt_data, time_data, selection
data_shape = None
for data in [lon_data, lat_data, alt_data, time_data, chi_data, station_data]:
if data is not None:
if data_shape is None:
data_shape = data.shape
elif data_shape != data.shape:
raise ValueError("All input arrays must have the same shape.")
if data_shape is None:
raise ValueError("At least one input array must be provided.")
selection = np.ones(data_shape, dtype=bool)
if xlim is not None:
if lon_data is None:
raise ValueError("Longitude data must be provided to filter by xlim")
else:
selection &= ((lon_data>xlim[0])&(lon_data<xlim[1]))
if ylim is not None:
if lat_data is None:
raise ValueError("Latitude data must be provided to filter by ylim")
else:
selection &= ((lat_data>ylim[0])&(lat_data<ylim[1]))
if zlim is not None:
if alt_data is None:
raise ValueError("Altitude data must be provided to filter by zlim")
else:
selection &= ((alt_data>zlim[0])&(alt_data<zlim[1]))
if tlim is not None:
if time_data is None:
raise ValueError("Time data must be provided to filter by tlim")
else:
nsToS = 1e9
time_array = np.array(time_data).astype('datetime64[ns]').astype(float)/nsToS
tlim_array = np.atleast_1d(tlim).astype('datetime64[ns]').astype(float)/nsToS
selection &= ((time_array>tlim_array[0])&(time_array<tlim_array[1]))
if xchi is not None:
if chi_data is None:
raise ValueError("chi squared data must be provided to filter by xchi")
else:
selection &= (chi_data <= xchi)
if stationmin is not None:
if station_data is None:
raise ValueError("Station data must be provided to filter by stationmin")
else:
selection &= (station_data >= stationmin)

things_to_return = []
if lon_data is not None:
things_to_return.append(lon_data[selection])
if lat_data is not None:
things_to_return.append(lat_data[selection])
if alt_data is not None:
things_to_return.append(alt_data[selection])
if time_data is not None:
things_to_return.append(time_data[selection])
return *things_to_return, selection


def color_by_time(time_array, tlim=None):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_plot_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ def test_subset():
'2023-12-24T00:57:07.814960674', '2023-12-24T00:57:07.826344209']).astype(np.datetime64).astype(float))
assert np.sum(selection) == 10

def test_subset_time_mismatch():
lma = xr.open_dataset('tests/truth/lma_netcdf/lma.nc')
time_subset, selection = subset(time_data=lma.event_time.data, tlim=(dt(2023, 12, 24, 0, 57, 0), dt(2023, 12, 24, 0, 57, 10)))
assert np.allclose(time_subset[0:10].astype(float), np.array(['2023-12-24T00:57:01.747284125', '2023-12-24T00:57:01.748099340',
'2023-12-24T00:57:01.748382054', '2023-12-24T00:57:01.749366380',
'2023-12-24T00:57:01.749571321', '2023-12-24T00:57:01.751596868',
'2023-12-24T00:57:01.752419634', '2023-12-24T00:57:01.753047708',
'2023-12-24T00:57:01.754500213', '2023-12-24T00:57:01.757822235']).astype(np.datetime64).astype(float))
assert np.sum(selection) == 2590

def test_color_by_time_datetime_nolimit():
some_datetimes = np.array([dt(2021, 4, 9, 1, 51, 0), dt(2021, 4, 9, 1, 52, 0), dt(2021, 4, 9, 1, 53, 0), dt(2021, 4, 9, 1, 54, 0), dt(2021, 4, 9, 1, 59, 0)])
vmin, vmax, colors = color_by_time(some_datetimes)
Expand Down