Skip to content

Commit 4f96f34

Browse files
committed
Implement graph.drop for arrow endpoints
1 parent 99d2b3b commit 4f96f34

File tree

3 files changed

+37
-8
lines changed

3 files changed

+37
-8
lines changed

graphdatascience/procedure_surface/api/catalog_endpoints.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,19 @@ def list(self, G: Optional[Union[Graph, str]] = None) -> List[GraphListResult]:
2626
"""
2727
pass
2828

29+
@abstractmethod
30+
def drop(self, G: Union[Graph, str], fail_if_missing: Optional[bool] = None) -> GraphListResult:
31+
"""Drop a graph from the graph catalog.
32+
33+
Args:
34+
G (Union[Graph, str]): Graph object or name to drop.
35+
fail_if_missing (Optional[bool], optional): Whether to fail if the graph is missing. Defaults to None.
36+
37+
Returns:
38+
GraphListResult: Graph metadata object containing information like
39+
graph name, node count, relationship count, etc.
40+
"""
41+
2942
@abstractmethod
3043
def filter(
3144
self,
@@ -68,7 +81,7 @@ class GraphListResult(BaseModel):
6881
modification_time: datetime
6982
graph_schema: dict[str, Any] = Field(alias="schema")
7083
schema_with_orientation: dict[str, Any]
71-
degree_distribution: dict[str, Any]
84+
degree_distribution: Optional[dict[str, Any]] = None
7285

7386
@field_validator("creation_time", "modification_time", mode="before")
7487
@classmethod

graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from graphdatascience import Graph, QueryRunner
1111
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
12-
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize
12+
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize, deserialize_single
1313
from graphdatascience.arrow_client.v2.job_client import JobClient
1414
from graphdatascience.procedure_surface.api.catalog_endpoints import (
1515
CatalogEndpoints,
@@ -86,6 +86,12 @@ def project(
8686

8787
return ProjectionResult(**JobClient.get_summary(self._arrow_client, job_id))
8888

89+
def drop(self, G: Union[Graph, str], fail_if_missing: Optional[bool] = None) -> GraphListResult:
90+
graph_name = G if isinstance(G, str) else G.name()
91+
config = ConfigConverter.convert_to_gds_config(graphName=graph_name, failIfMissing=fail_if_missing)
92+
result = self._arrow_client.do_action_with_retry("v2/graph.drop", json.dumps(config).encode("utf-8"))
93+
return GraphListResult(**deserialize_single(result))
94+
8995
def filter(
9096
self,
9197
G: Graph,

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_catalog_arrow_endpoints.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Generator
44

55
import pytest
6+
from pyarrow import ArrowKeyError
67

78
from graphdatascience import Graph, QueryRunner
89
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
@@ -29,7 +30,6 @@ def catalog_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[Catal
2930

3031

3132
def test_list_with_graph(catalog_endpoints: CatalogArrowEndpoints, sample_graph: Graph) -> None:
32-
"""Test listing graphs with a specific graph."""
3333
results = catalog_endpoints.list(G=sample_graph)
3434

3535
assert len(results) == 1
@@ -46,7 +46,7 @@ def test_list_with_graph(catalog_endpoints: CatalogArrowEndpoints, sample_graph:
4646
assert "KiB" in result.memory_usage
4747
assert result.size_in_bytes > 20000
4848
assert result.modification_time < datetime.datetime.now(datetime.timezone.utc)
49-
assert "p50" in result.degree_distribution
49+
assert "p50" in result.degree_distribution # type: ignore
5050

5151

5252
def test_list_without_graph(
@@ -60,6 +60,18 @@ def test_list_without_graph(
6060
assert set(g.graph_name for g in result) == {sample_graph.name(), g2.name()}
6161

6262

63+
def test_drop(catalog_endpoints: CatalogArrowEndpoints, sample_graph: Graph) -> None:
64+
res = catalog_endpoints.drop(sample_graph)
65+
66+
assert res.graph_name == sample_graph.name()
67+
assert len(catalog_endpoints.list()) == 0
68+
69+
70+
def test_drop_nonexistent(catalog_endpoints: CatalogArrowEndpoints) -> None:
71+
with pytest.raises(ArrowKeyError, match="does not exist on database"):
72+
catalog_endpoints.drop("nonexistent", fail_if_missing=True)
73+
74+
6375
def test_projection(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner) -> None:
6476
try:
6577
endpoints = CatalogArrowEndpoints(arrow_client, query_runner)
@@ -75,7 +87,7 @@ def test_projection(arrow_client: AuthenticatedArrowClient, query_runner: QueryR
7587

7688
assert len(endpoints.list("g")) == 1
7789
finally:
78-
arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8"))
90+
endpoints.drop("g")
7991

8092

8193
def test_graph_filter(catalog_endpoints: CatalogArrowEndpoints, sample_graph: Graph) -> None:
@@ -90,6 +102,4 @@ def test_graph_filter(catalog_endpoints: CatalogArrowEndpoints, sample_graph: Gr
90102
assert result.graph_name == "filtered"
91103
assert result.project_millis >= 0
92104
finally:
93-
catalog_endpoints._arrow_client.do_action(
94-
"v2/graph.drop", json.dumps({"graphName": "filtered"}).encode("utf-8")
95-
)
105+
catalog_endpoints.drop("filtered")

0 commit comments

Comments
 (0)