Skip to content

Commit 2a2416a

Browse files
authored
Merge pull request #45 from mwcraig/marker-to-catalog-migration
Implement catalog part of AIDA
2 parents 51d7582 + b5b1472 commit 2a2416a

File tree

3 files changed

+536
-265
lines changed

3 files changed

+536
-265
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 186 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
from collections import defaultdict
3+
from copy import deepcopy
24
from dataclasses import dataclass, field
35
from pathlib import Path
46
from typing import Any
@@ -13,6 +15,13 @@
1315

1416
from .interface_definition import ImageViewerInterface
1517

18+
@dataclass
19+
class CatalogInfo:
20+
"""
21+
Class to hold information about a catalog.
22+
"""
23+
style: dict[str, Any] = field(default_factory=dict)
24+
data: Table | None = None
1625

1726
@dataclass
1827
class ImageViewer:
@@ -27,27 +36,65 @@ class ImageViewer:
2736
image_height: int = 0
2837
zoom_level: float = 1
2938
_cursor: str = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS[0]
30-
marker: Any = "marker"
3139
_cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval(upper_percentile=95)
3240
_stretch: BaseStretch = LinearStretch
3341
# viewer: Any
3442

3543
# Allowed locations for cursor display
3644
ALLOWED_CURSOR_LOCATIONS: tuple = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS
3745

38-
# List of marker names that are for internal use only
39-
RESERVED_MARKER_SET_NAMES: tuple = ImageViewerInterface.RESERVED_MARKER_SET_NAMES
40-
41-
# Default marker name for marking via API
42-
DEFAULT_MARKER_NAME: str = ImageViewerInterface.DEFAULT_MARKER_NAME
43-
4446
# some internal variable for keeping track of viewer state
45-
_interactive_marker_name: str = ""
46-
_previous_marker: Any = ""
47-
_markers: dict[str, Table] = field(default_factory=dict)
4847
_wcs: WCS | None = None
4948
_center: tuple[float, float] = (0.0, 0.0)
5049

50+
def __post_init__(self):
51+
# This is a dictionary of marker sets. The keys are the names of the
52+
# marker sets, and the values are the tables containing the markers.
53+
self._catalogs = defaultdict(CatalogInfo)
54+
self._catalogs[None].data = None
55+
self._catalogs[None].style = self._default_catalog_style.copy()
56+
57+
def _user_catalog_labels(self) -> list[str]:
58+
"""
59+
Get the user-defined catalog labels.
60+
"""
61+
return [label for label in self._catalogs if label is not None]
62+
63+
def _resolve_catalog_label(self, catalog_label: str | None) -> str:
64+
"""
65+
Figure out the catalog label if the user did not specify one. This
66+
is needed so that the user gets what they expect in the simple case
67+
where there is only one catalog loaded. In that case the user may
68+
or may not have actually specified a catalog label.
69+
"""
70+
user_keys = self._user_catalog_labels()
71+
if catalog_label is None:
72+
match len(user_keys):
73+
case 0:
74+
# No user-defined catalog labels, so return the default label.
75+
catalog_label = None
76+
case 1:
77+
# The user must have loaded a catalog, so return that instead of
78+
# the default label, which live in the key None.
79+
catalog_label = user_keys[0]
80+
case _:
81+
raise ValueError(
82+
"Multiple catalog styles defined. Please specify a catalog_label to get the style."
83+
)
84+
85+
return catalog_label
86+
87+
@property
88+
def _default_catalog_style(self) -> dict[str, Any]:
89+
"""
90+
The default style for the catalog markers.
91+
"""
92+
return {
93+
"shape": "circle",
94+
"color": "red",
95+
"size": 5,
96+
}
97+
5198
def get_stretch(self) -> BaseStretch:
5299
return self._stretch
53100

@@ -79,6 +126,62 @@ def cursor(self, value: str) -> None:
79126

80127
# The methods, grouped loosely by purpose
81128

