Skip to content

Commit 0fe47cd

Browse files
committed
fix(edits/migration): pass edge layer to helper function; handle stale source nodes using atomic edges; improve latency with lru cache
1 parent 6c90cc1 commit 0fe47cd

File tree

7 files changed

+111
-63
lines changed

7 files changed

+111
-63
lines changed

pychunkedgraph/graph/edges/__init__.py

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tensorstore as ts
1313
import zstandard as zstd
1414
from graph_tool import Graph
15+
from cachetools import LRUCache
1516

1617
from pychunkedgraph.graph import types
1718
from pychunkedgraph.graph.chunks.utils import (
@@ -21,6 +22,7 @@
2122
from pychunkedgraph.graph.utils import basetypes
2223

2324
from ..utils import basetypes
25+
from ..utils.generic import get_parents_at_timestamp
2426

2527

2628
_edge_type_fileds = ("in_chunk", "between_chunk", "cross_chunk")
@@ -39,6 +41,7 @@
3941
]
4042
)
4143
ZSTD_EDGE_COMPRESSION = 17
44+
PARENTS_CACHE = LRUCache(256 * 1024)
4245

4346

4447
class Edges:
@@ -341,7 +344,72 @@ def _filter(node):
341344
chunks_map[node_b] = np.concatenate(chunks_map[node_b])
342345
return int(mlayer), _filter(node_a), _filter(node_b)
343346

