Skip to content

Commit b667634

Browse files
committed
feat(sv_split): sv split in frontend
1 parent 7982847 commit b667634

File tree

5 files changed

+118
-44
lines changed

5 files changed

+118
-44
lines changed

pychunkedgraph/app/segmentation/common.py

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@
99

1010
import numpy as np
1111
import pandas as pd
12+
import fastremap
1213
from flask import current_app, g, jsonify, make_response, request
1314
from pytz import UTC
1415

1516
from pychunkedgraph import __version__
1617
from pychunkedgraph.app import app_utils
17-
from pychunkedgraph.graph import (
18-
attributes,
19-
cutting,
20-
segmenthistory,
21-
)
18+
from pychunkedgraph.graph import attributes, cutting, segmenthistory, ChunkedGraph
2219
from pychunkedgraph.graph import (
2320
edges as cg_edges,
2421
)
@@ -27,6 +24,7 @@
2724
)
2825
from pychunkedgraph.graph.analysis import pathing
2926
from pychunkedgraph.graph.attributes import OperationLogs
27+
from pychunkedgraph.graph.edits_sv import split_supervoxel
3028
from pychunkedgraph.graph.misc import get_contact_sites
3129
from pychunkedgraph.graph.operation import GraphEditOperation
3230
from pychunkedgraph.graph.utils import basetypes
@@ -393,7 +391,7 @@ def handle_merge(table_id, allow_same_segment_merge=False):
393391
current_app.operation_id = ret.operation_id
394392
if ret.new_root_ids is None:
395393
raise cg_exceptions.InternalServerError(
396-
"Could not merge selected " "supervoxel."
394+
f"{ret.operation_id}: Could not merge selected supervoxels."
397395
)
398396

399397
current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids))
@@ -407,24 +405,10 @@ def handle_merge(table_id, allow_same_segment_merge=False):
407405
### SPLIT ----------------------------------------------------------------------
408406

409407

410-
def handle_split(table_id):
411-
current_app.table_id = table_id
412-
user_id = str(g.auth_user.get("id", current_app.user_id))
413-
414-
data = json.loads(request.data)
415-
is_priority = request.args.get("priority", True, type=str2bool)
416-
remesh = request.args.get("remesh", True, type=str2bool)
417-
mincut = request.args.get("mincut", True, type=str2bool)
418-
408+
def _get_sources_and_sinks(cg: ChunkedGraph, data):
419409
current_app.logger.debug(data)
420-
421-
# Call ChunkedGraph
422-
cg = app_utils.get_cg(table_id, skip_cache=True)
423410
node_idents = []
424-
node_ident_map = {
425-
"sources": 0,
426-
"sinks": 1,
427-
}
411+
node_ident_map = {"sources": 0, "sinks": 1}
428412
coords = []
429413
node_ids = []
430414

@@ -437,18 +421,74 @@ def handle_split(table_id):
437421
node_ids = np.array(node_ids, dtype=np.uint64)
438422
coords = np.array(coords)
439423
node_idents = np.array(node_idents)
424+
425+
start = time.time()
440426
sv_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids)
427+
current_app.logger.info(f"SV lookup took {time.time() - start}s.")
441428
current_app.logger.debug(
442429
{"node_id": node_ids, "sv_id": sv_ids, "node_ident": node_idents}
443430
)
444431

