Skip to content

Commit 8b92f00

Browse files
committed
perf(upgrade): reduce latency for atomic layer chunks
1 parent 01a3646 commit 8b92f00

File tree

1 file changed

+44
-40
lines changed

1 file changed

+44
-40
lines changed

pychunkedgraph/ingest/upgrade/atomic_layer.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,43 @@
11
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22

33
from collections import defaultdict
4-
from concurrent.futures import ThreadPoolExecutor, as_completed
54
from datetime import datetime, timedelta, timezone
6-
import logging, math, time
5+
import logging, time
76
from copy import copy
87

98
import fastremap
109
import numpy as np
11-
from tqdm import tqdm
1210
from pychunkedgraph.graph import ChunkedGraph, types
1311
from pychunkedgraph.graph.attributes import Connectivity, Hierarchy
1412
from pychunkedgraph.graph.utils import serializers
15-
from pychunkedgraph.utils.general import chunked
1613

1714
from .utils import get_end_timestamps, get_parent_timestamps
1815

1916
CHILDREN = {}
2017

2118

19+
def _get_parents_at_timestamp(nodes, parents_ts_map, time_stamp):
20+
"""
21+
Search for the first parent with ts <= `time_stamp`.
22+
`parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc).
23+
"""
24+
parents = []
25+
for node in nodes:
26+
for ts, parent in parents_ts_map[node].items():
27+
if time_stamp >= ts:
28+
parents.append(parent)
29+
break
30+
return parents
31+
32+
2233
def update_cross_edges(
2334
cg: ChunkedGraph,
2435
node,
2536
cx_edges_d: dict,
2637
node_ts,
2738
node_end_ts,
28-
timestamps_d: defaultdict[int, set],
39+
timestamps_map: defaultdict[int, set],
40+
parents_ts_map: defaultdict[int, dict],
2941
) -> list:
3042
"""
3143
Helper function to update a single L2 ID.
@@ -35,9 +47,9 @@ def update_cross_edges(
3547
edges = np.concatenate(list(cx_edges_d.values()))
3648
partners = np.unique(edges[:, 1])
3749

38-
timestamps = copy(timestamps_d[node])
50+
timestamps = copy(timestamps_map[node])
3951
for partner in partners:
40-
timestamps.update(timestamps_d[partner])
52+
timestamps.update(timestamps_map[partner])
4153

4254
node_end_ts = node_end_ts or datetime.now(timezone.utc)
4355
for ts in sorted(timestamps):
@@ -47,7 +59,7 @@ def update_cross_edges(
4759
break
4860

4961
val_dict = {}
50-
parents = cg.get_parents(partners, time_stamp=ts)
62+
parents = _get_parents_at_timestamp(partners, parents_ts_map, ts)
5163
edge_parents_d = dict(zip(partners, parents))
5264
for layer, layer_edges in cx_edges_d.items():
5365
layer_edges = fastremap.remap(
@@ -63,6 +75,7 @@ def update_cross_edges(
6375

6476

6577
def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
78+
start = time.time()
6679
if children_map is None:
6780
children_map = CHILDREN
6881
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, children_map, layer=2)
@@ -75,31 +88,39 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
7588
all_partners = np.unique(np.concatenate(all_cx_edges)[:, 1])
7689
timestamps_d = get_parent_timestamps(cg, np.concatenate([nodes, all_partners]))
7790

91+
parents_ts_map = defaultdict(dict)
92+
all_parents = cg.get_parents(all_partners, current=False)
93+
for partner, parents in zip(all_partners, all_parents):
94+
for parent, ts in parents:
95+
parents_ts_map[partner][ts] = parent
96+
logging.info(f"update_nodes init {len(nodes)}: {time.time() - start}")
97+
7898
rows = []
99+
skipped = []
79100
for node, node_ts, end_ts in zip(nodes, nodes_ts, end_timestamps):
80101
is_stale = end_ts is not None
81102
_cx_edges_d = cx_edges_d.get(node, {})
82-
if not _cx_edges_d:
83-
continue
84103
if is_stale:
85104
end_ts -= timedelta(milliseconds=1)
86-
87-
_rows = update_cross_edges(cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d)
88-
if is_stale:
89105
row_id = serializers.serialize_uint64(node)
90106
val_dict = {Hierarchy.StaleTimeStamp: 0}
91-
_rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=end_ts))
92-
rows.extend(_rows)
93-
94-
return rows
107+
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=end_ts))
95108

109+
if not _cx_edges_d:
110+
skipped.append(node)
111+
continue
96112

97-
def _update_nodes_helper(args):
98-
cg, nodes, nodes_ts = args
99-
return update_nodes(cg, nodes, nodes_ts)
113+
_rows = update_cross_edges(
114+
cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d, parents_ts_map
115+
)
116+
rows.extend(_rows)
117+
parents = cg.get_roots(skipped)
118+
layers = cg.get_chunk_layers(parents)
119+
assert np.all(layers == cg.meta.layer_count)
120+
return rows
100121

101122

102-
def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], debug: bool = False):
123+
def update_chunk(cg: ChunkedGraph, chunk_coords: list[int]):
103124
"""
104125
Iterate over all L2 IDs in a chunk and update their cross chunk edges,
105126
within the periods they were valid/active.
@@ -132,23 +153,6 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], debug: bool = False)
132153
else:
133154
return
134155

135-
if debug:
136-
rows = update_nodes(cg, nodes, nodes_ts)
137-
else:
138-
task_size = int(math.ceil(len(nodes) / 16))
139-
chunked_nodes = chunked(nodes, task_size)
140-
chunked_nodes_ts = chunked(nodes_ts, task_size)
141-
tasks = []
142-
for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts):
143-
args = (cg, chunk, ts_chunk)
144-
tasks.append(args)
145-
logging.info(f"task size {task_size}, count {len(tasks)}.")
146-
147-
rows = []
148-
with ThreadPoolExecutor(max_workers=8) as executor:
149-
futures = [executor.submit(_update_nodes_helper, task) for task in tasks]
150-
for future in tqdm(as_completed(futures), total=len(futures)):
151-
rows.extend(future.result())
152-
156+
rows = update_nodes(cg, nodes, nodes_ts)
153157
cg.client.write(rows)
154-
logging.info(f"total elaspsed time: {time.time() - start}")
158+
logging.info(f"mutations: {len(rows)}, time: {time.time() - start}")

0 commit comments

Comments
 (0)