Skip to content

Commit 4cb27d5

Browse files
committed
precommit
1 parent 10314c7 commit 4cb27d5

File tree

6 files changed

+113
-222
lines changed

6 files changed

+113
-222
lines changed

README.md

+9-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pip install --upgrade eventlet torch lightning[extra]
1515
pip install -e .
1616
```
1717

18-
The package expects to load models and data by default from
18+
The package expects to load models and data by default from
1919
```bash
2020
{ASTROCLIP_ROOT}
2121
```
@@ -31,7 +31,7 @@ If no environment is specified, the default path at Flatiron will be assumed.
3131

3232
## Pretrained Models
3333

34-
We provide the pretrained AstroCLIP model on the Huggingface model hub for easy access. Additionally, we provide the pretrained single-modal models for galaxy images and spectra as well. Model details, checkpoints, configs and logs are below.
34+
We provide the pretrained AstroCLIP model on the Huggingface model hub for easy access. Additionally, we provide the pretrained single-modal models for galaxy images and spectra as well. Model details, checkpoints, configs and logs are below.
3535

3636
<table>
3737
<tr>
@@ -154,7 +154,7 @@ Below, we include a high-level performance overview of our models on a variety o
154154
<tr>
155155
</table>
156156

157-
We report R-squared metrics on redshift and galaxy property estimation (averaged across all properties) and accuracy on galaxy morphology classification (averaged across all labels). Our models are marked with an asterisk (*).
157+
We report R-squared metrics on redshift and galaxy property estimation (averaged across all properties) and accuracy on galaxy morphology classification (averaged across all labels). Our models are marked with an asterisk (*).
158158

159159
## Data Access
160160

@@ -184,7 +184,7 @@ The directory is organized into south and north surveys, where each survey is sp
184184

185185
AstroCLIP is trained using a two-step process:
186186

187-
1. We pre-train a single-modal galaxy image encoder and a single-modal galaxy spectrum encoder separately.
187+
1. We pre-train a single-modal galaxy image encoder and a single-modal galaxy spectrum encoder separately.
188188
2. We CLIP-align these two encoders on a paired image-spectrum dataset.
189189

190190
### Single-Modal Pretraining
@@ -196,10 +196,10 @@ Model training can be launched with the following command:
196196
```
197197
image_trainer -c astroclip/astrodino/config.yaml
198198
```
199-
We train the model using 20 A100 GPUs (on 5 nodes) for 250k steps which takes roughly 46 hours.
199+
We train the model using 20 A100 GPUs (on 5 nodes) for 250k steps which takes roughly 46 hours.
200200

201201
#### Spectrum Pretraining - Masked Modelling Transformer:
202-
AstroCLIP uses a 1D Transformer to encode galaxy spectra. Pretraining is performed using a masked-modeling objective, whereby the 1D spectrum is split into contiguous, overlapping patches.
202+
AstroCLIP uses a 1D Transformer to encode galaxy spectra. Pretraining is performed using a masked-modeling objective, whereby the 1D spectrum is split into contiguous, overlapping patches.
203203

204204
Model training can be launched with the following command:
205205
```
@@ -213,17 +213,17 @@ Once pretrained, we align the image and spectrum encoder using cross-attention p
213213
```
214214
spectrum_trainer fit -c config/astroclip.yaml
215215
```
216-
We train the model using 4 A100 GPUs (on 1 node) for 25k steps or until the validation loss does not increase for a fixed number of steps. This takes roughly 12 hours.
216+
We train the model using 4 A100 GPUs (on 1 node) for 25k steps or until the validation loss does not increase for a fixed number of steps. This takes roughly 12 hours.
217217

218218
## Downstream Tasks
219219

220220
TODO
221221