344-
def _get_new_edge(edge, parent_ts, padding):
347+
def _populate_parents_cache(children: np.ndarray):
348+
global PARENTS_CACHE
349+
350+
not_cached = []
351+
for child in children:
352+
try:
353+
# reset lru index, these will be needed soon
354+
_ = PARENTS_CACHE[child]
355+
except KeyError:
356+
not_cached.append(child)
357+
358+
all_parents = cg.get_parents(not_cached, current=False)
359+
for child, parents in zip(not_cached, all_parents):
360+
PARENTS_CACHE[child] = {}
361+
for parent, ts in parents:
362+
PARENTS_CACHE[child][ts] = parent
363+
364+
def _get_parents_b(edges, parent_ts, layer):
365+
"""
366+
Attempts to find new partner side nodes.
367+
Gets new partners at parent_ts using supervoxels, at `parent_ts`.
368+
Searches for new partners that may have any edges to `edges[:,0]`.
369+
"""
370+
children_b = cg.get_children(edges[:, 1], flatten=True)
371+
_populate_parents_cache(children_b)
372+
_parents_b, missing = get_parents_at_timestamp(
373+
children_b, PARENTS_CACHE, time_stamp=parent_ts, unique=True
374+
)
375+
# handle cache miss cases
376+
_parents_b_missing = np.unique(cg.get_parents(missing, time_stamp=parent_ts))
377+
parents_b = np.concatenate([_parents_b, _parents_b_missing])
378+
379+
parents_a = edges[:, 0]
380+
stale_a = get_stale_nodes(cg, parents_a, parent_ts=parent_ts)
381+
if stale_a.size == parents_a.size:
382+
# this is applicable only for v2 to v3 migration
383+
# handle cases when source nodes in `edges[:,0]` are stale
384+
atomic_edges_d = cg.get_atomic_cross_edges(stale_a)
385+
partners = [types.empty_1d]
386+
for _edges_d in atomic_edges_d.values():
387+
_edges = _edges_d.get(layer, types.empty_2d)
388+
partners.append(_edges[:, 1])
389+
partners = np.concatenate(partners)
390+
return np.unique(cg.get_parents(partners, time_stamp=parent_ts))
391+
392+
_cx_edges_d = cg.get_cross_chunk_edges(parents_b, time_stamp=parent_ts)
393+
_parents_b = []
394+
for _node, _edges_d in _cx_edges_d.items():
395+
for _edges in _edges_d.values():
396+
_mask = np.isin(_edges[:, 1], parents_a)
397+
if np.any(_mask):
398+
_parents_b.append(_node)
399+
return np.array(_parents_b, dtype=basetypes.NODE_ID)
400+
401+
def _get_parents_b_with_chunk_mask(
402+
l2ids_b: np.ndarray, parents_b: np.ndarray, max_ts: datetime.datetime, edge
403+
):
404+
chunks_old = cg.get_chunk_ids_from_node_ids(l2ids_b)
405+
chunks_new = cg.get_chunk_ids_from_node_ids(parents_b)
406+
chunk_mask = np.isin(chunks_new, chunks_old)
407+
parents_b = parents_b[chunk_mask]
408+
_stale_nodes = get_stale_nodes(cg, parents_b, parent_ts=max_ts)
409+
assert _stale_nodes.size == 0, f"{edge}, {_stale_nodes}, {parent_ts}"
410+
return parents_b
411+
412+
def _get_new_edge(edge, edge_layer, parent_ts, padding):
345413
"""
346414
Attempts to find new edge(s) for the stale `edge`.
347415
* Find L2 IDs on opposite sides of the face in L2 chunks along the face.
@@ -353,11 +421,11 @@ def _get_new_edge(edge, parent_ts, padding):
353421
if l2ids_a.size == 0 or l2ids_b.size == 0:
354422
return types.empty_2d.copy()
355423

356-
_edges = []
357424
max_node_ts = max(nodes_ts_map[node_a], nodes_ts_map[node_b])
358425
_edges_d = cg.get_cross_chunk_edges(
359426
node_ids=l2ids_a, time_stamp=max_node_ts, raw_only=True
360427
)
428+
_edges = []
361429
for v in _edges_d.values():
362430
if edge_layer in v:
363431
_edges.append(v[edge_layer])
@@ -369,27 +437,13 @@ def _get_new_edge(edge, parent_ts, padding):
369437

370438
mask = np.isin(_edges[:, 1], l2ids_b)
371439
if np.any(mask):
372-
parents_a = _edges[mask][:, 0]
373-
children_b = cg.get_children(_edges[mask][:, 1], flatten=True)
374-
parents_b = np.unique(cg.get_parents(children_b, time_stamp=parent_ts))
375-
_cx_edges_d = cg.get_cross_chunk_edges(parents_b, time_stamp=parent_ts)
376-
parents_b = []
377-
for _node, _edges_d in _cx_edges_d.items():
378-
for _edges in _edges_d.values():
379-
_mask = np.isin(_edges[:, 1], parents_a)
380-
if np.any(_mask):
381-
parents_b.append(_node)
382-
parents_b = np.array(parents_b, dtype=basetypes.NODE_ID)
440+
parents_b = _get_parents_b(_edges[mask], parent_ts, edge_layer)
383441
else:
384442
# if none of `l2ids_b` were found in edges, `l2ids_a` already have new edges
385443
# so get the new identities of `l2ids_b` by using chunk mask
386-
parents_b = _edges[:, 1]
387-
chunks_old = cg.get_chunk_ids_from_node_ids(l2ids_b)
388-
chunks_new = cg.get_chunk_ids_from_node_ids(parents_b)
389-
chunk_mask = np.isin(chunks_new, chunks_old)
390-
parents_b = parents_b[chunk_mask]
391-
_stale_nodes = get_stale_nodes(cg, parents_b, parent_ts=max_node_ts)
392-
assert _stale_nodes.size == 0, f"{edge}, {_stale_nodes}, {parent_ts}"
444+
parents_b = _get_parents_b_with_chunk_mask(
445+
l2ids_b, _edges[:, 1], max_node_ts, edge
446+
)
393447

394448
parents_b = np.unique(
395449
cg.get_roots(parents_b, stop_layer=mlayer, ceil=False, time_stamp=parent_ts)
@@ -402,7 +456,7 @@ def _get_new_edge(edge, parent_ts, padding):
402456
for edge_layer, _edge in zip(edge_layers, stale_edges):
403457
max_chebyshev_distance = int(environ.get("MAX_CHEBYSHEV_DISTANCE", 3))
404458
for pad in range(0, max_chebyshev_distance):
405-
_new_edges = _get_new_edge(_edge, parent_ts, padding=pad)
459+
_new_edges = _get_new_edge(_edge, edge_layer, parent_ts, padding=pad)
406460
if _new_edges.size:
407461
break
408462
logging.info(f"{_edge}, expanding search with padding {pad+1}.")
@@ -446,7 +500,7 @@ def get_latest_edges_wrapper(
446500
stale_edge_layers,
447501
parent_ts=parent_ts,
448502
)
449-
logging.debug(f"{stale_edges} -> {latest_edges}; {parent_ts}")
503+
logging.debug(f"{stale_edges} -> {latest_edges[:,1].tolist()}; {parent_ts}")
450504
_new_cx_edges.append(latest_edges)
451505
new_cx_edges_d[layer] = np.concatenate(_new_cx_edges)
452506
nodes.append(np.unique(new_cx_edges_d[layer]))

pychunkedgraph/graph/utils/generic.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
TODO categorize properly
44
"""
55

