-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathDask_editms.py
209 lines (173 loc) · 8.11 KB
/
Dask_editms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import sys
import os
import numpy as np
from pprint import pprint
# import re
# from astropy.io import fits
# from astropy.units import Quantity
import dask.multiprocessing
#dask.config.set(scheduler=dask.multiprocessing.get)
import pyralysis
import pyralysis.io
# from pyralysis.transformers.weighting_schemes import Robust
from pyralysis.units import lambdas_equivalencies
import astropy.units as un
import dask.array as da
from pyralysis.units import array_unit_conversion
def apply_gain_shift(*args, **kwargs):
apply(*args, **kwargs)
def apply(
file_ms,
file_ms_output='output_dask.ms',
alpha_R=None, # 1.
addPS=None, # {'x0':0.,'y0':0.,'F':0},
Shift=None,
# datacolumn='CORRECTED_DATA', # DATA
# datacolumns_output='CORRECTED_DATA', # DATA
file_ms_ref=False,
Verbose=False):
# file_ms_ref : reference ms for pointing
# Shift: apply shift, pass shift alpha , dec in arcsec
# addPS: add PS, pass position in arcsec offset from phase center, and flux in Jy
print("applying shift with alpha_R = ", alpha_R, " Shift = ", Shift)
print("file_ms :", file_ms)
print("file_ms_output :", file_ms_output)
print(
"building output ms structure by copying from filen_ms to file_ms_output"
)
print("adding point sources", addPS)
os.system("rm -rf " + file_ms_output)
#os.system("rsync -a " + file_ms + "/ " + file_ms_output + "/")
reader = pyralysis.io.DaskMS(input_name=file_ms)
input_dataset = reader.read(calculate_psf=False)
print("done reading")
field_dataset = input_dataset.field.dataset
if Shift is not None:
delta_x = Shift[0] * np.pi / (180. * 3600.)
delta_y = Shift[1] * np.pi / (180. * 3600.)
print("will apply shifts ", delta_x, delta_y)
for ims, ms in enumerate(input_dataset.ms_list):
print("looping over partioned ms", ims) # spwid/field
column_keys = ms.visibilities.dataset.data_vars.keys()
if Verbose:
print("column_keys", column_keys)
uvw = ms.visibilities.uvw.data
spw_id = ms.spw_id
pol_id = ms.polarization_id
ncorrs = input_dataset.polarization.ncorrs[pol_id]
nchans = input_dataset.spws.nchans[spw_id]
print("spw_id", spw_id, "nchans", nchans)
uvw_broadcast = da.tile(uvw, nchans).reshape((len(uvw), nchans, 3))
#print("broadcasted uvw values to all channels")
#print("dask .compute on channel frequencies")
chans = input_dataset.spws.dataset[spw_id].CHAN_FREQ.data.squeeze(
axis=0).compute() * un.Hz
#print("done dask .compute")
chans_broadcast = chans[np.newaxis, :, np.newaxis]
#print("broadcasted channels to same dimmensions as uvw")
uvw_lambdas = uvw_broadcast / chans_broadcast.to(un.m, un.spectral())
# uvw_lambdas = array_unit_conversion(
# array=uvw_broadcast,
# unit=un.lambdas,
# equivalencies=lambdas_equivalencies(restfreq=chans_broadcast))
uvw_lambdas = da.map_blocks(lambda x: x.value,
uvw_lambdas,
dtype=np.float64)
msdatacolumns = []
for acolumn in column_keys:
if ("DATA" in acolumn) or ("CORRECTED" in acolumn) or ("MODEL"
in acolumn):
msdatacolumns.append(acolumn)
if Shift is not None:
print("applying gain and shift")
uus = uvw_lambdas[:, :, 0]
vvs = uvw_lambdas[:, :, 1]
eulerphase = alpha_R * da.exp(
2j * np.pi *
(uus * delta_x + vvs * delta_y)).astype(np.complex64)
# for acolumn in column_keys:
for acolumn in msdatacolumns:
#if "DATA" in acolumn:
print("shifting column ", acolumn)
ms.visibilities.dataset[acolumn] *= eulerphase[:, :,
np.newaxis]
# if "CORRECTED_DATA" in column_keys:
# ms.visibilities.corrected *= eulerphase[:, :, np.newa # msdatacolumns.append(acolumn)
# if "DATA" in column_keys:
# ms.visibilities.data *= eulerphase[:, :, np.newaxis]
# if "MODEL_DATA" in column_keys:
# ms.visibilities.model *= eulerphase[:, :, np.newaxis]
#
elif alpha_R is not None:
print("applying gain")
for acolumn in msdatacolumns:
print("shifting column ", acolumn)
ms.visibilities.dataset[acolumn] *= alpha_R
if addPS is not None:
for iPS, aPS in enumerate(addPS):
x0 = aPS['x0'] * np.pi / (180. * 3600.)
y0 = aPS['y0'] * np.pi / (180. * 3600.)
Flux = aPS['F']
print("adding PS: x0 ", x0, " y0 ", y0, "F", Flux)
uus = uvw_lambdas[:, :, 0]
vvs = uvw_lambdas[:, :, 1]
VisPS = Flux * da.exp(
2j * np.pi * (uus * x0 + vvs * y0)).astype(np.complex64)
for acolumn in msdatacolumns:
ms.visibilities.dataset[acolumn] += VisPS[:, :, np.newaxis]
if not os.path.isdir(file_ms_output):
os.system("rsync -a " + file_ms + "/ " + file_ms_output + "/")
print("PUNCH OUPUT MS")
if file_ms_ref:
print(
"paste pointing center from reference vis file into output vis file"
)
print("loading reference ms")
ref_reader = pyralysis.io.DaskMS(input_name=file_ms_ref)
ref_dataset = ref_reader.read(calculate_psf=False)
field_dataset = ref_dataset.field.dataset
#if len(field_dataset) == len(input_dataset.field.dataset):
# print("ANCHOR ")
# input_dataset.field.dataset = field_dataset
# print("uncomment above")
#
# ## print("field_dataset[0].REFERENCE_DIR",field_dataset[0].REFERENCE_DIR.compute())
# ## print("field_dataset[0].PHASE_DIR",field_dataset[0].PHASE_DIR.compute())
#else:
#print("field_dataset", field_dataset)
#pprint(field_dataset)
#print("field_dataset.REFERENCE_DIR", field_dataset.REFERENCE_DIR.compute())
#print("field_dataset.PHASE_DIR", field_dataset.PHASE_DIR.compute())
#pprint(field_dataset.REFERENCE_DIR)
#for i, row in enumerate(input_dataset.field.dataset):
#print("row", row)
#pprint(row)
#print("input_dataset.field.dataset",input_dataset.field.dataset)
#print("input_dataset.field.dataset.REFERENCE_DIR",input_dataset.field.dataset.REFERENCE_DIR.compute())
#print("input_dataset.field.dataset.PHASE_DIR",input_dataset.field.dataset.PHASE_DIR.compute())
input_dataset.field.dataset.REFERENCE_DIR[:] = field_dataset.REFERENCE_DIR[0]
input_dataset.field.dataset.PHASE_DIR[:] = field_dataset.PHASE_DIR[0]
# Write FIELD TABLE
print("Write FIELD TABLE ")
#print("Changed REFERENCE_DIR", dataset.field.dataset[0].REFERENCE_DIR.compute())
#print("Changed PHASE_DIR", dataset.field.dataset[0].PHASE_DIR.compute())
reader.write_xarray_ds(dataset=input_dataset.field.dataset,
ms_name=file_ms_output,
columns=[
'REFERENCE_DIR', 'PHASE_DIR',
'PhaseDir_Ref', 'RefDir_Ref'
],
table_name="FIELD")
# Write MAIN TABLE
print("Write MAIN TABLE ", msdatacolumns)
reader.write(dataset=input_dataset,
ms_name=file_ms_output,
columns=msdatacolumns)
#X-check pointing
check_reader = pyralysis.io.DaskMS(input_name=file_ms_output)
check_dataset = check_reader.read(calculate_psf=False)
field_dataset = check_dataset.field.dataset
# for i, row in enumerate(field_dataset):
#print("output REFERENCE_DIR", field_dataset.REFERENCE_DIR.compute())
#print("output PHASE_DIR", field_dataset.PHASE_DIR.compute())
return