222222
## Acknowledgements
223-
This reposity uses datasets and contrastive augmentations from [Stein, et al. (2022)](https://github.com/georgestein/ssl-legacysurvey/tree/main). The image pretraining is built on top of the [DINOv2](https://github.com/facebookresearch/dinov2/) framework; we also thank Piotr Bojanowski for valuable conversations around image pretraining.
223+
This reposity uses datasets and contrastive augmentations from [Stein, et al. (2022)](https://github.com/georgestein/ssl-legacysurvey/tree/main). The image pretraining is built on top of the [DINOv2](https://github.com/facebookresearch/dinov2/) framework; we also thank Piotr Bojanowski for valuable conversations around image pretraining.
224224

225225
## License
226226
AstroCLIP code and model weights are released under the MIT license. See [LICENSE](https://github.com/PolymathicAI/AstroCLIP/blob/main/LICENSE) for additional details.
227227

228228
## Citations
229-
TODO
229+
TODO

astroclip/data/crossmatch_scripts/README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
The following scripts are used to generate the datasets used in the paper:
33

44
```python
5-
- `cross_match_data.py`: Finds spectra for objects in the Legacy Survey
5+
- `cross_match_data.py`: Finds spectra for objects in the Legacy Survey
66
data prepared by George Stein (https://github.com/georgestein/ssl-legacysurvey/tree/main)
77

88
- `export_data.py`: Exports the combination of images and spectra into
99
a single HDF5 file.
1010
```
1111

1212
In principle you should not need to run these scripts, as the datasets are already provided by the resulting HuggingFace datasets. However, these scripts are provided for reproducibility purposes.
13-
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,91 @@
1-
import numpy as np
1+
import glob
2+
23
import h5py
4+
import numpy as np
5+
import pandas as pd
36
from astropy.table import Table, join, vstack
4-
from dl import authClient as ac, queryClient as qc
7+
from dl import authClient as ac
8+
from dl import queryClient as qc
59
from sparcl.client import SparclClient
610
from tqdm import tqdm
7-
import pandas as pd
8-
import glob
911

10-
DATA_DIR='/mnt/home/flanusse/ceph'
12+
DATA_DIR = "/mnt/home/flanusse/ceph"
1113

1214
client = SparclClient()
13-
inc = ['specid', 'redshift', 'flux', 'ra', 'dec',
14-
'wavelength', 'spectype', 'specprimary',
15-
'survey', 'program', 'targetid', 'coadd_fiberstatus']
15+
inc = [
16+
"specid",
17+
"redshift",
18+
"flux",
19+
"ra",
20+
"dec",
21+
"wavelength",
22+
"spectype",
23+
"specprimary",
24+
"survey",
25+
"program",
26+
"targetid",
27+
"coadd_fiberstatus",
28+
]
1629

1730

1831
print("Retrieving all objects in the DESI data release...")
1932
query = """
2033
SELECT phot.targetid, phot.brickid, phot.brick_objid, phot.release, zpix.healpix
21-
FROM desi_edr.photometry AS phot
34+
FROM desi_edr.photometry AS phot
2235
INNER JOIN desi_edr.zpix ON phot.targetid = zpix.targetid
2336
WHERE (zpix.coadd_fiberstatus = 0 AND zpix.sv_primary)
2437
"""
25-
cat = qc.query(sql = query, fmt = 'table')
38+
cat = qc.query(sql=query, fmt="table")
2639
print("done")
2740
# Building search key based on brick ids
28-
cat['key'] = ['%d_%d_%d'%(cat['release'][i], cat['brickid'][i], cat['brick_objid'][i]) for i in range(len(cat))]
41+
cat["key"] = [
42+
"%d_%d_%d" % (cat["release"][i], cat["brickid"][i], cat["brick_objid"][i])
43+
for i in range(len(cat))
44+
]
2945

3046
merged_cat = None
3147

3248
# Looping over the downloaded image files
33-
for file in tqdm(glob.glob(DATA_DIR+'/*.h5')):
49+
for file in tqdm(glob.glob(DATA_DIR + "/*.h5")):
3450
try:
3551
with h5py.File(file) as d:
36-
# search key
37-
d_key = np.array(['%d_%d_%d'%(d['release'][i], d['brickid'][i], d['objid'][i]) for i in range(len(d['brickid']))])
38-
t = Table(data=[d['inds'][:], d_key], names=['inds', 'key'])
52+
# search key
53+
d_key = np.array(
54+
[
55+
"%d_%d_%d" % (d["release"][i], d["brickid"][i], d["objid"][i])
56+
for i in range(len(d["brickid"]))
57+
]
58+
)
59+
t = Table(data=[d["inds"][:], d_key], names=["inds", "key"])
3960
except:
4061
continue
41-
file_cat = join(cat, t, keys=['key'])
42-
file_cat['image_file'] = file
43-
file_cat.sort('healpix')
44-
62+
file_cat = join(cat, t, keys=["key"])
63+
file_cat["image_file"] = file
64+
file_cat.sort("healpix")
65+
4566
# Retrieving spectra associated with this file
46-
target_ids = [int(i) for i in file_cat['targetid']]
67+
target_ids = [int(i) for i in file_cat["targetid"]]
4768
records = None
48-
for i in tqdm(range(len(target_ids)//500 + 1)):
49-
start = i*500
50-
end = min((i+1)*500, len(target_ids)-1)
69+
for i in tqdm(range(len(target_ids) // 500 + 1)):
70+
start = i * 500
71+
end = min((i + 1) * 500, len(target_ids) - 1)
5172

52-
res = client.retrieve_by_specid(specid_list = target_ids[start:end],
53-
include = inc,
54-
dataset_list = ['DESI-EDR'])
73+
res = client.retrieve_by_specid(
74+
specid_list=target_ids[start:end], include=inc, dataset_list=["DESI-EDR"]
75+
)
5576
if records is None:
5677
records = Table.from_pandas(pd.DataFrame.from_records(res.records))
5778
else:
5879
r = Table.from_pandas(pd.DataFrame.from_records(res.records))
5980
records = vstack([records, r])
6081

6182
# Merging catalogs
62-
file_cat = join(file_cat, records, keys=['targetid'])
63-
83+
file_cat = join(file_cat, records, keys=["targetid"])
84+
6485
if merged_cat is None:
6586
merged_cat = file_cat
6687
else:
6788
merged_cat = vstack([merged_cat, file_cat])
68-
89+
6990
# Saving the results
70-
merged_cat.to_pandas().to_parquet('matched_catalog.pq')
91+
merged_cat.to_pandas().to_parquet("matched_catalog.pq")
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,46 @@
11
import h5py
2-
from astropy.table import Table, join
32
import numpy as np
43
import pandas as pd
4+
from astropy.table import Table, join
55
from tqdm import tqdm
66

7-
DATA_DIR='/mnt/home/flanusse/ceph'
7+
DATA_DIR = "/mnt/home/flanusse/ceph"
88

99
# Open matched catalog
10-
joint_cat = pd.read_parquet(DATA_DIR+'/matched_catalog.pq').drop_duplicates(subset=["key"])
10+
joint_cat = pd.read_parquet(DATA_DIR + "/matched_catalog.pq").drop_duplicates(
11+
subset=["key"]
12+
)
1113

1214
# Create randomized indices to shuffle the dataset
1315
rng = np.random.default_rng(seed=42)
1416
indices = rng.permutation(len(joint_cat))
1517
joint_cat = joint_cat.iloc[indices]
1618

17-
with h5py.File(DATA_DIR+'/exported_data.h5', 'w') as f:
19+
with h5py.File(DATA_DIR + "/exported_data.h5", "w") as f:
1820
for i in range(10):
19-
print("Processing file %d"%i)
21+
print("Processing file %d" % i)
2022
# Considering only the objects that are in the current file
21-
sub_cat = joint_cat[joint_cat['inds'] // 1000000 == i]
23+
sub_cat = joint_cat[joint_cat["inds"] // 1000000 == i]
2224
images = []
2325
spectra = []
2426
redshifts = []
2527
targetids = []
26-
with h5py.File(DATA_DIR+'/images_npix152_0%02d000000_0%02d000000.h5'%(i,i+1)) as d:
28+
with h5py.File(
29+
DATA_DIR + "/images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1)
30+
) as d:
2731
for j in tqdm(range(len(sub_cat))):
28-
images.append(np.array(d['images'][sub_cat['inds'].iloc[j] % 1000000]).T.astype('float32'))
29-
spectra.append(np.reshape(sub_cat['flux'].iloc[j], [-1, 1]).astype('float32'))
30-
redshifts.append(sub_cat['redshift'].iloc[j])
31-
targetids.append(sub_cat['targetid'].iloc[j])
32+
images.append(
33+
np.array(d["images"][sub_cat["inds"].iloc[j] % 1000000]).T.astype(
34+
"float32"
35+
)
36+
)
37+
spectra.append(
38+
np.reshape(sub_cat["flux"].iloc[j], [-1, 1]).astype("float32")
39+
)
40+
redshifts.append(sub_cat["redshift"].iloc[j])
41+
targetids.append(sub_cat["targetid"].iloc[j])
3242
f.create_group(str(i))
33-
f[str(i)].create_dataset('images', data=images)
34-
f[str(i)].create_dataset('spectra', data=spectra)
35-
f[str(i)].create_dataset('redshifts', data=redshifts)
36-
f[str(i)].create_dataset('targetids', data=targetids)
43+
f[str(i)].create_dataset("images", data=images)
44+
f[str(i)].create_dataset("spectra", data=spectra)
45+
f[str(i)].create_dataset("redshifts", data=redshifts)
46+
f[str(i)].create_dataset("targetids", data=targetids)

downstream_tasks/similarity_search/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## In-Modal and Cross-Modal Retrieval
2-
AstroCLIP enables researchers to easily find similar galaxies to a query galaxy by simply exploiting the cosine similarity between galaxy embeddings in embedding space. Because AstroCLIP's embedding space is shared between both galaxy images and optical spectra, retrieval can be performed for both in-modal and cross-modal similarity searches.
2+
AstroCLIP enables researchers to easily find similar galaxies to a query galaxy by simply exploiting the cosine similarity between galaxy embeddings in embedding space. Because AstroCLIP's embedding space is shared between both galaxy images and optical spectra, retrieval can be performed for both in-modal and cross-modal similarity searches.
33

44
### Embedding the dataset
55
To perform retrieval on the held-out validation set, it is important to first generate AstroCLIP embeddings of the galaxy images and spectra. We provide the already-embedded held-out validation set here:
@@ -12,4 +12,4 @@ python embed_astroclip.py [save_path]
1212
```
1313

1414
### Similarity Search
15-
Once embedded, the ```similarity_search.ipynb``` jupyter notebook contains a brief tutorial that demonstrates the retrieval abilities of the model.
15+
Once embedded, the ```similarity_search.ipynb``` jupyter notebook contains a brief tutorial that demonstrates the retrieval abilities of the model.

downstream_tasks/similarity_search/similarity_search.ipynb

+25-164
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)