6-
76
import datetime
87
from typing import Dict
98
from typing import Iterable
@@ -173,14 +172,30 @@ def mask_nodes_by_bounding_box(
173172
adapt_layers = layers - 2
174173
adapt_layers[adapt_layers < 0] = 0
175174
fanout = meta.graph_config.FANOUT
176-
bounding_box_layer = (
177-
bounding_box[None] / (fanout ** adapt_layers)[:, None, None]
178-
)
175+
bounding_box_layer = bounding_box[None] / (fanout**adapt_layers)[:, None, None]
179176
bound_check = np.array(
180177
[
181178
np.all(chunk_coordinates < bounding_box_layer[:, 1], axis=1),
182179
np.all(chunk_coordinates + 1 > bounding_box_layer[:, 0], axis=1),
183180
]
184181
).T
185182

186-
return np.all(bound_check, axis=1)
183+
return np.all(bound_check, axis=1)
184+
185+
186+
def get_parents_at_timestamp(nodes, parents_ts_map, time_stamp, unique: bool = False):
187+
"""
188+
Search for the first parent with ts <= `time_stamp`.
189+
`parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc).
190+
"""
191+
skipped_nodes = []
192+
parents = set() if unique else []
193+
for node in nodes:
194+
try:
195+
for ts, parent in parents_ts_map[node].items():
196+
if time_stamp >= ts:
197+
parents.add(parent) if unique else parents.append(parent)
198+
break
199+
except KeyError:
200+
skipped_nodes.append(node)
201+
return list(parents), skipped_nodes

pychunkedgraph/ingest/cluster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Callable, Dict, Iterable, Tuple, Sequence
1111

1212
import numpy as np
13-
from rq import Queue as RQueue
13+
from rq import Queue as RQueue, Retry
1414

1515

1616
from .utils import chunk_id_str, get_chunks_not_done, randomize_grid_points
@@ -209,6 +209,7 @@ def _queue_tasks(imanager: IngestionManager, chunk_fn: Callable, coords: Iterabl
209209
timeout=environ.get("L2JOB_TIMEOUT", "3m"),
210210
result_ttl=0,
211211
job_id=chunk_id_str(2, chunk_coord),
212+
retry=Retry(int(environ.get("RETRY_COUNT", 1))),
212213
)
213214
)
214215
q.enqueue_many(job_datas)

pychunkedgraph/ingest/upgrade/atomic_layer.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,13 @@
1010
from pychunkedgraph.graph import ChunkedGraph, types
1111
from pychunkedgraph.graph.attributes import Connectivity, Hierarchy
1212
from pychunkedgraph.graph.utils import serializers
13+
from pychunkedgraph.graph.utils.generic import get_parents_at_timestamp
1314

1415
from .utils import fix_corrupt_nodes, get_end_timestamps, get_parent_timestamps
1516

1617
CHILDREN = {}
1718

1819

19-
def _get_parents_at_timestamp(nodes, parents_ts_map, time_stamp):
20-
"""
21-
Search for the first parent with ts <= `time_stamp`.
22-
`parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc).
23-
"""
24-
parents = []
25-
for node in nodes:
26-
for ts, parent in parents_ts_map[node].items():
27-
if time_stamp >= ts:
28-
parents.append(parent)
29-
break
30-
return parents
31-
32-
3320
def update_cross_edges(
3421
cg: ChunkedGraph,
3522
node,
@@ -59,7 +46,7 @@ def update_cross_edges(
5946
break
6047

6148
val_dict = {}
62-
parents = _get_parents_at_timestamp(partners, parents_ts_map, ts)
49+
parents, _ = get_parents_at_timestamp(partners, parents_ts_map, ts)
6350
edge_parents_d = dict(zip(partners, parents))
6451
for layer, layer_edges in cx_edges_d.items():
6552
layer_edges = fastremap.remap(

pychunkedgraph/ingest/upgrade/parent_layer.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import numpy as np
1111
from tqdm import tqdm
1212

13-
from pychunkedgraph.graph import ChunkedGraph
13+
from pychunkedgraph.graph import ChunkedGraph, edges
1414
from pychunkedgraph.graph.attributes import Connectivity, Hierarchy
15-
from pychunkedgraph.graph.edges import get_latest_edges_wrapper
1615
from pychunkedgraph.graph.utils import serializers
1716
from pychunkedgraph.graph.types import empty_2d
1817
from pychunkedgraph.utils.general import chunked
@@ -105,7 +104,6 @@ def _populate_cx_edges_with_timestamps(
105104
row_id = serializers.serialize_uint64(node)
106105
val_dict = {Hierarchy.StaleTimeStamp: 0}
107106
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=node_end_ts))
108-
109107
cg.client.write(rows)
110108

111109

@@ -119,7 +117,7 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts) -> list:
119117
for ts, cx_edges_d in CX_EDGES[node].items():
120118
if ts < node_ts:
121119
continue
122-
cx_edges_d, edge_nodes = get_latest_edges_wrapper(cg, cx_edges_d, parent_ts=ts)
120+
cx_edges_d, edge_nodes = edges.get_latest_edges_wrapper(cg, cx_edges_d, parent_ts=ts)
123121
if edge_nodes.size == 0:
124122
continue
125123

@@ -138,13 +136,7 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts) -> list:
138136
return rows
139137

