Skip to content

Commit 99d2b3b

Browse files
committed
Implement graph.filter for arrow endpoints
1 parent 26a9586 commit 99d2b3b

File tree

3 files changed

+96
-4
lines changed

3 files changed

+96
-4
lines changed

graphdatascience/procedure_surface/api/catalog_endpoints.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,42 @@
1414
class CatalogEndpoints(ABC):
1515
@abstractmethod
1616
def list(self, G: Optional[Union[Graph, str]] = None) -> List[GraphListResult]:
17+
"""List graphs in the graph catalog.
18+
19+
Args:
20+
G (Optional[Union[Graph, str]], optional): Graph object or name to filter results.
21+
If None, list all graphs. Defaults to None.
22+
23+
Returns:
24+
List[GraphListResult]: List of graph metadata objects containing information like
25+
graph name, node count, relationship count, etc.
26+
"""
27+
pass
28+
29+
@abstractmethod
30+
def filter(
31+
self,
32+
G: Graph,
33+
graph_name: str,
34+
node_filter: str,
35+
relationship_filter: str,
36+
concurrency: Optional[int] = None,
37+
job_id: Optional[str] = None,
38+
) -> GraphFilterResult:
39+
"""Create a subgraph of a graph based on a filter expression.
40+
41+
Args:
42+
G (Graph): Graph object to filter on
43+
graph_name (str): Name of subgraph to create
44+
node_filter (str): Filter expression for nodes
45+
relationship_filter (str): Filter expression for relationships
46+
concurrency (Optional[int], optional): Number of concurrent threads to use. Defaults to None.
47+
job_id (Optional[str], optional): Unique identifier for the filtering job. Defaults to None.
48+
49+
Returns:
50+
GraphFilterResult: Filter result containing information like
51+
graph name, node count, relationship count, etc.
52+
"""
1753
pass
1854

1955

@@ -40,3 +76,15 @@ def strip_timezone(cls, value: Any) -> Any:
4076
if isinstance(value, str):
4177
return re.sub(r"\[.*\]$", "", value)
4278
return value
79+
80+
81+
class GraphFilterResult(BaseModel):
82+
model_config = ConfigDict(alias_generator=to_camel)
83+
84+
graph_name: str
85+
from_graph_name: str
86+
node_filter: str
87+
relationship_filter: str
88+
node_count: int
89+
relationship_count: int
90+
project_millis: int

graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
1212
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize
1313
from graphdatascience.arrow_client.v2.job_client import JobClient
14-
from graphdatascience.procedure_surface.api.catalog_endpoints import CatalogEndpoints, GraphListResult
14+
from graphdatascience.procedure_surface.api.catalog_endpoints import (
15+
CatalogEndpoints,
16+
GraphFilterResult,
17+
GraphListResult,
18+
)
19+
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
1520
from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol
1621
from graphdatascience.query_runner.termination_flag import TerminationFlag
1722
from graphdatascience.session.dbms.protocol_resolver import ProtocolVersionResolver
@@ -81,6 +86,28 @@ def project(
8186

8287
return ProjectionResult(**JobClient.get_summary(self._arrow_client, job_id))
8388

89+
def filter(
90+
self,
91+
G: Graph,
92+
graph_name: str,
93+
node_filter: str,
94+
relationship_filter: str,
95+
concurrency: Optional[int] = None,
96+
job_id: Optional[str] = None,
97+
) -> GraphFilterResult:
98+
config = ConfigConverter.convert_to_gds_config(
99+
fromGraphName=G.name(),
100+
graphName=graph_name,
101+
nodeFilter=node_filter,
102+
relationshipFilter=relationship_filter,
103+
concurrency=concurrency,
104+
jobId=job_id,
105+
)
106+
107+
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.project.filter", config)
108+
109+
return GraphFilterResult(**JobClient.get_summary(self._arrow_client, job_id))
110+
84111
def _arrow_config(self) -> dict[str, Any]:
85112
connection_info = self._arrow_client.connection_info()
86113

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_catalog_arrow_endpoints.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
@pytest.fixture
1414
def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]:
1515
gdl = """
16-
(a: Node)
17-
(b: Node)
18-
(c: Node)
16+
(a :Node:A)
17+
(b :Node:A)
18+
(c :Node:B)
1919
(a)-[:REL]->(c)
2020
"""
2121

@@ -76,3 +76,20 @@ def test_projection(arrow_client: AuthenticatedArrowClient, query_runner: QueryR
7676
assert len(endpoints.list("g")) == 1
7777
finally:
7878
arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8"))
79+
80+
81+
def test_graph_filter(catalog_endpoints: CatalogArrowEndpoints, sample_graph: Graph) -> None:
82+
try:
83+
result = catalog_endpoints.filter(
84+
sample_graph, graph_name="filtered", node_filter="n:A", relationship_filter="*"
85+
)
86+
87+
assert result.node_count == 2
88+
assert result.relationship_count == 0
89+
assert result.from_graph_name == sample_graph.name()
90+
assert result.graph_name == "filtered"
91+
assert result.project_millis >= 0
92+
finally:
93+
catalog_endpoints._arrow_client.do_action(
94+
"v2/graph.drop", json.dumps({"graphName": "filtered"}).encode("utf-8")
95+
)

0 commit comments

Comments
 (0)