Skip to content

Commit 329b2b7

Browse files
committed
Add AggregateFilter, StringgAgg.as_mql() as per
django/django@4b977a5
1 parent 450a5f3 commit 329b2b7

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

django_mongodb_backend/aggregates.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1-
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
2-
from django.db.models.expressions import Case, Value, When
1+
from django.db.models.aggregates import (
2+
Aggregate,
3+
AggregateFilter,
4+
Count,
5+
StdDev,
6+
StringAgg,
7+
Variance,
8+
)
9+
from django.db.models.expressions import Case, Col, Value, When
310
from django.db.models.lookups import IsNull
411

512
from .query_utils import process_lhs
@@ -16,7 +23,11 @@ def aggregate(
1623
resolve_inner_expression=False,
1724
**extra_context, # noqa: ARG001
1825
):
19-
if self.filter:
26+
# TODO: isinstance(self.filter, Col) works around failure of
27+
# aggregation.tests.AggregateTestCase.test_distinct_on_aggregate. Is this
28+
# correct?
29+
if self.filter is not None and not isinstance(self.filter, Col):
30+
# Generate a CASE statement for this aggregate.
2031
node = self.copy()
2132
node.filter = None
2233
source_expressions = node.get_source_expressions()
@@ -31,6 +42,10 @@ def aggregate(
3142
return {f"${operator}": lhs_mql}
3243

3344

45+
def aggregate_filter(self, compiler, connection, **extra_context):
46+
return self.condition.as_mql(compiler, connection, **extra_context)
47+
48+
3449
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
3550
"""
3651
When resolve_inner_expression=True, return the MQL that resolves as a
@@ -72,8 +87,16 @@ def stddev_variance(self, compiler, connection, **extra_context):
7287
return aggregate(self, compiler, connection, operator=operator, **extra_context)
7388

7489

90+
def string_agg(self, compiler, connection, **extra_context): # # noqa: ARG001
91+
from django.db import NotSupportedError
92+
93+
raise NotSupportedError("StringAgg is not supported.")
94+
95+
7596
def register_aggregates():
7697
Aggregate.as_mql = aggregate
98+
AggregateFilter.as_mql = aggregate_filter
7799
Count.as_mql = count
78100
StdDev.as_mql = stddev_variance
101+
StringAgg.as_mql = string_agg
79102
Variance.as_mql = stddev_variance

django_mongodb_backend/features.py

+6
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ class DatabaseFeatures(BaseDatabaseFeatures):
104104
"backends.base.test_creation.TestDbCreationTests.test_serialize_deprecation",
105105
# This backend has a custom format_debug_sql().
106106
"backends.tests.LastExecutedQueryTest.test_debug_sql",
107+
# StringAgg is not supported.
108+
"aggregation.tests.AggregateTestCase.test_distinct_on_stringagg",
109+
"aggregation.tests.AggregateTestCase.test_string_agg_escapes_delimiter",
110+
"aggregation.tests.AggregateTestCase.test_string_agg_filter",
111+
"aggregation.tests.AggregateTestCase.test_string_agg_filter_in_subquery",
112+
"aggregation.tests.AggregateTestCase.test_stringagg_default_value",
107113
}
108114
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
109115
_django_test_expected_failures_bitwise = {

0 commit comments

Comments
 (0)