Skip to content

Commit 1d393b8

Browse files
committed
Adding support for overlap
1 parent fdea0c2 commit 1d393b8

File tree

3 files changed

+115
-11
lines changed

3 files changed

+115
-11
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
326326
def as_mql(self, compiler, connection):
327327
lhs_mql = process_lhs(self, compiler, connection)
328328
value = process_rhs(self, compiler, connection)
329-
return {
330-
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
331-
}
329+
return {"$and": [{"$isArray": lhs_mql}, {"$size": {"$setIntersection": [value, lhs_mql]}}]}
332330

333331

334332
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,20 @@ def get_transform(self, name):
6060
transform = super().get_transform(name)
6161
if transform:
6262
return transform
63-
return KeyTransformFactory(name, self.base_field)
63+
return KeyTransformFactory(name, self)
64+
65+
66+
class ProcessRHSMixin:
67+
def process_rhs(self, compiler, connection):
68+
if isinstance(self.lhs, KeyTransform):
69+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
70+
else:
71+
get_db_prep_value = self.lhs.output_field.get_db_prep_value
72+
return None, [get_db_prep_value(v, connection, prepared=True) for v in self.rhs]
6473

6574

6675
@EmbeddedModelArrayField.register_lookup
67-
class EMFArrayExact(EMFExact):
76+
class EMFArrayExact(EMFExact, ProcessRHSMixin):
6877
def as_mql(self, compiler, connection):
6978
lhs_mql = process_lhs(self, compiler, connection)
7079
value = process_rhs(self, compiler, connection)
@@ -106,15 +115,61 @@ def as_mql(self, compiler, connection):
106115
}
107116

108117

118+
@EmbeddedModelArrayField.register_lookup
119+
class ArrayOverlap(EMFExact, ProcessRHSMixin):
120+
lookup_name = "overlap"
121+
122+
def as_mql(self, compiler, connection):
123+
lhs_mql = process_lhs(self, compiler, connection)
124+
values = process_rhs(self, compiler, connection)
125+
if isinstance(self.lhs, KeyTransform):
126+
lhs_mql, inner_lhs_mql = lhs_mql
127+
return {
128+
"$anyElementTrue": {
129+
"$ifNull": [
130+
{
131+
"$map": {
132+
"input": lhs_mql,
133+
"as": "item",
134+
"in": {"$in": [inner_lhs_mql, values]},
135+
}
136+
},
137+
[],
138+
]
139+
}
140+
}
141+
conditions = []
142+
inner_lhs_mql = "$$item"
143+
for value in values:
144+
if isinstance(value, models.Model):
145+
value, emf_data = self.model_to_dict(value)
146+
# Get conditions for any nested EmbeddedModelFields.
147+
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
148+
return {
149+
"$anyElementTrue": {
150+
"$ifNull": [
151+
{
152+
"$map": {
153+
"input": lhs_mql,
154+
"as": "item",
155+
"in": {"$or": conditions},
156+
}
157+
},
158+
[],
159+
]
160+
}
161+
}
162+
163+
109164
class KeyTransform(Transform):
110165
# it should be different class than EMF keytransform even most of the methods are equal.
111-
def __init__(self, key_name, base_field, *args, **kwargs):
166+
def __init__(self, key_name, array_field, *args, **kwargs):
112167
super().__init__(*args, **kwargs)
113-
self.base_field = base_field
168+
self.array_field = array_field
114169
self.key_name = key_name
115170
# The iteration items begins from the base_field, a virtual column with
116171
# base field output type is created.
117-
column_target = base_field.clone()
172+
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
118173
column_name = f"$item.{key_name}"
119174
column_target.db_column = column_name
120175
column_target.set_attributes_from_name(column_name)
@@ -137,7 +192,7 @@ def _get_missing_field_or_lookup_exception(self, lhs, name):
137192
suggestion = "."
138193
raise FieldDoesNotExist(
139194
f"Unsupported lookup '{name}' for "
140-
f"{self.base_field.__class__.__name__} '{self.base_field.name}'"
195+
f"{self.array_field.base_field.__class__.__name__} '{self.array_field.base_field.name}'"
141196
f"{suggestion}"
142197
)
143198

@@ -150,7 +205,9 @@ def get_transform(self, name):
150205
transform = (
151206
self._lhs.get_transform(name)
152207
if isinstance(self._lhs, Transform)
153-
else self.base_field.embedded_model._meta.get_field(self.key_name).get_transform(name)
208+
else self.array_field.base_field.embedded_model._meta.get_field(
209+
self.key_name
210+
).get_transform(name)
154211
)
155212
if transform:
156213
self._sub_transform = transform
@@ -166,7 +223,7 @@ def as_mql(self, compiler, connection):
166223

167224
@property
168225
def output_field(self):
169-
return EmbeddedModelArrayField(self.base_field)
226+
return self.array_field
170227

171228

172229
class KeyTransformFactory:

tests/model_fields_/test_embedded_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,55 @@ def test_len(self):
282282
MuseumExhibit.objects.filter(sections__1__artifacts__len=1), [self.wonders]
283283
)
284284

285+
def test_overlap_simplefield(self):
286+
self.assertSequenceEqual(
287+
MuseumExhibit.objects.filter(sections__section_number__overlap=[10]), []
288+
)
289+
self.assertSequenceEqual(
290+
MuseumExhibit.objects.filter(sections__section_number__overlap=[1]),
291+
[self.egypt, self.wonders, self.new_descoveries],
292+
)
293+
self.assertSequenceEqual(
294+
MuseumExhibit.objects.filter(sections__section_number__overlap=[2]), [self.wonders]
295+
)
296+
297+
def test_overlap_emf(self):
298+
self.assertSequenceEqual(
299+
Movie.objects.filter(reviews__overlap=[Review(title="The best", rating=10)]),
300+
[self.clouds],
301+
)
302+
303+
"""
304+
def test_overlap_charfield_including_expression(self):
305+
obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
306+
obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
307+
CharArrayModel.objects.create(field=["lower text", "text"])
308+
self.assertSequenceEqual(
309+
CharArrayModel.objects.filter(
310+
field__overlap=[
311+
Upper(Value("text")),
312+
"other",
313+
]
314+
),
315+
[obj_1, obj_2],
316+
)
317+
318+
def test_overlap_values(self):
319+
qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
320+
self.assertCountEqual(
321+
NullableIntegerArrayModel.objects.filter(
322+
field__overlap=qs.values_list("field"),
323+
),
324+
self.objs[:3],
325+
)
326+
self.assertCountEqual(
327+
NullableIntegerArrayModel.objects.filter(
328+
field__overlap=qs.values("field"),
329+
),
330+
self.objs[:3],
331+
)
332+
"""
333+
285334

286335
class QueryingTests(TestCase):
287336
@classmethod

0 commit comments

Comments
 (0)