Skip to content

Commit f6d7abc

Browse files
committed
add argument validation
1 parent 3e23241 commit f6d7abc

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from django.core.exceptions import FieldDoesNotExist
55
from django.db import models
66
from django.db.models import lookups
7+
from django.db.models.expressions import Col
78
from django.db.models.fields.related import lazy_related_operation
89
from django.db.models.lookups import Transform
910

@@ -194,14 +195,22 @@ def get_conditions(self, emf_data, prefix):
194195
def as_mql(self, compiler, connection):
195196
lhs_mql = process_lhs(self, compiler, connection)
196197
value = process_rhs(self, compiler, connection)
197-
if isinstance(value, models.Model):
198-
value, emf_data = self.model_to_dict(value)
199-
prefix = self.lhs.as_mql(compiler, connection)
200-
# Get conditions for top-level EmbeddedModelField.
201-
conditions = [{"$eq": [f"{prefix}.{k}", v]} for k, v in value.items()]
202-
# Get conditions for any nested EmbeddedModelFields.
203-
conditions += self.get_conditions(emf_data, prefix)
204-
return {"$and": conditions}
198+
if isinstance(self.lhs, Col) or (
199+
isinstance(self.lhs, KeyTransform)
200+
and isinstance(self.lhs.ref_field, EmbeddedModelField)
201+
):
202+
if isinstance(value, models.Model):
203+
value, emf_data = self.model_to_dict(value)
204+
prefix = self.lhs.as_mql(compiler, connection)
205+
# Get conditions for top-level EmbeddedModelField.
206+
conditions = [{"$eq": [f"{prefix}.{k}", v]} for k, v in value.items()]
207+
# Get conditions for any nested EmbeddedModelFields.
208+
conditions += self.get_conditions(emf_data, prefix)
209+
return {"$and": conditions}
210+
raise TypeError(
211+
"An EmbeddedModelField must be queried using a model instance, got %s."
212+
% type(value)
213+
)
205214
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
206215

207216

tests/model_fields_/test_embedded_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,17 @@ def test_exact_with_deeply_nested_models(self):
159159
self.assertCountEqual(A.objects.filter(b__c__d=d2), [a2])
160160
self.assertCountEqual(A.objects.filter(b__c__d__nullable_e=e2), [a2])
161161

162+
def test_exact_validates_argument(self):
163+
msg = "An EmbeddedModelField must be queried using a model instance, got <class 'dict'>."
164+
with self.assertRaisesMessage(TypeError, msg):
165+
str(A.objects.filter(b={}))
166+
with self.assertRaisesMessage(TypeError, msg):
167+
str(A.objects.filter(b__c={}))
168+
with self.assertRaisesMessage(TypeError, msg):
169+
str(A.objects.filter(b__c__d={}))
170+
with self.assertRaisesMessage(TypeError, msg):
171+
str(A.objects.filter(b__c__d__e={}))
172+
162173
def test_embedded_json_field_lookups(self):
163174
objs = [
164175
Holder.objects.create(

0 commit comments

Comments
 (0)