140138

141-
def _update_cross_edges_helper_thread(args):
142-
cg, layer, node, node_ts = args
143-
return update_cross_edges(cg, layer, node, node_ts)
144-
145-
146139
def _update_cross_edges_helper(args):
147-
rows = []
148140
clean_task = os.environ.get("CLEAN_CHUNKS", "false") == "clean"
149141
cg_info, layer, nodes, nodes_ts = args
150142
cg = ChunkedGraph(**cg_info)
@@ -167,12 +159,9 @@ def _update_cross_edges_helper(args):
167159
fix_corrupt_nodes(cg, corrupt_nodes, CHILDREN)
168160
return
169161

170-
with ThreadPoolExecutor(max_workers=4) as executor:
171-
futures = [
172-
executor.submit(_update_cross_edges_helper_thread, task) for task in tasks
173-
]
174-
for future in tqdm(as_completed(futures), total=len(futures)):
175-
rows.extend(future.result())
162+
rows = []
163+
for task in tasks:
164+
rows.extend(update_cross_edges(*task))
176165
cg.client.write(rows)
177166

178167

@@ -204,12 +193,13 @@ def update_chunk(
204193

205194
if debug:
206195
rows = []
196+
logging.info(f"processing {len(nodes)} nodes with 1 worker.")
207197
for node, node_ts in zip(nodes, nodes_ts):
208198
rows.extend(update_cross_edges(cg, layer, node, node_ts))
209199
logging.info(f"total elaspsed time: {time.time() - start}")
210200
return
211201

212-
task_size = int(math.ceil(len(nodes) / mp.cpu_count() / 2))
202+
task_size = int(math.ceil(len(nodes) / mp.cpu_count()))
213203
chunked_nodes = chunked(nodes, task_size)
214204
chunked_nodes_ts = chunked(nodes_ts, task_size)
215205
cg_info = cg.get_serialized_info()

pychunkedgraph/ingest/upgrade/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def get_parent_timestamps(
108108

109109
def fix_corrupt_nodes(cg: ChunkedGraph, nodes: list, children_d: dict):
110110
"""
111-
Iteratively removes a node from parent column of its children.
112-
Then removes the node iteself, effectively erasing it.
111+
For each node: delete it from parent column of its children.
112+
Then deletes the node itself, effectively erasing it from hierarchy.
113113
"""
114114
table = cg.client._table
115115
batcher = table.mutations_batcher(flush_count=500)

pychunkedgraph/ingest/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111
import tensorstore as ts
12-
from rq import Queue, Worker
12+
from rq import Queue, Retry, Worker
1313
from rq.worker import WorkerStatus
1414

1515
from . import IngestConfig
@@ -199,6 +199,7 @@ def queue_layer_helper(parent_layer: int, imanager: IngestionManager, fn):
199199
result_ttl=0,
200200
job_id=chunk_id_str(parent_layer, chunk_coord),
201201
timeout=f"{timeout_scale * int(parent_layer * parent_layer)}m",
202+
retry=Retry(int(environ.get("RETRY_COUNT", 1))),
202203
)
203204
)
204205
q.enqueue_many(job_datas)

0 commit comments

Comments
 (0)