Skip to content

Commit

Permalink
* Fix performance of compute beam
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 4, 2024
1 parent 0b02345 commit f4bd23b
Show file tree
Hide file tree
Showing 13 changed files with 688 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MockAntennaModel(AltAzAntennaModel):
def __init__(self):
self.model_name = 'mock_antenna_model'
self._num_theta = 60
self._num_phi = 15
self._num_phi = 40
self._num_freqs = 20

@cached_property
Expand Down
175 changes: 168 additions & 7 deletions dsa2000_cal/dsa2000_cal/common/nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import numpy as np


class GridTree(NamedTuple):
class GridTree2D(NamedTuple):
grid: jax.Array # [num_grids, max_points_per_cell]
points: jax.Array # [n_points, 2]
extent: Tuple[jax.Array, jax.Array, jax.Array, jax.Array] # [4] (min_x, max_x, min_y, max_y)


@dataclasses.dataclass(eq=False)
class ApproximateTreeNN:
class ApproximateTreeNN2D:
"""
Approximate tree for nearest neighbor search on 2D box.
Expand Down Expand Up @@ -62,15 +62,15 @@ def _idx_to_grid(self, grid_idx: jax.Array, n_grid: int) -> Tuple[jax.Array, jax
cell_y = grid_idx // n_grid
return cell_x, cell_y

def build_tree(self, points: jax.Array) -> GridTree:
def build_tree(self, points: jax.Array) -> GridTree2D:
"""
Builds the tree structure given the points in the space [a,b]x[c,d].
Parameters:
points (jax.numpy.ndarray): Array of points with shape (n_points, 2).
Returns:
GridTree: A named tuple containing the grid, grid size, max points per cell, and the original points.
GridTree2D: A named tuple containing the grid, grid size, max points per cell, and the original points.
"""
n_points = points.shape[0]
if n_points == 0:
Expand Down Expand Up @@ -146,14 +146,14 @@ def cond(state):

_, grid, _ = jax.lax.while_loop(cond, body, (0, grid, storage_indices))

return GridTree(grid=grid, points=points, extent=extent)
return GridTree2D(grid=grid, points=points, extent=extent)

def query(self, tree: GridTree, test_point: jax.Array, k: int = 1) -> Tuple[jax.Array, jax.Array]:
def query(self, tree: GridTree2D, test_point: jax.Array, k: int = 1) -> Tuple[jax.Array, jax.Array]:
"""
Queries the tree structure to find the k nearest neighbors to the test point.
Parameters:
tree (GridTree): The tree structure built by the `build_tree` method.
tree (GridTree2D): The tree structure built by the `build_tree` method.
test_point (jax.numpy.ndarray): A point in [a,b]x[c,d] with shape (2,).
k (int): The number of nearest neighbors to find.
Expand All @@ -177,3 +177,164 @@ def query(self, tree: GridTree, test_point: jax.Array, k: int = 1) -> Tuple[jax.

# Return the actual distances and the corresponding indices in the original points array
return top_k_distances, point_indices[top_k_indices_within_cell]


class GridTree3D(NamedTuple):
grid: jax.Array # [num_grids, max_points_per_cell]
points: jax.Array # [n_points, 3]
extent: Tuple[
jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array] # [6] (min_x, max_x, min_y, max_y, min_z, max_z)


@dataclasses.dataclass(eq=False)
class ApproximateTreeNN3D:
"""
Approximate tree for nearest neighbor search on 3D box.
A tree structure is used to find the k nearest neighbors to a given point in 3D space, by constructing a grid of
shape (n_grid, n_grid, n_grid) where,
n_grid = int(cbrt(n / average_points_per_cell)), where n is the number of points.
The memory usage goes as O(n * ( 3 + kappa / cbrt(average_points_per_cell))).
The tree build time goes as O(n).
The tree query time goes as O((average_points_per_cell + kappa * cbrt(average_points_per_cell)) + k log k)
Accuracy generally increases with more points in the tree. Accuracy decreases with larger `k` queries due to cell
edge effects.
Args:
average_points_per_cell: Average number of points per cell in the grid.
kappa: how many sigmas above the expected number of points per cell to allow.
"""
average_points_per_cell: int = 16
kappa: float = 5.0

def _point_to_cell(self, point: jax.Array, n_grid: int,
extent: Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]) -> Tuple[
jax.Array, jax.Array, jax.Array]:
x_min, x_max, y_min, y_max, z_min, z_max = extent
cell_x = jnp.clip(jnp.floor((point[0] - x_min) / (x_max - x_min) * n_grid), 0, n_grid - 1).astype(int)
cell_y = jnp.clip(jnp.floor((point[1] - y_min) / (y_max - y_min) * n_grid), 0, n_grid - 1).astype(int)
cell_z = jnp.clip(jnp.floor((point[2] - z_min) / (z_max - z_min) * n_grid), 0, n_grid - 1).astype(int)
return cell_x, cell_y, cell_z

def _grid_to_idx(self, cell_x: jax.Array, cell_y: jax.Array, cell_z: jax.Array, n_grid: int) -> jax.Array:
"""Maps (cell_x, cell_y, cell_z) to grid_idx."""
return cell_z * n_grid * n_grid + cell_y * n_grid + cell_x

