Skip to content

Catalog Endpoints Part 1 #926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7b1bb98
Implement graph.list for arrow
DarthMax Jul 29, 2025
3d97f96
Implement remote projection for arrow endpoints
DarthMax Jul 31, 2025
37f6d3f
Implement graph.filter for arrow endpoints
DarthMax Jul 31, 2025
0dfc1aa
Implement graph.drop for arrow endpoints
DarthMax Jul 31, 2025
d69d58b
AuthenticatedArrowClient accepts dictionaries as action payload
DarthMax Jul 31, 2025
f5bf427
Minor cleanups
DarthMax Jul 31, 2025
0646005
Introduce shared GdsBaseModel
DarthMax Jul 31, 2025
221c983
Fix arrow test cleanup
DarthMax Aug 1, 2025
ca21bd3
Expose correct env variables
DarthMax Aug 1, 2025
81ac881
Try to fix projection tests
DarthMax Aug 1, 2025
548a7a3
Try to fix projection tests
DarthMax Aug 1, 2025
7fa5945
Try to fix projection tests
DarthMax Aug 1, 2025
de20406
Try to fix projection tests
DarthMax Aug 1, 2025
690ca3c
Try to fix projection tests
DarthMax Aug 1, 2025
27d37bc
Try to fix projection tests
DarthMax Aug 1, 2025
d524e86
Try to fix projection tests
DarthMax Aug 4, 2025
c67f23c
Try to fix projection tests
DarthMax Aug 4, 2025
b5513fe
Try to fix projection tests
DarthMax Aug 4, 2025
21cf2cf
Try to fix projection tests
DarthMax Aug 4, 2025
c2c3ae5
Try to fix projection tests
DarthMax Aug 4, 2025
d256325
Try to fix projection tests
DarthMax Aug 4, 2025
8604ca2
Try to fix projection tests
DarthMax Aug 4, 2025
75f6a2e
Try to fix projection tests
DarthMax Aug 4, 2025
3686581
Try to fix projection tests
DarthMax Aug 4, 2025
61d04e4
Try to fix projection tests
DarthMax Aug 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions graphdatascience/arrow_client/arrow_base_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Any

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel
from pydantic.alias_generators import to_camel


class ArrowBaseModel(BaseModel):
model_config = ConfigDict(alias_generator=to_camel)

class ArrowBaseModel(BaseModel, alias_generator=to_camel):
def dump_camel(self) -> dict[str, Any]:
return self.model_dump(by_alias=True)

Expand Down
32 changes: 28 additions & 4 deletions graphdatascience/arrow_client/authenticated_flight_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import json
import logging
from dataclasses import dataclass
from typing import Any, Iterator, Optional
from typing import Any, Iterator, Optional, Tuple, Union

from pyarrow import __version__ as arrow_version
from pyarrow import flight
Expand Down Expand Up @@ -36,6 +37,7 @@ def create(
arrow_client_options: Optional[dict[str, Any]] = None,
connection_string_override: Optional[str] = None,
retry_config: Optional[RetryConfig] = None,
advertised_listen_address: Optional[Tuple[str, int]] = None,
) -> AuthenticatedArrowClient:
connection_string: str
if connection_string_override is not None:
Expand Down Expand Up @@ -63,6 +65,7 @@ def create(
auth=auth,
encrypted=encrypted,
arrow_client_options=arrow_client_options,
advertised_listen_address=advertised_listen_address,
)

