Skip to content

Commit

Permalink
* Fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 5, 2024
1 parent d14c125 commit b5eb179
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 431 deletions.
156 changes: 1 addition & 155 deletions dsa2000_cal/dsa2000_cal/common/nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,160 +187,6 @@ class GridTree3D(NamedTuple):
jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array] # [6] (min_x, max_x, min_y, max_y, min_z, max_z)


@dataclasses.dataclass(eq=False)
class ApproximateTreeNN3D:
"""
Approximate tree for nearest neighbor search on 3D box.
A tree structure is used to find the k nearest neighbors to a given point in 3D space, by constructing a grid of
shape (n_grid, n_grid, n_grid) where,
n_grid = int(cbrt(n / average_points_per_cell)), where n is the number of points.
The memory usage goes as O(n * ( 3 + kappa / cbrt(average_points_per_cell))).
The tree build time goes as O(n).
The tree query time goes as O((average_points_per_cell + kappa * cbrt(average_points_per_cell)) + k log k)
Accuracy generally increases with more points in the tree. Accuracy decreases with larger `k` queries due to cell
edge effects.
Args:
average_points_per_cell: Average number of points per cell in the grid.
kappa: how many sigmas above the expected number of points per cell to allow.
"""
average_points_per_cell: int = 16
kappa: float = 5.0

def _point_to_cell(self, point: jax.Array, n_grid: int,
extent: Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]) -> Tuple[
jax.Array, jax.Array, jax.Array]:
x_min, x_max, y_min, y_max, z_min, z_max = extent
cell_x = jnp.clip(jnp.floor((point[0] - x_min) / (x_max - x_min) * n_grid), 0, n_grid - 1).astype(int)
cell_y = jnp.clip(jnp.floor((point[1] - y_min) / (y_max - y_min) * n_grid), 0, n_grid - 1).astype(int)
cell_z = jnp.clip(jnp.floor((point[2] - z_min) / (z_max - z_min) * n_grid), 0, n_grid - 1).astype(int)
return cell_x, cell_y, cell_z

def _grid_to_idx(self, cell_x: jax.Array, cell_y: jax.Array, cell_z: jax.Array, n_grid: int) -> jax.Array:
"""Maps (cell_x, cell_y, cell_z) to grid_idx."""
return cell_z * n_grid * n_grid + cell_y * n_grid + cell_x

