Skip to content

Commit 419b97e

Browse files
WaVEVtimgraham
authored andcommitted
refactor subquery wrapping pipeline
1 parent 5ca07f1 commit 419b97e

File tree

4 files changed

+83
-77
lines changed

4 files changed

+83
-77
lines changed

django_mongodb_backend/expressions.py

Lines changed: 11 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def order_by(self, compiler, connection):
9595
return self.expression.as_mql(compiler, connection)
9696

9797

98-
def query(self, compiler, connection, lookup_name=None):
98+
def query(self, compiler, connection, get_wrapping_pipeline=None):
9999
subquery_compiler = self.get_compiler(connection=connection)
100100
subquery_compiler.pre_sql_setup(with_col_aliases=False)
101101
field_name, expr = subquery_compiler.columns[0]
@@ -119,76 +119,12 @@ def query(self, compiler, connection, lookup_name=None):
119119
for col, i in subquery_compiler.column_indices.items()
120120
},
121121
}
122-
wrapping_result_pipeline = None
123-
# The result must be a list of values. The output is compressed with an
124-
# aggregation pipeline.
125-
if lookup_name in ("in", "range"):
126-
wrapping_result_pipeline = [
127-
{
128-
"$facet": {
129-
"group": [
130-
{
131-
"$group": {
132-
"_id": None,
133-
"tmp_name": {
134-
"$addToSet": expr.as_mql(subquery_compiler, connection)
135-
},
136-
}
137-
}
138-
]
139-
}
140-
},
141-
{
142-
"$project": {
143-
field_name: {
144-
"$ifNull": [
145-
{
146-
"$getField": {
147-
"input": {"$arrayElemAt": ["$group", 0]},
148-
"field": "tmp_name",
149-
}
150-
},
151-
[],
152-
]
153-
}
154-
}
155-
},
156-
]
157-
if lookup_name == "overlap":
158-
wrapping_result_pipeline = [
159-
{
160-
"$facet": {
161-
"group": [
162-
{"$project": {"tmp_name": expr.as_mql(subquery_compiler, connection)}},
163-
{
164-
"$unwind": "$tmp_name",
165-
},
166-
{
167-
"$group": {
168-
"_id": None,
169-
"tmp_name": {"$addToSet": "$tmp_name"},
170-
}
171-
},
172-
]
173-
}
174-
},
175-
{
176-
"$project": {
177-
field_name: {
178-
"$ifNull": [
179-
{
180-
"$getField": {
181-
"input": {"$arrayElemAt": ["$group", 0]},
182-
"field": "tmp_name",
183-
}
184-
},
185-
[],
186-
]
187-
}
188-
}
189-
},
190-
]
191-
if wrapping_result_pipeline:
122+
if get_wrapping_pipeline:
123+
# The results from some lookups must be converted to a list of values.
124+
# The output is compressed with an aggregation pipeline.
125+
wrapping_result_pipeline = get_wrapping_pipeline(
126+
subquery_compiler, connection, field_name, expr
127+
)
192128
# If the subquery is a combinator, wrap the result at the end of the
193129
# combinator pipeline...
194130
if subquery.query.combinator:
@@ -221,13 +157,13 @@ def star(self, compiler, connection): # noqa: ARG001
221157
return {"$literal": True}
222158

223159

224-
def subquery(self, compiler, connection, lookup_name=None):
225-
return self.query.as_mql(compiler, connection, lookup_name=lookup_name)
160+
def subquery(self, compiler, connection, get_wrapping_pipeline=None):
161+
return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
226162

227163

228-
def exists(self, compiler, connection, lookup_name=None):
164+
def exists(self, compiler, connection, get_wrapping_pipeline=None):
229165
try:
230-
lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name)
166+
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
231167
except EmptyResultSet:
232168
return Value(False).as_mql(compiler, connection)
233169
return connection.mongo_operators["isnull"](lhs_mql, False)

django_mongodb_backend/fields/array.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,41 @@ class ArrayExact(ArrayRHSMixin, Exact):
278278
class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
279279
lookup_name = "overlap"
280280

281+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
282+
return [
283+
{
284+
"$facet": {
285+
"group": [
286+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
287+
{
288+
"$unwind": "$tmp_name",
289+
},
290+
{
291+
"$group": {
292+
"_id": None,
293+
"tmp_name": {"$addToSet": "$tmp_name"},
294+
}
295+
},
296+
]
297+
}
298+
},
299+
{
300+
"$project": {
301+
field_name: {
302+
"$ifNull": [
303+
{
304+
"$getField": {
305+
"input": {"$arrayElemAt": ["$group", 0]},
306+
"field": "tmp_name",
307+
}
308+
},
309+
[],
310+
]
311+
}
312+
}
313+
},
314+
]
315+
281316
def as_mql(self, compiler, connection):
282317
lhs_mql = process_lhs(self, compiler, connection)
283318
value = process_rhs(self, compiler, connection)

django_mongodb_backend/lookups.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,38 @@ def in_(self, compiler, connection):
4545
return builtin_lookup(self, compiler, connection)
4646

4747

48+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001
49+
return [
50+
{
51+
"$facet": {
52+
"group": [
53+
{
54+
"$group": {
55+
"_id": None,
56+
"tmp_name": {"$addToSet": expr.as_mql(compiler, connection)},
57+
}
58+
}
59+
]
60+
}
61+
},
62+
{
63+
"$project": {
64+
field_name: {
65+
"$ifNull": [
66+
{
67+
"$getField": {
68+
"input": {"$arrayElemAt": ["$group", 0]},
69+
"field": "tmp_name",
70+
}
71+
},
72+
[],
73+
]
74+
}
75+
}
76+
},
77+
]
78+
79+
4880
def is_null(self, compiler, connection):
4981
if not isinstance(self.rhs, bool):
5082
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
@@ -97,6 +129,7 @@ def register_lookups():
97129
field_resolve_expression_parameter
98130
)
99131
In.as_mql = RelatedIn.as_mql = in_
132+
In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline
100133
IsNull.as_mql = is_null
101134
PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value
102135
UUIDTextMixin.as_mql = uuid_text_mixin

django_mongodb_backend/query_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ def process_lhs(node, compiler, connection):
2828
def process_rhs(node, compiler, connection):
2929
rhs = node.rhs
3030
if hasattr(rhs, "as_mql"):
31-
if getattr(rhs, "subquery", False):
32-
value = rhs.as_mql(compiler, connection, lookup_name=node.lookup_name)
31+
if getattr(rhs, "subquery", False) and hasattr(node, "get_subquery_wrapping_pipeline"):
32+
value = rhs.as_mql(
33+
compiler, connection, get_wrapping_pipeline=node.get_subquery_wrapping_pipeline
34+
)
3335
else:
3436
value = rhs.as_mql(compiler, connection)
3537
else:

0 commit comments

Comments
 (0)