Skip to content

Commit

Permalink
* Update flows
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Aug 22, 2024
1 parent ddb0d2d commit b7d69b1
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 2 deletions.
30 changes: 30 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''

---

**Describe the bug**
Some background.

**Expected behavior**
What did you expect?

**Observed behavior**
What did you actually see?

**Minimal Verifiable Complete Example**
Ideally you provide some short code that demonstrates the bug. When this is really impossible, give some exact step-by-step instructions.

```python
#code goes here
```

**Screenshots**
If applicable, add screenshots to help explain your problem.

**JAXNS version**
Output of `pip freeze`:
20 changes: 20 additions & 0 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
name: Feature request
about: Use this to request a feature
title: ''
labels: enhancement
assignees: ''

---

**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context or screenshots about the feature request here.
42 changes: 42 additions & 0 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python package

on:
pull_request:
branches: [ "main", "develop" ]
push:
branches: [ "main" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [ "3.11" ]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
cd dsa2000_cal
pip install -r requirements.txt
pip install -r requirements-tests.txt
pip install .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
70 changes: 70 additions & 0 deletions debug/wterm_poly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import jax.numpy as jnp
import numpy
import pylab as plt

def wterm(l,m,w):
n = jnp.square(1 - jnp.square(l) - jnp.square(m))
return jnp.exp(-2j * jnp.pi * w * n)/n

def break_up_ffts():
import numpy as np
import matplotlib.pyplot as plt

# Function to shift the FFT of a sub-image
def shift_fft(sub_fft, shift):
n, m = sub_fft.shape
u = np.fft.fftfreq(n).reshape(-1, 1)
v = np.fft.fftfreq(m).reshape(1, -1)
shift_matrix = np.exp(-2j * np.pi * (shift[0] * u + shift[1] * v))
return sub_fft * shift_matrix

# Create a sample image of size [2n, 2n]
n = 4
full_image = np.random.random((2 * n, 2 * n))

# Divide the image into four sub-images
I1 = full_image[0:n, 0:n]
I2 = full_image[0:n, n:2 * n]
I3 = full_image[n:2 * n, 0:n]
I4 = full_image[n:2 * n, n:2 * n]

# Compute the FFT of each sub-image
F1 = np.fft.fft2(I1)
F2 = np.fft.fft2(I2)
F3 = np.fft.fft2(I3)
F4 = np.fft.fft2(I4)

# Adjust the FFTs with shifts
F2_shifted = shift_fft(F2, (0, n))
F3_shifted = shift_fft(F3, (n, 0))
F4_shifted = shift_fft(F4, (n, n))

# Combine the FFTs into a full FFT array
combined_fft = np.zeros((2 * n, 2 * n), dtype=complex)
combined_fft[0:n, 0:n] = F1
combined_fft[0:n, n:2 * n] = F2_shifted
combined_fft[n:2 * n, 0:n] = F3_shifted
combined_fft[n:2 * n, n:2 * n] = F4_shifted

# Compute the FFT of the full image
full_fft = np.fft.fft2(full_image)

# Compare the results
print("Are the two FFTs identical? ", np.allclose(full_fft, combined_fft))

# For visualization purposes
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.title("FFT of Full Image")
plt.imshow(np.log(np.abs(np.fft.fftshift(full_fft)) + 1), cmap='gray')

plt.subplot(1, 2, 2)
plt.title("Combined FFT of Sub-Images")
plt.imshow(np.log(np.abs(np.fft.fftshift(combined_fft)) + 1), cmap='gray')

plt.show()


if __name__ == '__main__':
break_up_ffts()
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import dataclasses

import haiku as hk
import jax
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp
from jaxns import Model
from jaxns.framework import ops

from dsa2000_cal.calibration.probabilistic_models.gain_prior_models import AbstractGainPriorModel
from dsa2000_cal.calibration.probabilistic_models.probabilistic_model import AbstractProbabilisticModel, \
ProbabilisticModelInstance
from dsa2000_cal.common.jax_utils import promote_pytree
from dsa2000_cal.delay_models.far_field import VisibilityCoords
from dsa2000_cal.measurement_sets.measurement_set import VisibilityData
from dsa2000_cal.visibility_model.rime_model import RIMEModel

tfpd = tfp.distributions


@dataclasses.dataclass(eq=False)
class HorizonRFIModel(AbstractProbabilisticModel):
rime_model: RIMEModel
gain_prior_model: AbstractGainPriorModel

def create_model_instance(self, freqs: jax.Array,
times: jax.Array,
vis_data: VisibilityData,
vis_coords: VisibilityCoords
) -> ProbabilisticModelInstance:
model_data = self.rime_model.get_model_data(
times=times
) # [facets]

# TODO: explore using checkpointing
vis = self.rime_model.predict_visibilities(
model_data=model_data,
visibility_coords=vis_coords
) # [num_cal, num_row, num_chan[, 2, 2]]

# vis = jax.lax.with_sharding_constraint(vis, NamedSharding(mesh, P(None, None, 'chan')))'
# TODO: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#fsdp-tp-with-shard-map-at-the-top-level

# vis now contains the model visibilities for each calibrator
def prior_model():
gain_prior_model = self.gain_prior_model.build_prior_model(
num_source=self.rime_model.num_facets,
num_ant=self.rime_model.num_antennas,
freqs=freqs,
times=times
)
gains: jax.Array = yield from gain_prior_model() # [num_source, num_ant, num_chan[, 2, 2]]
visibilities = self.rime_model.apply_gains(
gains=gains,
vis=vis,
visibility_coords=vis_coords
) # [num_row, num_chan[, 2, 2]]
return visibilities, gains

def log_likelihood(vis_model: jax.Array, gains: jax.Array):
"""
Compute the log probability of the data given the gains.
Args:
vis_model: [num_rows, num_chan, 4]
gains: [num_source, num_ant, num_chan, 2, 2]
Returns:
log_prob: scalar
"""

vis_variance = 1. / vis_data.weights # Should probably use measurement set SIGMA here
vis_stddev = jnp.sqrt(vis_variance)
obs_dist_real = tfpd.Normal(
*promote_pytree('vis_real', (jnp.real(vis_data.vis), vis_stddev))
)
obs_dist_imag = tfpd.Normal(
*promote_pytree('vis_imag', (jnp.imag(vis_data.vis), vis_stddev))
)
log_prob_real = obs_dist_real.log_prob(jnp.real(vis_model))
log_prob_imag = obs_dist_imag.log_prob(jnp.imag(vis_model)) # [num_rows, num_chan, 4]
log_prob = log_prob_real + log_prob_imag # [num_rows, num_chan, 4]

# Mask out flagged data or zero-weighted data.
mask = jnp.logical_or(vis_data.weights == 0, vis_data.flags)
log_prob = jnp.where(mask, 0., log_prob)
return jnp.sum(log_prob)

model = Model(
prior_model=prior_model,
log_likelihood=log_likelihood
)

def get_init_params():
# Could use model.sample_W() for a random start
return model._W_placeholder()

def forward(params):
def _forward():
# Use jaxns.framework.ops to transform the params into the args for likelihood
return ops.prepare_input(W=params, prior_model=prior_model)

return hk.transform(_forward).apply(
params=model.params, rng=jax.random.PRNGKey(0)
)

def log_prob_joint(params):
def _log_prob_joint():
# Use jaxns.framework.ops to compute the log prob of joint
log_prob_prior = ops.compute_log_prob_prior(
W=params,
prior_model=model.prior_model
)
log_prob_likelihood = ops.compute_log_likelihood(
W=params,
prior_model=model.prior_model,
log_likelihood=model.log_likelihood,
allow_nan=False
)
return log_prob_prior + log_prob_likelihood

return hk.transform(_log_prob_joint).apply(
params=model.params, rng=jax.random.PRNGKey(0)
)

return ProbabilisticModelInstance(
get_init_params_fn=get_init_params,
forward_fn=forward,
log_prob_joint_fn=log_prob_joint
)
2 changes: 1 addition & 1 deletion dsa2000_cal/dsa2000_cal/common/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def multilinear_interp_2d(x, y, xp, yp, z):
return z_interp


def apply_interp(x: jax.Array, i0: jax.Array, alpha0: jax.Array, i1: jax.Array, alpha1: jax, axis: int = 0):
def apply_interp(x: jax.Array, i0: jax.Array, alpha0: jax.Array, i1: jax.Array, alpha1: jax.Array, axis: int = 0):
"""
Apply interpolation alpha given axis.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def test_create_sky_model():
assert len(rfi_emitter_sources) == 1
a_team_sources = synthetic_sky_model_producer.create_a_team_sources(a_team_sources=['cas_a'])
a_team_sources[0].plot(save_file='cas_a.png')
assert len(a_team_sources) == 1
assert len(a_team_sources) == 1
1 change: 1 addition & 0 deletions dsa2000_cal/requirements-tests.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
flake8
pytest
pytest-asyncio

0 comments on commit b7d69b1

Please sign in to comment.