Skip to content

Commit fd52b43

Browse files
committed
edits
1 parent dcdf271 commit fd52b43

File tree

5 files changed

+135
-50
lines changed

5 files changed

+135
-50
lines changed

django_mongodb_backend/features.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -634,9 +634,10 @@ def supports_atlas_search(self):
634634
# An existing collection must be used on MongoDB 6, otherwise
635635
# the operation will not error when unsupported.
636636
self.connection.get_collection("django_migrations").list_search_indexes()
637-
except OperationFailure as exc:
638-
if "$listSearchIndexes stage is only allowed on MongoDB Atlas" in str(exc):
639-
return False
640-
raise
637+
except OperationFailure:
638+
# It would be best to check the error message or error code to, but
639+
# they vary across MongoDB versions. Example: "$listSearchIndexes
640+
# stage is only allowed on MongoDB Atlas".
641+
return False
641642
else:
642643
return True

django_mongodb_backend/indexes.py

+13-21
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def check(self, model, connection):
131131
def search_index_data_types(self, db_type):
132132
"""
133133
Map a model field's type to search index type.
134-
Reference: https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types
134+
https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types
135135
"""
136136
if db_type in {"double", "int", "long"}:
137137
return "number"
@@ -163,17 +163,11 @@ def get_pymongo_index_model(
163163

164164
class VectorSearchIndex(SearchIndex):
165165
suffix = "vsi"
166-
VALID_SIMILARITIES = frozenset(("euclidean", "cosine", "dotProduct"))
166+
VALID_SIMILARITIES = frozenset(("cosine", "dotProduct", "euclidean"))
167167
VALID_FIELD_TYPES = frozenset(("boolean", "date", "number", "objectId", "string", "uuid"))
168168
_error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
169169

170-
def __init__(
171-
self,
172-
*,
173-
fields=(),
174-
similarities="cosine",
175-
name=None,
176-
):
170+
def __init__(self, *, fields=(), similarities="cosine", name=None):
177171
super().__init__(fields=fields, name=name)
178172
self.similarities = similarities
179173
self._multiple_similarities = isinstance(similarities, tuple | list)
@@ -191,11 +185,11 @@ def __init__(
191185

192186
def check(self, model, connection):
193187
errors = super().check(model, connection)
194-
expected_similarities = 0
188+
num_arrayfields = 0
195189
for field_name, _ in self.fields_orders:
196190
field = model._meta.get_field(field_name)
197191
if isinstance(field, ArrayField):
198-
expected_similarities += 1
192+
num_arrayfields += 1
199193
try:
200194
int(field.size)
201195
except (ValueError, TypeError):
@@ -219,7 +213,6 @@ def check(self, model, connection):
219213
)
220214
else:
221215
search_type = self.search_index_data_types(field.db_type(connection))
222-
# Validate allowed search types.
223216
if search_type not in self.VALID_FIELD_TYPES:
224217
errors.append(
225218
Error(
@@ -230,34 +223,33 @@ def check(self, model, connection):
230223
hint=f"Allowed types are {', '.join(sorted(self.VALID_FIELD_TYPES))}.",
231224
)
232225
)
233-
if self._multiple_similarities and expected_similarities != len(self.similarities):
226+
if self._multiple_similarities and num_arrayfields != len(self.similarities):
234227
errors.append(
235228
Error(
236229
f"VectorSearchIndex requires the same number of similarities "
237230
f"and vector fields; {model._meta.object_name} has "
238-
f"{expected_similarities} ArrayField(s) but similarities "
231+
f"{num_arrayfields} ArrayField(s) but similarities "
239232
f"has {len(self.similarities)} element(s).",
240233
obj=model,
241234
id=f"{self._error_id_prefix}.E005",
242235
)
243236
)
244-
# There isn't any vector.
245-
if expected_similarities == 0:
237+
if num_arrayfields == 0:
246238
errors.append(
247239
Error(
248-
"VectorSearchIndex requires at least one field containing vector data "
249-
"(e.g., an ArrayField(FloatField, size=10)). "
250-
"If you're aiming to perform search operations on other data types, "
251-
"consider using SearchIndex instead.",
240+
"VectorSearchIndex requires at least one ArrayField to " "store vector data.",
252241
obj=model,
253242
id=f"{self._error_id_prefix}.E006",
243+
hint="If you want to perform search operations without vectors, "
244+
"use SearchIndex instead.",
254245
)
255246
)
256247
return errors
257248

258249
def deconstruct(self):
259250
path, args, kwargs = super().deconstruct()
260-
kwargs["similarities"] = self.similarities
251+
if self.similarities != "cosine":
252+
kwargs["similarities"] = self.similarities
261253
return path, args, kwargs
262254

263255
def get_pymongo_index_model(

docs/source/ref/models/indexes.rst

+9-6
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,17 @@ model has multiple indexes).
3838
A subclass of :class:`SearchIndex` that creates a :doc:`vector search index
3939
<atlas:atlas-vector-search/vector-search-type>` on the given field(s).
4040

41-
Each index should references at least one vector field: an :class:`.ArrayField`
41+
The index should reference at least one vector field: an :class:`.ArrayField`
4242
with a :attr:`~.ArrayField.base_field` of :class:`~django.db.models.FloatField`
43-
or :class:`~django.db.models.IntegerField`.
43+
or :class:`~django.db.models.IntegerField`. It cannot reference an
44+
:class:`.ArrayField` of any other type. Each :class:`.ArrayField` must have a
45+
:attr:`~.ArrayField.size`.
4446

4547
It may also have other fields to filter on, provided the field stores
4648
``boolean``, ``date``, ``objectId``, ``numeric``, ``string``, or ``uuid``.
4749

48-
Available values for ``similarities`` are ``"euclidean"``, ``"cosine"``, and
49-
``"dotProduct"``. You can provide this value either a string, in which case
50-
that value will be applied to all vector fields, or a list or tuple of values
51-
with a similarity corresponding to each vector field.
50+
Available values for ``similarities`` are ``"cosine"``, ``"dotProduct"``, and
51+
``"euclidean"`` (see :ref:`atlas:avs-similarity-functions`). You can provide
52+
this value either a string, in which case that value will be applied to all
53+
vector fields, or a list or tuple of values with a similarity corresponding to
54+
each vector field.

tests/indexes_/test_checks.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ class Meta:
4141
def test_vector_search_requires_atlas_search_support(self):
4242
class Article(models.Model):
4343
title = models.CharField(max_length=10)
44-
vector_float = ArrayField(models.FloatField(), size=10)
44+
vector = ArrayField(models.FloatField(), size=10)
4545

4646
class Meta:
47-
indexes = [VectorSearchIndex(fields=["title", "vector_float"])]
47+
indexes = [VectorSearchIndex(fields=["title", "vector"])]
4848

4949
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
5050
self.assertEqual(
@@ -109,10 +109,10 @@ class Meta:
109109
def test_unsupported_type(self):
110110
class Article(models.Model):
111111
data = models.JSONField()
112-
vector_float = ArrayField(models.FloatField(), size=10)
112+
vector = ArrayField(models.FloatField(), size=10)
113113

114114
class Meta:
115-
indexes = [VectorSearchIndex(fields=["data", "vector_float"])]
115+
indexes = [VectorSearchIndex(fields=["data", "vector"])]
116116

117117
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
118118
self.assertEqual(
@@ -129,12 +129,12 @@ class Meta:
129129

130130
def test_invalid_number_similarity_function_singular(self):
131131
class Article(models.Model):
132-
vector_data = ArrayField(models.FloatField(), size=10)
132+
vector = ArrayField(models.FloatField(), size=10)
133133

134134
class Meta:
135135
indexes = [
136136
VectorSearchIndex(
137-
fields=["vector_data"],
137+
fields=["vector"],
138138
similarities=["dotProduct", "cosine"],
139139
)
140140
]
@@ -182,10 +182,10 @@ class Meta:
182182

183183
def test_simple(self):
184184
class Article(models.Model):
185-
vector_data = ArrayField(models.FloatField(), size=10)
185+
vector = ArrayField(models.FloatField(), size=10)
186186

187187
class Meta:
188-
indexes = [VectorSearchIndex(fields=["vector_data"])]
188+
indexes = [VectorSearchIndex(fields=["vector"])]
189189

190190
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
191191
self.assertEqual(errors, [])
@@ -234,11 +234,11 @@ class Meta:
234234
errors,
235235
[
236236
checks.Error(
237-
"VectorSearchIndex requires at least one field containing vector data "
238-
"(e.g., an ArrayField(FloatField, size=10)). If you're aiming to perform "
239-
"search operations on other data types, consider using SearchIndex instead.",
237+
"VectorSearchIndex requires at least one ArrayField to " "store vector data.",
240238
id="django_mongodb_backend.indexes.VectorSearchIndex.E006",
241239
obj=NoSearchVectorModel,
240+
hint="If you want to perform search operations without vectors, "
241+
"use SearchIndex instead.",
242242
),
243243
],
244244
)

tests/indexes_/test_search_indexes.py

+97-8
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,32 @@
1212
@mock.patch.object(connection.features, "supports_atlas_search", False)
1313
class UnsupportedSearchIndexesTests(TestCase):
1414
def test_search_index_not_created(self):
15-
with connection.schema_editor() as editor:
16-
index = SearchIndex(name="recent_test_idx", fields=["number"])
17-
with self.assertNumQueries(0):
18-
editor.add_index(index=index, model=SearchIndexTestModel)
15+
index = SearchIndex(name="recent_test_idx", fields=["number"])
16+
with connection.schema_editor() as editor, self.assertNumQueries(0):
17+
editor.add_index(index=index, model=SearchIndexTestModel)
1918
self.assertNotIn(
2019
index.name,
2120
connection.introspection.get_constraints(
2221
cursor=None,
2322
table_name=SearchIndexTestModel._meta.db_table,
2423
),
2524
)
25+
with connection.schema_editor() as editor, self.assertNumQueries(0):
26+
editor.remove_index(index=index, model=SearchIndexTestModel)
2627

2728
def test_vector_index_not_created(self):
28-
with connection.schema_editor() as editor:
29-
index = VectorSearchIndex(name="recent_test_idx", fields=["number"])
30-
with self.assertNumQueries(0):
31-
editor.add_index(index=index, model=SearchIndexTestModel)
29+
index = VectorSearchIndex(name="recent_test_idx", fields=["number"])
30+
with connection.schema_editor() as editor, self.assertNumQueries(0):
31+
editor.add_index(index=index, model=SearchIndexTestModel)
3232
self.assertNotIn(
3333
index.name,
3434
connection.introspection.get_constraints(
3535
cursor=None,
3636
table_name=SearchIndexTestModel._meta.db_table,
3737
),
3838
)
39+
with connection.schema_editor() as editor, self.assertNumQueries(0):
40+
editor.remove_index(index=index, model=SearchIndexTestModel)
3941

4042

4143
class SearchIndexTests(SimpleTestCase):
@@ -68,6 +70,7 @@ def test_no_extra_kargs(self):
6870
def test_deconstruct(self):
6971
index = VectorSearchIndex(name="recent_test_idx", fields=["number"])
7072
name, args, kwargs = index.deconstruct()
73+
self.assertEqual(kwargs, {"name": "recent_test_idx", "fields": ["number"]})
7174
new = VectorSearchIndex(*args, **kwargs)
7275
self.assertEqual(new.similarities, index.similarities)
7376

@@ -255,3 +258,89 @@ def test_multiple_fields(self):
255258
finally:
256259
with connection.schema_editor() as editor:
257260
editor.remove_index(index=index, model=SearchIndexTestModel)
261+
262+
def test_similarities_value(self):
263+
index = VectorSearchIndex(
264+
name="recent_test_idx",
265+
fields=["vector_float", "vector_integer"],
266+
similarities="euclidean",
267+
)
268+
with connection.schema_editor() as editor:
269+
editor.add_index(index=index, model=SearchIndexTestModel)
270+
try:
271+
index_info = connection.introspection.get_constraints(
272+
cursor=None,
273+
table_name=SearchIndexTestModel._meta.db_table,
274+
)
275+
expected_options = {
276+
"latestDefinition": {
277+
"fields": [
278+
{
279+
"numDimensions": 10,
280+
"path": "vector_float",
281+
"similarity": "euclidean",
282+
"type": "vector",
283+
},
284+
{
285+
"numDimensions": 10,
286+
"path": "vector_integer",
287+
"similarity": "euclidean",
288+
"type": "vector",
289+
},
290+
]
291+
},
292+
"latestVersion": 0,
293+
"name": "recent_test_idx",
294+
"queryable": False,
295+
"type": "vectorSearch",
296+
}
297+
self.assertCountEqual(index_info[index.name]["columns"], index.fields)
298+
index_info[index.name]["options"].pop("id")
299+
index_info[index.name]["options"].pop("status")
300+
self.assertEqual(index_info[index.name]["options"], expected_options)
301+
finally:
302+
with connection.schema_editor() as editor:
303+
editor.remove_index(index=index, model=SearchIndexTestModel)
304+
305+
def test_similarities_list(self):
306+
index = VectorSearchIndex(
307+
name="recent_test_idx",
308+
fields=["vector_float", "vector_integer"],
309+
similarities=["cosine", "euclidean"],
310+
)
311+
with connection.schema_editor() as editor:
312+
editor.add_index(index=index, model=SearchIndexTestModel)
313+
try:
314+
index_info = connection.introspection.get_constraints(
315+
cursor=None,
316+
table_name=SearchIndexTestModel._meta.db_table,
317+
)
318+
expected_options = {
319+
"latestDefinition": {
320+
"fields": [
321+
{
322+
"numDimensions": 10,
323+
"path": "vector_float",
324+
"similarity": "cosine",
325+
"type": "vector",
326+
},
327+
{
328+
"numDimensions": 10,
329+
"path": "vector_integer",
330+
"similarity": "euclidean",
331+
"type": "vector",
332+
},
333+
]
334+
},
335+
"latestVersion": 0,
336+
"name": "recent_test_idx",
337+
"queryable": False,
338+
"type": "vectorSearch",
339+
}
340+
self.assertCountEqual(index_info[index.name]["columns"], index.fields)
341+
index_info[index.name]["options"].pop("id")
342+
index_info[index.name]["options"].pop("status")
343+
self.assertEqual(index_info[index.name]["options"], expected_options)
344+
finally:
345+
with connection.schema_editor() as editor:
346+
editor.remove_index(index=index, model=SearchIndexTestModel)

0 commit comments

Comments
 (0)