-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ddb0d2d
commit b7d69b1
Showing
8 changed files
with
295 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
--- | ||
name: Bug report | ||
about: Create a report to help us improve | ||
title: '' | ||
labels: bug | ||
assignees: '' | ||
|
||
--- | ||
|
||
**Describe the bug** | ||
Some background. | ||
|
||
**Expected behavior** | ||
What did you expect? | ||
|
||
**Observed behavior** | ||
What did you actually see? | ||
|
||
**Minimal Verifiable Complete Example** | ||
Ideally you provide some short code that demonstrates the bug. When this is really impossible, give some exact step-by-step instructions. | ||
|
||
```python | ||
#code goes here | ||
``` | ||
|
||
**Screenshots** | ||
If applicable, add screenshots to help explain your problem. | ||
|
||
**JAXNS version** | ||
Output of `pip freeze`: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
--- | ||
name: Feature request | ||
about: Use this to request a feature | ||
title: '' | ||
labels: enhancement | ||
assignees: '' | ||
|
||
--- | ||
|
||
**Is your feature request related to a problem? Please describe.** | ||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] | ||
|
||
**Describe the solution you'd like** | ||
A clear and concise description of what you want to happen. | ||
|
||
**Describe alternatives you've considered** | ||
A clear and concise description of any alternative solutions or features you've considered. | ||
|
||
**Additional context** | ||
Add any other context or screenshots about the feature request here. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Python package | ||
|
||
on: | ||
pull_request: | ||
branches: [ "main", "develop" ] | ||
push: | ||
branches: [ "main" ] | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: [ "3.11" ] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
cd dsa2000_cal | ||
pip install -r requirements.txt | ||
pip install -r requirements-tests.txt | ||
pip install . | ||
- name: Lint with flake8 | ||
run: | | ||
# stop the build if there are Python syntax errors or undefined names | ||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics | ||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide | ||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics | ||
- name: Test with pytest | ||
run: | | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import jax.numpy as jnp | ||
import numpy | ||
import pylab as plt | ||
|
||
def wterm(l,m,w): | ||
n = jnp.square(1 - jnp.square(l) - jnp.square(m)) | ||
return jnp.exp(-2j * jnp.pi * w * n)/n | ||
|
||
def break_up_ffts(): | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
# Function to shift the FFT of a sub-image | ||
def shift_fft(sub_fft, shift): | ||
n, m = sub_fft.shape | ||
u = np.fft.fftfreq(n).reshape(-1, 1) | ||
v = np.fft.fftfreq(m).reshape(1, -1) | ||
shift_matrix = np.exp(-2j * np.pi * (shift[0] * u + shift[1] * v)) | ||
return sub_fft * shift_matrix | ||
|
||
# Create a sample image of size [2n, 2n] | ||
n = 4 | ||
full_image = np.random.random((2 * n, 2 * n)) | ||
|
||
# Divide the image into four sub-images | ||
I1 = full_image[0:n, 0:n] | ||
I2 = full_image[0:n, n:2 * n] | ||
I3 = full_image[n:2 * n, 0:n] | ||
I4 = full_image[n:2 * n, n:2 * n] | ||
|
||
# Compute the FFT of each sub-image | ||
F1 = np.fft.fft2(I1) | ||
F2 = np.fft.fft2(I2) | ||
F3 = np.fft.fft2(I3) | ||
F4 = np.fft.fft2(I4) | ||
|
||
# Adjust the FFTs with shifts | ||
F2_shifted = shift_fft(F2, (0, n)) | ||
F3_shifted = shift_fft(F3, (n, 0)) | ||
F4_shifted = shift_fft(F4, (n, n)) | ||
|
||
# Combine the FFTs into a full FFT array | ||
combined_fft = np.zeros((2 * n, 2 * n), dtype=complex) | ||
combined_fft[0:n, 0:n] = F1 | ||
combined_fft[0:n, n:2 * n] = F2_shifted | ||
combined_fft[n:2 * n, 0:n] = F3_shifted | ||
combined_fft[n:2 * n, n:2 * n] = F4_shifted | ||
|
||
# Compute the FFT of the full image | ||
full_fft = np.fft.fft2(full_image) | ||
|
||
# Compare the results | ||
print("Are the two FFTs identical? ", np.allclose(full_fft, combined_fft)) | ||
|
||
# For visualization purposes | ||
plt.figure(figsize=(10, 5)) | ||
|
||
plt.subplot(1, 2, 1) | ||
plt.title("FFT of Full Image") | ||
plt.imshow(np.log(np.abs(np.fft.fftshift(full_fft)) + 1), cmap='gray') | ||
|
||
plt.subplot(1, 2, 2) | ||
plt.title("Combined FFT of Sub-Images") | ||
plt.imshow(np.log(np.abs(np.fft.fftshift(combined_fft)) + 1), cmap='gray') | ||
|
||
plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
break_up_ffts() |
130 changes: 130 additions & 0 deletions
130
dsa2000_cal/dsa2000_cal/calibration/probabilistic_models/horizon_rfi_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import dataclasses | ||
|
||
import haiku as hk | ||
import jax | ||
import tensorflow_probability.substrates.jax as tfp | ||
from jax import numpy as jnp | ||
from jaxns import Model | ||
from jaxns.framework import ops | ||
|
||
from dsa2000_cal.calibration.probabilistic_models.gain_prior_models import AbstractGainPriorModel | ||
from dsa2000_cal.calibration.probabilistic_models.probabilistic_model import AbstractProbabilisticModel, \ | ||
ProbabilisticModelInstance | ||
from dsa2000_cal.common.jax_utils import promote_pytree | ||
from dsa2000_cal.delay_models.far_field import VisibilityCoords | ||
from dsa2000_cal.measurement_sets.measurement_set import VisibilityData | ||
from dsa2000_cal.visibility_model.rime_model import RIMEModel | ||
|
||
tfpd = tfp.distributions | ||
|
||
|
||
@dataclasses.dataclass(eq=False) | ||
class HorizonRFIModel(AbstractProbabilisticModel): | ||
rime_model: RIMEModel | ||
gain_prior_model: AbstractGainPriorModel | ||
|
||
def create_model_instance(self, freqs: jax.Array, | ||
times: jax.Array, | ||
vis_data: VisibilityData, | ||
vis_coords: VisibilityCoords | ||
) -> ProbabilisticModelInstance: | ||
model_data = self.rime_model.get_model_data( | ||
times=times | ||
) # [facets] | ||
|
||
# TODO: explore using checkpointing | ||
vis = self.rime_model.predict_visibilities( | ||
model_data=model_data, | ||
visibility_coords=vis_coords | ||
) # [num_cal, num_row, num_chan[, 2, 2]] | ||
|
||
# vis = jax.lax.with_sharding_constraint(vis, NamedSharding(mesh, P(None, None, 'chan')))' | ||
# TODO: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#fsdp-tp-with-shard-map-at-the-top-level | ||
|
||
# vis now contains the model visibilities for each calibrator | ||
def prior_model(): | ||
gain_prior_model = self.gain_prior_model.build_prior_model( | ||
num_source=self.rime_model.num_facets, | ||
num_ant=self.rime_model.num_antennas, | ||
freqs=freqs, | ||
times=times | ||
) | ||
gains: jax.Array = yield from gain_prior_model() # [num_source, num_ant, num_chan[, 2, 2]] | ||
visibilities = self.rime_model.apply_gains( | ||
gains=gains, | ||
vis=vis, | ||
visibility_coords=vis_coords | ||
) # [num_row, num_chan[, 2, 2]] | ||
return visibilities, gains | ||
|
||
def log_likelihood(vis_model: jax.Array, gains: jax.Array): | ||
""" | ||
Compute the log probability of the data given the gains. | ||
Args: | ||
vis_model: [num_rows, num_chan, 4] | ||
gains: [num_source, num_ant, num_chan, 2, 2] | ||
Returns: | ||
log_prob: scalar | ||
""" | ||
|
||
vis_variance = 1. / vis_data.weights # Should probably use measurement set SIGMA here | ||
vis_stddev = jnp.sqrt(vis_variance) | ||
obs_dist_real = tfpd.Normal( | ||
*promote_pytree('vis_real', (jnp.real(vis_data.vis), vis_stddev)) | ||
) | ||
obs_dist_imag = tfpd.Normal( | ||
*promote_pytree('vis_imag', (jnp.imag(vis_data.vis), vis_stddev)) | ||
) | ||
log_prob_real = obs_dist_real.log_prob(jnp.real(vis_model)) | ||
log_prob_imag = obs_dist_imag.log_prob(jnp.imag(vis_model)) # [num_rows, num_chan, 4] | ||
log_prob = log_prob_real + log_prob_imag # [num_rows, num_chan, 4] | ||
|
||
# Mask out flagged data or zero-weighted data. | ||
mask = jnp.logical_or(vis_data.weights == 0, vis_data.flags) | ||
log_prob = jnp.where(mask, 0., log_prob) | ||
return jnp.sum(log_prob) | ||
|
||
model = Model( | ||
prior_model=prior_model, | ||
log_likelihood=log_likelihood | ||
) | ||
|
||
def get_init_params(): | ||
# Could use model.sample_W() for a random start | ||
return model._W_placeholder() | ||
|
||
def forward(params): | ||
def _forward(): | ||
# Use jaxns.framework.ops to transform the params into the args for likelihood | ||
return ops.prepare_input(W=params, prior_model=prior_model) | ||
|
||
return hk.transform(_forward).apply( | ||
params=model.params, rng=jax.random.PRNGKey(0) | ||
) | ||
|
||
def log_prob_joint(params): | ||
def _log_prob_joint(): | ||
# Use jaxns.framework.ops to compute the log prob of joint | ||
log_prob_prior = ops.compute_log_prob_prior( | ||
W=params, | ||
prior_model=model.prior_model | ||
) | ||
log_prob_likelihood = ops.compute_log_likelihood( | ||
W=params, | ||
prior_model=model.prior_model, | ||
log_likelihood=model.log_likelihood, | ||
allow_nan=False | ||
) | ||
return log_prob_prior + log_prob_likelihood | ||
|
||
return hk.transform(_log_prob_joint).apply( | ||
params=model.params, rng=jax.random.PRNGKey(0) | ||
) | ||
|
||
return ProbabilisticModelInstance( | ||
get_init_params_fn=get_init_params, | ||
forward_fn=forward, | ||
log_prob_joint_fn=log_prob_joint | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
flake8 | ||
pytest | ||
pytest-asyncio |