Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bounding box argument to level 2 chunk graph endpoint #488

Merged
merged 5 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 35 additions & 43 deletions pychunkedgraph/app/segmentation/common.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
# pylint: disable=invalid-name, missing-docstring

import json
import time
import os
import time
from datetime import datetime
from functools import reduce

import numpy as np
from pytz import UTC
import pandas as pd

from flask import current_app, g, jsonify, make_response, request
from pytz import UTC

from pychunkedgraph import __version__
from pychunkedgraph.app import app_utils
from pychunkedgraph.graph import (
attributes,
cutting,
exceptions as cg_exceptions,
segmenthistory,
)
from pychunkedgraph.graph import (
edges as cg_edges,
)
from pychunkedgraph.graph import segmenthistory
from pychunkedgraph.graph.utils import basetypes
from pychunkedgraph.graph import (
exceptions as cg_exceptions,
)
from pychunkedgraph.graph.analysis import pathing
from pychunkedgraph.graph.attributes import OperationLogs
from pychunkedgraph.meshing import mesh_analysis
from pychunkedgraph.graph.misc import get_contact_sites
from pychunkedgraph.graph.operation import GraphEditOperation

from pychunkedgraph.graph.utils import basetypes
from pychunkedgraph.meshing import mesh_analysis

__api_versions__ = [0, 1]
__segmentation_url_prefix__ = os.environ.get("SEGMENTATION_URL_PREFIX", "segmentation")
Expand Down Expand Up @@ -72,6 +75,15 @@ def _parse_timestamp(
)


def _get_bounds_from_request(request):
if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
return bounding_box


# -------------------
# ------ Applications
# -------------------
Expand All @@ -93,9 +105,7 @@ def handle_info(table_id):
combined_info["verify_mesh"] = cg.meta.custom_data.get("mesh", {}).get(
"verify", False
)
mesh_dir = cg.meta.custom_data.get("mesh", {}).get(
"dir", None
)
mesh_dir = cg.meta.custom_data.get("mesh", {}).get("dir", None)
if mesh_dir is not None:
combined_info["mesh_dir"] = mesh_dir
elif combined_info.get("mesh_dir", None) is not None:
Expand Down Expand Up @@ -216,7 +226,7 @@ def publish_edit(
table_id: str, user_id: str, result: GraphEditOperation.Result, is_priority=True
):
import pickle
from os import getenv

from messagingclient import MessagingClient

