Skip to content
59 changes: 59 additions & 0 deletions examples/python-runtime.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DBID = \"beefbeef\"\n",
"ENVIRONMENT = \"\"\n",
"PASSWORD = \"\"\n",
"\n",
"from graphdatascience import GraphDataScience\n",
"\n",
"gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n",
"gds.set_database(\"neo4j\")\n",
"\n",
"gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" gds.graph.load_cora()\n",
"except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
5 changes: 4 additions & 1 deletion graphdatascience/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints
from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder
from .gnn.gnn_endpoints import GnnEndpoints
from .graph.graph_endpoints import (
GraphAlphaEndpoints,
GraphBetaEndpoints,
Expand Down Expand Up @@ -32,7 +33,9 @@
"""


class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints):
class DirectEndpoints(
DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints
):
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion):
super().__init__(query_runner, namespace, server_version)

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions graphdatascience/gnn/gnn_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ..caller_base import CallerBase
from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace
from .gnn_nc_runner import GNNNodeClassificationRunner


class GNNRunner(UncallableNamespace, IllegalAttrChecker):
@property
def nodeClassification(self) -> GNNNodeClassificationRunner:
return GNNNodeClassificationRunner(
self._query_runner, f"{self._namespace}.nodeClassification", self._server_version
)


class GnnEndpoints(CallerBase):
@property
def gnn(self) -> GNNRunner:
return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version)
79 changes: 79 additions & 0 deletions graphdatascience/gnn/gnn_nc_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
from typing import Any, List

from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace


class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
def make_graph_sage_config(self, graph_sage_config):
GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5,
"hidden_channels": 256, "learning_rate": 0.003}
final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG
if graph_sage_config:
bad_keys = []
for key in graph_sage_config:
if key not in GRAPH_SAGE_DEFAULT_CONFIG:
bad_keys.append(key)
if len(bad_keys) > 0:
raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.")

final_sage_config.update(graph_sage_config)
return final_sage_config

def train(
self,
graph_name: str,
model_name: str,
feature_properties: List[str],
target_property: str,
relationship_types: List[str],
target_node_label: str = None,
node_labels: List[str] = None,
graph_sage_config = None
) -> "Series[Any]": # noqa: F821
mlConfigMap = {
"featureProperties": feature_properties,
"targetProperty": target_property,
"job_type": "train",
"nodeProperties": feature_properties + [target_property],
"relationshipTypes": relationship_types,
"graph_sage_config": self.make_graph_sage_config(graph_sage_config)
}

if target_node_label:
mlConfigMap["targetNodeLabel"] = target_node_label
if node_labels:
mlConfigMap["nodeLabels"] = node_labels

mlTrainingConfig = json.dumps(mlConfigMap)

# token and uri will be injected by arrow_query_runner
self._query_runner.run_query(
"CALL gds.upload.graph($config)",
params={
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
},
)

def predict(
self,
graph_name: str,
model_name: str,
mutateProperty: str,
predictedProbabilityProperty: str = None,
) -> "Series[Any]": # noqa: F821
mlConfigMap = {
"job_type": "predict",
"mutateProperty": mutateProperty
}
if predictedProbabilityProperty:
mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty

mlTrainingConfig = json.dumps(mlConfigMap)
self._query_runner.run_query(
"CALL gds.upload.graph($config)",
params={
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
},
) # type: ignore
1 change: 1 addition & 0 deletions graphdatascience/ignored_server_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"gds.alpha.pipeline.nodeRegression.predict.stream",
"gds.alpha.pipeline.nodeRegression.selectFeatures",
"gds.alpha.pipeline.nodeRegression.train",
"gds.gnn.nc",
"gds.similarity.cosine",
"gds.similarity.euclidean",
"gds.similarity.euclideanDistance",
Expand Down
21 changes: 19 additions & 2 deletions graphdatascience/query_runner/arrow_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(
):
self._fallback_query_runner = fallback_query_runner
self._server_version = server_version
# FIXME handle version were tls cert is given
self._auth = auth
self._uri = uri

host, port_string = uri.split(":")

Expand All @@ -39,8 +42,9 @@ def __init__(
)

client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
self._auth_factory = AuthFactory(auth)
if auth:
client_options["middleware"] = [AuthFactory(auth)]
client_options["middleware"] = [self._auth_factory]
if tls_root_certs:
client_options["tls_root_certs"] = tls_root_certs

Expand Down Expand Up @@ -129,6 +133,10 @@ def run_query(
endpoint = "gds.beta.graph.relationships.stream"

return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types})
elif "gds.upload.graph" in query:
# inject parameters
params["config"]["token"] = self._get_or_request_token()
params["config"]["arrowEndpoint"] = self._uri

return self._fallback_query_runner.run_query(query, params, database, custom_error)

Expand Down Expand Up @@ -184,6 +192,10 @@ def create_graph_constructor(
database, graph_name, self._flight_client, concurrency, undirected_relationship_types
)

def _get_or_request_token(self) -> str:
self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1])
return self._auth_factory.token()


class AuthFactory(ClientMiddlewareFactory): # type: ignore
def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -217,9 +229,14 @@ def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None:
self._factory = factory

def received_headers(self, headers: Dict[str, Any]) -> None:
auth_header: str = headers.get("Authorization", None)
auth_header: str = headers.get("authorization", None)
if not auth_header:
return
# authenticate_basic_token() returns a list.
# TODO We should take the first Bearer element here
if isinstance(auth_header, list):
auth_header = auth_header[0]

[auth_type, token] = auth_header.split(" ", 1)
if auth_type == "Bearer":
self._factory.set_token(token)
Expand Down