Skip to content

Commit

Permalink
* refactor demo for far field
Browse files Browse the repository at this point in the history
* turn back on gravity
  • Loading branch information
Joshuaalbert committed Sep 11, 2024
1 parent d698073 commit 16bbaa2
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 139 deletions.
2 changes: 1 addition & 1 deletion dsa2000_cal/dsa2000_cal/delay_models/far_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def far_field_delay(

# Eq 11.7
delta_T_grav = jnp.sum(delta_T_grav_J) + delta_T_grav_earth # []
delta_T_grav *= 0.
# delta_T_grav *= 0.
# Around delta_T_grav=-0.00016 m * (|baseline|/1km)

# Since we perform analysis in BCRS kinematically non-rotating dynamic frame we need to convert to GCRS TT-compatible
Expand Down
154 changes: 154 additions & 0 deletions dsa2000_cal/dsa2000_cal/delay_models/tests/demo_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import jax
import numpy as np
import pytest
from astropy import time as at, units as au, coordinates as ac
from jax import numpy as jnp
from matplotlib import pyplot as plt
from tomographic_kernel.frames import ENU

from dsa2000_cal.delay_models.far_field import FarFieldDelayEngine


@pytest.mark.parametrize('time', [at.Time("2024-01-01T00:00:00", scale='utc'),
at.Time("2024-04-01T00:00:00", scale='utc'),
at.Time("2024-07-01T00:00:00", scale='utc'),
at.Time("2024-10-01T00:00:00", scale='utc')])
@pytest.mark.parametrize('baseline', [3 * au.km, 10 * au.km])
def test_aberated_plane_of_sky(time: at.Time, baseline: au.Quantity):
# aberation happens when uvw coordinates are assumed to be consistent for all points in the sky, however
# tau = (-?) c * delay = u l + v m + w sqrt(1 - l^2 - m^2) ==> w = tau(l=0, m=0)
# d/dl tau = u + w l / sqrt(1 - l^2 - m^2) ==> u = d/dl tau(l=0, m=0)
# d/dm tau = v + w m / sqrt(1 - l^2 - m^2) ==> v = d/dm tau(l=0, m=0)
# only true for l=m=0.

# Let us see the error in delay for the approximation tau(l,m) = u*l + v*m + w*sqrt(1 - l^2 - m^2)
array_location = ac.EarthLocation.of_site('vla')
antennas = ENU(
east=[0, baseline.to('km').value] * au.km,
north=[0, 0] * au.km,
up=[0, 0] * au.km,
location=array_location,
obstime=time
)
antennas = antennas.transform_to(ac.ITRS(obstime=time)).earth_location

phase_centre = ENU(east=0, north=0, up=1, location=array_location, obstime=time).transform_to(ac.ICRS())

engine = FarFieldDelayEngine(
antennas=antennas,
phase_center=phase_centre,
start_time=time,
end_time=time,
verbose=True
)
uvw = engine.compute_uvw_jax(
times=engine.time_to_jnp(time[None]),
antenna_1=jnp.asarray([0]),
antenna_2=jnp.asarray([1])
)
uvw = uvw * au.m

lvec = mvec = jnp.linspace(-1, 1, 100)
M, L = jnp.meshgrid(lvec, mvec, indexing='ij')
N = jnp.sqrt(1 - L ** 2 - M ** 2)
tau_approx = uvw[0, 0] * L + uvw[0, 1] * M + uvw[0, 2] * N
tau_approx = tau_approx.to('m').value

tau_exact = jax.vmap(
lambda l, m: engine.compute_delay_from_lm_jax(
l=l, m=m,
t1=engine.time_to_jnp(time),
i1=jnp.asarray(0),
i2=jnp.asarray(1))
)(L.ravel(), M.ravel()).reshape(L.shape)

tau_diff = tau_exact - tau_approx

# Plot exact, approx, and difference
fig, axs = plt.subplots(3, 1, figsize=(5, 15), sharex=True, sharey=True, squeeze=False)

im = axs[0, 0].imshow(tau_exact,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='PuOr'
)
axs[0, 0].set_title('Exact delay')
fig.colorbar(im, ax=axs[0, 0],
label='Light travel dist. (m)')

im = axs[1, 0].imshow(tau_approx,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='PuOr'
)
axs[1, 0].set_title('Approximated delay')
fig.colorbar(im, ax=axs[1, 0],
label='Light travel dist. (m)')

im = axs[2, 0].imshow(tau_diff,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='PuOr'
)
axs[2, 0].set_title(f'Difference: {time}')
fig.colorbar(im, ax=axs[2, 0],
label='Light travel dist. (m)')

axs[0, 0].set_ylabel('m')
axs[1, 0].set_ylabel('m')
axs[2, 0].set_ylabel('m')
axs[2, 0].set_xlabel('l')

fig.tight_layout()
plt.show()

freq = 70e6 * au.Hz
difference_deg = tau_diff * freq.to('Hz').value / 299792458.0 * 180 / np.pi

# The difference in delay in radians
fig, axs = plt.subplots(1, 1, figsize=(5, 5), sharex=True, sharey=True, squeeze=False)

im = axs[0, 0].imshow(difference_deg,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='PuOr'
)
# structure time like "1 Jan, 2024"
axs[0, 0].set_title(
fr'$\Delta \tau(l,m)$ over {baseline.to("km")} | {freq.to("MHz")} | {time.to_datetime().strftime("%d %b, %Y")}')
fig.colorbar(im, ax=axs[0, 0],
label='Phase difference (deg)')

axs[0, 0].set_ylabel('m')
axs[0, 0].set_xlabel('l')

fig.tight_layout()
fig.savefig(f'phase_error_{baseline.to("km").value:.0f}km_{freq.to("MHz").value:.0f}MHz_{time.to_datetime().strftime("%d_%b_%Y")}.png')
plt.show()

# The difference in delay in (m)
fig, axs = plt.subplots(1, 1, figsize=(5, 5), sharex=True, sharey=True, squeeze=False)

im = axs[0, 0].imshow(tau_diff,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='PuOr'
)
# structure time like "1 Jan, 2024"
axs[0, 0].set_title(
fr'$\Delta \tau(l,m)$ over {baseline.to("km")} | {time.to_datetime().strftime("%d %b, %Y")}')
fig.colorbar(im, ax=axs[0, 0],
label='Delay error (m)')

axs[0, 0].set_ylabel('m')
axs[0, 0].set_xlabel('l')

fig.tight_layout()
fig.savefig(
f'delay_error_{baseline.to("km").value:.0f}km_{time.to_datetime().strftime("%d_%b_%Y")}.png')
plt.show()
125 changes: 0 additions & 125 deletions dsa2000_cal/dsa2000_cal/delay_models/tests/test_far_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dsa2000_cal.delay_models.far_field import FarFieldDelayEngine



def test_far_field_delay_engine():
time = at.Time("2021-01-01T00:00:00", scale='utc')
array_location = ac.EarthLocation.of_site('vla')
Expand Down Expand Up @@ -60,7 +59,6 @@ def test_far_field_delay_engine():
np.testing.assert_allclose(delay, 1000., atol=0.55)



@pytest.mark.parametrize('with_autocorr', [True, False])
def test_compute_uvw(with_autocorr):
times = at.Time(["2021-01-01T00:00:00"], scale='utc')
Expand Down Expand Up @@ -111,129 +109,6 @@ def test_compute_uvw(with_autocorr):
convention='physical'))



