Skip to content

Commit 3596a99

Browse files
WaVEVtimgraham
authored andcommitted
Handle UUID as string and define embeddedfield.db_type()
1 parent b868f59 commit 3596a99

File tree

4 files changed

+27
-42
lines changed

4 files changed

+27
-42
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def __init__(self, embedded_model, *args, **kwargs):
2121
self.embedded_model = embedded_model
2222
super().__init__(*args, **kwargs)
2323

24+
def db_type(self, connection):
25+
return "embeddedDocuments"
26+
2427
def check(self, **kwargs):
2528
from ..models import EmbeddedModel
2629

django_mongodb_backend/indexes.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,12 @@ def check(self, model, connection):
125125
)
126126
return errors
127127

128-
def search_index_data_types(self, field, db_type):
128+
def search_index_data_types(self, db_type):
129129
"""
130130
Map a model field's internal type to search index type.
131131
Reference: https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types
132132
"""
133-
if field.get_internal_type() == "UUIDField":
134-
return "uuid"
135-
if field.get_internal_type() in {"ObjectIdAutoField", "ObjectIdField"}:
136-
return "ObjectId"
137-
if field.get_internal_type() == "EmbeddedModelField":
138-
return "embeddedDocuments"
139-
if db_type in {"int", "long"}:
133+
if db_type in {"double", "int", "long"}:
140134
return "number"
141135
if db_type == "binData":
142136
return "string"
@@ -153,8 +147,8 @@ def get_pymongo_index_model(
153147
return None
154148
fields = {}
155149
for field_name, _ in self.fields_orders:
156-
field_ = model._meta.get_field(field_name)
157-
type_ = self.search_index_data_types(field_, field_.db_type(schema_editor.connection))
150+
field = model._meta.get_field(field_name)
151+
type_ = self.search_index_data_types(field.db_type(schema_editor.connection))
158152
field_path = column_prefix + model._meta.get_field(field_name).column
159153
fields[field_path] = {"type": type_}
160154
return SearchIndexModel(
@@ -212,8 +206,7 @@ def check(self, model, connection):
212206
)
213207
)
214208
else:
215-
field_type = field_.db_type(connection)
216-
search_type = self.search_index_data_types(field_, field_type)
209+
search_type = self.search_index_data_types(field_.db_type(connection))
217210
# filter - for fields that contain boolean, date, objectId,
218211
# numeric, string, or UUID values. Reference:
219212
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields
@@ -253,7 +246,7 @@ def get_pymongo_index_model(
253246
return None
254247
similarities = (
255248
itertools.cycle([self.similarities])
256-
if isinstance(self.similarities, str)
249+
if not self._multiple_similarities
257250
else iter(self.similarities)
258251
)
259252
fields = []

tests/indexes_/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.db import models
22

3-
from django_mongodb_backend.fields import ArrayField, EmbeddedModelField
3+
from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField
44
from django_mongodb_backend.models import EmbeddedModel
55

66

@@ -10,6 +10,7 @@ class Data(EmbeddedModel):
1010

1111
class Article(models.Model):
1212
headline = models.CharField(max_length=100)
13+
object_id = ObjectIdField()
1314
number = models.IntegerField()
1415
body = models.TextField()
1516
data = models.JSONField()

tests/indexes_/test_search_indexes.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -79,43 +79,30 @@ def test_define_field_twice(self):
7979
)
8080

8181

82-
class SchemaAssertionMixin:
83-
def assertAddRemoveIndex(self, editor, model, index):
84-
with self.assertNumQueries(1):
85-
editor.add_index(index=index, model=model)
86-
self.assertIn(
87-
index.name,
88-
connection.introspection.get_constraints(
89-
cursor=None,
90-
table_name=model._meta.db_table,
91-
),
92-
)
93-
editor.remove_index(index=index, model=model)
94-
self.assertNotIn(
95-
index.name,
96-
connection.introspection.get_constraints(
97-
cursor=None,
98-
table_name=model._meta.db_table,
99-
),
100-
)
101-
102-
10382
@skipUnlessDBFeature("supports_atlas_search")
104-
class SearchIndexSchemaTests(SchemaAssertionMixin, TestCase):
83+
class SearchIndexSchemaTests(TestCase):
10584
def test_simple(self):
10685
with connection.schema_editor() as editor:
10786
index = SearchIndex(
10887
name="recent_article_idx",
10988
fields=["number"],
11089
)
11190
editor.add_index(index=index, model=Article)
112-
self.assertAddRemoveIndex(editor, Article, index)
91+
editor.remove_index(index=index, model=Article)
11392

11493
def test_multiple_fields(self):
11594
with connection.schema_editor() as editor:
11695
index = SearchIndex(
11796
name="recent_article_idx",
118-
fields=["headline", "number", "body", "data", "embedded", "created_at"],
97+
fields=[
98+
"headline",
99+
"number",
100+
"body",
101+
"data",
102+
"embedded",
103+
"created_at",
104+
"object_id",
105+
],
119106
)
120107
editor.add_index(index=index, model=Article)
121108
index_info = connection.introspection.get_constraints(
@@ -146,20 +133,21 @@ def test_multiple_fields(self):
146133
"representation": "double",
147134
"type": "number",
148135
},
136+
"object_id": {"type": "objectId"},
149137
},
150138
}
151139
self.assertCountEqual(index_info[index.name]["columns"], index.fields)
152140
self.assertEqual(index_info[index.name]["options"], expected_options)
153-
self.assertAddRemoveIndex(editor, Article, index)
141+
editor.remove_index(index=index, model=Article)
154142

155143

156144
@skipUnlessDBFeature("supports_atlas_search")
157-
class VectorSearchIndexSchemaTests(SchemaAssertionMixin, TestCase):
145+
class VectorSearchIndexSchemaTests(TestCase):
158146
def test_simple_vector_search(self):
159147
with connection.schema_editor() as editor:
160148
index = VectorSearchIndex(name="recent_article_idx", fields=["number"])
161149
editor.add_index(index=index, model=Article)
162-
self.assertAddRemoveIndex(editor, Article, index)
150+
editor.remove_index(index=index, model=Article)
163151

164152
def test_multiple_fields(self):
165153
with connection.schema_editor() as editor:
@@ -195,4 +183,4 @@ def test_multiple_fields(self):
195183
index_info[index.name]["options"].pop("id")
196184
index_info[index.name]["options"].pop("status")
197185
self.assertEqual(index_info[index.name]["options"], expected_options)
198-
self.assertAddRemoveIndex(editor, Article, index)
186+
editor.remove_index(index=index, model=Article)

0 commit comments

Comments
 (0)