11# Licensed under a 3-clause BSD style license - see LICENSE.rst
22
33from copy import deepcopy
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
55import warnings
66
77from astropy .modeling import fitting , models
1010from scipy .interpolate import UnivariateSpline
1111import numpy as np
1212
13- __all__ = ['Trace' , 'FlatTrace' , 'ArrayTrace' , 'KosmosTrace' ]
13+ __all__ = ['BaseTrace' , ' Trace' , 'FlatTrace' , 'ArrayTrace' , 'KosmosTrace' ]
1414
1515
16- @dataclass
17- class Trace :
16+ @dataclass ( frozen = True )
17+ class BaseTrace :
1818 """
19- Basic tracing class that by default traces the middle of the image.
20-
21- Parameters
22- ----------
23- image : `~astropy.nddata.CCDData`
24- Image to be traced
25-
26- Properties
27- ----------
28- shape : tuple
29- Shape of the array describing the trace
19+ A dataclass common to all Trace objects.
3020 """
3121 image : CCDData
22+ _trace_pos : (float , np .ndarray ) = field (repr = False )
23+ _trace : np .ndarray = field (repr = False )
3224
3325 def __post_init__ (self ):
34- self .trace_pos = self .image .shape [0 ] / 2
35- self .trace = np .ones_like (self .image [0 ]) * self .trace_pos
26+ # this class only exists to catch __post_init__ calls in its
27+ # subclasses, so that super().__post_init__ calls work correctly.
28+ pass
3629
3730 def __getitem__ (self , i ):
3831 return self .trace [i ]
3932
40- @property
41- def shape (self ):
42- return self .trace .shape
43-
44- def shift (self , delta ):
45- """
46- Shift the trace by delta pixels perpendicular to the axis being traced
47-
48- Parameters
49- ----------
50- delta : float
51- Shift to be applied to the trace
52- """
53- # act on self.trace.data to ignore the mask and then re-mask when calling _bound_trace
54- self .trace = np .asarray (self .trace .data ) + delta
55- self ._bound_trace ()
56-
5733 def _bound_trace (self ):
5834 """
5935 Mask trace positions that are outside the upper/lower bounds of the image.
6036 """
6137 ny = self .image .shape [0 ]
62- self . trace = np .ma .masked_outside (self .trace , 0 , ny - 1 )
38+ object . __setattr__ ( self , '_trace' , np .ma .masked_outside (self ._trace , 0 , ny - 1 ) )
6339
6440 def __add__ (self , delta ):
6541 """
@@ -77,9 +53,60 @@ def __sub__(self, delta):
7753 """
7854 return self .__add__ (- delta )
7955
56+ def shift (self , delta ):
57+ """
58+ Shift the trace by delta pixels perpendicular to the axis being traced
59+
60+ Parameters
61+ ----------
62+ delta : float
63+ Shift to be applied to the trace
64+ """
65+ # act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace
66+ object .__setattr__ (self , '_trace' , np .asarray (self ._trace .data ) + delta )
67+ object .__setattr__ (self , '_trace_pos' , self ._trace_pos + delta )
68+ self ._bound_trace ()
69+
70+ @property
71+ def shape (self ):
72+ return self ._trace .shape
73+
74+ @property
75+ def trace (self ):
76+ return self ._trace
8077
81- @dataclass
82- class FlatTrace (Trace ):
78+ @property
79+ def trace_pos (self ):
80+ return self ._trace_pos
81+
82+ @staticmethod
83+ def _default_trace_attrs (image ):
84+ """
85+ Compute a default trace position and trace array using only
86+ the image dimensions.
87+ """
88+ trace_pos = image .shape [0 ] / 2
89+ trace = np .ones_like (image [0 ]) * trace_pos
90+ return trace_pos , trace
91+
92+
93+ @dataclass (init = False , frozen = True )
94+ class Trace (BaseTrace ):
95+ """
96+ Basic tracing class that by default traces the middle of the image.
97+
98+ Parameters
99+ ----------
100+ image : `~astropy.nddata.CCDData`
101+ Image to be traced
102+ """
103+ def __init__ (self , image ):
104+ trace_pos , trace = self ._default_trace_attrs (image )
105+ super ().__init__ (image , trace_pos , trace )
106+
107+
108+ @dataclass (init = False , frozen = True )
109+ class FlatTrace (BaseTrace ):
83110 """
84111 Trace that is constant along the axis being traced
85112
@@ -92,10 +119,11 @@ class FlatTrace(Trace):
92119 trace_pos : float
93120 Position of the trace
94121 """
95- trace_pos : float
96122
97- def __post_init__ (self ):
98- self .set_position (self .trace_pos )
123+ def __init__ (self , image , trace_pos ):
124+ _ , trace = self ._default_trace_attrs (image )
125+ super ().__init__ (image , trace_pos , trace )
126+ self .set_position (trace_pos )
99127
100128 def set_position (self , trace_pos ):
101129 """
@@ -106,13 +134,13 @@ def set_position(self, trace_pos):
106134 trace_pos : float
107135 Position of the trace
108136 """
109- self . trace_pos = trace_pos
110- self . trace = np .ones_like (self .image [0 ]) * self . trace_pos
137+ object . __setattr__ ( self , '_trace_pos' , trace_pos )
138+ object . __setattr__ ( self , '_trace' , np .ones_like (self .image [0 ]) * trace_pos )
111139 self ._bound_trace ()
112140
113141
114- @dataclass
115- class ArrayTrace (Trace ):
142+ @dataclass ( init = False , frozen = True )
143+ class ArrayTrace (BaseTrace ):
116144 """
117145 Define a trace given an array of trace positions
118146
@@ -121,25 +149,27 @@ class ArrayTrace(Trace):
121149 trace : `numpy.ndarray`
122150 Array containing trace positions
123151 """
124- trace : np .ndarray
152+ def __init__ (self , image , trace ):
153+ trace_pos , _ = self ._default_trace_attrs (image )
154+ super ().__init__ (image , trace_pos , trace )
125155
126- def __post_init__ (self ):
127156 nx = self .image .shape [1 ]
128- nt = len (self . trace )
157+ nt = len (trace )
129158 if nt != nx :
130159 if nt > nx :
131160 # truncate trace to fit image
132- self . trace = self . trace [0 :nx ]
161+ trace = trace [0 :nx ]
133162 else :
134163 # assume trace starts at beginning of image and pad out trace to fit.
135164 # padding will be the last value of the trace, but will be masked out.
136- padding = np .ma .MaskedArray (np .ones (nx - nt ) * self .trace [- 1 ], mask = True )
137- self .trace = np .ma .hstack ([self .trace , padding ])
165+ padding = np .ma .MaskedArray (np .ones (nx - nt ) * trace [- 1 ], mask = True )
166+ trace = np .ma .hstack ([trace , padding ])
167+ object .__setattr__ (self , '_trace' , trace )
138168 self ._bound_trace ()
139169
140170
141- @dataclass
142- class KosmosTrace (Trace ):
171+ @dataclass ( init = False , frozen = True )
172+ class KosmosTrace (BaseTrace ):
143173 """
144174 Trace the spectrum aperture in an image.
145175
@@ -192,14 +222,25 @@ class KosmosTrace(Trace):
192222 4) add other interpolation modes besides spline, maybe via
193223 specutils.manipulation methods?
194224 """
195- bins : int = 20
196- guess : float = None
197- window : int = None
198- peak_method : str = 'gaussian'
225+ bins : int
226+ guess : float
227+ window : int
228+ peak_method : str
199229 _crossdisp_axis = 0
200230 _disp_axis = 1
201231
202- def __post_init__ (self ):
232+ def _process_init_kwargs (self , ** kwargs ):
233+ for attr , value in kwargs .items ():
234+ object .__setattr__ (self , attr , value )
235+
236+ def __init__ (self , image , bins = 20 , guess = None , window = None , peak_method = 'gaussian' ):
237+ # This method will assign the user supplied value (or default) to the attrs:
238+ self ._process_init_kwargs (
239+ bins = bins , guess = guess , window = window , peak_method = peak_method
240+ )
241+ trace_pos , trace = self ._default_trace_attrs (image )
242+ super ().__init__ (image , trace_pos , trace )
243+
203244 # handle multiple image types and mask uncaught invalid values
204245 if isinstance (self .image , NDData ):
205246 img = np .ma .masked_invalid (np .ma .masked_array (self .image .data ,
@@ -223,7 +264,7 @@ def __post_init__(self):
223264
224265 if not isinstance (self .bins , int ):
225266 warnings .warn ('TRACE: Converting bins to int' )
226- self . bins = int (self .bins )
267+ object . __setattr__ ( self , ' bins' , int (self .bins ) )
227268
228269 if self .bins < 4 :
229270 raise ValueError ('bins must be >= 4' )
@@ -240,7 +281,7 @@ def __post_init__(self):
240281 "length of the image's spatial direction" )
241282 elif self .window is not None and not isinstance (self .window , int ):
242283 warnings .warn ('TRACE: Converting window to int' )
243- self . window = int (self .window )
284+ object . __setattr__ ( self , ' window' , int (self .window ) )
244285
245286 # set max peak location by user choice or wavelength with max avg flux
246287 ztot = img .sum (axis = self ._disp_axis ) / img .shape [self ._disp_axis ]
@@ -343,4 +384,4 @@ def __post_init__(self):
343384 warnings .warn ("TRACE ERROR: No valid points found in trace" )
344385 trace_y = np .tile (np .nan , len (x_bins ))
345386
346- self . trace = np .ma .masked_invalid (trace_y )
387+ object . __setattr__ ( self , '_trace' , np .ma .masked_invalid (trace_y ) )
0 commit comments