@@ -220,7 +220,7 @@ def get_stale_nodes(
220220 )
221221 stale_mask = layer_nodes != _nodes
222222 stale_nodes .append (layer_nodes [stale_mask ])
223- return np .concatenate (stale_nodes ), edge_supervoxels
223+ return np .concatenate (stale_nodes )
224224
225225
226226def get_latest_edges (
@@ -279,7 +279,7 @@ def _get_l2chunkids_along_boundary(max_layer, coord_a, coord_b):
279279 def _get_filtered_l2ids (node_a , node_b , chunks_map ):
280280 def _filter (node ):
281281 result = []
282- children = cg . get_children ( node )
282+ children = np . array ([ node ], dtype = basetypes . NODE_ID )
283283 while True :
284284 chunk_ids = cg .get_chunk_ids_from_node_ids (children )
285285 mask = np .isin (chunk_ids , chunks_map [node ])
@@ -296,24 +296,24 @@ def _filter(node):
296296
297297 return _filter (node_a ), _filter (node_b )
298298
299- result = []
299+ result = [types . empty_2d ]
300300 chunks_map = {}
301301 for edge_layer , _edge in zip (edge_layers , stale_edges ):
302302 node_a , node_b = _edge
303303 mlayer , coord_a , coord_b = _get_normalized_coords (node_a , node_b )
304304 chunks_a , chunks_b = _get_l2chunkids_along_boundary (mlayer , coord_a , coord_b )
305305
306- chunks_map [node_a ] = []
307- chunks_map [node_b ] = []
306+ chunks_map [node_a ] = [np . array ([ cg . get_chunk_id ( node_a )]) ]
307+ chunks_map [node_b ] = [np . array ([ cg . get_chunk_id ( node_b )]) ]
308308 _layer = 2
309309 while _layer < mlayer :
310310 chunks_map [node_a ].append (chunks_a )
311311 chunks_map [node_b ].append (chunks_b )
312312 chunks_a = np .unique (cg .get_parent_chunk_id_multiple (chunks_a ))
313313 chunks_b = np .unique (cg .get_parent_chunk_id_multiple (chunks_b ))
314314 _layer += 1
315- chunks_map [node_a ] = np .concatenate (chunks_map [node_a ])
316- chunks_map [node_b ] = np .concatenate (chunks_map [node_b ])
315+ chunks_map [node_a ] = np .concatenate (chunks_map [node_a ]). astype ( basetypes . NODE_ID )
316+ chunks_map [node_b ] = np .concatenate (chunks_map [node_b ]). astype ( basetypes . NODE_ID )
317317
318318 l2ids_a , l2ids_b = _get_filtered_l2ids (node_a , node_b , chunks_map )
319319 edges_d = cg .get_cross_chunk_edges (
@@ -326,32 +326,57 @@ def _filter(node):
326326 _edges = np .concatenate (_edges )
327327 mask = np .isin (_edges [:, 1 ], l2ids_b )
328328
329- children_a = cg .get_children (_edges [mask ][:, 0 ], flatten = True )
330329 children_b = cg .get_children (_edges [mask ][:, 1 ], flatten = True )
331- if 85431849467249595 in children_a and 85502218144317440 in children_b :
332- print ("woohoo0" )
333- continue
334-
335- if 85502218144317440 in children_a and 85431849467249595 in children_b :
336- print ("woohoo1" )
337- continue
338- parents_a = np .unique (
339- cg .get_roots (
340- children_a , stop_layer = mlayer , ceil = False , time_stamp = parent_ts
341- )
342- )
343- assert parents_a .size == 1 and parents_a [0 ] == node_a , (
344- node_a ,
345- parents_a ,
346- children_a ,
347- )
348330
331+ parents_a = _edges [mask ][:, 0 ]
332+ parents_b = np .unique (cg .get_parents (children_b , time_stamp = parent_ts ))
333+ _cx_edges_d = cg .get_cross_chunk_edges (parents_b )
334+ parents_b = []
335+ for _node , _edges_d in _cx_edges_d .items ():
336+ for _edges in _edges_d .values ():
337+ _mask = np .isin (_edges [:,1 ], parents_a )
338+ if np .any (_mask ):
339+ parents_b .append (_node )
340+
341+ parents_b = np .array (parents_b , dtype = basetypes .NODE_ID )
349342 parents_b = np .unique (
350343 cg .get_roots (
351- children_b , stop_layer = mlayer , ceil = False , time_stamp = parent_ts
344+ parents_b , stop_layer = mlayer , ceil = False , time_stamp = parent_ts
352345 )
353346 )
354347
355348 parents_a = np .array ([node_a ] * parents_b .size , dtype = basetypes .NODE_ID )
356349 result .append (np .column_stack ((parents_a , parents_b )))
357350 return np .concatenate (result )
351+
352+
353+ def get_latest_edges_wrapper (
354+ cg ,
355+ cx_edges_d : dict ,
356+ parent_ts : datetime .datetime = None ,
357+ ) -> np .ndarray :
358+ """Helper function to filter stale edges and replace with latest edges."""
359+ _cx_edges = [types .empty_2d ]
360+ _edge_layers = [types .empty_1d ]
361+ for k , v in cx_edges_d .items ():
362+ _cx_edges .append (v )
363+ _edge_layers .append ([k ] * len (v ))
364+ _cx_edges = np .concatenate (_cx_edges )
365+ _edge_layers = np .concatenate (_edge_layers , dtype = int )
366+
367+ edge_nodes = np .unique (_cx_edges )
368+ stale_nodes = get_stale_nodes (cg , edge_nodes , parent_ts = parent_ts )
369+ stale_nodes_mask = np .isin (edge_nodes , stale_nodes )
370+
371+ latest_edges = types .empty_2d .copy ()
372+ if np .any (stale_nodes_mask ):
373+ stalte_edges_mask = np .isin (_cx_edges [:, 1 ], stale_nodes )
374+ stale_edges = _cx_edges [stalte_edges_mask ]
375+ stale_edge_layers = _edge_layers [stalte_edges_mask ]
376+ latest_edges = get_latest_edges (
377+ cg ,
378+ stale_edges ,
379+ stale_edge_layers ,
380+ parent_ts = parent_ts ,
381+ )
382+ return np .concatenate ([_cx_edges , latest_edges ])
0 commit comments