Skip to content

[PoC] Push simple filter conditions into $lookup stage. #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.math import Power
from django.db.models.lookups import IsNull
from django.db.models.lookups import IsNull, Lookup
from django.db.models.sql import compiler
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
from django.db.models.sql.datastructures import BaseTable
from django.db.models.sql.where import AND
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .query import MongoQuery, wrap_database_errors
from .query_utils import is_direct_value


class SQLCompiler(compiler.SQLCompiler):
Expand Down Expand Up @@ -550,10 +552,22 @@ def get_combinator_queries(self):

def get_lookup_pipeline(self):
result = []
where = self.get_where()
promote_filters = defaultdict(list)
for expr in where.children if where and where.connector == AND else ():
if (
isinstance(expr, Lookup)
and isinstance(expr.lhs, Col)
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value))
):
promote_filters[expr.lhs.alias].append(expr)

for alias in tuple(self.query.alias_map):
if not self.query.alias_refcount[alias] or self.collection_name == alias:
continue
result += self.query.alias_map[alias].as_mql(self, self.connection)
result += self.query.alias_map[alias].as_mql(
self, self.connection, promote_filters[alias]
)
return result

def _get_aggregate_expressions(self, expr):
Expand Down
57 changes: 37 additions & 20 deletions django_mongodb_backend/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")


def join(self, compiler, connection):
lookup_pipeline = []
lhs_fields = []
rhs_fields = []
# Add a join condition for each pair of joining fields.
parent_template = "parent__field__"
for lhs, rhs in self.join_fields:
lhs, rhs = connection.ops.prepare_join_on_clause(
self.parent_alias, lhs, compiler.collection_name, rhs
)
lhs_fields.append(lhs.as_mql(compiler, connection))
# In the lookup stage, the reference to this column doesn't include
# the collection name.
rhs_fields.append(rhs.as_mql(compiler, connection))
# Handle any join conditions besides matching field pairs.
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
if extra:
def join(self, compiler, connection, pushed_expressions=None):
def _get_reroot_replacements(expressions):
if not expressions:
return None
columns = []
for expr in extra.leaves():
for expr in expressions:
# Determine whether the column needs to be transformed or rerouted
# as part of the subquery.
for hand_side in ["lhs", "rhs"]:
Expand All @@ -165,18 +152,47 @@ def join(self, compiler, connection):
# based on their rerouted positions in the join pipeline.
replacements = {}
for col, parent_pos in columns:
column_target = Col(compiler.collection_name, expr.output_field.__class__())
target = col.target.clone()
target.remote_field = col.target.remote_field
column_target = Col(compiler.collection_name, target)
if parent_pos is not None:
target_col = f"${parent_template}{parent_pos}"
column_target.target.db_column = target_col
column_target.target.set_attributes_from_name(target_col)
else:
column_target.target = col.target
replacements[col] = column_target
# Apply the transformed expressions in the extra condition.
return replacements

lookup_pipeline = []
lhs_fields = []
rhs_fields = []
# Add a join condition for each pair of joining fields.
parent_template = "parent__field__"
for lhs, rhs in self.join_fields:
lhs, rhs = connection.ops.prepare_join_on_clause(
self.parent_alias, lhs, compiler.collection_name, rhs
)
lhs_fields.append(lhs.as_mql(compiler, connection))
# In the lookup stage, the reference to this column doesn't include
# the collection name.
rhs_fields.append(rhs.as_mql(compiler, connection))
# Handle any join conditions besides matching field pairs.
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)

if extra:
replacements = _get_reroot_replacements(extra.leaves())
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
else:
extra_condition = []
if self.join_type == INNER:
rerooted_replacement = _get_reroot_replacements(pushed_expressions)
resolved_pushed_expressions = [
expr.replace_expressions(rerooted_replacement).as_mql(compiler, connection)
for expr in pushed_expressions
]
else:
resolved_pushed_expressions = []

lookup_pipeline = [
{
Expand Down Expand Up @@ -204,6 +220,7 @@ def join(self, compiler, connection):
for i, field in enumerate(rhs_fields)
]
+ extra_condition
+ resolved_pushed_expressions
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion tests/queries_/test_mql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_join(self):
"{'$lookup': {'from': 'queries__author', "
"'let': {'parent__field__0': '$author_id'}, "
"'pipeline': [{'$match': {'$expr': "
"{'$and': [{'$eq': ['$$parent__field__0', '$_id']}]}}}], 'as': 'queries__author'}}, "
"{'$and': [{'$eq': ['$$parent__field__0', '$_id']}, "
"{'$eq': ['$name', 'Bob']}]}}}], 'as': 'queries__author'}}, "
"{'$unwind': '$queries__author'}, "
"{'$match': {'$expr': {'$eq': ['$queries__author.name', 'Bob']}}}])",
)