Skip to content

Commit

Permalink
Changed targets to target_ids. Updated test that added example entity
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyclements committed Jan 28, 2025
1 parent 485997b commit 0d01839
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 29 deletions.
12 changes: 6 additions & 6 deletions libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ def _write_entities(self, entities: List[Entity]) -> BulkWriteResult:
operations = []
for entity in entities:
relationships = entity.get("relationships", {})
targets = relationships.get("targets", [])
target_ids = relationships.get("target_ids", [])
types = relationships.get("types", [])
attributes = relationships.get("attributes", [])

# Ensure the lengths of targets, types, and attributes align
if not (len(targets) == len(types) == len(attributes)):
# Ensure the lengths of target_ids, types, and attributes align
if not (len(target_ids) == len(types) == len(attributes)):
logger.warning(
f"Targets, types, and attributes do not have the same length for {entity['_id']}!"
)
Expand All @@ -204,7 +204,7 @@ def _write_entities(self, entities: List[Entity]) -> BulkWriteResult:
},
},
"$push": { # Push new entries into arrays
"relationships.targets": {"$each": targets},
"relationships.target_ids": {"$each": target_ids},
"relationships.types": {"$each": types},
"relationships.attributes": {"$each": attributes},
},
Expand Down Expand Up @@ -320,8 +320,8 @@ def related_entities(
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$relationships.targets", # Start traversal with relationships.targets
"connectFromField": "relationships.targets", # Traverse via relationships.targets
"startWith": "$relationships.target_ids", # Start traversal with relationships.target_ids
"connectFromField": "relationships.target_ids", # Traverse via relationships.target_ids
"connectToField": "_id", # Match to entity _id field
"as": "connections", # Store connections
"maxDepth": max_depth or self.max_depth, # Limit traversal depth
Expand Down
12 changes: 9 additions & 3 deletions libs/langchain-mongodb/langchain_mongodb/graphrag/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,16 @@
Each object must conform to the following schema:
{entity_schema}
## Input Example
## Examples
Use the following examples to guide your work.
#### Input
Alice Palace, has been the CEO of MongoDB since January 1, 2018.
She maintains close friendships with Jarnail Singh, whom she has known since May 1, 2019,
and Jasbinder Kaur, who she has been seeing weekly since May 1, 2015.
## Output Example
#### Output
(If `allowed_entity_types` is ["Person"] and `allowed_relationship_types` is ["Friend"])
{{
"entities": [
Expand All @@ -96,7 +100,7 @@
"startDate": ["2018-01-01"]
}},
"relationships": {{
"targets": ["Jasbinder Kaur", "Jarnail Singh"],
"target_ids": ["Jasbinder Kaur", "Jarnail Singh"],
"types": ["Friend", "Friend"],
"attributes": [
{{ "since": ["2019-05-01"], "frequency": ["weekly"] }},
Expand All @@ -106,6 +110,8 @@
}}
]
}}
"""


Expand Down
8 changes: 4 additions & 4 deletions libs/langchain-mongodb/langchain_mongodb/graphrag/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,23 @@
"relationships": {
"bsonType": "object",
"description": "Key-value pairs of relationships",
"required": ["targets"],
"required": ["target_ids"],
"properties": {
"targets": {
"target_ids": {
"bsonType": "array",
"description": "name/_id values of the target entities",
"items": {"bsonType": "string"},
},
"types": {
"bsonType": "array",
"description": "An array of relationships to corresponding targets (in same array position).",
"description": "An array of relationships to corresponding target_ids (in same array position).",
"items": {"bsonType": "string"},
# Note: When constrained, predefined types are added. For example:
# "enum": ["used_in", "owns", "written_by", "located_in"], # Predefined types
},
"attributes": {
"bsonType": "array",
"description": "An array of attributes describing the relationships to corresponding targets (in same array position). Each element is an object containing key-value pairs, where values are arrays of strings.",
"description": "An array of attributes describing the relationships to corresponding target_ids (in same array position). Each element is an object containing key-value pairs, where values are arrays of strings.",
"items": {
"bsonType": "object",
"additionalProperties": {
Expand Down
29 changes: 13 additions & 16 deletions libs/langchain-mongodb/tests/integration_tests/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,18 @@ def entity_example():
"testing": "Integration Tests for new functionality, regression tests for bug-fixes"
}},
"relationships": {{
"plannedFeature": [
{{
"target": "PYTHON-1834",
"attributes": {{
"description": "Auto code formatting"
}}
}}
"target_ids": [
"PYTHON-1834",
"Node Team Practices"
]
"types": [
"plannedFeature",
"reference"
],
"reference": [
{{
"target": "Node Team Practices",
"attributes": {{
"url": "https://wiki.corp.mongodb.com/display/DRIVERS/Node+Team+Practices"
}}
}}
"attributes": [
{{"description": "Auto code formatting"}},
{{"url": "https://wiki.corp.mongodb.com/display/DRIVERS/Node+Team+Practices"}}
]
}}
}}
Expand Down Expand Up @@ -164,7 +161,7 @@ def test_additional_entity_examples(entity_extraction_model, entity_example, doc
db = client[DB_NAME]
clxn_name = f"{COLLECTION_NAME}_addl_examples"
db[clxn_name].drop()
collection = db[clxn_name]
collection = db.create_collection(clxn_name)
store_with_addl_examples = MongoDBGraphStore(
collection, entity_extraction_model, entity_examples=entity_example
)
Expand Down Expand Up @@ -221,7 +218,7 @@ def test_allowed_entity_types(documents, entity_extraction_model):
assert len(bulkwrite_results) == len(documents)
entities = store.collection.find({}).to_list()
assert set(e["type"] for e in entities) == {"Person"}
all([len(e["relationships"].get("targets", [])) == 0 for e in entities])
all([len(e["relationships"].get("target_ids", [])) == 0 for e in entities])
all([len(e["relationships"].get("types", [])) == 0 for e in entities])
all([len(e["relationships"].get("attributes", [])) == 0 for e in entities])

Expand Down

0 comments on commit 0d01839

Please sign in to comment.