Skip to content

Commit

Permalink
* ufunc wgridder predict with threading for performance gain
Browse files Browse the repository at this point in the history
* performance tes calibration
* I,Q,U,V imaging
* refactor forward model
* more mixed precision
* many small improvements
  • Loading branch information
Joshuaalbert committed Sep 17, 2024
1 parent f4b4195 commit f07cd80
Show file tree
Hide file tree
Showing 31 changed files with 1,247 additions and 714 deletions.
90 changes: 81 additions & 9 deletions debug/vmap_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,89 @@
import jax.numpy as jnp


def add(x, y):
print(x.shape, y.shape)
return x + y
@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def add_vmapped(x, y, z):
return x + y + z

@jax.jit
@partial(jax.vmap, in_axes=(0, None))
def cb(x, y):
return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, vectorized=False)

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_no_vec(x, y, z):
def add(x, y, z):
assert x.shape == ()
assert y.shape == ()
assert z.shape == ()
return x + y + z

return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=False)


def convert_to_ufunc(f, tile: bool = True):
f = jax.custom_batching.custom_vmap(f)

@f.def_vmap
def rule(axis_size, in_batched, *args):
batched_args = jax.tree.map(
lambda x, b: x if b else jax.lax.broadcast(x, ((axis_size if tile else 1),)), args,
tuple(in_batched))
out = f(*batched_args)
out_batched = jax.tree.map(lambda _: True, out)
return out, out_batched

return f


@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@convert_to_ufunc
def cb_vec_tiled(x, y, z):
def add(x, y, z):
assert x.shape == (4, 5)
assert y.shape == (4, 5)
assert z.shape == (4, 5)
return x + y + z

return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)


@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
@partial(convert_to_ufunc, tile=False)
def cb_vec_untiled(x, y, z):
def add(x, y, z):
assert x.shape == (4, 1)
assert y.shape == (1, 5)
assert z.shape == (1, 1)
return x + y + z

return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=jnp.broadcast_shapes(x.shape, y.shape), dtype=x.dtype), x,
y, z, vectorized=True)




def cb(x, y, z):
def add(x, y, z):
assert x.shape == (4, 5)
assert y.shape == (4, 5)
assert z.shape == ()
return x + y + z

return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=jnp.broadcast_shapes(x.shape, y.shape), dtype=x.dtype), x,
y, z, vectorized=True)


if __name__ == '__main__':
x = jnp.arange(4, dtype=jnp.float32)
y = jnp.asarray(1.)
print(cb(x, y))
y = jnp.arange(5, dtype=jnp.float32)
z = jnp.array(1, dtype=jnp.float32)

assert add_vmapped(x, y, z).shape == (4, 5)
assert cb_no_vec(x, y, z).shape == (4, 5)
assert cb_vec_tiled(x, y, z).shape == (4, 5)
assert cb_vec_untiled(x, y, z).shape == (4, 5)

assert jax.vmap(jax.vmap(convert_to_ufunc(partial(cb, z=z)), in_axes=(None, 0)), in_axes=(0, None))(x, y).shape == (4, 5)


26 changes: 16 additions & 10 deletions dsa2000_cal/dsa2000_cal/assets/rfi/lte_rfi/dsa_cell_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def plot_acf(self):
mat = loadmat(self.rfi_injection_model())
delays = jnp.asarray(mat['t_acf'][0])
auto_correlation_function = jnp.asarray(mat['acf']).T
plt.plot(delays, auto_correlation_function)
plt.plot(delays * 1e6, np.abs(auto_correlation_function))
plt.xlabel('Delay [us]')
plt.ylabel('ACF')
plt.show()