129+
def get_catalog_style(self, catalog_label=None) -> dict[str, Any]:
130+
"""
131+
Get the style for the catalog.
132+
133+
Parameters
134+
----------
135+
catalog_label : str, optional
136+
The label of the catalog. Default is ``None``.
137+
138+
Returns
139+
-------
140+
dict
141+
The style for the catalog.
142+
"""
143+
catalog_label = self._resolve_catalog_label(catalog_label)
144+
145+
style = self._catalogs[catalog_label].style.copy()
146+
style["catalog_label"] = catalog_label
147+
return style
148+
149+
def set_catalog_style(
150+
self,
151+
catalog_label: str | None = None,
152+
shape: str = "circle",
153+
color: str = "red",
154+
size: float = 5,
155+
**kwargs
156+
) -> None:
157+
"""
158+
Set the style for the catalog.
159+
160+
Parameters
161+
----------
162+
catalog_label : str, optional
163+
The label of the catalog.
164+
shape : str, optional
165+
The shape of the markers.
166+
color : str, optional
167+
The color of the markers.
168+
size : float, optional
169+
The size of the markers.
170+
**kwargs
171+
Additional keyword arguments to pass to the marker style.
172+
"""
173+
catalog_label = self._resolve_catalog_label(catalog_label)
174+
175+
if self._catalogs[catalog_label].data is None:
176+
raise ValueError("Must load a catalog before setting a catalog style.")
177+
178+
self._catalogs[catalog_label].style = dict(
179+
shape=shape,
180+
color=color,
181+
size=size,
182+
**kwargs
183+
)
184+
82185
# Methods for loading data
83186
def load_image(self, file: str | os.PathLike | ArrayLike | NDData) -> None:
84187
"""
@@ -175,142 +278,108 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
175278
p.write_text("This is a dummy file. The viewer does not save anything.")
176279

177280
# Marker-related methods
178-
def add_markers(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
281+
def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
179282
skycoord_colname: str = 'coord', use_skycoord: bool = False,
180-
marker_name: str | None = None) -> None:
181-
"""
182-
Add markers to the image.
183-
184-
Parameters
185-
----------
186-
table : `astropy.table.Table`
187-
The table containing the marker positions.
188-
x_colname : str, optional
189-
The name of the column containing the x positions. Default
190-
is ``'x'``.
191-
y_colname : str, optional
192-
The name of the column containing the y positions. Default
193-
is ``'y'``.
194-
skycoord_colname : str, optional
195-
The name of the column containing the sky coordinates. If
196-
given, the ``use_skycoord`` parameter is ignored. Default
197-
is ``'coord'``.
198-
use_skycoord : bool, optional
199-
If `True`, the ``skycoord_colname`` column will be used to
200-
get the marker positions. Default is `False`.
201-
marker_name : str, optional
202-
The name of the marker set to use. If not given, a unique
203-
name will be generated.
204-
"""
283+
catalog_label: str | None = None,
284+
catalog_style: dict | None = None) -> None:
205285
try:
206286
coords = table[skycoord_colname]
207287
except KeyError:
208288
coords = None
209289

210-
if use_skycoord:
211-
if self._wcs is not None:
290+
try:
291+
xy = (table[x_colname], table[y_colname])
292+
except KeyError:
293+
xy = None
294+
295+
to_add = deepcopy(table)
296+
if xy is None:
297+
if self._wcs is not None and coords is not None:
212298
x, y = self._wcs.world_to_pixel(coords)
299+
to_add[x_colname] = x
300+
to_add[y_colname] = y
213301
else:
214-
raise ValueError("WCS is not set. Cannot convert to pixel coordinates.")
302+
to_add[x_colname] = to_add[y_colname] = None
303+
304+
if coords is None:
305+
if use_skycoord and self._wcs is None:
306+
raise ValueError("Cannot use sky coordinates without a SkyCoord column or WCS.")
307+
elif xy is not None and self._wcs is not None:
308+
# If we have xy coordinates, convert them to sky coordinates
309+
coords = self._wcs.pixel_to_world(xy[0], xy[1])
310+
to_add[skycoord_colname] = coords
311+
else:
312+
to_add[skycoord_colname] = None
313+
314+
catalog_label = self._resolve_catalog_label(catalog_label)
315+
316+
# Either set new data or append to existing data
317+
if (
318+
catalog_label in self._catalogs
319+
and self._catalogs[catalog_label].data is not None
320+
):
321+
# If the catalog already exists, we append to it
322+
old_table = self._catalogs[catalog_label].data
323+
self._catalogs[catalog_label].data = vstack([old_table, to_add])
215324
else:
216-
x = table[x_colname]
217-
y = table[y_colname]
218-
219-
if not coords and self._wcs is not None:
220-
coords = self._wcs.pixel_to_world(x, y)
221-
222-
if marker_name in self.RESERVED_MARKER_SET_NAMES:
223-
raise ValueError(f"Marker name {marker_name} not allowed.")
325+
# If the catalog does not exist, we create a new one
326+
self._catalogs[catalog_label].data = to_add
224327

225-
marker_name = marker_name if marker_name else self.DEFAULT_MARKER_NAME
328+
# Ensure a catalog always has a style
329+
if catalog_style is None:
330+
if not self._catalogs[catalog_label].style:
331+
catalog_style = self._default_catalog_style.copy()
226332

227-
to_add = Table(
228-
dict(
229-
x=x,
230-
y=y,
231-
coord=coords if coords else [None] * len(x),
232-
)
233-
)
234-
to_add["marker name"] = marker_name
333+
self._catalogs[catalog_label].style = catalog_style
235334

236-
if marker_name in self._markers:
237-
marker_table = self._markers[marker_name]
238-
self._markers[marker_name] = vstack([marker_table, to_add])
239-
else:
240-
self._markers[marker_name] = to_add
335+
load_catalog.__doc__ = ImageViewerInterface.load_catalog.__doc__
241336

242-
def reset_markers(self) -> None:
243-
"""
244-
Remove all markers from the image.
245-
"""
246-
self._markers = {}
247-
248-
def remove_markers(self, marker_name: str | list[str] | None = None) -> None:
337+
def remove_catalog(self, catalog_label: str | None = None) -> None:
249338
"""
250339
Remove markers from the image.
251340
252341
Parameters
253342
----------
254343
marker_name : str, optional
255-
The name of the marker set to remove. If the value is ``"all"``,
344+
The name of the marker set to remove. If the value is ``"*"``,
256345
then all markers will be removed.
257346
"""
258-
if isinstance(marker_name, str):
259-
if marker_name in self._markers:
260-
del self._markers[marker_name]
261-
elif marker_name == "all":
262-
self._markers = {}
263-
else:
264-
raise ValueError(f"Marker name {marker_name} not found.")
265-
elif isinstance(marker_name, list):
266-
for name in marker_name:
267-
if name in self._markers:
268-
del self._markers[name]
269-
else:
270-
raise ValueError(f"Marker name {name} not found.")
271-
272-
def get_markers(self, x_colname: str = 'x', y_colname: str = 'y',
273-
skycoord_colname: str = 'coord',
274-
marker_name: str | list[str] | None = None) -> Table:
275-
"""
276-
Get the marker positions.
347+
if isinstance(catalog_label, list):
348+
raise ValueError(
349+
"Cannot remove multiple catalogs from a list. Please specify "
350+
"a single catalog label or use '*' to remove all catalogs."
351+
)
352+
elif catalog_label == "*":
353+
# If the user wants to remove all catalogs, we reset the
354+
# catalogs dictionary to an empty one.
355+
self._catalogs = defaultdict(CatalogInfo)
356+
return
277357

