Skip to content

Add SearchIndex and VectorSearchIndex #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 102 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
3bad834
Refactor index creation
WaVEV Feb 16, 2025
bd8f3fc
stubed up method
WaVEV Feb 22, 2025
a5b6826
Atlas index creation
WaVEV Feb 24, 2025
f987e7f
Test
WaVEV Feb 24, 2025
ea32d94
Fix mapping typing issue
WaVEV Feb 25, 2025
01cd73a
Refactor unit test
WaVEV Feb 26, 2025
d96fdbd
Add test with data
WaVEV Feb 26, 2025
44a6dcd
Add atlas vector search index
WaVEV Mar 3, 2025
f6fec5e
ObjectId Fields does not support integers
WaVEV Mar 4, 2025
d77b6a9
Refactor: remove singledispatchmethod
WaVEV Mar 4, 2025
ed25939
Add directConnection in testing settings
WaVEV Mar 7, 2025
1cb3a7e
Add lenght validator
WaVEV Mar 9, 2025
cf7284e
Add unit test
WaVEV Mar 9, 2025
ffcc0dd
Using fixed_size instead of size
WaVEV Mar 9, 2025
9a50963
Testing CI
WaVEV Mar 9, 2025
028898e
[Testing CI] Set mongo image version tag
WaVEV Mar 9, 2025
fdc0c39
Edits
WaVEV Mar 11, 2025
9fbf6a9
Testing CI
WaVEV Mar 11, 2025
ef3d3c5
Remove test with data
WaVEV Mar 12, 2025
59726da
Move mongo_data_types mapping
WaVEV Mar 12, 2025
c51b69a
Pumping mongo version to fix some unit test
WaVEV Mar 20, 2025
a4a3076
Add validators test
WaVEV Mar 20, 2025
9fe24d5
Refactor fixed_size and size in array field
WaVEV Mar 21, 2025
c456b00
Add create_search_index to OperationDebugWrapper
WaVEV Mar 21, 2025
8a08fd9
change tuple for frozenset
WaVEV Mar 21, 2025
fc87a80
Refactor
WaVEV Mar 21, 2025
951795d
Move import to the top
WaVEV Mar 21, 2025
4272630
Add drop_search_index and list_search_indexes to OperationDebugWrapper
WaVEV Mar 22, 2025
a36d11a
rename size and fixed_size
WaVEV Mar 22, 2025
09101a5
Docstring and search type mapping
WaVEV Mar 22, 2025
01cec0e
Refactor if conditions
WaVEV Mar 22, 2025
0a9771f
Add unit tests
WaVEV Mar 22, 2025
cadf68c
Simplify with
WaVEV Mar 22, 2025
985754c
Add deconstruct method
WaVEV Mar 24, 2025
5a19e26
Remove duplicate function
WaVEV Mar 26, 2025
33d1f76
Update django_mongodb_backend/introspection.py
WaVEV Mar 27, 2025
5a81285
Remove redundant tests
WaVEV Mar 27, 2025
dcd22df
Index prefix cannot be longer than 3 chars
WaVEV Mar 27, 2025
d9b1941
Check function in VectorSearchIndex.
WaVEV Mar 27, 2025
8b8e960
Add parameter connection to the vector search index check
WaVEV Mar 28, 2025
926cfbb
add system check and unit test
WaVEV Mar 28, 2025
db62c59
Refactor imports
WaVEV Mar 28, 2025
d0f9927
Fix invalid similarity checker
WaVEV Mar 28, 2025
f784c9a
Add similarities unit test
WaVEV Mar 28, 2025
7dfd5bf
Remove check test, were moved to system check
WaVEV Mar 28, 2025
a6ccf88
Remove valitador get_pymongo_index_model.
WaVEV Apr 1, 2025
de09466
Add docstring
WaVEV Apr 1, 2025
6e6043d
remove validators unit test
WaVEV Apr 1, 2025
75aac72
stub docs
timgraham Apr 6, 2025
262069b
Add supports_search_indexes flag
WaVEV Apr 9, 2025
731f349
Edits
WaVEV Apr 10, 2025
50d9b10
Test
WaVEV Apr 10, 2025
34dc035
Create an dummy collection when check atlas support
WaVEV Apr 12, 2025
10fdf69
Edits
WaVEV Apr 12, 2025
56e89a1
Add search index support check in search indexes
WaVEV Apr 12, 2025
feaead9
Rename feature flag to supports_atlas_search
WaVEV Apr 15, 2025
e0fe059
Edit warning message
WaVEV Apr 15, 2025
3ad033b
Dont drop dummy collection
WaVEV Apr 15, 2025
1d39d32
Refactor unit test
WaVEV Apr 16, 2025
1cb2d3f
Edit the error message
WaVEV Apr 16, 2025
21ba183
Remove unused field
WaVEV Apr 16, 2025
9a392e6
register checks
WaVEV Apr 16, 2025
2d6b574
Move system_check/test to indexes_
WaVEV Apr 16, 2025
debb056
Return None if the index cannot be created
WaVEV Apr 16, 2025
3cbae58
Simplify atlas search check compatibility
WaVEV Apr 16, 2025
36a9cf6
Docstring
WaVEV Apr 16, 2025
56e323f
Fix dropping atlas index when atlas is unsuported
WaVEV Apr 16, 2025
3ccb17e
Fix dropping a non existing atlas index
WaVEV Apr 16, 2025
63d74e7
Add field repeated checks in vector search.
WaVEV Apr 17, 2025
d7526e4
trim blank spaces
WaVEV Apr 17, 2025
4ac6a90
Refactor unit test
WaVEV Apr 17, 2025
a9f0d8b
Return None if the index cannot be created
WaVEV Apr 17, 2025
71d649e
Rename field
WaVEV Apr 17, 2025
3fa4b81
Revert GitHub Action
WaVEV Apr 17, 2025
07dae4c
Add atlas CI
WaVEV Apr 17, 2025
74b711a
Docstring
WaVEV Apr 18, 2025
5be8cea
Remove dead code
WaVEV Apr 18, 2025
b0d3482
Rename check index function
WaVEV Apr 18, 2025
afed7e7
Edit Error message
WaVEV Apr 18, 2025
c3b0b4c
Use edit literal instead of tuple.
WaVEV Apr 18, 2025
cfa9f01
Drop search index if atlas is supported
WaVEV Apr 18, 2025
82aa32e
Added system check for similarities and vector fields count mismatch.
WaVEV Apr 18, 2025
cad0ebf
Simplify if
WaVEV Apr 18, 2025
8dbd0f1
Parametrize mongo atlas ci
WaVEV Apr 18, 2025
ed97399
Check similiarities function in __init__
WaVEV Apr 18, 2025
c5a45fb
Fix error message
WaVEV Apr 18, 2025
6b84168
edits
timgraham Apr 19, 2025
80fffa5
Move field repeated check to init
WaVEV Apr 20, 2025
0e987fc
Amend errors messages
WaVEV Apr 20, 2025
69984f9
Refactor unit test
WaVEV Apr 20, 2025
ad3d6ec
edits
timgraham Apr 22, 2025
c74604e
Fix unit test
WaVEV Apr 22, 2025
b868f59
Replace decimal in vector index for integer
WaVEV Apr 23, 2025
3596a99
Handle UUID as string and define embeddedfield.db_type()
WaVEV Apr 23, 2025
4792689
Add support for arrayfield in SearchIndex.
WaVEV Apr 23, 2025
5fa8fbc
Edits
WaVEV Apr 23, 2025
3f43d46
Improve test coverage.
WaVEV Apr 23, 2025
0d6f719
reuse assertAddRemoveIndex
timgraham Apr 24, 2025
252f364
edits
timgraham Apr 24, 2025
0f28afe
more edits
timgraham Apr 25, 2025
dcdf271
Add check when VectorSearchIndex has no vector field
WaVEV Apr 25, 2025
c59297c
edits
timgraham Apr 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/mongodb_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"default": {
"ENGINE": "django_mongodb_backend",
"NAME": "djangotests",
# Required when connecting to the Atlas image in Docker.
"OPTIONS": {"directConnection": True},
},
"other": {
"ENGINE": "django_mongodb_backend",
"NAME": "djangotests-other",
"OPTIONS": {"directConnection": True},
},
}

