diff --git a/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py b/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py index 1e5de15..a7310ad 100644 --- a/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py +++ b/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py @@ -179,12 +179,12 @@ def _write_entities(self, entities: List[Entity]) -> BulkWriteResult: operations = [] for entity in entities: relationships = entity.get("relationships", {}) - targets = relationships.get("targets", []) + target_ids = relationships.get("target_ids", []) types = relationships.get("types", []) attributes = relationships.get("attributes", []) - # Ensure the lengths of targets, types, and attributes align - if not (len(targets) == len(types) == len(attributes)): + # Ensure the lengths of target_ids, types, and attributes align + if not (len(target_ids) == len(types) == len(attributes)): logger.warning( f"Targets, types, and attributes do not have the same length for {entity['_id']}!" ) @@ -204,7 +204,7 @@ def _write_entities(self, entities: List[Entity]) -> BulkWriteResult: }, }, "$push": { # Push new entries into arrays - "relationships.targets": {"$each": targets}, + "relationships.target_ids": {"$each": target_ids}, "relationships.types": {"$each": types}, "relationships.attributes": {"$each": attributes}, }, @@ -320,8 +320,8 @@ def related_entities( { "$graphLookup": { "from": self.collection.name, - "startWith": "$relationships.targets", # Start traversal with relationships.targets - "connectFromField": "relationships.targets", # Traverse via relationships.targets + "startWith": "$relationships.target_ids", # Start traversal with relationships.target_ids + "connectFromField": "relationships.target_ids", # Traverse via relationships.target_ids "connectToField": "_id", # Match to entity _id field "as": "connections", # Store connections "maxDepth": max_depth or self.max_depth, # Limit traversal depth diff --git a/libs/langchain-mongodb/langchain_mongodb/graphrag/prompts.py b/libs/langchain-mongodb/langchain_mongodb/graphrag/prompts.py index ad2d7e8..05b8f25 100644 --- a/libs/langchain-mongodb/langchain_mongodb/graphrag/prompts.py +++ b/libs/langchain-mongodb/langchain_mongodb/graphrag/prompts.py @@ -79,12 +79,16 @@ Each object must conform to the following schema: {entity_schema} -## Input Example + +## Examples +Use the following examples to guide your work. + +#### Input Alice Palace, has been the CEO of MongoDB since January 1, 2018. She maintains close friendships with Jarnail Singh, whom she has known since May 1, 2019, and Jasbinder Kaur, who she has been seeing weekly since May 1, 2015. -## Output Example +#### Output (If `allowed_entity_types` is ["Person"] and `allowed_relationship_types` is ["Friend"]) {{ "entities": [ @@ -96,7 +100,7 @@ "startDate": ["2018-01-01"] }}, "relationships": {{ - "targets": ["Jasbinder Kaur", "Jarnail Singh"], + "target_ids": ["Jasbinder Kaur", "Jarnail Singh"], "types": ["Friend", "Friend"], "attributes": [ {{ "since": ["2019-05-01"], "frequency": ["weekly"] }}, @@ -106,6 +110,8 @@ }} ] }} + + """ diff --git a/libs/langchain-mongodb/langchain_mongodb/graphrag/schema.py b/libs/langchain-mongodb/langchain_mongodb/graphrag/schema.py index 5a4c3dc..8979aa1 100644 --- a/libs/langchain-mongodb/langchain_mongodb/graphrag/schema.py +++ b/libs/langchain-mongodb/langchain_mongodb/graphrag/schema.py @@ -38,23 +38,23 @@ "relationships": { "bsonType": "object", "description": "Key-value pairs of relationships", - "required": ["targets"], + "required": ["target_ids"], "properties": { - "targets": { + "target_ids": { "bsonType": "array", "description": "name/_id values of the target entities", "items": {"bsonType": "string"}, }, "types": { "bsonType": "array", - "description": "An array of relationships to corresponding targets (in same array position).", + "description": "An array of relationships to corresponding target_ids (in same array position).", "items": {"bsonType": "string"}, # Note: When constrained, predefined types are added. For example: # "enum": ["used_in", "owns", "written_by", "located_in"], # Predefined types }, "attributes": { "bsonType": "array", - "description": "An array of attributes describing the relationships to corresponding targets (in same array position). Each element is an object containing key-value pairs, where values are arrays of strings.", + "description": "An array of attributes describing the relationships to corresponding target_ids (in same array position). Each element is an object containing key-value pairs, where values are arrays of strings.", "items": { "bsonType": "object", "additionalProperties": { diff --git a/libs/langchain-mongodb/tests/integration_tests/test_graphrag.py b/libs/langchain-mongodb/tests/integration_tests/test_graphrag.py index 427bdcc..ac5a53e 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_graphrag.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_graphrag.py @@ -96,21 +96,18 @@ def entity_example(): "testing": "Integration Tests for new functionality, regression tests for bug-fixes" }}, "relationships": {{ - "plannedFeature": [ - {{ - "target": "PYTHON-1834", - "attributes": {{ - "description": "Auto code formatting" - }} - }} + "target_ids": [ + "PYTHON-1834", + "Node Team Practices" + + ] + "types": [ + "plannedFeature", + "reference" ], - "reference": [ - {{ - "target": "Node Team Practices", - "attributes": {{ - "url": "https://wiki.corp.mongodb.com/display/DRIVERS/Node+Team+Practices" - }} - }} + "attributes": [ + {{"description": "Auto code formatting"}}, + {{"url": "https://wiki.corp.mongodb.com/display/DRIVERS/Node+Team+Practices"}} ] }} }} @@ -164,7 +161,7 @@ def test_additional_entity_examples(entity_extraction_model, entity_example, doc db = client[DB_NAME] clxn_name = f"{COLLECTION_NAME}_addl_examples" db[clxn_name].drop() - collection = db[clxn_name] + collection = db.create_collection(clxn_name) store_with_addl_examples = MongoDBGraphStore( collection, entity_extraction_model, entity_examples=entity_example ) @@ -221,7 +218,7 @@ def test_allowed_entity_types(documents, entity_extraction_model): assert len(bulkwrite_results) == len(documents) entities = store.collection.find({}).to_list() assert set(e["type"] for e in entities) == {"Person"} - all([len(e["relationships"].get("targets", [])) == 0 for e in entities]) + all([len(e["relationships"].get("target_ids", [])) == 0 for e in entities]) all([len(e["relationships"].get("types", [])) == 0 for e in entities]) all([len(e["relationships"].get("attributes", [])) == 0 for e in entities])