forked from coljac/makestamps
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtilemaker.py
201 lines (171 loc) · 6.58 KB
/
tilemaker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import numpy as np
import pandas as pd
import h5py
import astropy
import time
import os
import sys
import astropy.io.fits as pyfits
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from astropy.nddata import Cutout2D
from astropy import units as u
tiles = pd.read_csv(
os.path.dirname(os.path.realpath(__file__)) + "/y3a1tiles.csv")
def pb(current, to, width=40, show=True, message=None, stderr=False):
percent = float(current) / float(to)
length = int(width * percent)
if show:
count = " (%d/%d) " % (current, to)
else:
count = ""
if message:
count += message
outstream = sys.stderr if stderr else sys.stdout
outstream.write(("\r[" + ("#" * length) + " " * (width - length) +
"] %0d" % (percent * 100)) + "%" + count)
outstream.flush()
def grab_tile(tilename, bands=None):
pass
def log_to_file(logfile, logstring):
if not logfile:
return
with open(logfile, "a") as f:
f.write(logstring + "\n")
def make_cuts(catalog,
tile,
band,
stamp_size,
tofile=True,
data=None,
headers=None,
catmeta=None,
masks=None,
logfile=None,
results=None):
results = dict(bad_objects=[])
"""Given a fits data file, turn WCS into pixels and grab data from tile."""
band_idx = "grizY".index(band)
todo = len(catalog)
i = 0
w = WCS(tile[1].header)
cutouts = []
mask_sums = []
filenames = []
objids = []
log_to_file(logfile, "Starting cutouts with tile ")
for obj in catalog[['RA', 'DEC']].itertuples():
# pb(i + 1, todo)
# _, objid, ra, dec = obj
objid, ra, dec = obj
x, y = w.all_world2pix(catalog.iloc[i]['RA'], catalog.iloc[i]['DEC'],
1)
mask_sum = 0
try:
cutout = Cutout2D(
tile[1].data, (x, y), (stamp_size, stamp_size), wcs=w)
if masks is not None:
mask_cut = Cutout2D(
tile[2].data, (x, y), (stamp_size, stamp_size), wcs=w)
mask_sum = mask_cut.data.sum()
# print("Mask:", objid, mask_sum)
except TypeError as e:
results['bad_objects'].append(objid)
print("Error with object " + str(objid))
print(getattr(e, 'message', repr(e)))
log_to_file(logfile, "Error with object %s" % (str(objid)))
time.sleep(10) # In case system overloaded
continue
cutouts.append(cutout)
if masks is not None:
mask_sums.append(mask_sum)
filenames.append("stamps/" + str(objid) + "_" + band + ".fits")
objids.append(objid)
# write_cut(cutout, "stamps/" + str(objid) + "_" + band + ".fits", tile)
i += 1
if i % 50 == 0:
log_to_file(logfile,
"Done %d cutouts of %d. Still running." % (i, todo))
log_to_file(logfile, "Done with the cutouts, now to store.")
# start = time.time()
for j, cutout in enumerate(cutouts):
head = tile[1].header.copy()
head['CRPIX1'] = cutout.wcs.wcs.crpix[0]
head['CRPIX2'] = cutout.wcs.wcs.crpix[1]
if tofile:
write_cut(cutout, filenames[j], tile, head)
else:
if cutout.data.shape[0] != stamp_size or cutout.data.shape[1] != stamp_size:
log_to_file(logfile, "Size mismatch: %d, %d (%d)" % \
(cutout.data.shape[0], cutout.data.shape[1], stamp_size))
zeros = np.zeros((stamp_size, stamp_size))
zeros[0:cutout.data.shape[0], 0:cutout.data.shape[
1]] = cutout.data
cutout.data = zeros
data[j, :, :, band_idx] = cutout.data
if masks is not None:
masks[j, band_idx] = mask_sums[j]
headers[j, band_idx] = head.tostring().ljust(9000, ' ')
catmeta[j] = str(objids[j]).ljust(30, ' ')
# print(" %d/%d " % (j, todo))
if j % 50 == 0:
log_to_file(logfile,
"Stored %d cutouts of %d. Not dead yet." % (j, todo))
log_to_file(logfile, "Complete.")
# end = time.time()
# print("Time per write: %.4fs" % ((end-start/len(filenames))))
return results
def extract_tile(datastore, tile):
pass
def extract_stamp(datastore, tile, objid, filename):
write_cut()
def write_cut(cutout, filename, tile, head):
"""Save a cutout to the filesystem as a fits file."""
outfits = pyfits.PrimaryHDU(data=cutout.data, header=head)
outfits.writeto(filename, overwrite=True)
def find_tiles_reverse(catalog):
"""Lookup the tile name given an RA, DEC pair."""
N = len(tiles.index)
i = 0
catalog['TILENAME'] = "NONE"
catalog['STATUS'] = "new"
for tile in tiles[["URAMIN", "URAMAX", "UDECMIN", "UDECMAX",
"TILENAME"]].itertuples():
pb(i + 1, N)
idx, ramin, ramax, decmin, decmax, tilename = tile
if ramin > ramax:
found = catalog[((catalog.RA > ramin) | (catalog.RA < ramax)) &
(catalog.DEC > decmin) & (catalog.DEC < decmax)]
else:
found = catalog[(catalog.RA > ramin) & (catalog.RA < ramax) & \
(catalog.DEC > decmin) & (catalog.DEC < decmax)]
catalog.loc[found.index, 'TILENAME'] = tilename
i += 1
def find_tiles(catalog):
"""For a list of ra/dec pairs, find a tile that they are located in."""
i = 0
length = len(catalog)
catalog['TILENAME'] = "---"
last_tile = None
for row in catalog[['RA', 'DEC']].itertuples():
pb(i + 1, length)
idx, ra, dec = row
if last_tile is not None and (last_tile.URAMIN < ra) and (last_tile.URAMAX > ra) and\
(last_tile.UDECMIN < dec) and (last_tile.UDECMAX > dec):
catalog.loc[idx, 'TILENAME'] = last_tile.TILENAME
continue
found = tiles[(tiles.URAMIN < ra) & (tiles.URAMAX > ra) &
(tiles.UDECMIN < dec) & (tiles.UDECMAX > dec)]
if len(found) == 0:
catalog.loc[idx, 'TILENAME'] = "NONE"
else:
catalog.loc[idx, 'TILENAME'] = found.iloc[0].TILENAME
last_tile = found.iloc[0]
i += 1
if __name__ == "__main__":
"""Augment an object catalog with DES tile names and a status.
Usage: tilemaker.py <input cat> <output filename>
"""
catalog = pd.read_csv(sys.argv[1], index_col="COADD_OBJECT_ID")
find_tiles_reverse(catalog)
catalog.to_csv(sys.argv[2])