Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions benchmarks/graph_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
N_LINEAGES = 50


def _noop(*_args) -> None:
"""No-op slot to enable the `node_updated` signal-payload path in benchmarks."""


def _build_graph(backend_name: str, n_nodes: int) -> td.graph.BaseGraph:
graph = BACKENDS[backend_name]()
graph.add_node_attr_key("score", dtype=pl.Float64)
Expand Down Expand Up @@ -60,6 +64,16 @@ def setup(self, backend_name: str, n_nodes: int) -> None:
self.removal_targets = all_ids[:N_OPS]
self.update_targets = all_ids[: N_OPS * 4]

# Separate view with a no-op listener attached. Without a listener,
# update_node_attrs skips the signal-payload computation entirely, so
# the P2-2 optimization (deriving new_attrs from old + applied) isn't
# exercised. This view is the BBoxSpatialFilter / GraphArrayView use case.
self.listened_view = self.graph.filter().subgraph()
self.listened_view.node_updated.connect(_noop)
# Smaller batch, representative of interactive editing where the saved
# query overhead is a larger fraction of the total work.
self.listener_update_targets = all_ids[:N_OPS]

# --- remove_node ------------------------------------------------------

def time_remove_node_root(self, backend_name: str, n_nodes: int) -> None:
Expand All @@ -78,6 +92,9 @@ def time_update_node_attrs_root(self, backend_name: str, n_nodes: int) -> None:
def time_update_node_attrs_view(self, backend_name: str, n_nodes: int) -> None:
self.view.update_node_attrs(node_ids=self.update_targets, attrs={"score": 1.0})

def time_update_node_attrs_view_with_listener(self, backend_name: str, n_nodes: int) -> None:
self.listened_view.update_node_attrs(node_ids=self.listener_update_targets, attrs={"score": 1.0})

# --- filter (standalone, materialized to ids) ------------------------

def time_filter_node_ids(self, backend_name: str, n_nodes: int) -> None:
Expand Down
36 changes: 20 additions & 16 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Literal, cast, overload

import bidict
import numpy as np
import polars as pl
import rustworkx as rx

Expand Down Expand Up @@ -465,23 +466,20 @@ def remove_node(self, node_id: int) -> None:
# Get the local node ID and remove from local graph
local_node_id = self._external_to_local[node_id]

# Capture incident edges BEFORE removal. rustworkx drops them along with
# the node; afterwards we'd have no way to identify which entries to
# clean from `_edge_map_to_root` without scanning the whole bookkeeping.
incident_local_edge_ids = list(self.rx_graph.incident_edges(local_node_id))

with self.node_removed.blocked():
super().remove_node(local_node_id)

# Remove the node mapping
self._remove_id_mapping(external_id=node_id)

# Update edge mappings - remove edges involving this node
edges_to_remove = []
edge_indices = self.rx_graph.edge_indices()
for local_edge_id, _ in list(self._edge_map_to_root.items()):
# Check if this edge is still in the local graph
if local_edge_id not in edge_indices:
edges_to_remove.append(local_edge_id)

for edge_id in edges_to_remove:
if edge_id in self._edge_map_to_root:
del self._edge_map_to_root[edge_id]
# Drop just the affected edges from the bookkeeping
for edge_id in incident_local_edge_ids:
self._edge_map_to_root.pop(edge_id, None)
else:
self._out_of_sync = True

Expand Down Expand Up @@ -740,12 +738,18 @@ def update_node_attrs(
self._out_of_sync = True

if view_signal_on or root_signal_on:
new_attrs_by_id = (
self._root.filter(node_ids=node_ids)
.node_attrs(attr_keys=signal_keys)
.rows_by_key(key=DEFAULT_ATTR_KEYS.NODE_ID, named=True, unique=True, include_key=True)
)
old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy
# Derive new_attrs by overlaying applied `attrs` onto old_attrs, instead of
# re-querying root. Mirrors the broadcasting semantics of
# `_root.update_node_attrs`: scalars apply to all nodes, sequences index by
# position in `node_ids`.
new_attrs_by_id: dict[int, dict[str, Any]] = {}
for i, node_id in enumerate(node_ids):
new_attrs = dict(old_attrs_by_id[node_id])
for k, v in attrs.items():
if k in new_attrs:
new_attrs[k] = v if np.isscalar(v) else v[i]
new_attrs_by_id[node_id] = new_attrs
if root_signal_on:
for node_id in node_ids:
self._root.node_updated.emit(
Expand Down
Loading