Expand Down
24 changes: 24 additions & 0 deletions .github/workflows/start_local_atlas.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
set -eu

echo "Starting the container"

IMAGE=${1:-mongodb/mongodb-atlas-local:latest}
DOCKER=$(which docker || which podman)

$DOCKER pull $IMAGE

$DOCKER kill mongodb_atlas_local || true

CONTAINER_ID=$($DOCKER run --rm -d --name mongodb_atlas_local -p 27017:27017 $IMAGE)

function wait() {
CONTAINER_ID=$1
echo "waiting for container to become healthy..."
$DOCKER logs mongodb_atlas_local
}

wait "$CONTAINER_ID"

# Sleep for a bit to let all services start.
sleep 5
56 changes: 56 additions & 0 deletions .github/workflows/test-python-atlas.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: Python Tests on Atlas

on:
pull_request:
paths:
- '**.py'
- '!setup.py'
- '.github/workflows/test-python-atlas.yml'
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -eux {0}

jobs:
build:
name: Django Test Suite
runs-on: ubuntu-latest
steps:
- name: Checkout django-mongodb-backend
uses: actions/checkout@v4
with:
persist-credentials: false
- name: install django-mongodb-backend
run: |
pip3 install --upgrade pip
pip3 install -e .
- name: Checkout Django
uses: actions/checkout@v4
with:
repository: 'mongodb-forks/django'
ref: 'mongodb-5.2.x'
path: 'django_repo'
persist-credentials: false
- name: Install system packages for Django's Python test dependencies
run: |
sudo apt-get update
sudo apt-get install libmemcached-dev
- name: Install Django and its Python test dependencies
run: |
cd django_repo/tests/
pip3 install -e ..
pip3 install -r requirements/py3.txt
- name: Copy the test settings file
run: cp .github/workflows/mongodb_settings.py django_repo/tests/
- name: Copy the test runner file
run: cp .github/workflows/runtests.py django_repo/tests/runtests_.py
- name: Start local Atlas
working-directory: .
run: bash .github/workflows/start_local_atlas.sh mongodb/mongodb-atlas-local:7
- name: Run tests
run: python3 django_repo/tests/runtests_.py
2 changes: 2 additions & 0 deletions django_mongodb_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
check_django_compatability()

