@@ -129,25 +129,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001
129
129
raise NotSupportedError ("QuerySet.extra() is not supported on MongoDB." )
130
130
131
131
132
- def join (self , compiler , connection ):
133
- lookup_pipeline = []
134
- lhs_fields = []
135
- rhs_fields = []
136
- # Add a join condition for each pair of joining fields.
137
- parent_template = "parent__field__"
138
- for lhs , rhs in self .join_fields :
139
- lhs , rhs = connection .ops .prepare_join_on_clause (
140
- self .parent_alias , lhs , compiler .collection_name , rhs
141
- )
142
- lhs_fields .append (lhs .as_mql (compiler , connection ))
143
- # In the lookup stage, the reference to this column doesn't include
144
- # the collection name.
145
- rhs_fields .append (rhs .as_mql (compiler , connection ))
146
- # Handle any join conditions besides matching field pairs.
147
- extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
148
- if extra :
132
+ def join (self , compiler , connection , pushed_expressions = None ):
133
+ def _get_reroot_replacements (expressions ):
134
+ if not expressions :
135
+ return []
149
136
columns = []
150
- for expr in extra . leaves () :
137
+ for expr in expressions :
151
138
# Determine whether the column needs to be transformed or rerouted
152
139
# as part of the subquery.
153
140
for hand_side in ["lhs" , "rhs" ]:
@@ -165,18 +152,45 @@ def join(self, compiler, connection):
165
152
# based on their rerouted positions in the join pipeline.
166
153
replacements = {}
167
154
for col , parent_pos in columns :
168
- column_target = Col (compiler .collection_name , expr . output_field . __class__ () )
155
+ column_target = Col (compiler .collection_name , col . target , col . output_field )
169
156
if parent_pos is not None :
170
157
target_col = f"${ parent_template } { parent_pos } "
171
158
column_target .target .db_column = target_col
172
159
column_target .target .set_attributes_from_name (target_col )
173
160
else :
174
161
column_target .target = col .target
175
162
replacements [col ] = column_target
176
- # Apply the transformed expressions in the extra condition.
163
+ return replacements
164
+
165
+ lookup_pipeline = []
166
+ lhs_fields = []
167
+ rhs_fields = []
168
+ # Add a join condition for each pair of joining fields.
169
+ parent_template = "parent__field__"
170
+ for lhs , rhs in self .join_fields :
171
+ lhs , rhs = connection .ops .prepare_join_on_clause (
172
+ self .parent_alias , lhs , compiler .collection_name , rhs
173
+ )
174
+ lhs_fields .append (lhs .as_mql (compiler , connection ))
175
+ # In the lookup stage, the reference to this column doesn't include
176
+ # the collection name.
177
+ rhs_fields .append (rhs .as_mql (compiler , connection ))
178
+ # Handle any join conditions besides matching field pairs.
179
+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
180
+
181
+ if extra :
182
+ replacements = _get_reroot_replacements (extra .leaves ())
177
183
extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
178
184
else :
179
185
extra_condition = []
186
+ if self .join_type == INNER :
187
+ rerooted_replacement = _get_reroot_replacements (pushed_expressions )
188
+ resolved_pushed_expressions = [
189
+ expr .replace_expressions (rerooted_replacement ).as_mql (compiler , connection )
190
+ for expr in pushed_expressions
191
+ ]
192
+ else :
193
+ resolved_pushed_expressions = []
180
194
181
195
lookup_pipeline = [
182
196
{
@@ -204,6 +218,7 @@ def join(self, compiler, connection):
204
218
for i , field in enumerate (rhs_fields )
205
219
]
206
220
+ extra_condition
221
+ + resolved_pushed_expressions
207
222
}
208
223
}
209
224
}
0 commit comments