From 16bbaa2e64776edcc1212a143eb0fbc30338133a Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 11 Sep 2024 16:23:25 +0200 Subject: [PATCH] * refactor demo for far field * turn back on gravity --- .../dsa2000_cal/delay_models/far_field.py | 2 +- .../delay_models/tests/demo_tests.py | 154 ++++++++++++++++++ .../delay_models/tests/test_far_field.py | 125 -------------- dsa2000_cal/dsa2000_cal/imaging/imagor.py | 5 +- .../dsa2000_cal/imaging/tests/demo_tests.py | 37 +++-- 5 files changed, 184 insertions(+), 139 deletions(-) create mode 100644 dsa2000_cal/dsa2000_cal/delay_models/tests/demo_tests.py diff --git a/dsa2000_cal/dsa2000_cal/delay_models/far_field.py b/dsa2000_cal/dsa2000_cal/delay_models/far_field.py index 97e25127..a2dbb416 100644 --- a/dsa2000_cal/dsa2000_cal/delay_models/far_field.py +++ b/dsa2000_cal/dsa2000_cal/delay_models/far_field.py @@ -529,7 +529,7 @@ def far_field_delay( # Eq 11.7 delta_T_grav = jnp.sum(delta_T_grav_J) + delta_T_grav_earth # [] - delta_T_grav *= 0. + # delta_T_grav *= 0. # Around delta_T_grav=-0.00016 m * (|baseline|/1km) # Since we perform analysis in BCRS kinematically non-rotating dynamic frame we need to convert to GCRS TT-compatible diff --git a/dsa2000_cal/dsa2000_cal/delay_models/tests/demo_tests.py b/dsa2000_cal/dsa2000_cal/delay_models/tests/demo_tests.py new file mode 100644 index 00000000..371e0e81 --- /dev/null +++ b/dsa2000_cal/dsa2000_cal/delay_models/tests/demo_tests.py @@ -0,0 +1,154 @@ +import jax +import numpy as np +import pytest +from astropy import time as at, units as au, coordinates as ac +from jax import numpy as jnp +from matplotlib import pyplot as plt +from tomographic_kernel.frames import ENU + +from dsa2000_cal.delay_models.far_field import FarFieldDelayEngine + + +@pytest.mark.parametrize('time', [at.Time("2024-01-01T00:00:00", scale='utc'), + at.Time("2024-04-01T00:00:00", scale='utc'), + at.Time("2024-07-01T00:00:00", scale='utc'), + at.Time("2024-10-01T00:00:00", scale='utc')]) +@pytest.mark.parametrize('baseline', [3 * au.km, 10 * au.km]) +def test_aberated_plane_of_sky(time: at.Time, baseline: au.Quantity): + # aberation happens when uvw coordinates are assumed to be consistent for all points in the sky, however + # tau = (-?) c * delay = u l + v m + w sqrt(1 - l^2 - m^2) ==> w = tau(l=0, m=0) + # d/dl tau = u + w l / sqrt(1 - l^2 - m^2) ==> u = d/dl tau(l=0, m=0) + # d/dm tau = v + w m / sqrt(1 - l^2 - m^2) ==> v = d/dm tau(l=0, m=0) + # only true for l=m=0. + + # Let us see the error in delay for the approximation tau(l,m) = u*l + v*m + w*sqrt(1 - l^2 - m^2) + array_location = ac.EarthLocation.of_site('vla') + antennas = ENU( + east=[0, baseline.to('km').value] * au.km, + north=[0, 0] * au.km, + up=[0, 0] * au.km, + location=array_location, + obstime=time + ) + antennas = antennas.transform_to(ac.ITRS(obstime=time)).earth_location + + phase_centre = ENU(east=0, north=0, up=1, location=array_location, obstime=time).transform_to(ac.ICRS()) + + engine = FarFieldDelayEngine( + antennas=antennas, + phase_center=phase_centre, + start_time=time, + end_time=time, + verbose=True + ) + uvw = engine.compute_uvw_jax( + times=engine.time_to_jnp(time[None]), + antenna_1=jnp.asarray([0]), + antenna_2=jnp.asarray([1]) + ) + uvw = uvw * au.m + + lvec = mvec = jnp.linspace(-1, 1, 100) + M, L = jnp.meshgrid(lvec, mvec, indexing='ij') + N = jnp.sqrt(1 - L ** 2 - M ** 2) + tau_approx = uvw[0, 0] * L + uvw[0, 1] * M + uvw[0, 2] * N + tau_approx = tau_approx.to('m').value + + tau_exact = jax.vmap( + lambda l, m: engine.compute_delay_from_lm_jax( + l=l, m=m, + t1=engine.time_to_jnp(time), + i1=jnp.asarray(0), + i2=jnp.asarray(1)) + )(L.ravel(), M.ravel()).reshape(L.shape) + + tau_diff = tau_exact - tau_approx + + # Plot exact, approx, and difference + fig, axs = plt.subplots(3, 1, figsize=(5, 15), sharex=True, sharey=True, squeeze=False) + + im = axs[0, 0].imshow(tau_exact, + origin='lower', + extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), + interpolation='nearest', + cmap='PuOr' + ) + axs[0, 0].set_title('Exact delay') + fig.colorbar(im, ax=axs[0, 0], + label='Light travel dist. (m)') + + im = axs[1, 0].imshow(tau_approx, + origin='lower', + extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), + interpolation='nearest', + cmap='PuOr' + ) + axs[1, 0].set_title('Approximated delay') + fig.colorbar(im, ax=axs[1, 0], + label='Light travel dist. (m)') + + im = axs[2, 0].imshow(tau_diff, + origin='lower', + extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), + interpolation='nearest', + cmap='PuOr' + ) + axs[2, 0].set_title(f'Difference: {time}') + fig.colorbar(im, ax=axs[2, 0], + label='Light travel dist. (m)') + + axs[0, 0].set_ylabel('m') + axs[1, 0].set_ylabel('m') + axs[2, 0].set_ylabel('m') + axs[2, 0].set_xlabel('l') + + fig.tight_layout() + plt.show() + + freq = 70e6 * au.Hz + difference_deg = tau_diff * freq.to('Hz').value / 299792458.0 * 180 / np.pi + + # The difference in delay in radians + fig, axs = plt.subplots(1, 1, figsize=(5, 5), sharex=True, sharey=True, squeeze=False) + + im = axs[0, 0].imshow(difference_deg, + origin='lower', + extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), + interpolation='nearest', + cmap='PuOr' + ) + # structure time like "1 Jan, 2024" + axs[0, 0].set_title( + fr'$\Delta \tau(l,m)$ over {baseline.to("km")} | {freq.to("MHz")} | {time.to_datetime().strftime("%d %b, %Y")}') + fig.colorbar(im, ax=axs[0, 0], + label='Phase difference (deg)') + + axs[0, 0].set_ylabel('m') + axs[0, 0].set_xlabel('l') + + fig.tight_layout() + fig.savefig(f'phase_error_{baseline.to("km").value:.0f}km_{freq.to("MHz").value:.0f}MHz_{time.to_datetime().strftime("%d_%b_%Y")}.png') + plt.show() + + # The difference in delay in (m) + fig, axs = plt.subplots(1, 1, figsize=(5, 5), sharex=True, sharey=True, squeeze=False) + + im = axs[0, 0].imshow(tau_diff, + origin='lower', + extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), + interpolation='nearest', + cmap='PuOr' + ) + # structure time like "1 Jan, 2024" + axs[0, 0].set_title( + fr'$\Delta \tau(l,m)$ over {baseline.to("km")} | {time.to_datetime().strftime("%d %b, %Y")}') + fig.colorbar(im, ax=axs[0, 0], + label='Delay error (m)') + + axs[0, 0].set_ylabel('m') + axs[0, 0].set_xlabel('l') + + fig.tight_layout() + fig.savefig( + f'delay_error_{baseline.to("km").value:.0f}km_{time.to_datetime().strftime("%d_%b_%Y")}.png') + plt.show() diff --git a/dsa2000_cal/dsa2000_cal/delay_models/tests/test_far_field.py b/dsa2000_cal/dsa2000_cal/delay_models/tests/test_far_field.py index e8227f53..8cbfbf0f 100644 --- a/dsa2000_cal/dsa2000_cal/delay_models/tests/test_far_field.py +++ b/dsa2000_cal/dsa2000_cal/delay_models/tests/test_far_field.py @@ -12,7 +12,6 @@ from dsa2000_cal.delay_models.far_field import FarFieldDelayEngine - def test_far_field_delay_engine(): time = at.Time("2021-01-01T00:00:00", scale='utc') array_location = ac.EarthLocation.of_site('vla') @@ -60,7 +59,6 @@ def test_far_field_delay_engine(): np.testing.assert_allclose(delay, 1000., atol=0.55) - @pytest.mark.parametrize('with_autocorr', [True, False]) def test_compute_uvw(with_autocorr): times = at.Time(["2021-01-01T00:00:00"], scale='utc') @@ -111,129 +109,6 @@ def test_compute_uvw(with_autocorr): convention='physical')) - -@pytest.mark.parametrize('time', [at.Time("2024-01-01T00:00:00", scale='utc'), - at.Time("2024-04-01T00:00:00", scale='utc'), - at.Time("2024-07-01T00:00:00", scale='utc'), - at.Time("2024-10-01T00:00:00", scale='utc')]) -@pytest.mark.parametrize('baseline', [3 * au.km, 10 * au.km]) -def test_aberated_plane_of_sky(time: at.Time, baseline: au.Quantity): - # aberation happens when uvw coordinates are assumed to be consistent for all points in the sky, however - # tau = (-?) c * delay = u l + v m + w sqrt(1 - l^2 - m^2) ==> w = tau(l=0, m=0) - # d/dl tau = u + w l / sqrt(1 - l^2 - m^2) ==> u = d/dl tau(l=0, m=0) - # d/dm tau = v + w m / sqrt(1 - l^2 - m^2) ==> v = d/dm tau(l=0, m=0) - # only true for l=m=0. - - # Let us see the error in delay for the approximation tau(l,m) = u*l + v*m + w*sqrt(1 - l^2 - m^2) - array_location = ac.EarthLocation.of_site('vla') - antennas = ENU( - east=[0, baseline.to('km').value] * au.km, - north=[0, 0] * au.km, - up=[0, 0] * au.km, - location=array_location, - obstime=time - ) - antennas = antennas.transform_to(ac.ITRS(obstime=time)).earth_location - - phase_centre = ENU(east=0, north=0, up=1, location=array_location, obstime=time).transform_to(ac.ICRS()) - - engine = FarFieldDelayEngine( - antennas=antennas, - phase_center=phase_centre, - start_time=time, - end_time=time, - verbose=True - ) - uvw = engine.compute_uvw_jax( - times=engine.time_to_jnp(time[None]), - antenna_1=jnp.asarray([0]), - antenna_2=jnp.asarray([1]) - ) - uvw = uvw * au.m - - lvec = mvec = jnp.linspace(-1, 1, 100) - M, L = jnp.meshgrid(lvec, mvec, indexing='ij') - N = jnp.sqrt(1 - L ** 2 - M ** 2) - tau_approx = uvw[0, 0] * L + uvw[0, 1] * M + uvw[0, 2] * N - tau_approx = tau_approx.to('m').value - - tau_exact = jax.vmap( - lambda l, m: engine.compute_delay_from_lm_jax( - l=l, m=m, - t1=engine.time_to_jnp(time), - i1=jnp.asarray(0), - i2=jnp.asarray(1)) - )(L.ravel(), M.ravel()).reshape(L.shape) - - tau_diff = tau_exact - tau_approx - - # Plot exact, approx, and difference - fig, axs = plt.subplots(3, 1, figsize=(5, 15), sharex=True, sharey=True, squeeze=False) - - im = axs[0, 0].imshow(tau_exact, - origin='lower', - extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), - interpolation='nearest', - cmap='PuOr' - ) - axs[0, 0].set_title('Exact delay') - fig.colorbar(im, ax=axs[0, 0], - label='Light travel dist. (m)') - - im = axs[1, 0].imshow(tau_approx, - origin='lower', - extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), - interpolation='nearest', - cmap='PuOr' - ) - axs[1, 0].set_title('Approximated delay') - fig.colorbar(im, ax=axs[1, 0], - label='Light travel dist. (m)') - - im = axs[2, 0].imshow(tau_diff, - origin='lower', - extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), - interpolation='nearest', - cmap='PuOr' - ) - axs[2, 0].set_title(f'Difference: {time}') - fig.colorbar(im, ax=axs[2, 0], - label='Light travel dist. (m)') - - axs[0, 0].set_ylabel('m') - axs[1, 0].set_ylabel('m') - axs[2, 0].set_ylabel('m') - axs[2, 0].set_xlabel('l') - - fig.tight_layout() - plt.show() - - freq = 70e6 * au.Hz - difference_deg = tau_diff * freq.to('Hz').value / 299792458.0 * 180 / np.pi - - # The difference in delay in radians - fig, axs = plt.subplots(1, 1, figsize=(5, 5), sharex=True, sharey=True, squeeze=False) - - im = axs[0, 0].imshow(difference_deg, - origin='lower', - extent=(lvec.min(), lvec.max(), mvec.min(), mvec.max()), - interpolation='nearest', - cmap='coolwarm' - ) - # structure time like "1 Jan, 2024" - axs[0, 0].set_title( - fr'$\Delta \tau(l,m)$ over {baseline.to("km")} | {freq.to("MHz")} | {time.to_datetime().strftime("%d %b, %Y")}') - fig.colorbar(im, ax=axs[0, 0], - label='Phase difference (deg)') - - axs[0, 0].set_ylabel('m') - axs[0, 0].set_xlabel('l') - - fig.tight_layout() - plt.show() - - - @pytest.mark.parametrize('baseline', [10 * au.km, 100 * au.km, 1000 * au.km]) def test_resolution_error(baseline: au.Quantity): # aberation happens when uvw coordinates are assumed to be consistent for all points in the sky, however diff --git a/dsa2000_cal/dsa2000_cal/imaging/imagor.py b/dsa2000_cal/dsa2000_cal/imaging/imagor.py index 88fe2d33..2b2c9d49 100644 --- a/dsa2000_cal/dsa2000_cal/imaging/imagor.py +++ b/dsa2000_cal/dsa2000_cal/imaging/imagor.py @@ -38,12 +38,14 @@ class Imagor: plot_folder: str field_of_view: au.Quantity | None = None + baseline_min: au.Quantity = 1. * au.m oversample_factor: float = 5. nthreads: int | None = None epsilon: float = 1e-4 convention: str = 'physical' verbose: bool = False weighting: str = 'natural' + spectral_cube: bool = False seed: int = 42 def __post_init__(self): @@ -207,7 +209,8 @@ def _image_visibilties_jax(self, uvw: jax.Array, vis: jax.Array, weights: jax.Ar # Remove auto-correlations baseline_length = jnp.linalg.norm(uvw, axis=-1) # [num_rows] - flags = jnp.logical_or(flags, baseline_length[:, None, None] < 1.0) # [num_rows, num_chan, coh] + flags = jnp.logical_or(flags, baseline_length[:, None, None] < quantity_to_jnp( + self.baseline_min)) # [num_rows, num_chan, coh] if self.convention == 'engineering': uvw = jnp.negative(uvw) diff --git a/dsa2000_cal/dsa2000_cal/imaging/tests/demo_tests.py b/dsa2000_cal/dsa2000_cal/imaging/tests/demo_tests.py index 975e863f..0fde4d8e 100644 --- a/dsa2000_cal/dsa2000_cal/imaging/tests/demo_tests.py +++ b/dsa2000_cal/dsa2000_cal/imaging/tests/demo_tests.py @@ -3,6 +3,7 @@ import astropy.units as au import numpy as np import pytest +from tomographic_kernel.frames import ENU from dsa2000_cal.assets.content_registry import fill_registries from dsa2000_cal.assets.registries import array_registry @@ -10,18 +11,15 @@ from dsa2000_cal.measurement_sets.measurement_set import MeasurementSetMetaV0, MeasurementSet, VisibilityData -@pytest.fixture(scope='function') -def mock_calibrator_source_models(tmp_path): +def build_calibrator_source_models(array_name, tmp_path, full_stokes, num_chan): fill_registries() - array_name = 'dsa2000W_small' # Load array array = array_registry.get_instance(array_registry.get_match(array_name)) array_location = array.get_array_location() antennas = array.get_antennas() - # -00:36:29.015,58.45.50.398 - phase_tracking = ac.SkyCoord("-00h36m29.015s", "58d45m50.398s", frame='icrs') - phase_tracking = ac.ICRS(phase_tracking.ra, phase_tracking.dec) + obstime = at.Time("2021-01-01T00:00:00", scale='utc') + phase_tracking = zenith = ENU(0, 0, 1, obstime=obstime, location=array_location).transform_to(ac.ICRS()) meta = MeasurementSetMetaV0( array_name=array_name, @@ -29,10 +27,10 @@ def mock_calibrator_source_models(tmp_path): phase_tracking=phase_tracking, channel_width=array.get_channel_width(), integration_time=au.Quantity(1.5, 's'), - coherencies=['XX', 'XY', 'YX', 'YY'], + coherencies=['XX', 'XY', 'YX', 'YY'] if full_stokes else ['I'], pointings=phase_tracking, - times=at.Time("2021-01-01T00:00:00", scale='utc') + 1.5 * np.arange(1) * au.s, - freqs=au.Quantity([700], unit=au.MHz), + times=obstime + 1.5 * np.arange(1) * au.s, + freqs=au.Quantity(np.linspace(700, 2000, num_chan), unit=au.MHz), antennas=antennas, antenna_names=array.get_antenna_names(), antenna_diameters=array.get_antenna_diameter(), @@ -40,7 +38,7 @@ def mock_calibrator_source_models(tmp_path): mount_types='ALT-AZ' ) ms = MeasurementSet.create_measurement_set(str(tmp_path), meta) - gen = ms.create_block_generator(vis=True, weights=False, flags=True) + gen = ms.create_block_generator(vis=True, weights=True, flags=True) gen_response = None while True: try: @@ -56,8 +54,23 @@ def mock_calibrator_source_models(tmp_path): return ms -def test_dirty_imaging(mock_calibrator_source_models): - ms = mock_calibrator_source_models +@pytest.mark.parametrize("full_stokes", [True, False]) +@pytest.mark.parametrize("num_chan", [1, 2]) +def test_dirty_imaging(tmp_path, full_stokes, num_chan): + ms = build_calibrator_source_models('dsa2000W_small', tmp_path, full_stokes, num_chan) + + imagor = Imagor( + plot_folder='plots', + field_of_view=2 * au.deg + ) + image_model = imagor.image(image_name='test_dirty', ms=ms, overwrite=True) + # print(image_model) + + +@pytest.mark.parametrize("full_stokes", [False]) +@pytest.mark.parametrize("num_chan", [40]) +def test_demo(tmp_path, full_stokes, num_chan): + ms = build_calibrator_source_models('dsa2000W', tmp_path, full_stokes, num_chan) imagor = Imagor( plot_folder='plots',