def __init__(
Expand All @@ -74,6 +77,7 @@ def __init__(
encrypted: bool = False,
arrow_client_options: Optional[dict[str, Any]] = None,
user_agent: Optional[str] = None,
advertised_listen_address: Optional[Tuple[str, int]] = None,
):
"""Creates a new GdsArrowClient instance.

Expand All @@ -93,6 +97,8 @@ def __init__(
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
retry_config: Optional[RetryConfig]
The retry configuration to use for the Arrow requests send by the client.
advertised_listen_address: Optional[Tuple[str, int]]
The advertised listen address of the GDS Arrow server. This will be used by remote projection and writeback operations.
"""
self._host = host
self._port = port
Expand All @@ -106,6 +112,7 @@ def __init__(
if auth:
self._auth = auth
self._auth_middleware = AuthMiddleware(auth)
self.advertised_listen_address = advertised_listen_address

self._flight_client = self._instantiate_flight_client()

Expand All @@ -120,6 +127,21 @@ def connection_info(self) -> ConnectionInfo:
"""
return ConnectionInfo(self._host, self._port, self._encrypted)

def advertised_connection_info(self) -> ConnectionInfo:
"""
Returns the advertised host and port of the GDS Arrow server.

Returns
-------
ConnectionInfo
the host and port of the GDS Arrow server
"""
if self.advertised_listen_address is None:
return self.connection_info()

h, p = self.advertised_listen_address
return ConnectionInfo(h, p, self._encrypted)

def request_token(self) -> Optional[str]:
"""
Requests a token from the server and returns it.
Expand Down Expand Up @@ -152,10 +174,12 @@ def auth_with_retry() -> None:
def get_stream(self, ticket: Ticket) -> FlightStreamReader:
return self._flight_client.do_get(ticket)

def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]:
return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore
def do_action(self, endpoint: str, payload: Union[bytes, dict[str, Any]]) -> Iterator[Result]:
payload_bytes = payload if isinstance(payload, bytes) else json.dumps(payload).encode("utf-8")

return self._flight_client.do_action(Action(endpoint, payload_bytes)) # type: ignore

def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]:
def do_action_with_retry(self, endpoint: str, payload: Union[bytes, dict[str, Any]]) -> Iterator[Result]:
@retry(
reraise=True,
before=before_log("Send action", self._logger, logging.DEBUG),
Expand Down
13 changes: 4 additions & 9 deletions graphdatascience/arrow_client/v2/job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,22 @@ def run_job_and_wait(client: AuthenticatedArrowClient, endpoint: str, config: di

@staticmethod
def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str:
encoded_config = json.dumps(config).encode("utf-8")
res = client.do_action_with_retry(endpoint, encoded_config)
res = client.do_action_with_retry(endpoint, config)

single = deserialize_single(res)
return JobIdConfig(**single).job_id

@staticmethod
def wait_for_job(client: AuthenticatedArrowClient, job_id: str) -> None:
while True:
encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8")

arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, encoded_config)
arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel())
job_status = JobStatus(**deserialize_single(arrow_res))
if job_status.status == "Done":
break

@staticmethod
def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]:
encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8")

res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, encoded_config)
res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel())
return deserialize_single(res)

@staticmethod
Expand All @@ -51,7 +46,7 @@ def stream_results(client: AuthenticatedArrowClient, graph_name: str, job_id: st
"jobId": job_id,
}

res = client.do_action_with_retry("v2/results.stream", json.dumps(payload).encode("utf-8"))
res = client.do_action_with_retry("v2/results.stream", payload)
export_job_id = JobIdConfig(**deserialize_single(res)).job_id

stream_payload = {"version": "v2", "name": export_job_id, "body": {}}
Expand Down
4 changes: 1 addition & 3 deletions graphdatascience/arrow_client/v2/mutation_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import math
import time

Expand All @@ -13,8 +12,7 @@ class MutationClient:
@staticmethod
def mutate_node_property(client: AuthenticatedArrowClient, job_id: str, mutate_property: str) -> MutateResult:
mutate_config = {"jobId": job_id, "mutateProperty": mutate_property}
encoded_config = json.dumps(mutate_config).encode("utf-8")
start_time = time.time()
mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, encoded_config)
mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, mutate_config)
mutate_millis = math.ceil((time.time() - start_time) * 1000)
return MutateResult(mutateMillis=mutate_millis, **deserialize_single(mutate_arrow_res))
99 changes: 99 additions & 0 deletions graphdatascience/procedure_surface/api/catalog_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

import re
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, List, Optional, Union

from pydantic import Field, field_validator

from graphdatascience import Graph
from graphdatascience.procedure_surface.utils.GdsBaseModel import GdsBaseModel


class CatalogEndpoints(ABC):
@abstractmethod
def list(self, G: Optional[Union[Graph, str]] = None) -> List[GraphListResult]:
"""List graphs in the graph catalog.

Args:
G (Optional[Union[Graph, str]], optional): Graph object or name to filter results.
If None, list all graphs. Defaults to None.

Returns:
List[GraphListResult]: List of graph metadata objects containing information like
graph name, node count, relationship count, etc.
"""
pass

@abstractmethod
def drop(self, G: Union[Graph, str], fail_if_missing: Optional[bool] = None) -> Optional[GraphListResult]:
"""Drop a graph from the graph catalog.

Args:
G (Union[Graph, str]): Graph object or name to drop.
fail_if_missing (Optional[bool], optional): Whether to fail if the graph is missing. Defaults to None.

Returns:
GraphListResult: Graph metadata object containing information like
graph name, node count, relationship count, etc.
"""

