Skip to content

Commit 7dd2330

Browse files
authored
INTPYTHON-452 Add hybrid retriever test with nested field (#54)
1 parent fe6f9eb commit 7dd2330

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,8 @@ def _similarity_search_with_score(
758758

759759
# Format
760760
for res in cursor:
761+
if self._text_key not in res:
762+
continue
761763
text = res.pop(self._text_key)
762764
score = res.pop("score")
763765
make_serializable(res)

libs/langchain-mongodb/tests/integration_tests/test_retrievers.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121

2222
DB_NAME = "langchain_test_db"
2323
COLLECTION_NAME = "langchain_test_retrievers"
24+
COLLECTION_NAME_NESTED = "langchain_test_retrievers_nested"
2425
VECTOR_INDEX_NAME = "vector_index"
2526
EMBEDDING_FIELD = "embedding"
2627
PAGE_CONTENT_FIELD = "text"
28+
PAGE_CONTENT_FIELD_NESTED = "title.text"
2729
SEARCH_INDEX_NAME = "text_index"
30+
SEARCH_INDEX_NAME_NESTED = "text_index_nested"
2831

2932
TIMEOUT = 60.0
3033
INTERVAL = 0.5
@@ -71,6 +74,39 @@ def collection(client: MongoClient, dimensions: int) -> Collection:
7174
return clxn
7275

7376

77+
@pytest.fixture(scope="module")
78+
def collection_nested(client: MongoClient, dimensions: int) -> Collection:
79+
"""A Collection with both a Vector and a Full-text Search Index"""
80+
if COLLECTION_NAME_NESTED not in client[DB_NAME].list_collection_names():
81+
clxn = client[DB_NAME].create_collection(COLLECTION_NAME_NESTED)
82+
else:
83+
clxn = client[DB_NAME][COLLECTION_NAME_NESTED]
84+
85+
clxn.delete_many({})
86+
87+
if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
88+
create_vector_search_index(
89+
collection=clxn,
90+
index_name=VECTOR_INDEX_NAME,
91+
dimensions=dimensions,
92+
path="embedding",
93+
similarity="cosine",
94+
wait_until_complete=TIMEOUT,
95+
)
96+
97+
if not any(
98+
[SEARCH_INDEX_NAME_NESTED == ix["name"] for ix in clxn.list_search_indexes()]
99+
):
100+
create_fulltext_search_index(
101+
collection=clxn,
102+
index_name=SEARCH_INDEX_NAME_NESTED,
103+
field=PAGE_CONTENT_FIELD_NESTED,
104+
wait_until_complete=TIMEOUT,
105+
)
106+
107+
return clxn
108+
109+
74110
@pytest.fixture(scope="module")
75111
def indexed_vectorstore(
76112
collection: Collection,
@@ -93,6 +129,28 @@ def indexed_vectorstore(
93129
vectorstore.collection.delete_many({})
94130

95131

132+
@pytest.fixture(scope="module")
133+
def indexed_nested_vectorstore(
134+
collection_nested: Collection,
135+
example_documents: List[Document],
136+
embedding: Embeddings,
137+
) -> Generator[MongoDBAtlasVectorSearch, None, None]:
138+
"""Return a VectorStore with example document embeddings indexed."""
139+
140+
vectorstore = PatchedMongoDBAtlasVectorSearch(
141+
collection=collection_nested,
142+
embedding=embedding,
143+
index_name=VECTOR_INDEX_NAME,
144+
text_key=PAGE_CONTENT_FIELD_NESTED,
145+
)
146+
147+
vectorstore.add_documents(example_documents)
148+
149+
yield vectorstore
150+
151+
vectorstore.collection.delete_many({})
152+
153+
96154
def test_vector_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None:
97155
"""Test VectorStoreRetriever"""
98156
retriever = indexed_vectorstore.as_retriever()
@@ -125,6 +183,26 @@ def test_hybrid_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch)
125183
assert "New Orleans" in results[0].page_content
126184

127185

186+
def test_hybrid_retriever_nested(
187+
indexed_nested_vectorstore: PatchedMongoDBAtlasVectorSearch,
188+
) -> None:
189+
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
190+
retriever = MongoDBAtlasHybridSearchRetriever(
191+
vectorstore=indexed_nested_vectorstore,
192+
search_index_name=SEARCH_INDEX_NAME_NESTED,
193+
top_k=3,
194+
)
195+
196+
query1 = "What did I visit France?"
197+
results = retriever.invoke(query1)
198+
assert len(results) == 3
199+
assert "Paris" in results[0].page_content
200+
201+
query2 = "When was the last time I visited new orleans?"
202+
results = retriever.invoke(query2)
203+
assert "New Orleans" in results[0].page_content
204+
205+
128206
def test_fulltext_retriever(
129207
indexed_vectorstore: PatchedMongoDBAtlasVectorSearch,
130208
) -> None:

0 commit comments

Comments
 (0)