Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions simple/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ pandas==2.1.0
parameterized==0.9.0
platformdirs==3.10.0
protobuf==4.25.3
PyLD==2.0.4
PyMySQL==1.1.0
python-dateutil==2.8.2
pytest==7.4.2
PyYAML==6.0.1
pytz==2023.3.post1
redis==5.2.1
requests==2.31.0
rdflib==7.4.0
s2sphere==0.2.5
six==1.16.0
tomli==2.0.1
Expand Down
112 changes: 101 additions & 11 deletions simple/stats/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@
from google.cloud.sql.connector.connector import Connector
from google.cloud.sql.connector.connector import IPTypes
import pandas as pd
from pyld import jsonld
from pymysql.connections import Connection
from pymysql.cursors import Cursor
from rdflib import Graph
from rdflib import Literal
from rdflib import Namespace
from rdflib import RDF
from rdflib import URIRef
import requests
from stats import constants
from stats import schema_constants as sc
from stats.data import McfNode
from stats.data import STAT_VAR_GROUP
from stats.data import STATISTICAL_VARIABLE
Expand Down Expand Up @@ -58,10 +66,12 @@
ENV_DB_PASS = "DB_PASS"
ENV_DB_NAME = "DB_NAME"

DATACOMMONS_PLATFORM_URL = "datacommons_platform_url"
DATA_COMMONS_NAMESPACE = "dcid"
DATA_COMMONS_NAMESPACE_URL = "https://datacommons.org/browser/"
DATA_COMMONS_PLATFORM_URL = "data_commons_platform_url"

ENV_USE_DATACOMMONS_PLATFORM = "USE_DATACOMMONS_PLATFORM"
ENV_DATACOMMONS_PLATFORM_URL = "DATACOMMONS_PLATFORM_URL"
ENV_USE_DATA_COMMONS_PLATFORM = "USE_DATA_COMMONS_PLATFORM"
ENV_DATA_COMMONS_PLATFORM_URL = "DATA_COMMONS_PLATFORM_URL"

ENV_SQLITE_PATH = "SQLITE_PATH"

Expand Down Expand Up @@ -397,18 +407,35 @@ def _import_metadata(self) -> dict:

class DataCommonsPlatformDb(Db):
"""Class to insert triples and observations into Data Commons Platform."""
# Default namespace map for Data Commons Platform.
NS_MAP = {DATA_COMMONS_NAMESPACE: DATA_COMMONS_NAMESPACE_URL}

# Path to the nodes endpoint in the Data Commons Platform.
NODES_PATH = "/nodes"

def __init__(self, config: dict) -> None:
self.url = config[FIELD_DB_PARAMS][DATACOMMONS_PLATFORM_URL]
self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL]

def maybe_clear_before_import(self):
# Not applicable for Data Commons Platform.
pass

def insert_triples(self, triples: list[Triple]):
# TODO: Implement triple insertion into Data Commons Platform.
logging.info("TODO: Writing %s triples to [%s]", len(triples), self.url)
pass
"""
Convert triples to a jsonld graph and writes the graph to the Data Commons Platform instance.
"""
g = self._triples_to_graph(triples)
jsonld = self._graph_to_jsonld(g)
logging.info(
"Writing %s triples (%s nodes) to Data Commons Platform at [%s]",
len(triples), len(jsonld["@graph"]), self.url)
logging.info("Writing jsonld: %s", json.dumps(jsonld, indent=2))
nodes_url = self.url + self.NODES_PATH
response = requests.post(nodes_url, json=jsonld)
if response.status_code != 200:
# TODO: For now, we just log a warning, but we should raise an exception.
logging.warning("Failed to write triples to Data Commons Platform: %s",
response.text)

def insert_observations(self, observations_df: pd.DataFrame,
input_file: File):
Expand Down Expand Up @@ -441,6 +468,69 @@ def select_entity_names(self, dcids: list[str]) -> dict[str, str]:
# TODO: Implement entity name selection from Data Commons Platform.
return {}

def _expand_id(self, item: str) -> URIRef:
"""
Expand an id into a full Data Commons URI.

Example:
_expand_id("country/USA") -> "https://datacommons.org/browser/country/USA"
"""
if not item:
return None

base_url = DATA_COMMONS_NAMESPACE_URL
return URIRef(f"{base_url}{item.lstrip('/')}")

def _triples_to_graph(self, triples: list[Triple]) -> Graph:
g = Graph()

for prefix, uri in self.NS_MAP.items():
g.bind(prefix, Namespace(uri))

for t in triples:
try:
# The Triple class doesn't include a namespace, so for now we
# assume that all ids will be expanded using the default "dcid" prefix.
s = self._expand_id(t.subject_id)
p = self._expand_id(t.predicate)

if t.object_id:
o = self._expand_id(t.object_id)
else:
o = Literal(t.object_value)

# logging.info("Expanded %s into triple: %s", t, (s, p, o))
if p == URIRef(f"{DATA_COMMONS_NAMESPACE_URL}{sc.PREDICATE_TYPE_OF}"):
g.add((s, RDF.type, o))
else:
g.add((s, p, o))
except Exception as e:
logging.warning(f"Error processing triple {t}: {e}", exc_info=True)
return g

