Skip to content

Commit 1a6d2ff

Browse files
WaVEVtimgraham
authored andcommitted
EmbeddedModelArrayField Querying
1 parent 2617e5b commit 1a6d2ff

File tree

5 files changed

+534
-8
lines changed

5 files changed

+534
-8
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ class ArrayLenTransform(Transform):
338338

339339
def as_mql(self, compiler, connection):
340340
lhs_mql = process_lhs(self, compiler, connection)
341-
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
341+
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
342342

343343

344344
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ def as_mql(self, compiler, connection):
186186
key_transforms.insert(0, previous.key_name)
187187
previous = previous.lhs
188188
mql = previous.as_mql(compiler, connection)
189-
transforms = ".".join(key_transforms)
190-
return f"{mql}.{transforms}"
189+
for key in key_transforms:
190+
mql = {"$getField": {"input": mql, "field": key}}
191+
return mql
191192

192193
@property
193194
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 235 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
from django.db.models import Field
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db.models import Field, lookups
5+
from django.db.models.expressions import Col
26
from django.db.models.fields.related import lazy_related_operation
7+
from django.db.models.lookups import Lookup, Transform
38

49
from .. import forms
10+
from ..query_utils import process_lhs, process_rhs
511
from . import EmbeddedModelField
6-
from .array import ArrayField
12+
from .array import ArrayField, ArrayLenTransform
713

814

