1
1
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
2
2
3
- from datetime import timedelta
4
-
5
3
import fastremap
6
4
import numpy as np
7
5
from pychunkedgraph .graph import ChunkedGraph
8
- from pychunkedgraph .graph .attributes import Connectivity
6
+ from pychunkedgraph .graph .attributes import Connectivity , Hierarchy
9
7
from pychunkedgraph .graph .utils import serializers
10
8
11
- from .utils import exists_as_parent , get_parent_timestamps
9
+ from .utils import exists_as_parent , get_end_timestamps , get_parent_timestamps
10
+
11
+ CHILDREN = {}
12
12
13
13
14
14
def update_cross_edges (
15
- cg : ChunkedGraph , node , cx_edges_d : dict , node_ts , timestamps : set , earliest_ts
15
+ cg : ChunkedGraph , node , cx_edges_d : dict , node_ts , node_end_ts , timestamps : set
16
16
) -> list :
17
17
"""
18
18
Helper function to update a single L2 ID.
@@ -27,13 +27,15 @@ def update_cross_edges(
27
27
assert not exists_as_parent (cg , node , edges [:, 0 ])
28
28
return rows
29
29
30
- partner_parent_ts_d = get_parent_timestamps (cg , edges [:, 1 ])
30
+ partner_parent_ts_d = get_parent_timestamps (cg , np . unique ( edges [:, 1 ]) )
31
31
for v in partner_parent_ts_d .values ():
32
32
timestamps .update (v )
33
33
34
34
for ts in sorted (timestamps ):
35
- if ts < earliest_ts :
36
- ts = earliest_ts
35
+ if ts < node_ts :
36
+ continue
37
+ if ts > node_end_ts :
38
+ break
37
39
val_dict = {}
38
40
svs = edges [:, 1 ]
39
41
parents = cg .get_parents (svs , time_stamp = ts )
@@ -51,21 +53,22 @@ def update_cross_edges(
51
53
return rows
52
54
53
55
54
- def update_nodes (cg : ChunkedGraph , nodes ) -> list :
55
- nodes_ts = cg .get_node_timestamps (nodes , return_numpy = False , normalize = True )
56
- earliest_ts = cg .get_earliest_timestamp ()
56
+ def update_nodes (cg : ChunkedGraph , nodes , nodes_ts , children_map = None ) -> list :
57
+ if children_map is None :
58
+ children_map = CHILDREN
59
+ end_timestamps = get_end_timestamps (cg , nodes , nodes_ts , children_map )
57
60
timestamps_d = get_parent_timestamps (cg , nodes )
58
61
cx_edges_d = cg .get_atomic_cross_edges (nodes )
59
62
rows = []
60
- for node , node_ts in zip (nodes , nodes_ts ):
63
+ for node , node_ts , end_ts in zip (nodes , nodes_ts , end_timestamps ):
61
64
if cg .get_parent (node ) is None :
62
- # invalid id caused by failed ingest task
65
+ # invalid id caused by failed ingest task / edits
63
66
continue
64
67
_cx_edges_d = cx_edges_d .get (node , {})
65
68
if not _cx_edges_d :
66
69
continue
67
70
_rows = update_cross_edges (
68
- cg , node , _cx_edges_d , node_ts , timestamps_d [node ], earliest_ts
71
+ cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d [node ]
69
72
)
70
73
rows .extend (_rows )
71
74
return rows
@@ -76,10 +79,26 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], layer: int = 2):
76
79
Iterate over all L2 IDs in a chunk and update their cross chunk edges,
77
80
within the periods they were valid/active.
78
81
"""
82
+ global CHILDREN
83
+
79
84
x , y , z = chunk_coords
80
85
chunk_id = cg .get_chunk_id (layer = layer , x = x , y = y , z = z )
81
86
cg .copy_fake_edges (chunk_id )
82
87
rr = cg .range_read_chunk (chunk_id )
83
- nodes = list (rr .keys ())
84
- rows = update_nodes (cg , nodes )
88
+
89
+ nodes = []
90
+ nodes_ts = []
91
+ earliest_ts = cg .get_earliest_timestamp ()
92
+ for k , v in rr .items ():
93
+ nodes .append (k )
94
+ CHILDREN [k ] = v [Hierarchy .Child ][0 ].value
95
+ ts = v [Hierarchy .Child ][0 ].timestamp
96
+ nodes_ts .append (earliest_ts if ts < earliest_ts else ts )
97
+
98
+ if len (nodes ):
99
+ assert len (CHILDREN ) > 0 , (nodes , CHILDREN )
100
+ else :
101
+ return
102
+
103
+ rows = update_nodes (cg , nodes , nodes_ts )
85
104
cg .client .write (rows )
0 commit comments