def make_source_params(self, freqs: au.Quantity, central_freq: au.Quantity | None = None,
Expand All @@ -31,7 +33,7 @@ def make_source_params(self, freqs: au.Quantity, central_freq: au.Quantity | Non
# E=1
mat = loadmat(self.rfi_injection_model())
delays = jnp.asarray(mat['t_acf'][0])
auto_correlation_function = jnp.asarray(mat['acf']).T # [n_delays, 1]
auto_correlation_function = jnp.asarray(mat['acf']).T # [n_delays, E=1]
if np.allclose(np.diff(delays), delays[1] - delays[0], atol=1e-8):
regular_grid = True
else:
Expand All @@ -44,33 +46,37 @@ def make_source_params(self, freqs: au.Quantity, central_freq: au.Quantity | Non

rfi_band_mask = np.logical_and(freqs >= central_freq - bandwidth / 2, freqs <= central_freq + bandwidth / 2)

nominal_spectral_flux_density = (100 * au.Jy) * (1 * au.km) ** 2 * ((55 * au.MHz) / central_freq)
nominal_spectral_flux_density = (100 * au.Jy) * (1 * au.km) ** 2
spectral_flux_density = rfi_band_mask[None].astype(np.float32) * nominal_spectral_flux_density.to(
'W/MHz') # [1, num_chans]
'Jy*m^2') # [E=1, num_chans]
if full_stokes:
spectral_flux_density = 0.5 * au.Quantity(
np.stack(
[
np.stack([spectral_flux_density, 0 * spectral_flux_density], axis=-1),
np.stack([0 * spectral_flux_density, spectral_flux_density], axis=-1)
],
axis=-1
axis=-2
)
)
) # [E=1, num_chans, 2, 2]
auto_correlation_function = auto_correlation_function[:, :, None, None,
None] * spectral_flux_density # [n_delays, E=1, num_chans, 2, 2]
else:
auto_correlation_function = auto_correlation_function[:, :,
None] * spectral_flux_density # [n_delays, E=1, num_chans]

# ENU coords
# Far field limit would be around
far_field_limit = fraunhofer_far_field_limit(diameter=18. * au.km, freq=central_freq)
far_field_limit = fraunhofer_far_field_limit(diameter=2.7 * au.km, freq=central_freq)
print(f"Far field limit: {far_field_limit} at {central_freq}")
position_enu = au.Quantity([[14e3, 0, 80]], unit='m') # [1, 3]
position_enu = au.Quantity([[1e3, 1e3, 120]], unit='m') # [1, 3]

delay_acf = InterpolatedArray(
x=delays, values=auto_correlation_function, axis=0, regular_grid=regular_grid
) # [E]
) # [E=1,chan[2,2]]

return RFIEmitterSourceModelParams(
freqs=freqs,
position_enu=position_enu,
spectral_flux_density=spectral_flux_density,
delay_acf=delay_acf
)
26 changes: 15 additions & 11 deletions dsa2000_cal/dsa2000_cal/assets/rfi/lte_rfi/lwa_cell_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def plot_acf(self):
mat = loadmat(self.rfi_injection_model())
delays = jnp.asarray(mat['t_acf'][0])
auto_correlation_function = jnp.asarray(mat['acf']).T
plt.plot(delays, auto_correlation_function)
plt.plot(delays * 1e6, np.abs(auto_correlation_function))
plt.xlabel('Delay [us]')
plt.ylabel('ACF')
plt.show()

def make_source_params(self, freqs: au.Quantity, central_freq: au.Quantity | None = None,
Expand All @@ -31,7 +33,7 @@ def make_source_params(self, freqs: au.Quantity, central_freq: au.Quantity | Non
# E=1
mat = loadmat(self.rfi_injection_model())
delays = jnp.asarray(mat['t_acf'][0])
auto_correlation_function = jnp.asarray(mat['acf']).T # [n_delays, 1]
auto_correlation_function = jnp.asarray(mat['acf']).T # [n_delays, E=1]
if np.allclose(np.diff(delays), delays[1] - delays[0], atol=1e-8):
regular_grid = True
else:
Expand All @@ -44,35 +46,37 @@ def make_source_params(self, freqs: au.Quantity, central_freq: au.Quantity | Non

rfi_band_mask = np.logical_and(freqs >= central_freq - bandwidth / 2, freqs <= central_freq + bandwidth / 2)

nominal_spectral_flux_density = (100 * au.Jy) * (1 * au.km) ** 2 * ((55 * au.MHz) / central_freq)
nominal_spectral_flux_density = (100 * au.Jy) * (1 * au.km) ** 2
spectral_flux_density = rfi_band_mask[None].astype(np.float32) * nominal_spectral_flux_density.to(
'W/MHz') # [1, num_chans]
'Jy*m^2') # [E=1, num_chans]
if full_stokes:
spectral_flux_density = 0.5 * au.Quantity(
np.stack(
[
np.stack([spectral_flux_density, 0 * spectral_flux_density], axis=-1),
np.stack([0 * spectral_flux_density, spectral_flux_density], axis=-1)
],
axis=-1
axis=-2
)
)
) # [E=1, num_chans, 2, 2]
auto_correlation_function = auto_correlation_function[:, :, None, None,
None] * spectral_flux_density # [n_delays, E=1, num_chans, 2, 2]
else:
auto_correlation_function = auto_correlation_function[:, :,
None] * spectral_flux_density # [n_delays, E=1, num_chans]

# ENU coords
# Far field limit would be around
far_field_limit = fraunhofer_far_field_limit(diameter=2.7 * au.km, freq=central_freq)
print(f"Far field limit: {far_field_limit} at {central_freq}")
position_enu = au.Quantity([[1e3, 1e3, 1e3]], unit='m') # [1, 3]
position_enu = au.Quantity([[1e3, 1e3, 120]], unit='m') # [1, 3]

delay_acf = InterpolatedArray(
x=delays, values=auto_correlation_function, axis=0, regular_grid=regular_grid
) # [E]
) # [E=1,chan[2,2]]