915
class EmbeddedModelArrayField(ArrayField):
@@ -56,3 +62,230 @@ def formfield(self, **kwargs):
5662
**kwargs,
5763
},
5864
)
65+
66+
def get_transform(self, name):
67+
transform = super().get_transform(name)
68+
if transform:
69+
return transform
70+
return KeyTransformFactory(name, self)
71+
72+
def _get_lookup(self, lookup_name):
73+
lookup = super()._get_lookup(lookup_name)
74+
if lookup is None or lookup is ArrayLenTransform:
75+
return lookup
76+
77+
class EmbeddedModelArrayFieldLookups(Lookup):
78+
def as_mql(self, compiler, connection):
79+
raise ValueError(
80+
"Cannot apply this lookup directly to EmbeddedModelArrayField. "
81+
"Try querying one of its embedded fields instead."
82+
)
83+
84+
return EmbeddedModelArrayFieldLookups
85+
86+
87+
class _EmbeddedModelArrayOutputField(ArrayField):
88+
"""
89+
Represents the output of an EmbeddedModelArrayField when traversed in a query path.
90+
91+
This field is not meant to be used directly in model definitions. It exists solely to
92+
support query output resolution; when an EmbeddedModelArrayField is accessed in a query,
93+
the result should behave like an array of the embedded model's target type.
94+
95+
While it mimics ArrayField's lookups behavior, the way those lookups are resolved
96+
follows the semantics of EmbeddedModelArrayField rather than native array behavior.
97+
"""
98+
99+
ALLOWED_LOOKUPS = {
100+
"in",
101+
"exact",
102+
"iexact",
103+
"gt",
104+
"gte",
105+
"lt",
106+
"lte",
107+
"all",
108+
"contained_by",
109+
}
110+
111+
def get_lookup(self, name):
112+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
113+
114+
115+
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
116+
def process_rhs(self, compiler, connection):
117+
value = self.rhs
118+
if not self.get_db_prep_lookup_value_is_iterable:
119+
value = [value]
120+
# Value must be serialized based on the query target.
121+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
122+
# field of the subfield. Otherwise, use the base field of the array itself.
123+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
124+
return None, [
125+
v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True)
126+
for v in value
127+
]
128+
129+
def as_mql(self, compiler, connection):
130+
# Querying a subfield within the array elements (via nested KeyTransform).
131+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
132+
# `$in` on the subfield.
133+
lhs_mql = process_lhs(self, compiler, connection)
134+
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
135+
values = process_rhs(self, compiler, connection)
136+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
137+
inner_lhs_mql, values
138+
)
139+
return {"$anyElementTrue": lhs_mql}
140+
141+
142+
@_EmbeddedModelArrayOutputField.register_lookup
143+
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
144+
pass
145+
146+
147+
@_EmbeddedModelArrayOutputField.register_lookup
148+
class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact):
149+
pass
150+
151+
152+
@_EmbeddedModelArrayOutputField.register_lookup
153+
class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact):
154+
get_db_prep_lookup_value_is_iterable = False
155+
156+
157+
@_EmbeddedModelArrayOutputField.register_lookup
158+
class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan):
159+
pass
160+
161+
162+
@_EmbeddedModelArrayOutputField.register_lookup
163+
class EmbeddedModelArrayFieldGreaterThanOrEqual(
164+
EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual
165+
):
166+
pass
167+
168+
169+
@_EmbeddedModelArrayOutputField.register_lookup
170+
class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan):
171+
pass
172+
173+
174+
@_EmbeddedModelArrayOutputField.register_lookup
175+
class EmbeddedModelArrayFieldLessThanOrEqual(
176+
EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual
177+
):
178+
pass
179+
180+
181+
@_EmbeddedModelArrayOutputField.register_lookup
182+
class EmbeddedModelArrayFieldAll(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
183+
lookup_name = "all"
184+
get_db_prep_lookup_value_is_iterable = False
185+
186+
def as_mql(self, compiler, connection):
187+
lhs_mql = process_lhs(self, compiler, connection)
188+
values = process_rhs(self, compiler, connection)
189+
return {
190+
"$and": [
191+
{"$ne": [lhs_mql, None]},
192+
{"$ne": [values, None]},
193+
{"$setIsSubset": [values, lhs_mql]},
194+
]
195+
}
196+
197+
198+
@_EmbeddedModelArrayOutputField.register_lookup
199+
class ArrayContainedBy(EmbeddedModelArrayFieldBuiltinLookup, Lookup):
200+
lookup_name = "contained_by"
201+
get_db_prep_lookup_value_is_iterable = False
202+
203+
def as_mql(self, compiler, connection):
204+
lhs_mql = process_lhs(self, compiler, connection)
205+
value = process_rhs(self, compiler, connection)
206+
return {
207+
"$and": [
208+
{"$ne": [lhs_mql, None]},
209+
{"$ne": [value, None]},
210+
{"$setIsSubset": [lhs_mql, value]},
211+
]
212+
}
213+
214+
215+
class KeyTransform(Transform):
216+
def __init__(self, key_name, array_field, *args, **kwargs):
217+
super().__init__(*args, **kwargs)
218+
self.array_field = array_field
219+
self.key_name = key_name
220+
# The iteration items begins from the base_field, a virtual column with
221+
# base field output type is created.
222+
column_target = array_field.embedded_model._meta.get_field(key_name).clone()
223+
column_name = f"$item.{key_name}"
224+
column_target.db_column = column_name
225+
column_target.set_attributes_from_name(column_name)
226+
self._lhs = Col(None, column_target)
227+
self._sub_transform = None
228+
229+
def __call__(self, this, *args, **kwargs):
230+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
231+
return self
232+
233+
def get_lookup(self, name):
234+
return self.output_field.get_lookup(name)
235+
236+
def get_transform(self, name):
237+
"""
238+
Validate that `name` is either a field of an embedded model or a
239+
lookup on an embedded model's field.
240+
"""
241+
# Once the sub lhs is a transform, all the filter are applied over it.
242+
# Otherwise get transform from EMF.
243+
if transform := self._lhs.get_transform(name):
244+
if isinstance(transform, KeyTransformFactory):
245+
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
246+
self._sub_transform = transform
247+
return self
248+
output_field = self._lhs.output_field
249+
allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection(
250+
set(output_field.get_lookups())
251+
)
252+
suggested_lookups = difflib.get_close_matches(name, allowed_lookups)
253+
if suggested_lookups:
254+
suggested_lookups = " or ".join(suggested_lookups)
255+
suggestion = f", perhaps you meant {suggested_lookups}?"
256+
else:
257+
suggestion = ""
258+
raise FieldDoesNotExist(
259+
f"Unsupported lookup '{name}' for "
260+
f"EmbeddedModelArrayField of '{output_field.__class__.__name__}'"
261+
f"{suggestion}"
262+
)
263+
264+
def as_mql(self, compiler, connection):
265+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
266+
lhs_mql = process_lhs(self, compiler, connection)
267+
return {
268+
"$ifNull": [
269+
{
270+
"$map": {
271+
"input": lhs_mql,
272+
"as": "item",
273+
"in": inner_lhs_mql,
274+
}
275+
},
276+
[],
277+
]
278+
}
279+
280+
@property
281+
def output_field(self):
282+
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
283+
284+
285+
class KeyTransformFactory:
286+
def __init__(self, key_name, base_field):
287+
self.key_name = key_name
288+
self.base_field = base_field
289+
290+
def __call__(self, *args, **kwargs):
291+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

tests/model_fields_/models.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,40 @@ class Review(EmbeddedModel):
165165

166166
def __str__(self):
167167
return self.title
168+
169+
170+
class RestorationRecord(EmbeddedModel):
171+
date = models.DateField()
172+
restored_by = models.CharField(max_length=255)
173+
174+
175+
# Details about a specific artifact.
176+
class ArtifactDetail(EmbeddedModel):
177+
name = models.CharField(max_length=255)
178+
metadata = models.JSONField()
179+
restorations = EmbeddedModelArrayField(RestorationRecord, null=True)
180+
last_restoration = EmbeddedModelField(RestorationRecord, null=True)
181+
182+
183+
# A section within an exhibit, containing multiple artifacts.
184+
class ExhibitSection(EmbeddedModel):
185+
section_number = models.IntegerField()
186+
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)
187+
188+
189+
# An exhibit in the museum, composed of multiple sections.
190+
class MuseumExhibit(models.Model):
191+
exhibit_name = models.CharField(max_length=255)
192+
sections = EmbeddedModelArrayField(ExhibitSection, null=True)
193+
main_section = EmbeddedModelField(ExhibitSection, null=True)
194+
195+
def __str__(self):
196+
return self.exhibit_name
197+
198+
199+
class Tour(models.Model):
200+
guide = models.CharField(max_length=100)
201+
exhibit = models.ForeignKey(MuseumExhibit, on_delete=models.CASCADE)
202+
203+
def __str__(self):
204+
return f"Tour by {self.guide}"

0 commit comments

Comments
 (0)