Skip to content

Commit 8804c87

Browse files
WaVEVtimgraham
andcommitted
Add SearchIndex and VectorSearchIndex
Co-authored-by: Tim Graham <[email protected]>
1 parent 38cf2b4 commit 8804c87

File tree

20 files changed

+976
-6
lines changed

20 files changed

+976
-6
lines changed

.github/workflows/mongodb_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717
"default": {
1818
"ENGINE": "django_mongodb_backend",
1919
"NAME": "djangotests",
20+
# Required when connecting to the Atlas image in Docker.
21+
"OPTIONS": {"directConnection": True},
2022
},
2123
"other": {
2224
"ENGINE": "django_mongodb_backend",
2325
"NAME": "djangotests-other",
26+
"OPTIONS": {"directConnection": True},
2427
},
2528
}
2629

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
set -eu
3+
4+
echo "Starting the container"
5+
6+
IMAGE=${1:-mongodb/mongodb-atlas-local:latest}
7+
DOCKER=$(which docker || which podman)
8+
9+
$DOCKER pull $IMAGE
10+
11+
$DOCKER kill mongodb_atlas_local || true
12+
13+
CONTAINER_ID=$($DOCKER run --rm -d --name mongodb_atlas_local -p 27017:27017 $IMAGE)
14+
15+
function wait() {
16+
CONTAINER_ID=$1
17+
echo "waiting for container to become healthy..."
18+
$DOCKER logs mongodb_atlas_local
19+
}
20+
21+
wait "$CONTAINER_ID"
22+
23+
# Sleep for a bit to let all services start.
24+
sleep 5
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
name: Python Tests on Atlas
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- '**.py'
7+
- '!setup.py'
8+
- '.github/workflows/test-python-atlas.yml'
9+
workflow_dispatch:
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: true
14+
15+
defaults:
16+
run:
17+
shell: bash -eux {0}
18+
19+
jobs:
20+
build:
21+
name: Django Test Suite
22+
runs-on: ubuntu-latest
23+
steps:
24+
- name: Checkout django-mongodb-backend
25+
uses: actions/checkout@v4
26+
with:
27+
persist-credentials: false
28+
- name: install django-mongodb-backend
29+
run: |
30+
pip3 install --upgrade pip
31+
pip3 install -e .
32+
- name: Checkout Django
33+
uses: actions/checkout@v4
34+
with:
35+
repository: 'mongodb-forks/django'
36+
ref: 'mongodb-5.2.x'
37+
path: 'django_repo'
38+
persist-credentials: false
39+
- name: Install system packages for Django's Python test dependencies
40+
run: |
41+
sudo apt-get update
42+
sudo apt-get install libmemcached-dev
43+
- name: Install Django and its Python test dependencies
44+
run: |
45+
cd django_repo/tests/
46+
pip3 install -e ..
47+
pip3 install -r requirements/py3.txt
48+
- name: Copy the test settings file
49+
run: cp .github/workflows/mongodb_settings.py django_repo/tests/
50+
- name: Copy the test runner file
51+
run: cp .github/workflows/runtests.py django_repo/tests/runtests_.py
52+
- name: Start local Atlas
53+
working-directory: .
54+
run: bash .github/workflows/start_local_atlas.sh mongodb/mongodb-atlas-local:7
55+
- name: Run tests
56+
run: python3 django_repo/tests/runtests_.py

django_mongodb_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
check_django_compatability()
88

99
from .aggregates import register_aggregates # noqa: E402
10+
from .checks import register_checks # noqa: E402
1011
from .expressions import register_expressions # noqa: E402
1112
from .fields import register_fields # noqa: E402
1213
from .functions import register_functions # noqa: E402
@@ -17,6 +18,7 @@
1718
__all__ = ["parse_uri"]
1819

1920
register_aggregates()
21+
register_checks()
2022
register_expressions()
2123
register_fields()
2224
register_functions()

django_mongodb_backend/checks.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from itertools import chain
2+
3+
from django.apps import apps
4+
from django.core.checks import Tags, register
5+
from django.db import connections, router
6+
7+
8+
def check_indexes(app_configs, databases=None, **kwargs): # noqa: ARG001
9+
"""
10+
Call Index.check() on all model indexes.
11+
12+
This function will be obsolete when Django calls Index.check() after
13+
https://code.djangoproject.com/ticket/36273.
14+
"""
15+
errors = []
16+
if app_configs is None:
17+
models = apps.get_models()
18+
else:
19+
models = chain.from_iterable(app_config.get_models() for app_config in app_configs)
20+
for model in models:
21+
for db in databases or ():
22+
if not router.allow_migrate_model(db, model):
23+
continue
24+
connection = connections[db]
25+
for model_index in model._meta.indexes:
26+
if hasattr(model_index, "check"):
27+
errors.extend(model_index.check(model, connection))
28+
return errors
29+
30+
31+
def register_checks():
32+
register(check_indexes, Tags.models)

