1
1
import os
2
+ from collections import defaultdict
3
+ from copy import deepcopy
2
4
from dataclasses import dataclass , field
3
5
from pathlib import Path
4
6
from typing import Any
13
15
14
16
from .interface_definition import ImageViewerInterface
15
17
18
+ @dataclass
19
+ class CatalogInfo :
20
+ """
21
+ Class to hold information about a catalog.
22
+ """
23
+ style : dict [str , Any ] = field (default_factory = dict )
24
+ data : Table | None = None
16
25
17
26
@dataclass
18
27
class ImageViewer :
@@ -27,27 +36,65 @@ class ImageViewer:
27
36
image_height : int = 0
28
37
zoom_level : float = 1
29
38
_cursor : str = ImageViewerInterface .ALLOWED_CURSOR_LOCATIONS [0 ]
30
- marker : Any = "marker"
31
39
_cuts : BaseInterval | tuple [float , float ] = AsymmetricPercentileInterval (upper_percentile = 95 )
32
40
_stretch : BaseStretch = LinearStretch
33
41
# viewer: Any
34
42
35
43
# Allowed locations for cursor display
36
44
ALLOWED_CURSOR_LOCATIONS : tuple = ImageViewerInterface .ALLOWED_CURSOR_LOCATIONS
37
45
38
- # List of marker names that are for internal use only
39
- RESERVED_MARKER_SET_NAMES : tuple = ImageViewerInterface .RESERVED_MARKER_SET_NAMES
40
-
41
- # Default marker name for marking via API
42
- DEFAULT_MARKER_NAME : str = ImageViewerInterface .DEFAULT_MARKER_NAME
43
-
44
46
# some internal variable for keeping track of viewer state
45
- _interactive_marker_name : str = ""
46
- _previous_marker : Any = ""
47
- _markers : dict [str , Table ] = field (default_factory = dict )
48
47
_wcs : WCS | None = None
49
48
_center : tuple [float , float ] = (0.0 , 0.0 )
50
49
50
+ def __post_init__ (self ):
51
+ # This is a dictionary of marker sets. The keys are the names of the
52
+ # marker sets, and the values are the tables containing the markers.
53
+ self ._catalogs = defaultdict (CatalogInfo )
54
+ self ._catalogs [None ].data = None
55
+ self ._catalogs [None ].style = self ._default_catalog_style .copy ()
56
+
57
+ def _user_catalog_labels (self ) -> list [str ]:
58
+ """
59
+ Get the user-defined catalog labels.
60
+ """
61
+ return [label for label in self ._catalogs if label is not None ]
62
+
63
+ def _resolve_catalog_label (self , catalog_label : str | None ) -> str :
64
+ """
65
+ Figure out the catalog label if the user did not specify one. This
66
+ is needed so that the user gets what they expect in the simple case
67
+ where there is only one catalog loaded. In that case the user may
68
+ or may not have actually specified a catalog label.
69
+ """
70
+ user_keys = self ._user_catalog_labels ()
71
+ if catalog_label is None :
72
+ match len (user_keys ):
73
+ case 0 :
74
+ # No user-defined catalog labels, so return the default label.
75
+ catalog_label = None
76
+ case 1 :
77
+ # The user must have loaded a catalog, so return that instead of
78
+ # the default label, which live in the key None.
79
+ catalog_label = user_keys [0 ]
80
+ case _:
81
+ raise ValueError (
82
+ "Multiple catalog styles defined. Please specify a catalog_label to get the style."
83
+ )
84
+
85
+ return catalog_label
86
+
87
+ @property
88
+ def _default_catalog_style (self ) -> dict [str , Any ]:
89
+ """
90
+ The default style for the catalog markers.
91
+ """
92
+ return {
93
+ "shape" : "circle" ,
94
+ "color" : "red" ,
95
+ "size" : 5 ,
96
+ }
97
+
51
98
def get_stretch (self ) -> BaseStretch :
52
99
return self ._stretch
53
100
@@ -79,6 +126,62 @@ def cursor(self, value: str) -> None:
79
126
80
127
# The methods, grouped loosely by purpose
81
128
129
+ def get_catalog_style (self , catalog_label = None ) -> dict [str , Any ]:
130
+ """
131
+ Get the style for the catalog.
132
+
133
+ Parameters
134
+ ----------
135
+ catalog_label : str, optional
136
+ The label of the catalog. Default is ``None``.
137
+
138
+ Returns
139
+ -------
140
+ dict
141
+ The style for the catalog.
142
+ """
143
+ catalog_label = self ._resolve_catalog_label (catalog_label )
144
+
145
+ style = self ._catalogs [catalog_label ].style .copy ()
146
+ style ["catalog_label" ] = catalog_label
147
+ return style
148
+
149
+ def set_catalog_style (
150
+ self ,
151
+ catalog_label : str | None = None ,
152
+ shape : str = "circle" ,
153
+ color : str = "red" ,
154
+ size : float = 5 ,
155
+ ** kwargs
156
+ ) -> None :
157
+ """
158
+ Set the style for the catalog.
159
+
160
+ Parameters
161
+ ----------
162
+ catalog_label : str, optional
163
+ The label of the catalog.
164
+ shape : str, optional
165
+ The shape of the markers.
166
+ color : str, optional
167
+ The color of the markers.
168
+ size : float, optional
169
+ The size of the markers.
170
+ **kwargs
171
+ Additional keyword arguments to pass to the marker style.
172
+ """
173
+ catalog_label = self ._resolve_catalog_label (catalog_label )
174
+
175
+ if self ._catalogs [catalog_label ].data is None :
176
+ raise ValueError ("Must load a catalog before setting a catalog style." )
177
+
178
+ self ._catalogs [catalog_label ].style = dict (
179
+ shape = shape ,
180
+ color = color ,
181
+ size = size ,
182
+ ** kwargs
183
+ )
184
+
82
185
# Methods for loading data
83
186
def load_image (self , file : str | os .PathLike | ArrayLike | NDData ) -> None :
84
187
"""
@@ -175,142 +278,108 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
175
278
p .write_text ("This is a dummy file. The viewer does not save anything." )
176
279
177
280
# Marker-related methods
178
- def add_markers (self , table : Table , x_colname : str = 'x' , y_colname : str = 'y' ,
281
+ def load_catalog (self , table : Table , x_colname : str = 'x' , y_colname : str = 'y' ,
179
282
skycoord_colname : str = 'coord' , use_skycoord : bool = False ,
180
- marker_name : str | None = None ) -> None :
181
- """
182
- Add markers to the image.
183
-
184
- Parameters
185
- ----------
186
- table : `astropy.table.Table`
187
- The table containing the marker positions.
188
- x_colname : str, optional
189
- The name of the column containing the x positions. Default
190
- is ``'x'``.
191
- y_colname : str, optional
192
- The name of the column containing the y positions. Default
193
- is ``'y'``.
194
- skycoord_colname : str, optional
195
- The name of the column containing the sky coordinates. If
196
- given, the ``use_skycoord`` parameter is ignored. Default
197
- is ``'coord'``.
198
- use_skycoord : bool, optional
199
- If `True`, the ``skycoord_colname`` column will be used to
200
- get the marker positions. Default is `False`.
201
- marker_name : str, optional
202
- The name of the marker set to use. If not given, a unique
203
- name will be generated.
204
- """
283
+ catalog_label : str | None = None ,
284
+ catalog_style : dict | None = None ) -> None :
205
285
try :
206
286
coords = table [skycoord_colname ]
207
287
except KeyError :
208
288
coords = None
209
289
210
- if use_skycoord :
211
- if self ._wcs is not None :
290
+ try :
291
+ xy = (table [x_colname ], table [y_colname ])
292
+ except KeyError :
293
+ xy = None
294
+
295
+ to_add = deepcopy (table )
296
+ if xy is None :
297
+ if self ._wcs is not None and coords is not None :
212
298
x , y = self ._wcs .world_to_pixel (coords )
299
+ to_add [x_colname ] = x
300
+ to_add [y_colname ] = y
213
301
else :
214
- raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
302
+ to_add [x_colname ] = to_add [y_colname ] = None
303
+
304
+ if coords is None :
305
+ if use_skycoord and self ._wcs is None :
306
+ raise ValueError ("Cannot use sky coordinates without a SkyCoord column or WCS." )
307
+ elif xy is not None and self ._wcs is not None :
308
+ # If we have xy coordinates, convert them to sky coordinates
309
+ coords = self ._wcs .pixel_to_world (xy [0 ], xy [1 ])
310
+ to_add [skycoord_colname ] = coords
311
+ else :
312
+ to_add [skycoord_colname ] = None
313
+
314
+ catalog_label = self ._resolve_catalog_label (catalog_label )
315
+
316
+ # Either set new data or append to existing data
317
+ if (
318
+ catalog_label in self ._catalogs
319
+ and self ._catalogs [catalog_label ].data is not None
320
+ ):
321
+ # If the catalog already exists, we append to it
322
+ old_table = self ._catalogs [catalog_label ].data
323
+ self ._catalogs [catalog_label ].data = vstack ([old_table , to_add ])
215
324
else :
216
- x = table [x_colname ]
217
- y = table [y_colname ]
218
-
219
- if not coords and self ._wcs is not None :
220
- coords = self ._wcs .pixel_to_world (x , y )
221
-
222
- if marker_name in self .RESERVED_MARKER_SET_NAMES :
223
- raise ValueError (f"Marker name { marker_name } not allowed." )
325
+ # If the catalog does not exist, we create a new one
326
+ self ._catalogs [catalog_label ].data = to_add
224
327
225
- marker_name = marker_name if marker_name else self .DEFAULT_MARKER_NAME
328
+ # Ensure a catalog always has a style
329
+ if catalog_style is None :
330
+ if not self ._catalogs [catalog_label ].style :
331
+ catalog_style = self ._default_catalog_style .copy ()
226
332
227
- to_add = Table (
228
- dict (
229
- x = x ,
230
- y = y ,
231
- coord = coords if coords else [None ] * len (x ),
232
- )
233
- )
234
- to_add ["marker name" ] = marker_name
333
+ self ._catalogs [catalog_label ].style = catalog_style
235
334
236
- if marker_name in self ._markers :
237
- marker_table = self ._markers [marker_name ]
238
- self ._markers [marker_name ] = vstack ([marker_table , to_add ])
239
- else :
240
- self ._markers [marker_name ] = to_add
335
+ load_catalog .__doc__ = ImageViewerInterface .load_catalog .__doc__
241
336
242
- def reset_markers (self ) -> None :
243
- """
244
- Remove all markers from the image.
245
- """
246
- self ._markers = {}
247
-
248
- def remove_markers (self , marker_name : str | list [str ] | None = None ) -> None :
337
+ def remove_catalog (self , catalog_label : str | None = None ) -> None :
249
338
"""
250
339
Remove markers from the image.
251
340
252
341
Parameters
253
342
----------
254
343
marker_name : str, optional
255
- The name of the marker set to remove. If the value is ``"all "``,
344
+ The name of the marker set to remove. If the value is ``"* "``,
256
345
then all markers will be removed.
257
346
"""
258
- if isinstance (marker_name , str ):
259
- if marker_name in self ._markers :
260
- del self ._markers [marker_name ]
261
- elif marker_name == "all" :
262
- self ._markers = {}
263
- else :
264
- raise ValueError (f"Marker name { marker_name } not found." )
265
- elif isinstance (marker_name , list ):
266
- for name in marker_name :
267
- if name in self ._markers :
268
- del self ._markers [name ]
269
- else :
270
- raise ValueError (f"Marker name { name } not found." )
271
-
272
- def get_markers (self , x_colname : str = 'x' , y_colname : str = 'y' ,
273
- skycoord_colname : str = 'coord' ,
274
- marker_name : str | list [str ] | None = None ) -> Table :
275
- """
276
- Get the marker positions.
347
+ if isinstance (catalog_label , list ):
348
+ raise ValueError (
349
+ "Cannot remove multiple catalogs from a list. Please specify "
350
+ "a single catalog label or use '*' to remove all catalogs."
351
+ )
352
+ elif catalog_label == "*" :
353
+ # If the user wants to remove all catalogs, we reset the
354
+ # catalogs dictionary to an empty one.
355
+ self ._catalogs = defaultdict (CatalogInfo )
356
+ return
277
357
278
- Parameters
279
- ----------
280
- x_colname : str, optional
281
- The name of the column containing the x positions. Default
282
- is ``'x'``.
283
- y_colname : str, optional
284
- The name of the column containing the y positions. Default
285
- is ``'y'``.
286
- skycoord_colname : str, optional
287
- The name of the column containing the sky coordinates. Default
288
- is ``'coord'``.
289
- marker_name : str or list of str, optional
290
- The name of the marker set to use. If that value is ``"all"``,
291
- then all markers will be returned.
358
+ # Special cases are done, so we can resolve the catalog label
359
+ catalog_label = self ._resolve_catalog_label (catalog_label )
292
360
293
- Returns
294
- -------
295
- table : `astropy.table.Table`
296
- The table containing the marker positions. If no markers match the
297
- ``marker_name`` parameter, an empty table is returned.
298
- """
299
- if isinstance (marker_name , str ):
300
- if marker_name == "all" :
301
- marker_name = self ._markers .keys ()
302
- else :
303
- marker_name = [marker_name ]
304
- elif marker_name is None :
305
- marker_name = [self .DEFAULT_MARKER_NAME ]
361
+ try :
362
+ del self ._catalogs [catalog_label ]
363
+ except KeyError :
364
+ raise ValueError (f"Catalog label { catalog_label } not found." )
365
+
366
+ def get_catalog (self , x_colname : str = 'x' , y_colname : str = 'y' ,
367
+ skycoord_colname : str = 'coord' ,
368
+ catalog_label : str | None = None ) -> Table :
369
+ # Dostring is copied from the interface definition, so it is not
370
+ # duplicated here.
371
+ catalog_label = self ._resolve_catalog_label (catalog_label )
306
372
307
- to_stack = [ self ._markers [ name ] for name in marker_name if name in self . _markers ]
373
+ result = self ._catalogs [ catalog_label ]. data if catalog_label in self . _catalogs else Table ( names = [ "x" , "y" , "coord" ])
308
374
309
- result = vstack (to_stack ) if to_stack else Table (names = ["x" , "y" , "coord" , "marker name" ])
310
375
result .rename_columns (["x" , "y" , "coord" ], [x_colname , y_colname , skycoord_colname ])
311
376
312
377
return result
378
+ get_catalog .__doc__ = ImageViewerInterface .get_catalog .__doc__
313
379
380
+ def get_catalog_names (self ) -> list [str ]:
381
+ return list (self ._user_catalog_labels ())
382
+ get_catalog_names .__doc__ = ImageViewerInterface .get_catalog_names .__doc__
314
383
315
384
# Methods that modify the view
316
385
def center_on (self , point : tuple | SkyCoord ):
0 commit comments