@pytest.mark.parametrize('time', [at.Time("2024-01-01T00:00:00", scale='utc'),
at.Time("2024-04-01T00:00:00", scale='utc'),
at.Time("2024-07-01T00:00:00", scale='utc'),
at.Time("2024-10-01T00:00:00", scale='utc')])
@pytest.mark.parametrize('baseline', [3 * au.km, 10 * au.km])
def test_aberated_plane_of_sky(time: at.Time, baseline: au.Quantity):
# aberation happens when uvw coordinates are assumed to be consistent for all points in the sky, however
# tau = (-?) c * delay = u l + v m + w sqrt(1 - l^2 - m^2) ==> w = tau(l=0, m=0)
# d/dl tau = u + w l / sqrt(1 - l^2 - m^2) ==> u = d/dl tau(l=0, m=0)
# d/dm tau = v + w m / sqrt(1 - l^2 - m^2) ==> v = d/dm tau(l=0, m=0)
# only true for l=m=0.

# Let us see the error in delay for the approximation tau(l,m) = u*l + v*m + w*sqrt(1 - l^2 - m^2)
array_location = ac.EarthLocation.of_site('vla')
antennas = ENU(
east=[0, baseline.to('km').value] * au.km,
north=[0, 0] * au.km,
up=[0, 0] * au.km,
location=array_location,
obstime=time
)
antennas = antennas.transform_to(ac.ITRS(obstime=time)).earth_location

phase_centre = ENU(east=0, north=0, up=1, location=array_location, obstime=time).transform_to(ac.ICRS())