def _idx_to_grid(self, grid_idx: jax.Array, n_grid: int) -> Tuple[jax.Array, jax.Array, jax.Array]:
"""Maps grid_idx back to (cell_x, cell_y, cell_z)."""
cell_x = grid_idx % n_grid
cell_y = (grid_idx // n_grid) % n_grid
cell_z = grid_idx // (n_grid * n_grid)
return cell_x, cell_y, cell_z

def build_tree(self, points: jax.Array) -> GridTree3D:
"""
Builds the tree structure given the points in the space [a,b]x[c,d]x[e,f].
Parameters:
points (jax.numpy.ndarray): Array of points with shape (n_points, 3).
Returns:
GridTree3D: A named tuple containing the grid, grid size, max points per cell, and the original points.
"""
n_points = points.shape[0]
if n_points == 0:
raise ValueError("No points provided to build the tree.")
n_grid = int(np.cbrt(n_points / self.average_points_per_cell))
if n_grid < 1:
warnings.warn("Number of points is too small to meet desired average points per cell.")
n_grid = 1
num_cells = n_grid * n_grid * n_grid

max_points_per_cell = int(
n_points / num_cells + self.kappa * np.cbrt(n_points / num_cells)
)
if max_points_per_cell < 1:
raise ValueError("max_points_per_cell must be at least 1.")

grid = -1 * jnp.ones((num_cells, max_points_per_cell), dtype=int)
storage_indices = jnp.zeros(num_cells, dtype=int) # To track where to store the next point in each grid
points_min = jnp.min(points, axis=0)
points_max = jnp.max(points, axis=0)
extent = (points_min[0], points_max[0], points_min[1], points_max[1], points_min[2], points_max[2])

def assign_point(i, state):
grid, storage_indices = state
point = points[i]
cell_x, cell_y, cell_z = self._point_to_cell(point, n_grid, extent)
grid_idx = self._grid_to_idx(cell_x, cell_y, cell_z, n_grid)

storage_index = storage_indices[grid_idx]
grid = grid.at[grid_idx, storage_index].set(i)
storage_indices = storage_indices.at[grid_idx].set((storage_index + 1) % max_points_per_cell)

return grid, storage_indices

grid, storage_indices = jax.lax.fori_loop(0, n_points, assign_point, (grid, storage_indices))

# Similar process as in 2D for filling unfilled cells with random points from neighboring cells
def body(state):
i, grid, storage_indices = state
G, P = jnp.meshgrid(jnp.arange(np.shape(grid)[0]), jnp.arange(np.shape(grid)[1]), indexing='ij')
cell_x, cell_y, cell_z = self._idx_to_grid(G, n_grid)
neighbour_inc = jax.random.randint(jax.random.PRNGKey(42), np.shape(G) + (3,),
-1, 2) # [num_cells, max_points_per_cell, 3]
neighbour_x = jnp.clip(cell_x + neighbour_inc[:, :, 0], 0, n_grid - 1)
neighbour_y = jnp.clip(cell_y + neighbour_inc[:, :, 1], 0, n_grid - 1)
neighbour_z = jnp.clip(cell_z + neighbour_inc[:, :, 2], 0, n_grid - 1)
neighbour_grid_idx = self._grid_to_idx(neighbour_x, neighbour_y, neighbour_z, n_grid)
random_select = jax.random.randint(jax.random.PRNGKey(42), np.shape(G), 0,
storage_indices[:, None]) # [num_cells, max_points_per_cell]
random_neighbour = grid[neighbour_grid_idx, random_select]

@partial(jax.vmap, in_axes=(0, 0))
@partial(jax.vmap, in_axes=(0, 0))
def check_cell(i, j):
return jnp.logical_not(jnp.any(grid[i] == random_neighbour[i, j]))

replace = (grid == -1) & check_cell(G, P)
grid = jnp.where(replace, random_neighbour, grid)
grid = jnp.sort(grid, axis=1, descending=True)
storage_indices = jnp.sum(grid != -1, axis=1)
return i + 1, grid, storage_indices

def cond(state):
i, grid, storage_indices = state
return jnp.any(grid == -1) & (i < 10)

_, grid, _ = jax.lax.while_loop(cond, body, (0, grid, storage_indices))

return GridTree3D(grid=grid, points=points, extent=extent)

def query(self, tree: GridTree3D, test_point: jax.Array, k: int = 1) -> Tuple[jax.Array, jax.Array]:
"""
Queries the tree structure to find the k nearest neighbors to the test point in 3D.
Parameters:
tree (GridTree3D): The tree structure built by the `build_tree` method.
test_point (jax.numpy.ndarray): A point in [a,b]x[c,d]x[e,f] with shape (3,).
k (int): The number of nearest neighbors to find.
Returns:
distances (jax.numpy.ndarray): Distances to the k nearest neighbors.
indices (jax.numpy.ndarray): Indices of the k nearest neighbors.
"""
n_grid = int(np.cbrt(np.shape(tree.grid)[0]))
cell_x, cell_y, cell_z = self._point_to_cell(test_point, n_grid, tree.extent)
grid_idx = self._grid_to_idx(cell_x, cell_y, cell_z, n_grid)
point_indices = tree.grid[grid_idx] # [max_points_per_cell]
points_in_cell = tree.points[point_indices] # [max_points_per_cell, 3]

valid_mask = point_indices >= 0

distances = jnp.linalg.norm(points_in_cell - test_point, axis=1) # [max_points_per_cell]
neg_distances = jnp.where(valid_mask, -distances, -jnp.inf)
top_k_neg_distances, top_k_indices_within_cell = jax.lax.top_k(neg_distances, k)
top_k_distances = -top_k_neg_distances

return top_k_distances, point_indices[top_k_indices_within_cell]
Loading

0 comments on commit f4bd23b

Please sign in to comment.