Skip to content

Commit

Permalink
* add serialiser for parametric RFI emitter
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 17, 2024
1 parent f07cd80 commit e46449c
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 38 deletions.
22 changes: 12 additions & 10 deletions dsa2000_cal/dsa2000_cal/calibration/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,18 @@ def calibrate(self, ms: MeasurementSet) -> MeasurementSet:
time_idx=mp_policy.cast_to_index(visibility_coords.time_idx)
) # [num_row, ...]

solutions, residual, state, diagnostics = block_until_ready(
self._solve_jax(
key=solve_key,
freqs=tree_device_put(freqs_jax, mesh, ('chan',)),
times=tree_device_put(times_jax, mesh, ()),
init_state=last_state, # already shard as prior output
vis_data=tree_device_put(vis_data, mesh, (None, 'chan')),
vis_coords=tree_device_put(visibility_coords, mesh, ()),
num_iterations=num_iterations
with jax.profiler.trace("/tmp/profiler", create_perfetto_link=True):
solutions, residual, state, diagnostics = block_until_ready(
self._solve_jax(
key=solve_key,
freqs=tree_device_put(freqs_jax, mesh, ('chan',)),
times=tree_device_put(times_jax, mesh, ()),
init_state=last_state, # already shard as prior output
vis_data=tree_device_put(vis_data, mesh, (None, 'chan')),
vis_coords=tree_device_put(visibility_coords, mesh, ()),
num_iterations=num_iterations
)
)
)

cadence_idx += 1
# Update metrics
Expand Down Expand Up @@ -352,6 +353,7 @@ def residual_fn(params: List[Any]) -> Any:
state, diagnostics = solver.solve(state)

# Predict at full resolution.
@jax.jit
@partial(
multi_vmap,
in_mapping="[c],[t],[t,r],[t,r,c]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dsa2000_cal.common.serialise_utils import SerialisableBaseModel
from dsa2000_cal.delay_models.far_field import VisibilityCoords
from dsa2000_cal.measurement_sets.measurement_set import VisibilityData, MeasurementSet
from dsa2000_cal.visibility_model.source_models.rfi.parametric_rfi_emitter import ParametricDelayACF
from dsa2000_cal.visibility_model.source_models.rfi.rfi_emitter_source_model import RFIEmitterModelData, \
RFIEmitterPredict

Expand Down Expand Up @@ -107,7 +108,6 @@ def save_solution(self, solution: Any, file_name: str, times: at.Time, ms: Measu
freqs=ms.meta.freqs,
position_enu=solution.position_enu * au.m,
array_location=ms.meta.array_location,
luminosity=(solution.luminosity * (au.Jy * au.m ** 2)).to('Jy*km^2'),
delay_acf=solution.delay_acf,
antennas=ms.meta.antennas,
antenna_labels=ms.meta.antenna_names,
Expand All @@ -127,8 +127,7 @@ class RFIEmitterSolutions(SerialisableBaseModel):
freqs: au.Quantity # [num_chans]
position_enu: au.Quantity # [E, 3]
array_location: ac.EarthLocation
luminosity: au.Quantity # [E, num_chans[,2,2]]
delay_acf: InterpolatedArray # [E]
delay_acf: InterpolatedArray | ParametricDelayACF # [E]

antennas: ac.EarthLocation # [ant]
antenna_labels: List[str] # [ant]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from dsa2000_cal.common.types import mp_policy
from dsa2000_cal.gain_models.beam_gain_model import BeamGainModel
from dsa2000_cal.geodesics.geodesic_model import GeodesicModel
from dsa2000_cal.visibility_model.source_models.rfi.rfi_emitter_source_model import RFIEmitterModelData
from dsa2000_cal.visibility_model.source_models.rfi.parametric_rfi_emitter import ParametricDelayACF
from dsa2000_cal.visibility_model.source_models.rfi.rfi_emitter_source_model import RFIEmitterModelData

tfpd = tfp.distributions

Expand Down Expand Up @@ -101,7 +101,7 @@ def get_source_enu(self):
source_positions_enu = jnp.stack([east, north, up], axis=-1) # [E, 3]
return source_positions_enu

def get_acf(self):
def get_acf(self, freqs: jax.Array):
max_delay = 1e-5 # seconds
delay_acf_x = jnp.linspace(0., max_delay, self.acf_resolution)
delay_acf_x = jnp.concatenate([-delay_acf_x[::-1], delay_acf_x[1:]])
Expand All @@ -122,15 +122,7 @@ def get_acf(self):
delay_acf_values = jax.lax.complex(delay_acf_values_real, delay_acf_values_imag)
delay_acf_values /= delay_acf_values[0:1, :] # normalise central value to 1
delay_acf_values = jnp.concatenate([delay_acf_values[::-1], delay_acf_values[1:]], axis=0)
delay_acf = InterpolatedArray(
x=delay_acf_x,
values=delay_acf_values,
axis=0,
regular_grid=True
) # [ E]
return delay_acf

def get_spectral_power(self, freqs: jax.Array):
if self.full_stokes:
luminosity = yield Prior(
tfpd.Uniform(
Expand All @@ -139,7 +131,8 @@ def get_spectral_power(self, freqs: jax.Array):
),
name='luminosity'
).parametrised()
luminosity = jnp.tile(luminosity[:, None, :, :], (1, len(freqs), 1, 1)) # [num_source, num_chan, 2, 2]
luminosity = jnp.tile(luminosity[:, None, :, :], (1, len(freqs), 1, 1)) # [e, num_chan, 2, 2]
delay_acf_values = delay_acf_values[:, :, None, None, None] * luminosity # [num_delays, e, num_chan, 2, 2]
else:
luminosity = yield Prior(
tfpd.Uniform(
Expand All @@ -148,14 +141,22 @@ def get_spectral_power(self, freqs: jax.Array):
),
name='luminosity'
).parametrised()
luminosity = jnp.tile(luminosity[:, None], (1, len(freqs))) # [num_source, num_chan]
return luminosity
luminosity = jnp.tile(luminosity[:, None], (1, len(freqs))) # [e, num_chan]

delay_acf_values = delay_acf_values[:, :, None] * luminosity # [num_delays, e, num_chan]

delay_acf = InterpolatedArray(
x=delay_acf_x,
values=delay_acf_values,
axis=0,
regular_grid=True
) # [ E]
return delay_acf

def build_prior_model(self, freqs: jax.Array, times: jax.Array) -> PriorModelType:
def prior_model():
source_positions_enu = yield from self.get_source_enu()
delay_acf = yield from self.get_acf()
luminosity = yield from self.get_spectral_power(freqs)
delay_acf = yield from self.get_acf(freqs)

geodesics = self.geodesic_model.compute_near_field_geodesics(
times=times,
Expand All @@ -169,17 +170,13 @@ def prior_model():
return RFIEmitterModelData(
freqs=freqs,
position_enu=source_positions_enu,
luminosity=luminosity,
delay_acf=delay_acf,
gains=gains
)

return prior_model





@dataclasses.dataclass(eq=False)
class ParametricRFIHorizonEmitter(FullyParameterisedRFIHorizonEmitter):
"""
Expand All @@ -197,7 +194,11 @@ def __post_init__(self):
super().__post_init__()
self.fwhm_high = self.channel_width

def get_acf(self):
def get_acf(self, freqs: jax.Array):
chan_width = quantity_to_jnp(self.channel_width)
# chan_width = freqs[1] - freqs[0]
chan_lower = freqs - chan_width / 2
chan_upper = freqs + chan_width / 2
ones = jnp.ones((self.num_emitters,), dtype=mp_policy.freq_dtype)
mu = yield Prior(
tfpd.Uniform(
Expand Down Expand Up @@ -242,8 +243,8 @@ def get_acf(self):
else:
spectral_power = yield Prior(
tfpd.Uniform(
low=quantity_to_jnp(self.min_channel_power/self.channel_width, 'Jy*m^2/Hz') * ones,
high=quantity_to_jnp(self.max_channel_power/self.channel_width, 'Jy*m^2/Hz') * ones
low=quantity_to_jnp(self.min_channel_power / self.channel_width, 'Jy*m^2/Hz') * ones,
high=quantity_to_jnp(self.max_channel_power / self.channel_width, 'Jy*m^2/Hz') * ones
),
name='spectral_power'
).parametrised()
Expand All @@ -252,6 +253,8 @@ def get_acf(self):
mu=mu,
fwhp=fwhp,
spectral_power=spectral_power,
channel_lower=chan_lower,
channel_upper=chan_upper,
resolution=self.acf_resolution,
convention=self.convention
)
29 changes: 29 additions & 0 deletions dsa2000_cal/dsa2000_cal/common/serialise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tomographic_kernel.frames import ENU

from dsa2000_cal.common.interp_utils import InterpolatedArray
from dsa2000_cal.visibility_model.source_models.rfi.parametric_rfi_emitter import ParametricDelayACF

C = TypeVar('C')

Expand Down Expand Up @@ -92,6 +93,18 @@ def deserialise_interpolated_array(obj):
)


def deserialise_parametric_delay_acf(obj):
return ParametricDelayACF(
mu=np.asarray(deserialise_ndarray(obj["mu"])),
fwhp=np.asarray(deserialise_ndarray(obj["fwhp"])),
spectral_power=np.asarray(deserialise_ndarray(obj["spectral_power"])),
channel_lower=np.asarray(deserialise_ndarray(obj["channel_lower"])),
channel_upper=np.asarray(deserialise_ndarray(obj["channel_upper"])),
resolution=obj["resolution"],
convention=obj["convention"]
)


class SerialisableBaseModel(BaseModel):
"""
A pydantic BaseModel that can be serialised and deserialised using pickle, working well with Ray.
Expand Down Expand Up @@ -154,6 +167,16 @@ class Config:
"values": np.asarray(x.values),
"axis": x.axis,
"regular_grid": x.regular_grid
},
ParametricDelayACF: lambda x: {
"type": 'dsa2000_cal.visibility_model.source_models.rfi_parametric_rfi_emitter.ParametricDelayACF',
"mu": np.asarray(x.mu),
"fwhp": np.asarray(x.fwhp),
"spectral_power": np.asarray(x.spectral_power),
"channel_lower": np.asarray(x.channel_lower),
"channel_upper": np.asarray(x.channel_upper),
"resolution": x.resolution,
"convention": x.convention
}
}

Expand Down Expand Up @@ -222,6 +245,12 @@ def parse_obj(cls: Type[C], obj: Dict[str, Any]) -> C:
obj[name] = deserialise_interpolated_array(obj[name])
continue

# Deserialise ParametricDelayACF
elif field.type_ is ParametricDelayACF and isinstance(obj.get(name), dict) and obj[name].get(
"type") == 'dsa2000_cal.visibility_model.source_models.rfi_parametric_rfi_emitter.ParametricDelayACF':
obj[name] = deserialise_parametric_delay_acf(obj[name])
continue

# Deserialise nested models
elif inspect.isclass(field.type_) and issubclass(field.type_, BaseModel):
obj[name] = field.type_.parse_obj(obj[name])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import ujson
from tomographic_kernel.frames import ENU

from dsa2000_cal.common.serialise_utils import SerialisableBaseModel
from dsa2000_cal.common.interp_utils import InterpolatedArray
from dsa2000_cal.common.serialise_utils import SerialisableBaseModel
from dsa2000_cal.visibility_model.source_models.rfi.parametric_rfi_emitter import ParametricDelayACF


class MockModelInt(SerialisableBaseModel):
Expand Down Expand Up @@ -385,3 +386,35 @@ class Model(SerialisableBaseModel):
np.testing.assert_allclose(deserialised.x.values, original.x.values)
np.testing.assert_allclose(deserialised.x.axis, original.x.axis)
np.testing.assert_allclose(deserialised.x.regular_grid, original.x.regular_grid)


def test_parametric_delay_acf():
acf = ParametricDelayACF(
mu=np.asarray([1.0]),
fwhp=np.asarray([1.0]),
spectral_power=np.asarray([1.0]),
channel_lower=np.asarray([1.0]),
channel_upper=np.asarray([1.0]),
resolution=1,
convention='physical'
)

class Model(SerialisableBaseModel):
acf: ParametricDelayACF

original = Model(acf=acf)

# Serialise the object to JSON
serialised = original.json(indent=2)
print(serialised)

# Deserialise the object from JSON
deserialised = Model.parse_raw(serialised)

np.testing.assert_allclose(deserialised.acf.mu, original.acf.mu)
np.testing.assert_allclose(deserialised.acf.fwhp, original.acf.fwhp)
np.testing.assert_allclose(deserialised.acf.spectral_power, original.acf.spectral_power)
np.testing.assert_allclose(deserialised.acf.channel_lower, original.acf.channel_lower)
np.testing.assert_allclose(deserialised.acf.channel_upper, original.acf.channel_upper)
np.testing.assert_allclose(deserialised.acf.resolution, original.acf.resolution)
assert deserialised.acf.convention == original.acf.convention
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class ParametricDelayACF:
mu: FloatArray # [E]
fwhp: FloatArray # [E]
spectral_power: FloatArray # [E[,2,2]] in Jy*m^2/Hz
channel_lower: jax.Array # [chan]
channel_upper: jax.Array # [chan]
channel_lower: FloatArray # [chan]
channel_upper: FloatArray # [chan]
resolution: int = 32 # Should be chosen so that channel width / resolution ~ PFB kernel resolution
convention: str = 'physical' # Doesn't matter for the ACF

Expand Down

0 comments on commit e46449c

Please sign in to comment.