def _graph_to_jsonld(self, g: Graph) -> dict:
# To force @id to compact, we pass the context explicitly.
# If it still fails, it's because rdflib is being stubborn with the slash.
# We can 'help' it by providing the context as a list or a scoped dict.
jsonld_str = g.serialize(context=self.NS_MAP, format="json-ld", indent=4)

# 2. Export to "Expanded" JSON-LD.
expanded_jsonld = json.loads(jsonld_str)

# 3. Re-serialize JSON-LD using PyLD to correctly handle node values with slashes
# rdflib will always fully expand node values containing slashes rather than using
# namespace shortcuts
compacted_jsonld = jsonld.compact(expanded_jsonld, self.NS_MAP)

# 4. Force @graph structure if PyLD flattens it
if "@graph" not in compacted_jsonld:
data_only = {k: v for k, v in compacted_jsonld.items() if k != "@context"}
compacted_jsonld = {
"@context": compacted_jsonld.get("@context"),
"@graph": [data_only]
}
return compacted_jsonld


def from_triple_tuple(tuple: tuple) -> Triple:
return Triple(*tuple)
Expand Down Expand Up @@ -893,14 +983,14 @@ def get_sqlite_path_from_env() -> str | None:


def get_datacommons_platform_config_from_env() -> dict | None:
if os.getenv(ENV_USE_DATACOMMONS_PLATFORM, "").lower() != "true":
if os.getenv(ENV_USE_DATA_COMMONS_PLATFORM, "").lower() != "true":
return None
dcp_url = os.getenv(ENV_DATACOMMONS_PLATFORM_URL)
assert dcp_url, f"Environment variable {ENV_DATACOMMONS_PLATFORM_URL} not specified."
dcp_url = os.getenv(ENV_DATA_COMMONS_PLATFORM_URL)
assert dcp_url, f"Environment variable {ENV_DATA_COMMONS_PLATFORM_URL} not specified."
return {
FIELD_DB_TYPE: TYPE_DATACOMMONS_PLATFORM,
FIELD_DB_PARAMS: {
DATACOMMONS_PLATFORM_URL: dcp_url,
DATA_COMMONS_PLATFORM_URL: dcp_url,
}
}

Expand Down
74 changes: 60 additions & 14 deletions simple/tests/stats/db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,20 +318,6 @@ def test_get_cloud_sql_config_from_env_valid(self):
}
})

@mock.patch.dict(
os.environ, {
"USE_DATACOMMONS_PLATFORM": "true",
"DATACOMMONS_PLATFORM_URL": "https://test_url"
})
def test_get_datacommons_platform_config_from_env(self):
self.assertEqual(
get_datacommons_platform_config_from_env(), {
"type": "datacommons_platform",
"params": {
"datacommons_platform_url": "https://test_url"
}
})

@mock.patch.dict(os.environ, {
"USE_CLOUDSQL": "true",
"CLOUDSQL_INSTANCE": ""
Expand All @@ -350,6 +336,66 @@ def test_get_sqlite_path_from_env_empty(self):
def test_get_sqlite_path_from_env(self):
self.assertEqual(get_sqlite_path_from_env(), "/path/datacommons.db")

@mock.patch.dict(
os.environ, {
"USE_DATA_COMMONS_PLATFORM": "true",
"DATA_COMMONS_PLATFORM_URL": "https://test_url"
})
def test_get_datacommons_platform_config_from_env(self):
self.assertEqual(
get_datacommons_platform_config_from_env(), {
"type": "datacommons_platform",
"params": {
"data_commons_platform_url": "https://test_url"
}
})

@mock.patch('requests.post')
@mock.patch.dict(
os.environ, {
"USE_DATA_COMMONS_PLATFORM": "true",
"DATA_COMMONS_PLATFORM_URL": "https://test_url"
})
def test_insert_triples_into_datacommons_platform(self, mock_post):
config = get_datacommons_platform_config_from_env()
db = create_and_update_db(config)

# Configure the mock response
mock_post.return_value.status_code = 200
mock_post.return_value.text = "Success"

# Execute
db.insert_triples(_TRIPLES)

# Assertions
# 1. Check that the POST request was made to the correct URL
expected_url = "https://test_url/nodes"
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], expected_url)

# 2. Extract the JSON-LD payload
sent_json = kwargs.get('json')
self.assertIsNotNone(sent_json)
self.assertIn('@graph', sent_json)

# 3. Validate specific nodes in the graph
# We look for 'sub1' and 'sub2' within the @graph list
nodes = {node['@id']: node for node in sent_json['@graph']}

# Check sub1
sub1_id = "dcid:sub1"
self.assertIn(sub1_id, nodes)
self.assertEqual(nodes[sub1_id]['@type'], "dcid:StatisticalVariable")
self.assertEqual(nodes[sub1_id]['dcid:pred1'], "objval1")
self.assertEqual(nodes[sub1_id]['dcid:name'], "name1")

# Check sub2
sub2_id = "dcid:sub2"
self.assertIn(sub2_id, nodes)
self.assertEqual(nodes[sub2_id]['@type'], "dcid:StatisticalVariable")
self.assertEqual(nodes[sub2_id]['dcid:name'], "name2")


class TestBulkImportContext(unittest.TestCase):
"""Tests for BulkImportContext used by CloudSqlDbEngine."""
Expand Down