Skip to content

Commit

Permalink
added neo4j to target
Browse files Browse the repository at this point in the history
  • Loading branch information
jogunjobi committed Aug 3, 2024
1 parent ba01210 commit 3d8064e
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ vector_etl/config/*
vector_etl/tempfile_downloads/
*_bkp.py
vector_etl/source_mods/backup/
vector_etl/target_mods/backup/

# Additional files
.DS_Store
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ google-api-python-client
unstructured-client
box-sdk-gen
pymongo
neo4j
python-magic
pytest
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"unstructured-client",
"box-sdk-gen",
"pymongo",
"neo4j",
"python-magic",
"pytest",
],
Expand Down
3 changes: 3 additions & 0 deletions vector_etl/target_mods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .lancedb import LanceDBTarget
from .tembo import TemboTarget
from .mongodb import MongoDBTarget
from .neo4j import Neo4jTarget

def get_target_database(config):
target_type = config['target_database']
Expand All @@ -25,5 +26,7 @@ def get_target_database(config):
return TemboTarget(config)
elif target_type == 'MongoDB':
return MongoDBTarget(config)
elif target_type == 'Neo4j':
return Neo4jTarget(config)
else:
raise ValueError(f"Unsupported target database: {target_type}")
141 changes: 141 additions & 0 deletions vector_etl/target_mods/neo4j.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable, ClientError
from .base import BaseTarget

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Neo4jTarget(BaseTarget):
def __init__(self, config):
self.config = config
self.driver = None

def connect(self):
logger.info("Connecting to Neo4j...")
try:
self.driver = GraphDatabase.driver(
self.config['neo4j_uri'],
auth=(self.config['username'], self.config['password'])
)
with self.driver.session() as session:
session.run("RETURN 1")
logger.info("Connected to Neo4j successfully.")
except ServiceUnavailable as e:
logger.error(f"Failed to connect to Neo4j: {str(e)}")
raise

def sanitize_property_name(self, name):
return name.replace(' ', '_').replace('-', '_')

def create_index_if_not_exists(self):
if self.driver is None:
self.connect()

with self.driver.session() as session:
# Create index on Entity node
try:
session.run("CREATE INDEX IF NOT EXISTS FOR (e:Entity) ON (e.id)")
logger.info("Created index on Entity.id")
except ClientError as e:
logger.warning(f"Failed to create index on Entity.id: {str(e)}")

for node in self.config['graph_structure']['nodes']:
label = node['label']
for prop in node['properties']:
sanitized_prop = self.sanitize_property_name(prop)
try:
session.run(f"CREATE INDEX IF NOT EXISTS FOR (n:{label}) ON (n.{sanitized_prop})")
logger.info(f"Created index on {label}.{sanitized_prop}")
except ClientError as e:
logger.warning(f"Failed to create index on {label}.{sanitized_prop}: {str(e)}")

# Create vector index on Entity node
vector_prop = self.sanitize_property_name(self.config['vector_property'])
try:
session.run(f"""
CREATE VECTOR INDEX {vector_prop}_index IF NOT EXISTS
FOR (e:Entity) ON (e.{vector_prop})
OPTIONS {{
indexConfig: {{
`vector.dimensions`: {self.config['vector_dimensions']},
`vector.similarity_function`: '{self.config['similarity_function']}'
}}
}}
""")
logger.info(f"Created vector index on Entity.{vector_prop}")
except ClientError as e:
logger.warning(f"Failed to create vector index: {str(e)}. This may be due to Neo4j version limitations.")

def build_cypher_query(self):
nodes = self.config['graph_structure']['nodes']
relationships = self.config['graph_structure']['relationships']

# Create Entity node
vector_prop = self.sanitize_property_name(self.config['vector_property'])
create_entity = f"CREATE (e:Entity {{id: row.id, {vector_prop}: row.embedding}})"

# Create specific nodes and connect to Entity
create_nodes = []
for node in nodes:
props = ", ".join([f"{self.sanitize_property_name(prop)}: row.metadata['{prop}']" for prop in node['properties']])
create_nodes.append(f"""
CREATE (n_{node['label']}:{node['label']} {{{props}}})
CREATE (e)-[:HAS_{node['label']}]->(n_{node['label']})
""")

# Create relationships between specific nodes
create_relationships = []
for rel in relationships:
create_relationships.append(f"CREATE (n_{rel['start_node']})-[:{rel['type']}]->(n_{rel['end_node']})")

# Combine all parts
cypher_query = f"""
UNWIND $batch AS row
{create_entity}
{" ".join(create_nodes)}
{" ".join(create_relationships)}
"""

return cypher_query

def write_data(self, df, columns, domain=None):
logger.info("Writing data to Neo4j...")
if self.driver is None:
self.connect()

self.create_index_if_not_exists()

def create_graph(tx, batch):
query = self.build_cypher_query()
tx.run(query, batch=batch)

batch_size = 1000 # Adjust based on your needs
total_processed = 0

with self.driver.session() as session:
for i in range(0, len(df), batch_size):
batch = []
for _, row in df.iloc[i:i+batch_size].iterrows():
node = {
'id': str(row['df_uuid']),
'embedding': row['embeddings'],
'metadata': {k: v for k, v in row.items() if k not in ['df_uuid', 'embeddings']}
}
if domain:
node['metadata']['domain'] = domain
batch.append(node)

session.execute_write(create_graph, batch)
total_processed += len(batch)
logger.info(f"Processed {total_processed} out of {len(df)} records.")

logger.info("Completed writing data to Neo4j.")

def close(self):
if self.driver:
self.driver.close()
logger.info("Neo4j connection closed.")

def __del__(self):
self.close()

0 comments on commit 3d8064e

Please sign in to comment.