278-
Parameters
279-
----------
280-
x_colname : str, optional
281-
The name of the column containing the x positions. Default
282-
is ``'x'``.
283-
y_colname : str, optional
284-
The name of the column containing the y positions. Default
285-
is ``'y'``.
286-
skycoord_colname : str, optional
287-
The name of the column containing the sky coordinates. Default
288-
is ``'coord'``.
289-
marker_name : str or list of str, optional
290-
The name of the marker set to use. If that value is ``"all"``,
291-
then all markers will be returned.
358+
# Special cases are done, so we can resolve the catalog label
359+
catalog_label = self._resolve_catalog_label(catalog_label)
292360

293-
Returns
294-
-------
295-
table : `astropy.table.Table`
296-
The table containing the marker positions. If no markers match the
297-
``marker_name`` parameter, an empty table is returned.
298-
"""
299-
if isinstance(marker_name, str):
300-
if marker_name == "all":
301-
marker_name = self._markers.keys()
302-
else:
303-
marker_name = [marker_name]
304-
elif marker_name is None:
305-
marker_name = [self.DEFAULT_MARKER_NAME]
361+
try:
362+
del self._catalogs[catalog_label]
363+
except KeyError:
364+
raise ValueError(f"Catalog label {catalog_label} not found.")
365+
366+
def get_catalog(self, x_colname: str = 'x', y_colname: str = 'y',
367+
skycoord_colname: str = 'coord',
368+
catalog_label: str | None = None) -> Table:
369+
# Dostring is copied from the interface definition, so it is not
370+
# duplicated here.
371+
catalog_label = self._resolve_catalog_label(catalog_label)
306372

307-
to_stack = [self._markers[name] for name in marker_name if name in self._markers]
373+
result = self._catalogs[catalog_label].data if catalog_label in self._catalogs else Table(names=["x", "y", "coord"])
308374

309-
result = vstack(to_stack) if to_stack else Table(names=["x", "y", "coord", "marker name"])
310375
result.rename_columns(["x", "y", "coord"], [x_colname, y_colname, skycoord_colname])
311376

312377
return result
378+
get_catalog.__doc__ = ImageViewerInterface.get_catalog.__doc__
313379

380+
def get_catalog_names(self) -> list[str]:
381+
return list(self._user_catalog_labels())
382+
get_catalog_names.__doc__ = ImageViewerInterface.get_catalog_names.__doc__
314383

315384
# Methods that modify the view
316385
def center_on(self, point: tuple | SkyCoord):

0 commit comments

Comments
 (0)