from .aggregates import register_aggregates # noqa: E402
from .checks import register_checks # noqa: E402
from .expressions import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
Expand All @@ -17,6 +18,7 @@
__all__ = ["parse_uri"]

register_aggregates()
register_checks()
register_expressions()
register_fields()
register_functions()
Expand Down
32 changes: 32 additions & 0 deletions django_mongodb_backend/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from itertools import chain

from django.apps import apps
from django.core.checks import Tags, register
from django.db import connections, router


def check_indexes(app_configs, databases=None, **kwargs): # noqa: ARG001
"""
Call Index.check() on all model indexes.

This function will be obsolete when Django calls Index.check() after
https://code.djangoproject.com/ticket/36273.
"""
errors = []
if app_configs is None:
models = apps.get_models()
else:
models = chain.from_iterable(app_config.get_models() for app_config in app_configs)
for model in models:
for db in databases or ():
if not router.allow_migrate_model(db, model):
continue
connection = connections[db]
for model_index in model._meta.indexes:
if hasattr(model_index, "check"):
errors.extend(model_index.check(model, connection))
return errors


def register_checks():
register(check_indexes, Tags.models)
16 changes: 16 additions & 0 deletions django_mongodb_backend/features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
from pymongo.errors import OperationFailure


class DatabaseFeatures(BaseDatabaseFeatures):
Expand Down Expand Up @@ -625,3 +626,18 @@ def django_test_expected_failures(self):
@cached_property
def is_mongodb_6_3(self):
return self.connection.get_database_version() >= (6, 3)