432+
source_ids = sv_ids[node_idents == 0]
433+
sink_ids = sv_ids[node_idents == 1]
434+
source_coords = coords[node_idents == 0]
435+
sink_coords = coords[node_idents == 1]
436+
return (source_ids, sink_ids, source_coords, sink_coords)
437+
438+
439+
def handle_split(table_id):
440+
current_app.table_id = table_id
441+
user_id = str(g.auth_user.get("id", current_app.user_id))
442+
443+
data = json.loads(request.data)
444+
is_priority = request.args.get("priority", True, type=str2bool)
445+
remesh = request.args.get("remesh", True, type=str2bool)
446+
mincut = request.args.get("mincut", True, type=str2bool)
447+
448+
cg = app_utils.get_cg(table_id, skip_cache=True)
449+
sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data)
445450
try:
446451
ret = cg.remove_edges(
447452
user_id=user_id,
448-
source_ids=sv_ids[node_idents == 0],
449-
sink_ids=sv_ids[node_idents == 1],
450-
source_coords=coords[node_idents == 0],
451-
sink_coords=coords[node_idents == 1],
453+
source_ids=sources,
454+
sink_ids=sinks,
455+
source_coords=source_coords,
456+
sink_coords=sink_coords,
457+
mincut=mincut,
458+
)
459+
except cg_exceptions.SupervoxelSplitRequiredError as e:
460+
current_app.logger.info(e)
461+
sources_remapped = fastremap.remap(
462+
sources,
463+
e.sv_remapping,
464+
preserve_missing_labels=True,
465+
in_place=False,
466+
)
467+
sinks_remapped = fastremap.remap(
468+
sinks,
469+
e.sv_remapping,
470+
preserve_missing_labels=True,
471+
in_place=False,
472+
)
473+
overlap_mask = np.isin(sources_remapped, sinks_remapped)
474+
for sv_to_split in np.unique(sources_remapped[overlap_mask]):
475+
_mask0 = sources_remapped[sources_remapped == sv_to_split]
476+
_mask1 = sinks_remapped[sinks_remapped == sv_to_split]
477+
split_supervoxel(
478+
cg,
479+
sv_to_split,
480+
source_coords[_mask0],
481+
sink_coords[_mask1],
482+
e.operation_id,
483+
)
484+
485+
sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data)
486+
ret = cg.remove_edges(
487+
user_id=user_id,
488+
source_ids=sources,
489+
sink_ids=sinks,
490+
source_coords=source_coords,
491+
sink_coords=sink_coords,
452492
mincut=mincut,
453493
)
454494
except cg_exceptions.LockingError as e:
@@ -459,7 +499,7 @@ def handle_split(table_id):
459499
current_app.operation_id = ret.operation_id
460500
if ret.new_root_ids is None:
461501
raise cg_exceptions.InternalServerError(
462-
"Could not split selected segment groups."
502+
f"{ret.operation_id}: Could not split selected segment groups."
463503
)
464504

465505
current_app.logger.debug(("after split:", ret.new_root_ids))

pychunkedgraph/graph/cutting.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
1-
import collections
21
import fastremap
32
import numpy as np
43
import itertools
5-
import logging
64
import time
75
import graph_tool
86
import graph_tool.flow
97

10-
from typing import Dict
118
from typing import Tuple
12-
from typing import Optional
139
from typing import Sequence
1410
from typing import Iterable
1511

1612
from .utils import flatgraph
17-
from .utils import basetypes
1813
from .utils.generic import get_bounding_box
1914
from .edges import Edges
20-
from .exceptions import PreconditionError
15+
from .exceptions import PreconditionError, SupervoxelSplitRequiredError
2116
from .exceptions import PostconditionError
2217

2318
DEBUG_MODE = False
@@ -116,6 +111,10 @@ def __init__(
116111
self.cross_chunk_edge_remapping,
117112
) = merge_cross_chunk_edges_graph_tool(cg_edges, cg_affs)
118113