return RFIEmitterSourceModelParams(
freqs=freqs,
position_enu=position_enu,
spectral_flux_density=spectral_flux_density,
delay_acf=delay_acf
)


Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ def test_lte_rfi_source_factory():
model = MockCellTower(seed='test')
import pylab as plt
source_params = model.make_source_params(freqs=np.linspace(700, 800, 50) * au.MHz)
plt.plot(source_params.freqs, source_params.spectral_flux_density[0])
plt.xlabel('Frequency [MHz]')
plt.ylabel('Luminosity [W/Hz]')
plt.show()
plt.plot(source_params.delay_acf.x, source_params.delay_acf.values[:, 0])
plt.xlabel('Delay [s]')
plt.ylabel('Auto-correlation function')
Expand Down
18 changes: 13 additions & 5 deletions dsa2000_cal/dsa2000_cal/calibration/multi_step_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class MultiStepLevenbergMarquardt(Generic[X, Y]):
residual_fn: Callable[[X], Y]
num_approx_steps: int = 0
num_iterations: int = 1
more_outputs_than_inputs: bool = False

# Improvement threshold
p_any_improvement: FloatArray = 0.1 # p0 > 0
Expand Down Expand Up @@ -233,7 +232,7 @@ def matvec(v: X) -> X:

return matvec

J_bare = JVPLinearOp(fn=residual_fn, more_outputs_than_inputs=self.more_outputs_than_inputs)
J_bare = JVPLinearOp(fn=residual_fn)

output_dtypes = jax.tree.map(lambda x: x.dtype, state)

Expand Down Expand Up @@ -325,15 +324,24 @@ def body(iteration: int, step: int, state: MultiStepLevenbergMarquardtState, J:
)
return state, diagnostic

diagnostics = []
for iteration in range(self.num_iterations):
@jax.jit
def single_iteration(iteration: jax.Array, state: MultiStepLevenbergMarquardtState):
diagnostics = []

# Does one initial exact step using the current jacobian estimate, followed by inexact steps using the same
# jacobian estimate (which is slightly cheaper).
J = J_bare(state.x)
for step in range(self.num_approx_steps + 1):
state, diagnostic = body(mp_policy.cast_to_index(iteration), mp_policy.cast_to_index(step), state, J)
diagnostics.append(diagnostic)
diagnostics = jax.tree.map(lambda *args: jnp.stack(args), *diagnostics)
diagnostics = jax.tree.map(lambda *args: jnp.stack(args), *diagnostics)
return state, diagnostics

diagnostics = []
for iteration in range(self.num_iterations):
state, diagnostic = single_iteration(mp_policy.cast_to_index(iteration), state)
diagnostics.append(diagnostic)
diagnostics = jax.tree.map(lambda *args: jnp.concatenate(args), *diagnostics)

# Convert back to complex
state = state._replace(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def save_solution(self, solution: Any, file_name: str, times: at.Time, ms: Measu
freqs=ms.meta.freqs,
position_enu=solution.position_enu * au.m,
array_location=ms.meta.array_location,
luminosity=solution.luminosity * (au.W / au.MHz),
luminosity=(solution.luminosity * (au.Jy * au.m ** 2)).to('Jy*km^2'),
delay_acf=solution.delay_acf,
antennas=ms.meta.antennas,
antenna_labels=ms.meta.antenna_names,
Expand Down
4 changes: 2 additions & 2 deletions dsa2000_cal/dsa2000_cal/calibration/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def test_calibration(mock_calibrator_source_models):
calibration = Calibration(
# models to calibrate based on. Each model gets a gain direction in the flux weighted direction.
probabilistic_models=probabilistic_models,
num_iterations=1,
num_approx_steps=0,
num_iterations=2,
num_approx_steps=2,
inplace_subtract=True,
plot_folder='plots',
solution_folder='solutions',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def forward(gains, antenna1, antenna2, vis_per_source):
jax.random.PRNGKey(4), (row, chan), dtype=mp_policy.vis_dtype)

def run(antenna1, antenna2, vis_per_source, data):

def residuals(params):
gains_real, gains_imag = params
gains = mp_policy.cast_to_gain(gains_real + 1j * gains_imag)
Expand Down
7 changes: 0 additions & 7 deletions dsa2000_cal/dsa2000_cal/common/fits_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,6 @@ def _check_image_model(image_model: ImageModel):
if image_model.image.shape[3] != len(image_model.coherencies):
raise ValueError(f"num_coherencies must match image[3] shape, "
f"got {image_model.image.shape[3]} != {len(image_model.coherencies)}")
if image_model.coherencies not in [
['XX', 'XY', 'YX', 'YY'],
['I', 'Q', 'U', 'V'],
['RR', 'RL', 'LR', 'LL'],
['I']
]:
raise ValueError(f"coherencies format {image_model.coherencies} is invalid.")
# Ensure freqs are uniformly spaced
dfreq = np.diff(image_model.freqs.to(au.Hz).value)
if len(dfreq) > 0 and not np.allclose(dfreq, dfreq[0], atol=1e-6):
Expand Down
Loading

0 comments on commit f07cd80

Please sign in to comment.