11# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22
33from collections import defaultdict
4- from concurrent .futures import ThreadPoolExecutor , as_completed
54from datetime import datetime , timedelta , timezone
6- import logging , math , time
5+ import logging , time
76from copy import copy
87
98import fastremap
109import numpy as np
11- from tqdm import tqdm
1210from pychunkedgraph .graph import ChunkedGraph , types
1311from pychunkedgraph .graph .attributes import Connectivity , Hierarchy
1412from pychunkedgraph .graph .utils import serializers
15- from pychunkedgraph .utils .general import chunked
1613
1714from .utils import get_end_timestamps , get_parent_timestamps
1815
1916CHILDREN = {}
2017
2118
19+ def _get_parents_at_timestamp (nodes , parents_ts_map , time_stamp ):
20+ """
21+ Search for the first parent with ts <= `time_stamp`.
22+ `parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc).
23+ """
24+ parents = []
25+ for node in nodes :
26+ for ts , parent in parents_ts_map [node ].items ():
27+ if time_stamp >= ts :
28+ parents .append (parent )
29+ break
30+ return parents
31+
32+
2233def update_cross_edges (
2334 cg : ChunkedGraph ,
2435 node ,
2536 cx_edges_d : dict ,
2637 node_ts ,
2738 node_end_ts ,
28- timestamps_d : defaultdict [int , set ],
39+ timestamps_map : defaultdict [int , set ],
40+ parents_ts_map : defaultdict [int , dict ],
2941) -> list :
3042 """
3143 Helper function to update a single L2 ID.
@@ -35,9 +47,9 @@ def update_cross_edges(
3547 edges = np .concatenate (list (cx_edges_d .values ()))
3648 partners = np .unique (edges [:, 1 ])
3749
38- timestamps = copy (timestamps_d [node ])
50+ timestamps = copy (timestamps_map [node ])
3951 for partner in partners :
40- timestamps .update (timestamps_d [partner ])
52+ timestamps .update (timestamps_map [partner ])
4153
4254 node_end_ts = node_end_ts or datetime .now (timezone .utc )
4355 for ts in sorted (timestamps ):
@@ -47,7 +59,7 @@ def update_cross_edges(
4759 break
4860
4961 val_dict = {}
50- parents = cg . get_parents (partners , time_stamp = ts )
62+ parents = _get_parents_at_timestamp (partners , parents_ts_map , ts )
5163 edge_parents_d = dict (zip (partners , parents ))
5264 for layer , layer_edges in cx_edges_d .items ():
5365 layer_edges = fastremap .remap (
@@ -63,6 +75,7 @@ def update_cross_edges(
6375
6476
6577def update_nodes (cg : ChunkedGraph , nodes , nodes_ts , children_map = None ) -> list :
78+ start = time .time ()
6679 if children_map is None :
6780 children_map = CHILDREN
6881 end_timestamps = get_end_timestamps (cg , nodes , nodes_ts , children_map , layer = 2 )
@@ -75,31 +88,39 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
7588 all_partners = np .unique (np .concatenate (all_cx_edges )[:, 1 ])
7689 timestamps_d = get_parent_timestamps (cg , np .concatenate ([nodes , all_partners ]))
7790
91+ parents_ts_map = defaultdict (dict )
92+ all_parents = cg .get_parents (all_partners , current = False )
93+ for partner , parents in zip (all_partners , all_parents ):
94+ for parent , ts in parents :
95+ parents_ts_map [partner ][ts ] = parent
96+ logging .info (f"update_nodes init { len (nodes )} : { time .time () - start } " )
97+
7898 rows = []
99+ skipped = []
79100 for node , node_ts , end_ts in zip (nodes , nodes_ts , end_timestamps ):
80101 is_stale = end_ts is not None
81102 _cx_edges_d = cx_edges_d .get (node , {})
82- if not _cx_edges_d :
83- continue
84103 if is_stale :
85104 end_ts -= timedelta (milliseconds = 1 )
86-
87- _rows = update_cross_edges (cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d )
88- if is_stale :
89105 row_id = serializers .serialize_uint64 (node )
90106 val_dict = {Hierarchy .StaleTimeStamp : 0 }
91- _rows .append (cg .client .mutate_row (row_id , val_dict , time_stamp = end_ts ))
92- rows .extend (_rows )
93-
94- return rows
107+ rows .append (cg .client .mutate_row (row_id , val_dict , time_stamp = end_ts ))
95108
109+ if not _cx_edges_d :
110+ skipped .append (node )
111+ continue
96112
97- def _update_nodes_helper (args ):
98- cg , nodes , nodes_ts = args
99- return update_nodes (cg , nodes , nodes_ts )
113+ _rows = update_cross_edges (
114+ cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d , parents_ts_map
115+ )
116+ rows .extend (_rows )
117+ parents = cg .get_roots (skipped )
118+ layers = cg .get_chunk_layers (parents )
119+ assert np .all (layers == cg .meta .layer_count )
120+ return rows
100121
101122
102- def update_chunk (cg : ChunkedGraph , chunk_coords : list [int ], debug : bool = False ):
123+ def update_chunk (cg : ChunkedGraph , chunk_coords : list [int ]):
103124 """
104125 Iterate over all L2 IDs in a chunk and update their cross chunk edges,
105126 within the periods they were valid/active.
@@ -132,23 +153,6 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], debug: bool = False)
132153 else :
133154 return
134155
135- if debug :
136- rows = update_nodes (cg , nodes , nodes_ts )
137- else :
138- task_size = int (math .ceil (len (nodes ) / 16 ))
139- chunked_nodes = chunked (nodes , task_size )
140- chunked_nodes_ts = chunked (nodes_ts , task_size )
141- tasks = []
142- for chunk , ts_chunk in zip (chunked_nodes , chunked_nodes_ts ):
143- args = (cg , chunk , ts_chunk )
144- tasks .append (args )
145- logging .info (f"task size { task_size } , count { len (tasks )} ." )
146-
147- rows = []
148- with ThreadPoolExecutor (max_workers = 8 ) as executor :
149- futures = [executor .submit (_update_nodes_helper , task ) for task in tasks ]
150- for future in tqdm (as_completed (futures ), total = len (futures )):
151- rows .extend (future .result ())
152-
156+ rows = update_nodes (cg , nodes , nodes_ts )
153157 cg .client .write (rows )
154- logging .info (f"total elaspsed time: { time .time () - start } " )
158+ logging .info (f"mutations: { len ( rows ) } , time: { time .time () - start } " )
0 commit comments