4
4
from django .db import models
5
5
from django .db .models import Field
6
6
from django .db .models .expressions import Col
7
- from django .db .models .lookups import Transform
7
+ from django .db .models .lookups import Lookup , Transform
8
8
9
9
from .. import forms
10
10
from ..query_utils import process_lhs , process_rhs
11
11
from . import EmbeddedModelField
12
12
from .array import ArrayField
13
- from .embedded_model import EMFExact
13
+ from .embedded_model import EMFExact , EMFMixin
14
14
15
15
16
16
class EmbeddedModelArrayField (ArrayField ):
@@ -63,17 +63,8 @@ def get_transform(self, name):
63
63
return KeyTransformFactory (name , self )
64
64
65
65
66
- class ProcessRHSMixin :
67
- def process_rhs (self , compiler , connection ):
68
- if isinstance (self .lhs , KeyTransform ):
69
- get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
70
- else :
71
- get_db_prep_value = self .lhs .output_field .get_db_prep_value
72
- return None , [get_db_prep_value (v , connection , prepared = True ) for v in self .rhs ]
73
-
74
-
75
66
@EmbeddedModelArrayField .register_lookup
76
- class EMFArrayExact (EMFExact , ProcessRHSMixin ):
67
+ class EMFArrayExact (EMFExact ):
77
68
def as_mql (self , compiler , connection ):
78
69
lhs_mql = process_lhs (self , compiler , connection )
79
70
value = process_rhs (self , compiler , connection )
@@ -116,12 +107,29 @@ def as_mql(self, compiler, connection):
116
107
117
108
118
109
@EmbeddedModelArrayField .register_lookup
119
- class ArrayOverlap (EMFExact , ProcessRHSMixin ):
110
+ class ArrayOverlap (EMFMixin , Lookup ):
120
111
lookup_name = "overlap"
112
+ get_db_prep_lookup_value_is_iterable = True
113
+
114
+ def process_rhs (self , compiler , connection ):
115
+ values = self .rhs
116
+ if self .get_db_prep_lookup_value_is_iterable :
117
+ values = [values ]
118
+ # Compute how to serialize each value based on the query target.
119
+ # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
120
+ # field of the subfield. Otherwise, use the base field of the array itself.
121
+ if isinstance (self .lhs , KeyTransform ):
122
+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
123
+ else :
124
+ get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
125
+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
121
126
122
127
def as_mql (self , compiler , connection ):
123
128
lhs_mql = process_lhs (self , compiler , connection )
124
129
values = process_rhs (self , compiler , connection )
130
+ # Querying a subfield within the array elements (via nested KeyTransform).
131
+ # Replicates MongoDB's implicit ANY-match by mapping over the array and applying
132
+ # `$in` on the subfield.
125
133
if isinstance (self .lhs , KeyTransform ):
126
134
lhs_mql , inner_lhs_mql = lhs_mql
127
135
return {
@@ -140,11 +148,12 @@ def as_mql(self, compiler, connection):
140
148
}
141
149
conditions = []
142
150
inner_lhs_mql = "$$item"
151
+ # Querying full embedded documents in the array.
152
+ # Builds `$or` conditions and maps them over the array to match any full document.
143
153
for value in values :
144
- if isinstance (value , models .Model ):
145
- value , emf_data = self .model_to_dict (value )
146
- # Get conditions for any nested EmbeddedModelFields.
147
- conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
154
+ value , emf_data = self .model_to_dict (value )
155
+ # Get conditions for any nested EmbeddedModelFields.
156
+ conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
148
157
return {
149
158
"$anyElementTrue" : {
150
159
"$ifNull" : [
0 commit comments