Skip to content

Commit 59e32d6

Browse files
committed
Allow querying an EmbeddedModelField by model instance
1 parent 223a271 commit 59e32d6

File tree

5 files changed

+164
-10
lines changed

5 files changed

+164
-10
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
from django.core import checks
44
from django.core.exceptions import FieldDoesNotExist
55
from django.db import models
6+
from django.db.models import lookups
7+
from django.db.models.expressions import Col
68
from django.db.models.fields.related import lazy_related_operation
79
from django.db.models.lookups import Transform
810

911
from .. import forms
12+
from ..query_utils import process_lhs, process_rhs
1013

1114

1215
class EmbeddedModelField(models.Field):
@@ -151,6 +154,67 @@ def formfield(self, **kwargs):
151154
)
152155

153156

157+
@EmbeddedModelField.register_lookup
158+
class EMFExact(lookups.Exact):
159+
def model_to_dict(self, instance, connection):
160+
"""
161+
Return a dict containing the data in a model instance, as well as a
162+
dict containing the data for any embedded model fields.
163+
"""
164+
data = {}
165+
emf_data = {}
166+
for f in instance._meta.concrete_fields:
167+
value = f.get_db_prep_value(f.value_from_object(instance), connection)
168+
if isinstance(f, EmbeddedModelField):
169+
emf_data[f.name] = (
170+
self.model_to_dict(value, connection) if value is not None else (None, {})
171+
)
172+
continue
173+
# Unless explicitly set, primary keys aren't included in embedded
174+
# models.
175+
if f.primary_key and value is None:
176+
continue
177+
data[f.name] = value
178+
return data, emf_data
179+
180+
def get_conditions(self, emf_data, prefix=None):
181+
"""
182+
Recursively transform a dictionary of {"field_name": {<model_to_dict>}}
183+
lookups into MQL. `prefix` tracks the string that must be appended to
184+
nested fields.
185+
"""
186+
conditions = []
187+
for k, v in emf_data.items():
188+
v, emf_data = v
189+
subprefix = f"{prefix}.{k}" if prefix else k
190+
conditions += self.get_conditions(emf_data, subprefix)
191+
if v is not None:
192+
# Match all fields of the EmbeddedModelField.
193+
conditions += [{"$eq": [f"{subprefix}.{x}", y]} for x, y in v.items()]
194+
else:
195+
# Match a null EmbeddedModelField.
196+
conditions += [{"$eq": [f"{subprefix}", None]}]
197+
return conditions
198+
199+
def as_mql(self, compiler, connection):
200+
lhs_mql = process_lhs(self, compiler, connection)
201+
value = process_rhs(self, compiler, connection)
202+
if isinstance(self.lhs, Col) or (
203+
isinstance(self.lhs, KeyTransform)
204+
and isinstance(self.lhs.ref_field, EmbeddedModelField)
205+
):
206+
if isinstance(value, models.Model):
207+
value, emf_data = self.model_to_dict(value, connection)
208+
# Get conditions for any nested EmbeddedModelFields.
209+
conditions = self.get_conditions({lhs_mql: (value, emf_data)})
210+
return {"$and": conditions}
211+
raise TypeError(
212+
"An EmbeddedModelField must be queried using a model instance, got %s."
213+
% type(value)
214+
)
215+
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
216+
217+
154218
class KeyTransform(Transform):
155219
def __init__(self, key_name, ref_field, *args, **kwargs):
156220
super().__init__(*args, **kwargs)

docs/source/releases/5.2.x.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ New features
2222
a model's :attr:`Meta.indexes <django.db.models.Options.indexes>`.
2323
- PyMongo's connection pooling is now used by default. See
2424
:ref:`connection-management`.
25+
- Allowed ``EmbeddedModelField``’s ``exact`` lookup to use a model instance.
2526

2627
Backwards incompatible changes
2728
------------------------------

docs/source/topics/embedded-models.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,16 @@ as relational fields. For example, to retrieve all customers who have an
5454
address with the city "New York"::
5555

5656
>>> Customer.objects.filter(address__city="New York")
57+
58+
You can also query using a model instance. Unlike a normal relational lookup
59+
which does the lookup by primary key, since embedded models typically don't
60+
have a primary key set, the query requires that every field match. For example,
61+
this query gives customers with addresses with the city "New York" and all
62+
other fields of the address equal to their default (:attr:`Field.default
63+
<django.db.models.Field.default>`, ``None``, or an empty string).
64+
65+
>>> Customer.objects.filter(address=Address(city="New York"))
66+
67+
.. versionadded:: 5.2.0b0
68+
69+
The ability to query by model instance was added.

tests/model_fields_/models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,31 @@ class Library(models.Model):
138138

