From 485997bba25608b87b3b0d53157f52c33b4ab0ea Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Fri, 24 Jan 2025 10:30:21 -0500 Subject: [PATCH] Touch ups to docstrings. --- .../langchain_mongodb/graphrag/graph.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py b/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py index d6d8bd5..1e5de15 100644 --- a/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py +++ b/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py @@ -24,9 +24,9 @@ logger = logging.getLogger(__name__) -# Represents an entity in the knowledge graph with _id, type, attributes, and relationships fields -# See .schema for full schema + Entity: TypeAlias = Dict[str, Any] +"""Represents an Entity in the knowledge graph with specific schema. See .schema""" class MongoDBGraphStore: @@ -248,7 +248,7 @@ def extract_entities(self, raw_document: str, **kwargs: Any) -> List[Entity]: Returns: List of Entity dictionaries. """ - # Combine the llm with the prompt template to form a chain + # Combine the LLM with the prompt template to form a chain chain: RunnableSequence = self.entity_prompt | self.entity_extraction_model # Invoke on a document to extract entities and relationships response: AIMessage = chain.invoke( @@ -290,17 +290,19 @@ def extract_entity_names(self, raw_document: str, **kwargs: Any) -> List[str]: ) return json.loads(json_string) - def find_entity_by_name(self, name: str) -> Optional[List[Entity]]: + def find_entity_by_name(self, name: str) -> Optional[Entity]: """Utility to get Entity dict from Knowledge Graph / Collection. Args: name: _id string to look for Returns: List of Entity dicts if any match name """ - return list(self.collection.find({"_id": name})) + return self.collection.find_one({"_id": name}) def related_entities( - self, starting_entities: List[str], max_depth=3 + self, + starting_entities: List[str], + max_depth: Optional[int] = None, ) -> List[Entity]: """Traverse Graph along relationship edges to find connected entities. @@ -322,7 +324,7 @@ def related_entities( "connectFromField": "relationships.targets", # Traverse via relationships.targets "connectToField": "_id", # Match to entity _id field "as": "connections", # Store connections - "maxDepth": 3, # Limit traversal depth + "maxDepth": max_depth or self.max_depth, # Limit traversal depth "depthField": "depth", # Track depth } }, @@ -392,7 +394,7 @@ def chat_response( """Responds to a query given information found in Knowledge Graph. Args: - query: Query to send the chat_model + query: Prompt before it is augmented by Knowledge Graph. chat_model: ChatBot. Defaults to entity_extraction_model. prompt: Alternative Prompt Template. Defaults to prompts.rag_prompt Returns: @@ -405,7 +407,7 @@ def chat_response( # Perform Retrieval on knowledge graph related_entities = self.similarity_search(query) - # Combine the llm with the prompt template to form a chain + # Combine the LLM with the prompt template to form a chain chain: RunnableSequence = prompt | chat_model # Invoke with query return chain.invoke(