114+
# save this representative mapping for supervoxel splitting
115+
# passed along with SupervoxelSplitRequiredError
116+
self.sv_remapping = dict(complete_mapping)
117+
119118
dt = time.time() - time_start
120119
if logger is not None:
121120
logger.debug("Cross edge merging: %.2fms" % (dt * 1000))
@@ -233,9 +232,10 @@ def _augment_mincut_capacity(self):
233232
self.source_graph_ids,
234233
)
235234
except AssertionError:
236-
raise PreconditionError(
235+
raise SupervoxelSplitRequiredError(
237236
"Paths between source or sink points irreparably overlap other labels from other side. "
238-
"Check that labels are correct and consider spreading points out farther."
237+
"Check that labels are correct and consider spreading points out farther.",
238+
self.sv_remapping
239239
)
240240

241241
paths_e_s_no, paths_e_y_no, do_check = flatgraph.remove_overlapping_edges(
@@ -581,11 +581,12 @@ def _sink_and_source_connectivity_sanity_check(self, cut_edge_set):
581581
# but return a flag to return a message to the user
582582
illegal_split = True
583583
else:
584-
raise PreconditionError(
584+
raise SupervoxelSplitRequiredError(
585585
"Failed to find a cut that separated the sources from the sinks. "
586586
"Please try another cut that partitions the sets cleanly if possible. "
587587
"If there is a clear path between all the supervoxels in each set, "
588-
"that helps the mincut algorithm."
588+
"that helps the mincut algorithm.",
589+
self.sv_remapping
589590
)
590591
except IsolatingCutException as e:
591592
if self.split_preview:

pychunkedgraph/graph/exceptions.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,25 @@
33

44
class ChunkedGraphError(Exception):
55
"""Base class for all exceptions raised by the ChunkedGraph"""
6+
67
pass
78

89

910
class LockingError(ChunkedGraphError):
1011
"""Raised when a Bigtable Lock could not be acquired"""
12+
1113
pass
1214

1315

1416
class PreconditionError(ChunkedGraphError):
1517
"""Raised when preconditions for Chunked Graph operations are not met"""
18+
1619
pass
1720

1821

1922
class PostconditionError(ChunkedGraphError):
2023
"""Raised when postconditions for Chunked Graph operations are not met"""
24+
2125
pass
2226

2327

@@ -42,7 +46,7 @@ def __init__(self, message):
4246
self.message = message
4347

4448
def __str__(self):
45-
return f'[{self.status_code}]: {self.message}'
49+
return f"[{self.status_code}]: {self.message}"
4650

4751

4852
class ClientError(ChunkedGraphAPIError):
@@ -51,21 +55,25 @@ class ClientError(ChunkedGraphAPIError):
5155

5256
class BadRequest(ClientError):
5357
"""Exception mapping a ``400 Bad Request`` response."""
58+
5459
status_code = http_client.BAD_REQUEST
5560

5661

5762
class Unauthorized(ClientError):
5863
"""Exception mapping a ``401 Unauthorized`` response."""
64+
5965
status_code = http_client.UNAUTHORIZED
6066

6167

6268
class Forbidden(ClientError):
6369
"""Exception mapping a ``403 Forbidden`` response."""
70+
6471
status_code = http_client.FORBIDDEN
6572

6673

6774
class Conflict(ClientError):
6875
"""Exception mapping a ``409 Conflict`` response."""
76+
6977
status_code = http_client.CONFLICT
7078

7179

@@ -75,9 +83,29 @@ class ServerError(ChunkedGraphAPIError):
7583

7684
class InternalServerError(ServerError):
7785
"""Exception mapping a ``500 Internal Server Error`` response."""
86+
7887
status_code = http_client.INTERNAL_SERVER_ERROR
7988

8089

8190
class GatewayTimeout(ServerError):
8291
"""Exception mapping a ``504 Gateway Timeout`` response."""
92+
8393
status_code = http_client.GATEWAY_TIMEOUT
94+
95+
96+
class SupervoxelSplitRequiredError(ChunkedGraphError):
97+
"""
98+
Raised when supervoxel splitting is necessary.
99+
Edit process should catch this error and retry after supervoxel has been split.
100+
Saves remapping required for detecting which supervoxels need to be split.
101+
"""
102+
103+
def __init__(
104+
self,
105+
message: str,
106+
sv_remapping: dict,
107+
operation_id: int | None = None,
108+
):
109+
super().__init__(message)
110+
self.sv_remapping = sv_remapping
111+
self.operation_id = operation_id

pychunkedgraph/graph/operation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .utils import serializers
2727
from .cache import CacheService
2828
from .cutting import run_multicut
29-
from .exceptions import PreconditionError
29+
from .exceptions import PreconditionError, SupervoxelSplitRequiredError
3030
from .exceptions import PostconditionError
3131
from .utils.generic import get_bounding_box as get_bbox
3232
from ..logging.log_db import TimeIt
@@ -451,6 +451,10 @@ def execute(
451451
new_root_ids=new_root_ids,
452452
new_lvl2_ids=new_lvl2_ids,
453453
)
454+
except SupervoxelSplitRequiredError as err:
455+
raise SupervoxelSplitRequiredError(
456+
str(err), err.sv_remapping, operation_id=lock.operation_id
457+
) from err
454458
except PreconditionError as err:
455459
self.cg.cache = None
456460
raise PreconditionError(err) from err
@@ -852,9 +856,10 @@ def __init__(
852856
self.path_augment = path_augment
853857
self.disallow_isolating_cut = disallow_isolating_cut
854858
if np.any(np.in1d(self.sink_ids, self.source_ids)):
855-
raise PreconditionError(
859+
raise SupervoxelSplitRequiredError(
856860
"Supervoxels exist in both sink and source, "
857-
"try placing the points further apart."
861+
"try placing the points further apart.",
862+
None,
858863
)
859864

860865
ids = np.concatenate([self.source_ids, self.sink_ids])

pychunkedgraph/repair/fake_edges.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from os import environ
1010
from typing import Optional
1111

12-
environ["BIGTABLE_PROJECT"] = "<>"
13-
environ["BIGTABLE_INSTANCE"] = "<>"
14-
environ["GOOGLE_APPLICATION_CREDENTIALS"] = "<path>"
12+
# environ["BIGTABLE_PROJECT"] = "<>"
13+
# environ["BIGTABLE_INSTANCE"] = "<>"
14+
# environ["GOOGLE_APPLICATION_CREDENTIALS"] = "<path>"
1515

1616
from pychunkedgraph.graph import edits
1717
from pychunkedgraph.graph import ChunkedGraph

0 commit comments

Comments
 (0)