django_mongodb_backend/features.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.db.backends.base.features import BaseDatabaseFeatures
22
from django.utils.functional import cached_property
3+
from pymongo.errors import OperationFailure
34

45

56
class DatabaseFeatures(BaseDatabaseFeatures):
@@ -548,3 +549,19 @@ def django_test_expected_failures(self):
548549
@cached_property
549550
def is_mongodb_6_3(self):
550551
return self.connection.get_database_version() >= (6, 3)
552+
553+
@cached_property
554+
def supports_atlas_search(self):
555+
"""Does the server support Atlas search queries and search indexes?"""
556+
try:
557+
# An existing collection must be used on MongoDB 6, otherwise
558+
# the operation will not error when unsupported.
559+
self.connection.get_collection("django_migrations").list_search_indexes()
560+
except OperationFailure:
561+
# It would be best to check the error message or error code to
562+
# avoid hiding some other exception, but the message/code varies
563+
# across MongoDB versions. Example error message:
564+
# "$listSearchIndexes stage is only allowed on MongoDB Atlas".
565+
return False
566+
else:
567+
return True

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def __init__(self, embedded_model, *args, **kwargs):
2020
self.embedded_model = embedded_model
2121
super().__init__(*args, **kwargs)
2222

23+
def db_type(self, connection):
24+
return "embeddedDocuments"
25+
2326
def check(self, **kwargs):
2427
from ..models import EmbeddedModel
2528

django_mongodb_backend/indexes.py

Lines changed: 181 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import itertools
12
from collections import defaultdict
23

4+
from django.core.checks import Error, Warning
35
from django.db import NotSupportedError
4-
from django.db.models import Index
6+
from django.db.models import FloatField, Index, IntegerField
57
from django.db.models.lookups import BuiltinLookup
68
from django.db.models.sql.query import Query
79
from django.db.models.sql.where import AND, XOR, WhereNode
810
from pymongo import ASCENDING, DESCENDING
9-
from pymongo.operations import IndexModel
11+
from pymongo.operations import IndexModel, SearchIndexModel
12+
13+
from django_mongodb_backend.fields import ArrayField
1014

1115
from .query_utils import process_rhs
1216

@@ -101,6 +105,181 @@ def where_node_idx(self, compiler, connection):
101105
return mql
102106

103107

