Skip to content

Commit bd33b3f

Browse files
committed
add system check and unit test
1 parent 26aa350 commit bd33b3f

File tree

3 files changed

+116
-0
lines changed

3 files changed

+116
-0
lines changed

django_mongodb_backend/checks.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from itertools import chain
2+
3+
from django.apps import apps
4+
from django.core.checks import Tags, register
5+
from django.db import connections, router
6+
7+
from django_mongodb_backend.indexes import VectorSearchIndex
8+
9+
10+
@register(Tags.models)
11+
def check_vector_search_indexes(app_configs, databases=None, **kwargs): # noqa: ARG001
12+
errors = []
13+
if app_configs is None:
14+
models = apps.get_models()
15+
else:
16+
models = chain.from_iterable(app_config.get_models() for app_config in app_configs)
17+
for model in models:
18+
for db in databases or ():
19+
if not router.allow_migrate_model(db, model):
20+
continue
21+
connection = connections[db]
22+
for model_index in model._meta.indexes:
23+
if not isinstance(model_index, VectorSearchIndex):
24+
continue
25+
errors.extend(model_index.check(model, connection))
26+
return errors

tests/system_checks/__init__.py

Whitespace-only changes.

tests/system_checks/tests.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from django.core import checks
2+
from django.db import models
3+
from django.test import SimpleTestCase
4+
from django.test.utils import (
5+
isolate_apps,
6+
override_system_checks,
7+
)
8+
9+
from django_mongodb_backend.checks import check_vector_search_indexes
10+
from django_mongodb_backend.fields import ArrayField
11+
from django_mongodb_backend.indexes import VectorSearchIndex
12+
13+
14+
@isolate_apps("system_checks", attr_name="apps")
15+
@override_system_checks([check_vector_search_indexes])
16+
class InvalidSearchIndexesTest(SimpleTestCase):
17+
def test_vectorsearch_requires_size(self):
18+
class Article(models.Model):
19+
title_embedded = ArrayField(models.FloatField())
20+
21+
class Meta:
22+
indexes = [
23+
VectorSearchIndex(fields=["title_embedded"]),
24+
]
25+
26+
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
27+
self.assertEqual(
28+
errors,
29+
[
30+
checks.Error(
31+
"Atlas vector search requires size.",
32+
id="django_mongodb_backend.indexes.VectorSearchIndex.E001",
33+
obj=Article._meta.indexes[0],
34+
)
35+
],
36+
)
37+
38+
def test_vectorsearch_requires_float_inner_field(self):
39+
class Article(models.Model):
40+
title_embedded = ArrayField(models.CharField(), size=30)
41+
42+
class Meta:
43+
indexes = [
44+
VectorSearchIndex(fields=["title_embedded"]),
45+
]
46+
47+
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
48+
self.assertEqual(
49+
errors,
50+
[
51+
checks.Error(
52+
"Base type must be Float or Decimal.",
53+
id="django_mongodb_backend.indexes.VectorSearchIndex.E002",
54+
obj=Article._meta.indexes[0],
55+
)
56+
],
57+
)
58+
59+
def test_vectorsearch_unsupported_type(self):
60+
class Article(models.Model):
61+
data = models.JSONField()
62+
63+
class Meta:
64+
indexes = [
65+
VectorSearchIndex(fields=["data"]),
66+
]
67+
68+
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
69+
self.assertEqual(
70+
errors,
71+
[
72+
checks.Error(
73+
"Unsupported filter of type JSONField.",
74+
id="django_mongodb_backend.indexes.VectorSearchIndex.E003",
75+
obj=Article._meta.indexes[0],
76+
)
77+
],
78+
)
79+
80+
def test_vectorsearch(self):
81+
class Article(models.Model):
82+
vector_data = ArrayField(models.DecimalField(), size=10)
83+
84+
class Meta:
85+
indexes = [
86+
VectorSearchIndex(fields=["vector_data"]),
87+
]
88+
89+
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
90+
self.assertEqual(errors, [])

0 commit comments

Comments
 (0)