@cached_property
def supports_atlas_search(self):
"""Does the server support Atlas search queries and search indexes?"""
try:
# An existing collection must be used on MongoDB 6, otherwise
# the operation will not error when unsupported.
self.connection.get_collection("django_migrations").list_search_indexes()
except OperationFailure:
# It would be best to check the error message or error code to, but
# they vary across MongoDB versions. Example: "$listSearchIndexes
# stage is only allowed on MongoDB Atlas".
return False
else:
return True
3 changes: 3 additions & 0 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self, embedded_model, *args, **kwargs):
self.embedded_model = embedded_model
super().__init__(*args, **kwargs)

def db_type(self, connection):
return "embeddedDocuments"

def check(self, **kwargs):
from ..models import EmbeddedModel

Expand Down
184 changes: 182 additions & 2 deletions django_mongodb_backend/indexes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import itertools
from collections import defaultdict

from django.core.checks import Error, Warning
from django.db import NotSupportedError
from django.db.models import Index
from django.db.models import FloatField, Index, IntegerField
from django.db.models.lookups import BuiltinLookup
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, XOR, WhereNode
from pymongo import ASCENDING, DESCENDING
from pymongo.operations import IndexModel
from pymongo.operations import IndexModel, SearchIndexModel

from django_mongodb_backend.fields import ArrayField

from .query_utils import process_rhs

Expand Down Expand Up @@ -101,6 +105,182 @@ def where_node_idx(self, compiler, connection):
return mql


class SearchIndex(Index):
suffix = "six"
_error_id_prefix = "django_mongodb_backend.indexes.SearchIndex"

def __init__(self, *, fields=(), name=None):
super().__init__(fields=fields, name=name)

def check(self, model, connection):
errors = []
if not connection.features.supports_atlas_search:
errors.append(
Warning(
f"This MongoDB server does not support {self.__class__.__name__}.",
hint=(
"The index won't be created. Use an Atlas-enabled version of MongoDB, "
"or silence this warning if you don't care about it."
),
obj=model,
id=f"{self._error_id_prefix}.W001",
)
)
return errors

def search_index_data_types(self, db_type):
"""
Map a model field's type to search index type.
https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types
"""
if db_type in {"double", "int", "long"}:
return "number"
if db_type == "binData":
return "string"
if db_type == "bool":
return "boolean"
if db_type == "object":
return "document"
if db_type == "array":
return "embeddedDocuments"
return db_type

def get_pymongo_index_model(
self, model, schema_editor, field=None, unique=False, column_prefix=""
):
if not schema_editor.connection.features.supports_atlas_search:
return None
fields = {}
for field_name, _ in self.fields_orders:
field = model._meta.get_field(field_name)
type_ = self.search_index_data_types(field.db_type(schema_editor.connection))
field_path = column_prefix + model._meta.get_field(field_name).column
fields[field_path] = {"type": type_}
return SearchIndexModel(
definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name
)


class VectorSearchIndex(SearchIndex):
suffix = "vsi"
VALID_SIMILARITIES = frozenset(("cosine", "dotProduct", "euclidean"))
VALID_FIELD_TYPES = frozenset(("boolean", "date", "number", "objectId", "string", "uuid"))
_error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"

