Skip to content

Commit d56dd0b

Browse files
committed
Push simple filter conditions into $lookup stage.
1 parent 880c906 commit d56dd0b

File tree

3 files changed

+53
-23
lines changed

3 files changed

+53
-23
lines changed

django_mongodb_backend/compiler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1111
from django.db.models.functions.comparison import Coalesce
1212
from django.db.models.functions.math import Power
13-
from django.db.models.lookups import IsNull
13+
from django.db.models.lookups import IsNull, Lookup
1414
from django.db.models.sql import compiler
1515
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1616
from django.db.models.sql.datastructures import BaseTable
17+
from django.db.models.sql.where import AND
1718
from django.utils.functional import cached_property
1819
from pymongo import ASCENDING, DESCENDING
1920

2021
from .query import MongoQuery, wrap_database_errors
22+
from .query_utils import is_direct_value
2123

2224

2325
class SQLCompiler(compiler.SQLCompiler):
@@ -550,10 +552,22 @@ def get_combinator_queries(self):
550552

551553
def get_lookup_pipeline(self):
552554
result = []
555+
where = self.get_where()
556+
promote_filters = defaultdict(list)
557+
for expr in where.children if where and where.connector == AND else ():
558+
if (
559+
isinstance(expr, Lookup)
560+
and isinstance(expr.lhs, Col)
561+
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value))
562+
):
563+
promote_filters[expr.lhs.alias].append(expr)
564+
553565
for alias in tuple(self.query.alias_map):
554566
if not self.query.alias_refcount[alias] or self.collection_name == alias:
555567
continue
556-
result += self.query.alias_map[alias].as_mql(self, self.connection)
568+
result += self.query.alias_map[alias].as_mql(
569+
self, self.connection, promote_filters[alias]
570+
)
557571
return result
558572

559573
def _get_aggregate_expressions(self, expr):

django_mongodb_backend/query.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -129,25 +129,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001
129129
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")
130130

131131

132-
def join(self, compiler, connection):
133-
lookup_pipeline = []
134-
lhs_fields = []
135-
rhs_fields = []
136-
# Add a join condition for each pair of joining fields.
137-
parent_template = "parent__field__"
138-
for lhs, rhs in self.join_fields:
139-
lhs, rhs = connection.ops.prepare_join_on_clause(
140-
self.parent_alias, lhs, compiler.collection_name, rhs
141-
)
142-
lhs_fields.append(lhs.as_mql(compiler, connection))
143-
# In the lookup stage, the reference to this column doesn't include
144-
# the collection name.
145-
rhs_fields.append(rhs.as_mql(compiler, connection))
146-
# Handle any join conditions besides matching field pairs.
147-
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
148-
if extra:
132+
def join(self, compiler, connection, pushed_expressions=None):
133+
def _get_reroot_replacements(expressions):
134+
if not expressions:
135+
return []
149136
columns = []
150-
for expr in extra.leaves():
137+
for expr in expressions:
151138
# Determine whether the column needs to be transformed or rerouted
152139
# as part of the subquery.
153140
for hand_side in ["lhs", "rhs"]:
@@ -165,18 +152,45 @@ def join(self, compiler, connection):
165152
# based on their rerouted positions in the join pipeline.
166153
replacements = {}
167154
for col, parent_pos in columns:
168-
column_target = Col(compiler.collection_name, expr.output_field.__class__())
155+
column_target = Col(compiler.collection_name, col.target, col.output_field)
169156
if parent_pos is not None:
170157
target_col = f"${parent_template}{parent_pos}"
171158
column_target.target.db_column = target_col
172159
column_target.target.set_attributes_from_name(target_col)
173160
else:
174161
column_target.target = col.target
175162
replacements[col] = column_target
176-
# Apply the transformed expressions in the extra condition.
163+
return replacements
164+
165+
lookup_pipeline = []
166+
lhs_fields = []
167+
rhs_fields = []
168+
# Add a join condition for each pair of joining fields.
169+
parent_template = "parent__field__"
170+
for lhs, rhs in self.join_fields:
171+
lhs, rhs = connection.ops.prepare_join_on_clause(
172+
self.parent_alias, lhs, compiler.collection_name, rhs
173+
)
174+
lhs_fields.append(lhs.as_mql(compiler, connection))
175+
# In the lookup stage, the reference to this column doesn't include
176+
# the collection name.
177+
rhs_fields.append(rhs.as_mql(compiler, connection))
178+
# Handle any join conditions besides matching field pairs.
179+
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
180+
181+
if extra:
182+
replacements = _get_reroot_replacements(extra.leaves())
177183
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
178184
else:
179185
extra_condition = []
186+
if self.join_type == INNER:
187+
rerooted_replacement = _get_reroot_replacements(pushed_expressions)
188+
resolved_pushed_expressions = [
189+
expr.replace_expressions(rerooted_replacement).as_mql(compiler, connection)
190+
for expr in pushed_expressions
191+
]
192+
else:
193+
resolved_pushed_expressions = []
180194

181195
lookup_pipeline = [
182196
{
@@ -204,6 +218,7 @@ def join(self, compiler, connection):
204218
for i, field in enumerate(rhs_fields)
205219
]
206220
+ extra_condition
221+
+ resolved_pushed_expressions
207222
}
208223
}
209224
}

tests/queries_/test_mql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def test_join(self):
2020
"{'$lookup': {'from': 'queries__author', "
2121
"'let': {'parent__field__0': '$author_id'}, "
2222
"'pipeline': [{'$match': {'$expr': "
23-
"{'$and': [{'$eq': ['$$parent__field__0', '$_id']}]}}}], 'as': 'queries__author'}}, "
23+
"{'$and': [{'$eq': ['$$parent__field__0', '$_id']}, "
24+
"{'$eq': ['$name', 'Bob']}]}}}], 'as': 'queries__author'}}, "
2425
"{'$unwind': '$queries__author'}, "
2526
"{'$match': {'$expr': {'$eq': ['$queries__author.name', 'Bob']}}}])",
2627
)

0 commit comments

Comments
 (0)