@abstractmethod
def filter(
self,
G: Graph,
graph_name: str,
node_filter: str,
relationship_filter: str,
concurrency: Optional[int] = None,
job_id: Optional[str] = None,
) -> GraphFilterResult:
"""Create a subgraph of a graph based on a filter expression.

Args:
G (Graph): Graph object to filter on
graph_name (str): Name of subgraph to create
node_filter (str): Filter expression for nodes
relationship_filter (str): Filter expression for relationships
concurrency (Optional[int], optional): Number of concurrent threads to use. Defaults to None.
job_id (Optional[str], optional): Unique identifier for the filtering job. Defaults to None.

Returns:
GraphFilterResult: Filter result containing information like
graph name, node count, relationship count, etc.
"""
pass


Comment on lines +67 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no project?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, project is very implementation dependent, for Cypher+Plugin it will be native projection, Arrow+Session -> remote projection, Arrow + Plugin?? that is why I skipped that.

class GraphListResult(GdsBaseModel):
graph_name: str
database: str
database_location: str
configuration: dict[str, Any]
memory_usage: str
size_in_bytes: int
node_count: int
relationship_count: int
creation_time: datetime
modification_time: datetime
graph_schema: dict[str, Any] = Field(alias="schema")
schema_with_orientation: dict[str, Any]
degree_distribution: Optional[dict[str, Any]] = None

@field_validator("creation_time", "modification_time", mode="before")
@classmethod
def strip_timezone(cls, value: Any) -> Any:
if isinstance(value, str):
return re.sub(r"\[.*\]$", "", value)
return value


class GraphFilterResult(GdsBaseModel):
graph_name: str
from_graph_name: str
node_filter: str
relationship_filter: str
node_count: int
relationship_count: int
project_millis: int
10 changes: 2 additions & 8 deletions graphdatascience/procedure_surface/api/estimation_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

from typing import Any

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
from graphdatascience.procedure_surface.utils.GdsBaseModel import GdsBaseModel


class EstimationResult(BaseModel):
model_config = ConfigDict(alias_generator=to_camel)

class EstimationResult(GdsBaseModel):
node_count: int
relationship_count: int
required_memory: str
Expand All @@ -19,9 +16,6 @@ class EstimationResult(BaseModel):
heap_percentage_min: float
heap_percentage_max: float

def __getitem__(self, item: str) -> Any:
return getattr(self, item)

@staticmethod
def from_cypher(cypher_result: dict[str, Any]) -> EstimationResult:
return EstimationResult(**cypher_result)
24 changes: 4 additions & 20 deletions graphdatascience/procedure_surface/api/k1coloring_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from typing import Any, List, Optional

from pandas import DataFrame
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel

from ...graph.graph_object import Graph
from ..utils.GdsBaseModel import GdsBaseModel
from .estimation_result import EstimationResult


Expand Down Expand Up @@ -245,9 +244,7 @@ def estimate(
pass


class K1ColoringMutateResult(BaseModel):
model_config = ConfigDict(alias_generator=to_camel)

class K1ColoringMutateResult(GdsBaseModel):
node_count: int
color_count: int
ran_iterations: int
Expand All @@ -257,13 +254,8 @@ class K1ColoringMutateResult(BaseModel):
mutate_millis: int
configuration: dict[str, Any]

def __getitem__(self, item: str) -> Any:
return getattr(self, item)


class K1ColoringStatsResult(BaseModel):
model_config = ConfigDict(alias_generator=to_camel)

class K1ColoringStatsResult(GdsBaseModel):
node_count: int
color_count: int
ran_iterations: int
Expand All @@ -272,13 +264,8 @@ class K1ColoringStatsResult(BaseModel):
compute_millis: int
configuration: dict[str, Any]

def __getitem__(self, item: str) -> Any:
return getattr(self, item)


class K1ColoringWriteResult(BaseModel):
model_config = ConfigDict(alias_generator=to_camel)

class K1ColoringWriteResult(GdsBaseModel):
node_count: int
color_count: int
ran_iterations: int
Expand All @@ -287,6 +274,3 @@ class K1ColoringWriteResult(BaseModel):
compute_millis: int
write_millis: int
configuration: dict[str, Any]

def __getitem__(self, item: str) -> Any:
return getattr(self, item)
Loading