108+
class SearchIndex(Index):
109+
suffix = "six"
110+
_error_id_prefix = "django_mongodb_backend.indexes.SearchIndex"
111+
112+
def __init__(self, *, fields=(), name=None):
113+
super().__init__(fields=fields, name=name)
114+
115+
def check(self, model, connection):
116+
errors = []
117+
if not connection.features.supports_atlas_search:
118+
errors.append(
119+
Warning(
120+
f"This MongoDB server does not support {self.__class__.__name__}.",
121+
hint=(
122+
"The index won't be created. Use an Atlas-enabled version of MongoDB, "
123+
"or silence this warning if you don't care about it."
124+
),
125+
obj=model,
126+
id=f"{self._error_id_prefix}.W001",
127+
)
128+
)
129+
return errors
130+
131+
def search_index_data_types(self, db_type):
132+
"""
133+
Map a model field's type to search index type.
134+
https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types
135+
"""
136+
if db_type in {"double", "int", "long"}:
137+
return "number"
138+
if db_type == "binData":
139+
return "string"
140+
if db_type == "bool":
141+
return "boolean"
142+
if db_type == "object":
143+
return "document"
144+
if db_type == "array":
145+
return "embeddedDocuments"
146+
return db_type
147+
148+
def get_pymongo_index_model(
149+
self, model, schema_editor, field=None, unique=False, column_prefix=""
150+
):
151+
if not schema_editor.connection.features.supports_atlas_search:
152+
return None
153+
fields = {}
154+
for field_name, _ in self.fields_orders:
155+
field = model._meta.get_field(field_name)
156+
type_ = self.search_index_data_types(field.db_type(schema_editor.connection))
157+
field_path = column_prefix + model._meta.get_field(field_name).column
158+
fields[field_path] = {"type": type_}
159+
return SearchIndexModel(
160+
definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name
161+
)
162+
163+
164+
class VectorSearchIndex(SearchIndex):
165+
suffix = "vsi"
166+
_error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
167+
VALID_FIELD_TYPES = frozenset(("boolean", "date", "number", "objectId", "string", "uuid"))
168+
VALID_SIMILARITIES = frozenset(("cosine", "dotProduct", "euclidean"))
169+
170+
def __init__(self, *, fields=(), name=None, similarities):
171+
super().__init__(fields=fields, name=name)
172+
self.similarities = similarities
173+
self._multiple_similarities = isinstance(similarities, tuple | list)
174+
for func in similarities if self._multiple_similarities else (similarities,):
175+
if func not in self.VALID_SIMILARITIES:
176+
raise ValueError(
177+
f"'{func}' isn't a valid similarity function "
178+
f"({', '.join(sorted(self.VALID_SIMILARITIES))})."
179+
)
180+
seen_fields = set()
181+
for field_name, _ in self.fields_orders:
182+
if field_name in seen_fields:
183+
raise ValueError(f"Field '{field_name}' is duplicated in fields.")
184+
seen_fields.add(field_name)
185+
186+
def check(self, model, connection):
187+
errors = super().check(model, connection)
188+
num_arrayfields = 0
189+
for field_name, _ in self.fields_orders:
190+
field = model._meta.get_field(field_name)
191+
if isinstance(field, ArrayField):
192+
num_arrayfields += 1
193+
try:
194+
int(field.size)
195+
except (ValueError, TypeError):
196+
errors.append(
197+
Error(
198+
f"VectorSearchIndex requires 'size' on field '{field_name}'.",
199+
obj=model,
200+
id=f"{self._error_id_prefix}.E002",
201+
)
202+
)
203+
if not isinstance(field.base_field, FloatField | IntegerField):
204+
errors.append(
205+
Error(
206+
"VectorSearchIndex requires the base field of "
207+
f"ArrayField '{field.name}' to be FloatField or "
208+
"IntegerField but is "
209+
f"{field.base_field.get_internal_type()}.",
210+
obj=model,
211+
id=f"{self._error_id_prefix}.E003",
212+
)
213+
)
214+
else:
215+
search_type = self.search_index_data_types(field.db_type(connection))
216+
if search_type not in self.VALID_FIELD_TYPES:
217+
errors.append(
218+
Error(
219+
"VectorSearchIndex does not support field "
220+
f"'{field_name}' ({field.get_internal_type()}).",
221+
obj=model,
222+
id=f"{self._error_id_prefix}.E004",
223+
hint=f"Allowed types are {', '.join(sorted(self.VALID_FIELD_TYPES))}.",
224+
)
225+
)
226+
if self._multiple_similarities and num_arrayfields != len(self.similarities):
227+
errors.append(
228+
Error(
229+
f"VectorSearchIndex requires the same number of similarities "
230+
f"and vector fields; {model._meta.object_name} has "
231+
f"{num_arrayfields} ArrayField(s) but similarities "
232+
f"has {len(self.similarities)} element(s).",
233+
obj=model,
234+
id=f"{self._error_id_prefix}.E005",
235+
)
236+
)
237+
if num_arrayfields == 0:
238+
errors.append(
239+
Error(
240+
"VectorSearchIndex requires at least one ArrayField to " "store vector data.",
241+
obj=model,
242+
id=f"{self._error_id_prefix}.E006",
243+
hint="If you want to perform search operations without vectors, "
244+
"use SearchIndex instead.",
245+
)
246+
)
247+
return errors
248+
249+
def deconstruct(self):
250+
path, args, kwargs = super().deconstruct()
251+
kwargs["similarities"] = self.similarities
252+
return path, args, kwargs
253+
254+
def get_pymongo_index_model(
255+
self, model, schema_editor, field=None, unique=False, column_prefix=""
256+
):
257+
if not schema_editor.connection.features.supports_atlas_search:
258+
return None
259+
similarities = (
260+
itertools.cycle([self.similarities])
261+
if not self._multiple_similarities
262+
else iter(self.similarities)
263+
)
264+
fields = []
265+
for field_name, _ in self.fields_orders:
266+
field_ = model._meta.get_field(field_name)
267+
field_path = column_prefix + model._meta.get_field(field_name).column
268+
mappings = {"path": field_path}
269+
if isinstance(field_, ArrayField):
270+
mappings.update(
271+
{
272+
"type": "vector",
273+
"numDimensions": int(field_.size),
274+
"similarity": next(similarities),
275+
}
276+
)
277+
else:
278+
mappings["type"] = "filter"
279+
fields.append(mappings)
280+
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")
281+
282+
104283
def register_indexes():
105284
BuiltinLookup.as_mql_idx = builtin_lookup_idx
106285
Index._get_condition_mql = _get_condition_mql

0 commit comments

Comments
 (0)