engine = FarFieldDelayEngine(
antennas=antennas,
phase_center=phase_centre,
start_time=time,
end_time=time,
verbose=True
)
uvw = engine.compute_uvw_jax(
times=engine.time_to_jnp(time[None]),
antenna_1=jnp.asarray([0]),
antenna_2=jnp.asarray([1])
)
uvw = uvw * au.m

lvec = mvec = jnp.linspace(-1, 1, 100)
M, L = jnp.meshgrid(lvec, mvec, indexing='ij')
N = jnp.sqrt(1 - L ** 2 - M ** 2)
tau_approx = uvw[0, 0] * L + uvw[0, 1] * M + uvw[0, 2] * N
tau_approx = tau_approx.to('m').value

tau_exact = jax.vmap(
lambda l, m: engine.compute_delay_from_lm_jax(
l=l, m=m,
t1=engine.time_to_jnp(time),
i1=jnp.asarray(0),
i2=jnp.asarray(1))
)(L.ravel(), M.ravel()).reshape(L.shape)

tau_diff = tau_exact - tau_approx

# Plot exact, approx, and difference
fig, axs = plt.subplots(3, 1, figsize=(5, 15), sharex=True, sharey=True, squeeze=False)

im = axs[0, 0].imshow(tau_exact,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='PuOr'
)
axs[0, 0].set_title('Exact delay')
fig.colorbar(im, ax=axs[0, 0],
label='Light travel dist. (m)')

im = axs[1, 0].imshow(tau_approx,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='PuOr'
)
axs[1, 0].set_title('Approximated delay')
fig.colorbar(im, ax=axs[1, 0],
label='Light travel dist. (m)')

im = axs[2, 0].imshow(tau_diff,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='PuOr'
)
axs[2, 0].set_title(f'Difference: {time}')
fig.colorbar(im, ax=axs[2, 0],
label='Light travel dist. (m)')

axs[0, 0].set_ylabel('m')
axs[1, 0].set_ylabel('m')
axs[2, 0].set_ylabel('m')
axs[2, 0].set_xlabel('l')

fig.tight_layout()
plt.show()

freq = 70e6 * au.Hz
difference_deg = tau_diff * freq.to('Hz').value / 299792458.0 * 180 / np.pi

# The difference in delay in radians
fig, axs = plt.subplots(1, 1, figsize=(5, 5), sharex=True, sharey=True, squeeze=False)

im = axs[0, 0].imshow(difference_deg,
origin='lower',
extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()),
interpolation='nearest',
cmap='coolwarm'
)
# structure time like "1 Jan, 2024"
axs[0, 0].set_title(
fr'$\Delta \tau(l,m)$ over {baseline.to("km")} | {freq.to("MHz")} | {time.to_datetime().strftime("%d %b, %Y")}')
fig.colorbar(im, ax=axs[0, 0],
label='Phase difference (deg)')

axs[0, 0].set_ylabel('m')
axs[0, 0].set_xlabel('l')

fig.tight_layout()
plt.show()



@pytest.mark.parametrize('baseline', [10 * au.km, 100 * au.km, 1000 * au.km])
def test_resolution_error(baseline: au.Quantity):
# aberation happens when uvw coordinates are assumed to be consistent for all points in the sky, however
Expand Down
5 changes: 4 additions & 1 deletion dsa2000_cal/dsa2000_cal/imaging/imagor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ class Imagor:
plot_folder: str

field_of_view: au.Quantity | None = None
baseline_min: au.Quantity = 1. * au.m
oversample_factor: float = 5.
nthreads: int | None = None
epsilon: float = 1e-4
convention: str = 'physical'
verbose: bool = False
weighting: str = 'natural'
spectral_cube: bool = False
seed: int = 42

def __post_init__(self):
Expand Down Expand Up @@ -207,7 +209,8 @@ def _image_visibilties_jax(self, uvw: jax.Array, vis: jax.Array, weights: jax.Ar

# Remove auto-correlations
baseline_length = jnp.linalg.norm(uvw, axis=-1) # [num_rows]
flags = jnp.logical_or(flags, baseline_length[:, None, None] < 1.0) # [num_rows, num_chan, coh]
flags = jnp.logical_or(flags, baseline_length[:, None, None] < quantity_to_jnp(
self.baseline_min)) # [num_rows, num_chan, coh]

if self.convention == 'engineering':
uvw = jnp.negative(uvw)
Expand Down
Loading

0 comments on commit 16bbaa2

Please sign in to comment.