@@ -60,11 +60,20 @@ def get_transform(self, name):
60
60
transform = super ().get_transform (name )
61
61
if transform :
62
62
return transform
63
- return KeyTransformFactory (name , self .base_field )
63
+ return KeyTransformFactory (name , self )
64
+
65
+
66
+ class ProcessRHSMixin :
67
+ def process_rhs (self , compiler , connection ):
68
+ if isinstance (self .lhs , KeyTransform ):
69
+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
70
+ else :
71
+ get_db_prep_value = self .lhs .output_field .get_db_prep_value
72
+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in self .rhs ]
64
73
65
74
66
75
@EmbeddedModelArrayField .register_lookup
67
- class EMFArrayExact (EMFExact ):
76
+ class EMFArrayExact (EMFExact , ProcessRHSMixin ):
68
77
def as_mql (self , compiler , connection ):
69
78
lhs_mql = process_lhs (self , compiler , connection )
70
79
value = process_rhs (self , compiler , connection )
@@ -106,15 +115,61 @@ def as_mql(self, compiler, connection):
106
115
}
107
116
108
117
118
+ @EmbeddedModelArrayField .register_lookup
119
+ class ArrayOverlap (EMFExact , ProcessRHSMixin ):
120
+ lookup_name = "overlap"
121
+
122
+ def as_mql (self , compiler , connection ):
123
+ lhs_mql = process_lhs (self , compiler , connection )
124
+ values = process_rhs (self , compiler , connection )
125
+ if isinstance (self .lhs , KeyTransform ):
126
+ lhs_mql , inner_lhs_mql = lhs_mql
127
+ return {
128
+ "$anyElementTrue" : {
129
+ "$ifNull" : [
130
+ {
131
+ "$map" : {
132
+ "input" : lhs_mql ,
133
+ "as" : "item" ,
134
+ "in" : {"$in" : [inner_lhs_mql , values ]},
135
+ }
136
+ },
137
+ [],
138
+ ]
139
+ }
140
+ }
141
+ conditions = []
142
+ inner_lhs_mql = "$$item"
143
+ for value in values :
144
+ if isinstance (value , models .Model ):
145
+ value , emf_data = self .model_to_dict (value )
146
+ # Get conditions for any nested EmbeddedModelFields.
147
+ conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
148
+ return {
149
+ "$anyElementTrue" : {
150
+ "$ifNull" : [
151
+ {
152
+ "$map" : {
153
+ "input" : lhs_mql ,
154
+ "as" : "item" ,
155
+ "in" : {"$or" : conditions },
156
+ }
157
+ },
158
+ [],
159
+ ]
160
+ }
161
+ }
162
+
163
+
109
164
class KeyTransform (Transform ):
110
165
# it should be different class than EMF keytransform even most of the methods are equal.
111
- def __init__ (self , key_name , base_field , * args , ** kwargs ):
166
+ def __init__ (self , key_name , array_field , * args , ** kwargs ):
112
167
super ().__init__ (* args , ** kwargs )
113
- self .base_field = base_field
168
+ self .array_field = array_field
114
169
self .key_name = key_name
115
170
# The iteration items begins from the base_field, a virtual column with
116
171
# base field output type is created.
117
- column_target = base_field .clone ()
172
+ column_target = array_field . base_field . embedded_model . _meta . get_field ( key_name ) .clone ()
118
173
column_name = f"$item.{ key_name } "
119
174
column_target .db_column = column_name
120
175
column_target .set_attributes_from_name (column_name )
@@ -137,7 +192,7 @@ def _get_missing_field_or_lookup_exception(self, lhs, name):
137
192
suggestion = "."
138
193
raise FieldDoesNotExist (
139
194
f"Unsupported lookup '{ name } ' for "
140
- f"{ self .base_field .__class__ .__name__ } '{ self .base_field .name } '"
195
+ f"{ self .array_field . base_field .__class__ .__name__ } '{ self . array_field .base_field .name } '"
141
196
f"{ suggestion } "
142
197
)
143
198
@@ -150,7 +205,9 @@ def get_transform(self, name):
150
205
transform = (
151
206
self ._lhs .get_transform (name )
152
207
if isinstance (self ._lhs , Transform )
153
- else self .base_field .embedded_model ._meta .get_field (self .key_name ).get_transform (name )
208
+ else self .array_field .base_field .embedded_model ._meta .get_field (
209
+ self .key_name
210
+ ).get_transform (name )
154
211
)
155
212
if transform :
156
213
self ._sub_transform = transform
@@ -166,7 +223,7 @@ def as_mql(self, compiler, connection):
166
223
167
224
@property
168
225
def output_field (self ):
169
- return EmbeddedModelArrayField ( self .base_field )
226
+ return self .array_field
170
227
171
228
172
229
class KeyTransformFactory :
0 commit comments