def __init__(self, *, fields=(), similarities="cosine", name=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you decide that cosine should be the default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, any time that I worked with embedding, this similarity was the first that I tried. Except when the data was normalized (in that case, cosine and dot product gives the same result and dot product is faster). L2 norm is used but less than cosine in semantics searches.
In order to simplify the index, I decided to put cosine in default.
I have no preference to put similarities as a needed parameter

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading the Atlas docs, I didn't get the sense that there is a sensible default that works for most situations, but I really don't know, so let's ask the team's opinion on Monday.

super().__init__(fields=fields, name=name)
self.similarities = similarities
self._multiple_similarities = isinstance(similarities, tuple | list)
for func in similarities if self._multiple_similarities else (similarities,):
if func not in self.VALID_SIMILARITIES:
raise ValueError(
f"'{func}' isn't a valid similarity function "
f"({', '.join(sorted(self.VALID_SIMILARITIES))})."
)
seen_fields = set()
for field_name, _ in self.fields_orders:
if field_name in seen_fields:
raise ValueError(f"Field '{field_name}' is duplicated in fields.")
seen_fields.add(field_name)

def check(self, model, connection):
errors = super().check(model, connection)
num_arrayfields = 0
for field_name, _ in self.fields_orders:
field = model._meta.get_field(field_name)
if isinstance(field, ArrayField):
num_arrayfields += 1
try:
int(field.size)
except (ValueError, TypeError):
errors.append(
Error(
f"VectorSearchIndex requires 'size' on field '{field_name}'.",
obj=model,
id=f"{self._error_id_prefix}.E002",
)
)
if not isinstance(field.base_field, FloatField | IntegerField):
errors.append(
Error(
"VectorSearchIndex requires the base field of "
f"ArrayField '{field.name}' to be FloatField or "
"IntegerField but is "
f"{field.base_field.get_internal_type()}.",
obj=model,
id=f"{self._error_id_prefix}.E003",
)
)
else:
search_type = self.search_index_data_types(field.db_type(connection))
if search_type not in self.VALID_FIELD_TYPES:
errors.append(
Error(
"VectorSearchIndex does not support field "
f"'{field_name}' ({field.get_internal_type()}).",
obj=model,
id=f"{self._error_id_prefix}.E004",
hint=f"Allowed types are {', '.join(sorted(self.VALID_FIELD_TYPES))}.",
)
)
if self._multiple_similarities and num_arrayfields != len(self.similarities):
errors.append(
Error(
f"VectorSearchIndex requires the same number of similarities "
f"and vector fields; {model._meta.object_name} has "
f"{num_arrayfields} ArrayField(s) but similarities "
f"has {len(self.similarities)} element(s).",
obj=model,
id=f"{self._error_id_prefix}.E005",
)
)
if num_arrayfields == 0:
errors.append(
Error(
"VectorSearchIndex requires at least one ArrayField to " "store vector data.",
obj=model,
id=f"{self._error_id_prefix}.E006",
hint="If you want to perform search operations without vectors, "
"use SearchIndex instead.",
)
)
return errors

def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.similarities != "cosine":
kwargs["similarities"] = self.similarities
return path, args, kwargs

def get_pymongo_index_model(
self, model, schema_editor, field=None, unique=False, column_prefix=""
):
if not schema_editor.connection.features.supports_atlas_search:
return None
similarities = (
itertools.cycle([self.similarities])
if not self._multiple_similarities
else iter(self.similarities)
)
fields = []
for field_name, _ in self.fields_orders:
field_ = model._meta.get_field(field_name)
field_path = column_prefix + model._meta.get_field(field_name).column
mappings = {"path": field_path}
if isinstance(field_, ArrayField):
mappings.update(
{
"type": "vector",
"numDimensions": int(field_.size),
"similarity": next(similarities),
}
)
else:
mappings["type"] = "filter"
fields.append(mappings)
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")


def register_indexes():
BuiltinLookup.as_mql_idx = builtin_lookup_idx
Index._get_condition_mql = _get_condition_mql
Expand Down
Loading