def _idx_to_grid(self, grid_idx: jax.Array, n_grid: int) -> Tuple[jax.Array, jax.Array, jax.Array]:
"""Maps grid_idx back to (cell_x, cell_y, cell_z)."""
cell_x = grid_idx % n_grid
cell_y = (grid_idx // n_grid) % n_grid
cell_z = grid_idx // (n_grid * n_grid)
return cell_x, cell_y, cell_z

def build_tree(self, points: jax.Array) -> GridTree3D:
"""
Builds the tree structure given the points in the space [a,b]x[c,d]x[e,f].
Parameters:
points (jax.numpy.ndarray): Array of points with shape (n_points, 3).
Returns:
GridTree3D: A named tuple containing the grid, grid size, max points per cell, and the original points.
"""
n_points = points.shape[0]
if n_points == 0:
raise ValueError("No points provided to build the tree.")
n_grid = int(np.cbrt(n_points / self.average_points_per_cell))
if n_grid < 1:
warnings.warn("Number of points is too small to meet desired average points per cell.")
n_grid = 1
num_cells = n_grid * n_grid * n_grid

max_points_per_cell = int(
n_points / num_cells + self.kappa * np.cbrt(n_points / num_cells)
)
if max_points_per_cell < 1:
raise ValueError("max_points_per_cell must be at least 1.")

grid = -1 * jnp.ones((num_cells, max_points_per_cell), dtype=int)
storage_indices = jnp.zeros(num_cells, dtype=int) # To track where to store the next point in each grid
points_min = jnp.min(points, axis=0)
points_max = jnp.max(points, axis=0)
extent = (points_min[0], points_max[0], points_min[1], points_max[1], points_min[2], points_max[2])

def assign_point(i, state):
grid, storage_indices = state
point = points[i]
cell_x, cell_y, cell_z = self._point_to_cell(point, n_grid, extent)
grid_idx = self._grid_to_idx(cell_x, cell_y, cell_z, n_grid)

storage_index = storage_indices[grid_idx]
grid = grid.at[grid_idx, storage_index].set(i)
storage_indices = storage_indices.at[grid_idx].set((storage_index + 1) % max_points_per_cell)

return grid, storage_indices

grid, storage_indices = jax.lax.fori_loop(0, n_points, assign_point, (grid, storage_indices))

# Similar process as in 2D for filling unfilled cells with random points from neighboring cells
def body(state):
i, grid, storage_indices = state
G, P = jnp.meshgrid(jnp.arange(np.shape(grid)[0]), jnp.arange(np.shape(grid)[1]), indexing='ij')
cell_x, cell_y, cell_z = self._idx_to_grid(G, n_grid)
neighbour_inc = jax.random.randint(jax.random.PRNGKey(42), np.shape(G) + (3,),
-1, 2) # [num_cells, max_points_per_cell, 3]
neighbour_x = jnp.clip(cell_x + neighbour_inc[:, :, 0], 0, n_grid - 1)
neighbour_y = jnp.clip(cell_y + neighbour_inc[:, :, 1], 0, n_grid - 1)
neighbour_z = jnp.clip(cell_z + neighbour_inc[:, :, 2], 0, n_grid - 1)
neighbour_grid_idx = self._grid_to_idx(neighbour_x, neighbour_y, neighbour_z, n_grid)
random_select = jax.random.randint(jax.random.PRNGKey(42), np.shape(G), 0,
storage_indices[:, None]) # [num_cells, max_points_per_cell]
random_neighbour = grid[neighbour_grid_idx, random_select]

@partial(jax.vmap, in_axes=(0, 0))
@partial(jax.vmap, in_axes=(0, 0))
def check_cell(i, j):
return jnp.logical_not(jnp.any(grid[i] == random_neighbour[i, j]))

replace = (grid == -1) & check_cell(G, P)
grid = jnp.where(replace, random_neighbour, grid)
grid = jnp.sort(grid, axis=1, descending=True)
storage_indices = jnp.sum(grid != -1, axis=1)
return i + 1, grid, storage_indices

def cond(state):
i, grid, storage_indices = state
return jnp.any(grid == -1) & (i < 10)

_, grid, _ = jax.lax.while_loop(cond, body, (0, grid, storage_indices))

return GridTree3D(grid=grid, points=points, extent=extent)

def query(self, tree: GridTree3D, test_point: jax.Array, k: int = 1) -> Tuple[jax.Array, jax.Array]:
"""
Queries the tree structure to find the k nearest neighbors to the test point in 3D.
Parameters:
tree (GridTree3D): The tree structure built by the `build_tree` method.
test_point (jax.numpy.ndarray): A point in [a,b]x[c,d]x[e,f] with shape (3,).
k (int): The number of nearest neighbors to find.
Returns:
distances (jax.numpy.ndarray): Distances to the k nearest neighbors.
indices (jax.numpy.ndarray): Indices of the k nearest neighbors.
"""
n_grid = int(np.cbrt(np.shape(tree.grid)[0]))
cell_x, cell_y, cell_z = self._point_to_cell(test_point, n_grid, tree.extent)
grid_idx = self._grid_to_idx(cell_x, cell_y, cell_z, n_grid)
point_indices = tree.grid[grid_idx] # [max_points_per_cell]
points_in_cell = tree.points[point_indices] # [max_points_per_cell, 3]

valid_mask = point_indices >= 0

distances = jnp.linalg.norm(points_in_cell - test_point, axis=1) # [max_points_per_cell]
neg_distances = jnp.where(valid_mask, -distances, -jnp.inf)
top_k_neg_distances, top_k_indices_within_cell = jax.lax.top_k(neg_distances, k)
top_k_distances = -top_k_neg_distances

return top_k_distances, point_indices[top_k_indices_within_cell]


def kd_tree_nn(points: jax.Array, test_points: jax.Array, k: int = 1) -> Tuple[jax.Array, jax.Array]:
"""
Uses a KD-tree to find the k nearest neighbors to a test point in 3D space.
Expand Down Expand Up @@ -391,7 +237,7 @@ def _kd_tree_nn_host(points: jax.Array, test_points: jax.Array, k: int) -> Tuple
k = int(k)
tree = KDTree(points, compact_nodes=False, balanced_tree=False)
if k == 1:
distances, indices = tree.query(test_points, k=[1]) # unsqueeze k
distances, indices = tree.query(test_points, k=[1]) # unsqueeze k
else:
distances, indices = tree.query(test_points, k=k)
return distances, indices
105 changes: 1 addition & 104 deletions dsa2000_cal/dsa2000_cal/common/tests/test_nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import random as random, numpy as jnp
from jax._src.tree_util import Partial

from dsa2000_cal.common.nearest_neighbours import ApproximateTreeNN2D, GridTree2D, ApproximateTreeNN3D, GridTree3D
from dsa2000_cal.common.nearest_neighbours import ApproximateTreeNN2D, GridTree2D


@pytest.fixture
Expand Down Expand Up @@ -108,106 +108,3 @@ def test_build_tree_handles_empty_points_2d(setup_tree_2d):
points = jnp.array([]).reshape(0, 2)
with pytest.raises(ValueError, match="No points provided to build the tree."):
_ = approx_tree.build_tree(points)


@pytest.fixture
def setup_tree_3d():
approx_tree = ApproximateTreeNN3D(average_points_per_cell=16, kappa=5.0)
return approx_tree


def test_build_tree_3d(setup_tree_3d):
approx_tree = setup_tree_3d
n_points = 100
points = random.uniform(random.PRNGKey(0), (n_points, 3))

tree = approx_tree.build_tree(points)

assert isinstance(tree, GridTree3D)
assert tree.grid.shape == (4, 25 + 5 * np.cbrt(25)) # Adjusted for 3D
assert jnp.all(tree.points == points)

tree = jax.jit(approx_tree.build_tree)(points)

assert isinstance(tree, GridTree3D)
assert tree.grid.shape == (4, 25 + 5 * np.cbrt(25)) # Adjusted for 3D
assert jnp.all(tree.points == points)


def test_query_within_single_cell_3d(setup_tree_3d):
approx_tree = setup_tree_3d
points = jnp.array([[0.1, 0.1, 0.1], [0.15, 0.15, 0.15], [0.2, 0.2, 0.2]])

tree = approx_tree.build_tree(points)
test_point = jnp.array([0.12, 0.12, 0.12])
k = 2

distances, indices = approx_tree.query(tree, test_point, k)

assert len(distances) == k
assert len(indices) == k
assert indices[0] in [0, 1, 2]
assert jnp.allclose(jnp.sort(distances), jnp.sort(jnp.linalg.norm(points - test_point, axis=1)[:k]))

distances, indices = jax.jit(Partial(approx_tree.query, k=k))(tree, test_point)

assert len(distances) == k
assert len(indices) == k
assert indices[0] in [0, 1, 2]
assert jnp.allclose(jnp.sort(distances), jnp.sort(jnp.linalg.norm(points - test_point, axis=1)[:k]))


def test_query_no_points_in_cell_3d(setup_tree_3d):
approx_tree = setup_tree_3d
points = jnp.array([[0.8, 0.8, 0.8], [0.9, 0.9, 0.9], [0.85, 0.85, 0.85]])

tree = approx_tree.build_tree(points)
tree = tree._replace(grid=tree.grid.at[0, :].set(-1))
test_point = jnp.array([0.1, 0.1, 0.1])
k = 2

distances, indices = approx_tree.query(tree, test_point, k)

assert distances.size == 2
assert indices.size == 2

np.testing.assert_allclose(distances, jnp.inf)
np.testing.assert_allclose(indices, -1)


def test_query_with_exactly_k_points_3d(setup_tree_3d):
approx_tree = setup_tree_3d
points = jnp.array([[0.1, 0.1, 0.1], [0.15, 0.15, 0.15], [0.2, 0.2, 0.2]])

tree = approx_tree.build_tree(points)
test_point = jnp.array([0.12, 0.12, 0.12])
k = 3

distances, indices = approx_tree.query(tree, test_point, k)

assert len(distances) == k
assert len(indices) == k
assert jnp.all(indices < len(points))


def test_query_nearest_neighbors_on_boundary_3d(setup_tree_3d):
approx_tree = setup_tree_3d
points = jnp.array([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]])

tree = approx_tree.build_tree(points)
test_point = jnp.array([0.5, 0.5, 0.5])
k = 1

distances, indices = approx_tree.query(tree, test_point, k)

assert len(distances) == k
assert len(indices) == k
assert indices[0] == 1
assert jnp.allclose(distances[0], 0.0)


def test_build_tree_handles_empty_points_3d(setup_tree_3d):
approx_tree = setup_tree_3d
points = jnp.array([]).reshape(0, 3) # Adjusted for 3D
with pytest.raises(ValueError, match="No points provided to build the tree."):
_ = approx_tree.build_tree(points)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from astropy import coordinates as ac, units as au

from dsa2000_cal.assets.content_registry import fill_registries
from dsa2000_cal.forward_models.synthetic_sky_model.synthetic_sky_model_producer import SyntheticSkyModelProducer


def test_create_sky_model():
fill_registries()
synthetic_sky_model_producer = SyntheticSkyModelProducer(
phase_tracking=ac.ICRS(15 * au.deg, 0 * au.deg),
freqs=au.Quantity([700], unit='MHz'),
field_of_view=4 * au.deg,
seed=42
)
bright_point_sources = synthetic_sky_model_producer.create_sources_outside_fov(num_bright_sources=100,
full_stokes=False)
bright_point_sources.plot(save_file='bright_point_sources.png')
assert bright_point_sources.num_sources == 100
inner_point_sources = synthetic_sky_model_producer.create_sources_inside_fov(num_sources=100, full_stokes=False)
inner_point_sources.plot(save_file='inner_point_sources.png')

(bright_point_sources + inner_point_sources).plot(save_file='all_point_sources.png')

assert inner_point_sources.num_sources == 37 # Should debug
inner_diffuse_sources = synthetic_sky_model_producer.create_diffuse_sources_inside_fov(num_sources=100,
full_stokes=False)
inner_diffuse_sources.plot(save_file='inner_diffuse_sources.png')
assert inner_diffuse_sources.num_sources == 37 # Should debug
rfi_emitter_sources = synthetic_sky_model_producer.create_rfi_emitter_sources(full_stokes=False)
rfi_emitter_sources[0].plot(save_file='rfi_emitter_sources.png')
assert len(rfi_emitter_sources) == 1
a_team_sources = synthetic_sky_model_producer.create_a_team_sources(a_team_sources=['cas_a'])
a_team_sources[0].plot(save_file='cas_a.png')
assert len(a_team_sources) == 1
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pytest
from astropy import units as au, coordinates as ac

from dsa2000_cal.assets.content_registry import fill_registries
from dsa2000_cal.common.astropy_utils import create_spherical_grid
from dsa2000_cal.forward_models.synthetic_sky_model.synthetic_sky_model_producer import choose_dr, \
SyntheticSkyModelProducer
from dsa2000_cal.forward_models.synthetic_sky_model.synthetic_sky_model_producer import choose_dr


@pytest.mark.parametrize('total_n, expected_n', [
Expand All @@ -27,33 +25,3 @@ def test_choose_dr(total_n, expected_n):
plt.plot(sources.ra, sources.dec, 'o')
plt.show()
assert len(sources) == expected_n


def test_create_sky_model():
fill_registries()
synthetic_sky_model_producer = SyntheticSkyModelProducer(
phase_tracking=ac.ICRS(15 * au.deg, 0 * au.deg),
freqs=au.Quantity([700], unit='MHz'),
field_of_view=4 * au.deg,
seed=42
)
bright_point_sources = synthetic_sky_model_producer.create_sources_outside_fov(num_bright_sources=100,
full_stokes=False)
bright_point_sources.plot(save_file='bright_point_sources.png')
assert bright_point_sources.num_sources == 100
inner_point_sources = synthetic_sky_model_producer.create_sources_inside_fov(num_sources=100, full_stokes=False)
inner_point_sources.plot(save_file='inner_point_sources.png')

(bright_point_sources + inner_point_sources).plot(save_file='all_point_sources.png')

assert inner_point_sources.num_sources == 37 # Should debug
inner_diffuse_sources = synthetic_sky_model_producer.create_diffuse_sources_inside_fov(num_sources=100,
full_stokes=False)
inner_diffuse_sources.plot(save_file='inner_diffuse_sources.png')
assert inner_diffuse_sources.num_sources == 37 # Should debug
rfi_emitter_sources = synthetic_sky_model_producer.create_rfi_emitter_sources(full_stokes=False)
rfi_emitter_sources[0].plot(save_file='rfi_emitter_sources.png')
assert len(rfi_emitter_sources) == 1
a_team_sources = synthetic_sky_model_producer.create_a_team_sources(a_team_sources=['cas_a'])
a_team_sources[0].plot(save_file='cas_a.png')
assert len(a_team_sources) == 1
Loading

0 comments on commit b5eb179

Please sign in to comment.