139139
def __str__(self):
140140
return self.name
141+
142+
143+
class A(models.Model):
144+
b = EmbeddedModelField("B")
145+
146+
147+
class B(EmbeddedModel):
148+
c = EmbeddedModelField("C")
149+
name = models.CharField(max_length=100)
150+
value = models.IntegerField()
151+
152+
153+
class C(EmbeddedModel):
154+
d = EmbeddedModelField("D")
155+
name = models.CharField(max_length=100)
156+
value = models.IntegerField()
157+
158+
159+
class D(EmbeddedModel):
160+
e = EmbeddedModelField("E")
161+
nullable_e = EmbeddedModelField("E", null=True, blank=True)
162+
name = models.CharField(max_length=100)
163+
value = models.IntegerField()
164+
165+
166+
class E(EmbeddedModel):
167+
name = models.CharField(max_length=100)
168+
value = models.IntegerField()

tests/model_fields_/test_embedded_model.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import timedelta
33

44
from django.core.exceptions import FieldDoesNotExist, ValidationError
5-
from django.db import models
5+
from django.db import connection, models
66
from django.db.models import (
77
Exists,
88
ExpressionWrapper,
@@ -17,15 +17,7 @@
1717
from django_mongodb_backend.fields import EmbeddedModelField
1818
from django_mongodb_backend.models import EmbeddedModel
1919

20-
from .models import (
21-
Address,
22-
Author,
23-
Book,
24-
Data,
25-
Holder,
26-
Library,
27-
NestedData,
28-
)
20+
from .models import A, Address, Author, B, Book, C, D, Data, E, Holder, Library, NestedData
2921
from .utils import truncate_ms
3022

3123

@@ -145,6 +137,62 @@ def test_order_by_embedded_field(self):
145137
qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer")
146138
self.assertSequenceEqual(qs, list(reversed(self.objs[4:])))
147139

140+
def test_exact_with_model(self):
141+
data = Holder.objects.first().data
142+
self.assertEqual(
143+
Holder.objects.filter(data=data).get().data.integer, self.objs[0].data.integer
144+
)
145+
146+
def test_exact_with_model_ignores_key_order(self):
147+
# Due to the possibility of schema changes or the reordering of a
148+
# model's fields, a lookup must work if an embedded document has its
149+
# keys in a different order than what's declared on the embedded model.
150+
data = {}
151+
for field in reversed(Data._meta.fields):
152+
data[field.name] = None
153+
del data["id"]
154+
data["integer"] = 100
155+
connection.get_collection("model_fields__holder").insert_one({"data": data})
156+
self.assertEqual(Holder.objects.filter(data=Data(integer=100)).get().data.integer, 100)
157+
158+
def test_exact_with_nested_model(self):
159+
address = Address(city="NYC", state="NY")
160+
author = Author(name="Shakespeare", age=55, address=address)
161+
obj = Book.objects.create(author=author)
162+
self.assertCountEqual(Book.objects.filter(author=author), [obj])
163+
self.assertCountEqual(Book.objects.filter(author__address=address), [obj])
164+
165+
def test_exact_with_deeply_nested_models(self):
166+
e1 = E(name="E1", value=5)
167+
d1 = D(name="D1", value=4, e=e1)
168+
c1 = C(name="C1", value=3, d=d1)
169+
b1 = B(name="B1", value=2, c=c1)
170+
a1 = A.objects.create(b=b1)
171+
e2 = E(name="E2", value=6)
172+
d2 = D(name="D2", value=4, e=e1, nullable_e=e2)
173+
c2 = C(name="C2", value=3, d=d2)
174+
b2 = B(name="B2", value=2, c=c2)
175+
a2 = A.objects.create(b=b2)
176+
self.assertCountEqual(A.objects.filter(b=b1), [a1])
177+
self.assertCountEqual(A.objects.filter(b__c=c1), [a1])
178+
self.assertCountEqual(A.objects.filter(b__c__d=d1), [a1])
179+
self.assertCountEqual(A.objects.filter(b__c__d__e=e1), [a1, a2])
180+
self.assertCountEqual(A.objects.filter(b=b2), [a2])
181+
self.assertCountEqual(A.objects.filter(b__c=c2), [a2])
182+
self.assertCountEqual(A.objects.filter(b__c__d=d2), [a2])
183+
self.assertCountEqual(A.objects.filter(b__c__d__nullable_e=e2), [a2])
184+
185+
def test_exact_validates_argument(self):
186+
msg = "An EmbeddedModelField must be queried using a model instance, got <class 'dict'>."
187+
with self.assertRaisesMessage(TypeError, msg):
188+
str(A.objects.filter(b={}))
189+
with self.assertRaisesMessage(TypeError, msg):
190+
str(A.objects.filter(b__c={}))
191+
with self.assertRaisesMessage(TypeError, msg):
192+
str(A.objects.filter(b__c__d={}))
193+
with self.assertRaisesMessage(TypeError, msg):
194+
str(A.objects.filter(b__c__d__e={}))
195+
148196
def test_embedded_json_field_lookups(self):
149197
objs = [
150198
Holder.objects.create(

0 commit comments

Comments
 (0)