attributes = {
Expand Down Expand Up @@ -454,7 +464,7 @@ def handle_rollback(table_id):
continue
try:
ret = cg.undo_operation(user_id=target_user_id, operation_id=operation_id)
except cg_exceptions.LockingError as e:
except cg_exceptions.LockingError:
raise cg_exceptions.InternalServerError(
"Could not acquire root lock for undo operation."
)
Expand Down Expand Up @@ -506,14 +516,14 @@ def all_user_operations(
user_id = entry[OperationLogs.UserID]

should_check = (
not OperationLogs.Status in entry
OperationLogs.Status not in entry
or entry[OperationLogs.Status] == OperationLogs.StatusCodes.SUCCESS.value
)

split_valid = (
include_partial_splits
or (OperationLogs.AddedEdge in entry)
or (not OperationLogs.RootID in entry)
or (OperationLogs.RootID not in entry)
or (len(entry[OperationLogs.RootID]) > 1)
)
if not split_valid:
Expand Down Expand Up @@ -589,15 +599,11 @@ def handle_leaves(table_id, root_id):
user_id = str(g.auth_user.get("id", current_app.user_id))

stop_layer = int(request.args.get("stop_layer", 1))
bounding_box = None
if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T

bounding_box = _get_bounds_from_request(request)

cg = app_utils.get_cg(table_id)
if stop_layer > 1:
from pychunkedgraph.graph.types import empty_1d

subgraph = cg.get_subgraph_nodes(
int(root_id),
bbox=bounding_box,
Expand All @@ -621,11 +627,7 @@ def handle_leaves_many(table_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

node_ids = np.array(json.loads(request.data)["node_ids"], dtype=np.uint64)
stop_layer = int(request.args.get("stop_layer", 1))
Expand All @@ -652,11 +654,7 @@ def handle_leaves_from_leave(table_id, atomic_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

# Call ChunkedGraph
cg = app_utils.get_cg(table_id)
Expand All @@ -676,11 +674,7 @@ def handle_subgraph(table_id, root_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

# Call ChunkedGraph
cg = app_utils.get_cg(table_id)
Expand Down Expand Up @@ -820,6 +814,7 @@ def merge_log(table_id, root_id):

def handle_lineage_graph(table_id, root_id=None):
from networkx import node_link_data

from pychunkedgraph.graph.lineage import lineage_graph

current_app.table_id = table_id
Expand Down Expand Up @@ -890,11 +885,7 @@ def handle_contact_sites(table_id, root_id):

timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True)

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

# Call ChunkedGraph
cg = app_utils.get_cg(table_id)
Expand Down Expand Up @@ -1042,9 +1033,11 @@ def handle_get_layer2_graph(table_id, node_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

bounding_box = _get_bounds_from_request(request)

cg = app_utils.get_cg(table_id)
print("Finding edge graph...")
edge_graph = pathing.get_lvl2_edge_list(cg, int(node_id))
edge_graph = pathing.get_lvl2_edge_list(cg, int(node_id), bbox=bounding_box)
print("Edge graph found len: {}".format(len(edge_graph)))
return {"edge_graph": edge_graph}

Expand Down Expand Up @@ -1089,7 +1082,6 @@ def handle_root_timestamps(table_id, is_binary):


def operation_details(table_id):
from pychunkedgraph.graph import attributes
from pychunkedgraph.export.operation_logs import parse_attr

current_app.table_id = table_id
Expand Down
31 changes: 18 additions & 13 deletions pychunkedgraph/app/segmentation/v1/routes.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
# pylint: disable=invalid-name, missing-docstring

import io
import csv
import pickle
import io
import json
import pickle

import pandas as pd
import numpy as np
from flask import make_response
from flask import Blueprint, request
from middle_auth_client import auth_requires_permission
from middle_auth_client import auth_requires_admin
from middle_auth_client import auth_required
import pandas as pd
from flask import Blueprint, make_response, request
from middle_auth_client import (
auth_required,
auth_requires_admin,
auth_requires_permission,
)

from pychunkedgraph.app import common as app_common
from pychunkedgraph.app.app_utils import (
jsonify_with_kwargs,
toboolean,
tobinary,
remap_public,
tobinary,
toboolean,
)
from pychunkedgraph.app import common as app_common
from pychunkedgraph.app.segmentation import common
from pychunkedgraph.graph import exceptions as cg_exceptions


bp = Blueprint(
"pcg_segmentation_v1",
__name__,
Expand Down Expand Up @@ -561,7 +561,12 @@ def handle_roots_from_coords(table_id):
def handle_get_lvl2_graph(table_id, node_id):
int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean)
resp = common.handle_get_layer2_graph(table_id, node_id)
return jsonify_with_kwargs(resp, int64_as_str=int64_as_str)
out = jsonify_with_kwargs(resp, int64_as_str=int64_as_str)
if "bounds" in request.args:
out.headers["Used-Bounds"] = True
else:
out.headers["Used-Bounds"] = False
return out


### GET OPERATION DETAILS --------------------------------------------------------
Expand Down
71 changes: 50 additions & 21 deletions pychunkedgraph/graph/analysis/pathing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import typing

import fastremap
import graph_tool
import numpy as np

from pychunkedgraph.graph.utils import flatgraph

from ..subgraph import get_subgraph_nodes


def get_first_shared_parent(
cg, first_node_id: np.uint64, second_node_id: np.uint64, time_stamp=None
Expand All @@ -18,9 +23,7 @@ def get_first_shared_parent(
second_node_parent_ids = set()
cur_first_node_parent = first_node_id
cur_second_node_parent = second_node_id
while (
cur_first_node_parent is not None or cur_second_node_parent is not None
):
while cur_first_node_parent is not None or cur_second_node_parent is not None:
if cur_first_node_parent is not None:
first_node_parent_ids.add(cur_first_node_parent)
if cur_second_node_parent is not None:
Expand Down Expand Up @@ -70,14 +73,40 @@ def get_children_at_layer(
return np.concatenate(children_at_layer)


def get_lvl2_edge_list(cg, node_id: np.uint64):
def get_lvl2_edge_list(
cg,
node_id: np.uint64,
bbox: typing.Optional[typing.Sequence[typing.Sequence[int]]] = None,
):
"""get an edge list of lvl2 ids for a particular node

:param cg: ChunkedGraph object
:param node_id: np.uint64 that you want the edge list for
:param bbox: Optional[Sequence[Sequence[int]]] a bounding box to limit the search
"""

lvl2_ids = get_children_at_layer(cg, node_id, 2)
if bbox is None:
# maybe temporary, this was the old implementation
bdpedigo marked this conversation as resolved.
Show resolved Hide resolved
lvl2_ids = get_children_at_layer(cg, node_id, 2)
else:
lvl2_ids = get_subgraph_nodes(
cg,
node_id,
bbox=bbox,
bbox_is_coordinate=True,
return_layers=[2],
return_flattened=True,
)

edges = _get_edges_for_lvl2_ids(cg, lvl2_ids, induced=True)
return edges


def _get_edges_for_lvl2_ids(cg, lvl2_ids, induced=False):
# protect in case there are no lvl2 ids
if len(lvl2_ids) == 0:
return np.empty((0, 2), dtype=np.uint64)

cce_dict = cg.get_atomic_cross_edges(lvl2_ids)

# Gather all of the supervoxel ids into two lists, we will map them to
Expand All @@ -95,19 +124,15 @@ def get_lvl2_edge_list(cg, node_id: np.uint64):
known_supervoxels_for_lv2_id = cce_dict[lvl2_id][level][:, 0]
unknown_supervoxels_for_lv2_id = cce_dict[lvl2_id][level][:, 1]
known_supervoxels_list.append(known_supervoxels_for_lv2_id)
known_l2_list.append(
np.full(known_supervoxels_for_lv2_id.shape, lvl2_id)
)
known_l2_list.append(np.full(known_supervoxels_for_lv2_id.shape, lvl2_id))
unknown_supervoxel_list.append(unknown_supervoxels_for_lv2_id)

# Create two arrays to map supervoxels for which we know their parents
known_supervoxel_array, unique_indices = np.unique(
np.concatenate(known_supervoxels_list), return_index=True
)
known_l2_array = (np.concatenate(known_l2_list))[unique_indices]
unknown_supervoxel_array = np.unique(
np.concatenate(unknown_supervoxel_list)
)
unknown_supervoxel_array = np.unique(np.concatenate(unknown_supervoxel_list))

# Call get_parents on any supervoxels for which we don't know their parents
supervoxels_to_query_parent = np.setdiff1d(
Expand All @@ -123,10 +148,18 @@ def get_lvl2_edge_list(cg, node_id: np.uint64):
# Map the cross-chunk edges from supervoxels to lvl2 ids
edge_view = edge_array.view()
edge_view.shape = -1
fastremap.remap_from_array_kv(
edge_view, known_supervoxel_array, known_l2_array
)
return np.unique(np.sort(edge_array, axis=1), axis=0)
fastremap.remap_from_array_kv(edge_view, known_supervoxel_array, known_l2_array)

edge_array = np.unique(np.sort(edge_array, axis=1), axis=0)

if induced:
# make this an induced subgraph
# keep only the edges that are between the lvl2 ids asked for
edge_array = edge_array[
np.isin(edge_array[:, 0], lvl2_ids) & np.isin(edge_array[:, 1], lvl2_ids)
]

return edge_array


def find_l2_shortest_path(
Expand Down Expand Up @@ -164,9 +197,7 @@ def find_l2_shortest_path(
)

# Remap the graph-tool ids to lvl2 ids and return the path
vertex_indices = [
weighted_graph.vertex_index[vertex] for vertex in vertex_list
]
vertex_indices = [weighted_graph.vertex_index[vertex] for vertex in vertex_list]
l2_traversal_path = graph_indexed_l2_ids[vertex_indices]
return l2_traversal_path

Expand All @@ -181,9 +212,7 @@ def compute_rough_coordinate_path(cg, l2_ids):
"""
coordinate_path = []
for l2_id in l2_ids:
chunk_center = cg.get_chunk_coordinates(l2_id) + np.array(
[0.5, 0.5, 0.5]
)
chunk_center = cg.get_chunk_coordinates(l2_id) + np.array([0.5, 0.5, 0.5])
coordinate = chunk_center * np.array(
cg.meta.graph_config.CHUNK_SIZE
) + np.array(cg.meta.cv.mip_voxel_offset(0))
Expand Down
Loading