Skip to content

Commit 52eb282

Browse files
committed
Fix emf flow and add subquery unit test
1 parent 1d393b8 commit 52eb282

File tree

2 files changed

+33
-40
lines changed

2 files changed

+33
-40
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from django.db import models
55
from django.db.models import Field
66
from django.db.models.expressions import Col
7-
from django.db.models.lookups import Transform
7+
from django.db.models.lookups import Lookup, Transform
88

99
from .. import forms
1010
from ..query_utils import process_lhs, process_rhs
1111
from . import EmbeddedModelField
1212
from .array import ArrayField
13-
from .embedded_model import EMFExact
13+
from .embedded_model import EMFExact, EMFMixin
1414

1515

1616
class EmbeddedModelArrayField(ArrayField):
@@ -63,17 +63,8 @@ def get_transform(self, name):
6363
return KeyTransformFactory(name, self)
6464

6565

66-
class ProcessRHSMixin:
67-
def process_rhs(self, compiler, connection):
68-
if isinstance(self.lhs, KeyTransform):
69-
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
70-
else:
71-
get_db_prep_value = self.lhs.output_field.get_db_prep_value
72-
return None, [get_db_prep_value(v, connection, prepared=True) for v in self.rhs]
73-
74-
7566
@EmbeddedModelArrayField.register_lookup
76-
class EMFArrayExact(EMFExact, ProcessRHSMixin):
67+
class EMFArrayExact(EMFExact):
7768
def as_mql(self, compiler, connection):
7869
lhs_mql = process_lhs(self, compiler, connection)
7970
value = process_rhs(self, compiler, connection)
@@ -116,12 +107,29 @@ def as_mql(self, compiler, connection):
116107

117108

118109
@EmbeddedModelArrayField.register_lookup
119-
class ArrayOverlap(EMFExact, ProcessRHSMixin):
110+
class ArrayOverlap(EMFMixin, Lookup):
120111
lookup_name = "overlap"
112+
get_db_prep_lookup_value_is_iterable = True
113+
114+
def process_rhs(self, compiler, connection):
115+
values = self.rhs
116+
if self.get_db_prep_lookup_value_is_iterable:
117+
values = [values]
118+
# Compute how to serialize each value based on the query target.
119+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
120+
# field of the subfield. Otherwise, use the base field of the array itself.
121+
if isinstance(self.lhs, KeyTransform):
122+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
123+
else:
124+
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
125+
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]
121126

122127
def as_mql(self, compiler, connection):
123128
lhs_mql = process_lhs(self, compiler, connection)
124129
values = process_rhs(self, compiler, connection)
130+
# Querying a subfield within the array elements (via nested KeyTransform).
131+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
132+
# `$in` on the subfield.
125133
if isinstance(self.lhs, KeyTransform):
126134
lhs_mql, inner_lhs_mql = lhs_mql
127135
return {
@@ -140,11 +148,12 @@ def as_mql(self, compiler, connection):
140148
}
141149
conditions = []
142150
inner_lhs_mql = "$$item"
151+
# Querying full embedded documents in the array.
152+
# Builds `$or` conditions and maps them over the array to match any full document.
143153
for value in values:
144-
if isinstance(value, models.Model):
145-
value, emf_data = self.model_to_dict(value)
146-
# Get conditions for any nested EmbeddedModelFields.
147-
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
154+
value, emf_data = self.model_to_dict(value)
155+
# Get conditions for any nested EmbeddedModelFields.
156+
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
148157
return {
149158
"$anyElementTrue": {
150159
"$ifNull": [

tests/model_fields_/test_embedded_model.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -300,36 +300,20 @@ def test_overlap_emf(self):
300300
[self.clouds],
301301
)
302302

303-
"""
304-
def test_overlap_charfield_including_expression(self):
305-
obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
306-
obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
307-
CharArrayModel.objects.create(field=["lower text", "text"])
308-
self.assertSequenceEqual(
309-
CharArrayModel.objects.filter(
310-
field__overlap=[
311-
Upper(Value("text")),
312-
"other",
313-
]
314-
),
315-
[obj_1, obj_2],
316-
)
317-
318303
def test_overlap_values(self):
319-
qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
304+
qs = Movie.objects.filter(title__in=["Clouds", "Frozen"])
320305
self.assertCountEqual(
321-
NullableIntegerArrayModel.objects.filter(
322-
field__overlap=qs.values_list("field"),
306+
Movie.objects.filter(
307+
reviews__overlap=qs.values_list("reviews"),
323308
),
324-
self.objs[:3],
309+
[self.clouds, self.frozen],
325310
)
326311
self.assertCountEqual(
327-
NullableIntegerArrayModel.objects.filter(
328-
field__overlap=qs.values("field"),
312+
Movie.objects.filter(
313+
reviews__overlap=qs.values("reviews"),
329314
),
330-
self.objs[:3],
315+
[self.clouds, self.frozen],
331316
)
332-
"""
333317

334318

335319
class QueryingTests(TestCase):

0 commit comments

Comments
 (0)