Skip to content

Commit 2526da3

Browse files
committed
add ArrayField
1 parent 7ee09ac commit 2526da3

File tree

24 files changed

+2504
-8
lines changed

24 files changed

+2504
-8
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ repos:
4444
hooks:
4545
- id: rstcheck
4646
additional_dependencies: [sphinx]
47+
args: ["--ignore-directives=fieldlookup,setting", "--ignore-roles=lookup,setting"]
4748

4849
# We use the Python version instead of the original version which seems to require Docker
4950
# https://github.com/koalaman/shellcheck-precommit

django_mongodb_backend/features.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ class DatabaseFeatures(BaseDatabaseFeatures):
8080
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
8181
# GenericRelation.value_to_string() assumes integer pk.
8282
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
83+
# contains with Exists() doesn't work:
84+
# https://github.com/mongodb-labs/django-mongodb/issues/204
85+
"model_fields_.test_arrayfield.QueryingTests.test_contains_subquery",
86+
# overlap with values() returns no results:
87+
# https://github.com/mongodb-labs/django-mongodb/issues/209
88+
"model_fields_.test_arrayfield.QueryingTests.test_overlap_values",
89+
# icontains doesn't work on ArrayField:
90+
# Unsupported conversion from array to string in $convert
91+
"model_fields_.test_arrayfield.QueryingTests.test_icontains",
92+
# ArrayField's contained_by lookup crashes with Exists: "both operands "
93+
# of $setIsSubset must be arrays. Second argument is of type: null"
94+
# https://jira.mongodb.org/browse/SERVER-99186
95+
"model_fields_.test_arrayfield.QueryingTests.test_contained_by_subquery",
8396
}
8497
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
8598
_django_test_expected_failures_bitwise = {

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from .array import ArrayField
12
from .auto import ObjectIdAutoField
23
from .duration import register_duration_field
34
from .json import register_json_field
45
from .objectid import ObjectIdField
56

6-
__all__ = ["register_fields", "ObjectIdAutoField", "ObjectIdField"]
7+
__all__ = ["register_fields", "ArrayField", "ObjectIdAutoField", "ObjectIdField"]
78

89

910
def register_fields():
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
import json
2+
3+
from django.contrib.postgres.validators import ArrayMaxLengthValidator
4+
from django.core import checks, exceptions
5+
from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value
6+
from django.db.models.fields.mixins import CheckFieldDefaultMixin
7+
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup
8+
from django.utils.translation import gettext_lazy as _
9+
10+
from ..forms import SimpleArrayField
11+
from ..query_utils import process_lhs, process_rhs
12+
from ..utils import prefix_validation_error
13+
14+
__all__ = ["ArrayField"]
15+
16+
17+
class AttributeSetter:
18+
def __init__(self, name, value):
19+
setattr(self, name, value)
20+
21+
22+
class ArrayField(CheckFieldDefaultMixin, Field):
23+
empty_strings_allowed = False
24+
default_error_messages = {
25+
"item_invalid": _("Item %(nth)s in the array did not validate:"),
26+
"nested_array_mismatch": _("Nested arrays must have the same length."),
27+
}
28+
_default_hint = ("list", "[]")
29+
30+
def __init__(self, base_field, size=None, **kwargs):
31+
self.base_field = base_field
32+
self.size = size
33+
if self.size:
34+
self.default_validators = [
35+
*self.default_validators,
36+
ArrayMaxLengthValidator(self.size),
37+
]
38+
# For performance, only add a from_db_value() method if the base field
39+
# implements it.
40+
if hasattr(self.base_field, "from_db_value"):
41+
self.from_db_value = self._from_db_value
42+
super().__init__(**kwargs)
43+
44+
@property
45+
def model(self):
46+
try:
47+
return self.__dict__["model"]
48+
except KeyError:
49+
raise AttributeError(
50+
"'%s' object has no attribute 'model'" % self.__class__.__name__
51+
) from None
52+
53+
@model.setter
54+
def model(self, model):
55+
self.__dict__["model"] = model
56+
self.base_field.model = model
57+
58+
@classmethod
59+
def _choices_is_value(cls, value):
60+
return isinstance(value, list | tuple) or super()._choices_is_value(value)
61+
62+
def check(self, **kwargs):
63+
errors = super().check(**kwargs)
64+
if self.base_field.remote_field:
65+
errors.append(
66+
checks.Error(
67+
"Base field for array cannot be a related field.",
68+
obj=self,
69+
id="django_mongodb_backend.array.E002",
70+
)
71+
)
72+
else:
73+
base_checks = self.base_field.check()
74+
if base_checks:
75+
error_messages = "\n ".join(
76+
f"{base_check.msg} ({base_check.id})"
77+
for base_check in base_checks
78+
if isinstance(base_check, checks.Error)
79+
)
80+
if error_messages:
81+
errors.append(
82+
checks.Error(
83+
f"Base field for array has errors:\n {error_messages}",
84+
obj=self,
85+
id="django_mongodb_backend.array.E001",
86+
)
87+
)
88+
warning_messages = "\n ".join(
89+
f"{base_check.msg} ({base_check.id})"
90+
for base_check in base_checks
91+
if isinstance(base_check, checks.Warning)
92+
)
93+
if warning_messages:
94+
errors.append(
95+
checks.Warning(
96+
f"Base field for array has warnings:\n {warning_messages}",
97+
obj=self,
98+
id="django_mongodb_backend.array.W004",
99+
)
100+
)
101+
return errors
102+
103+
def set_attributes_from_name(self, name):
104+
super().set_attributes_from_name(name)
105+
self.base_field.set_attributes_from_name(name)
106+
107+
@property
108+
def description(self):
109+
return f"Array of {self.base_field.description}"
110+
111+
def db_type(self, connection):
112+
return "array"
113+
114+
def get_db_prep_value(self, value, connection, prepared=False):
115+
if isinstance(value, list | tuple):
116+
# Workaround for https://code.djangoproject.com/ticket/35982
117+
# (fixed in Django 5.2).
118+
if isinstance(self.base_field, DecimalField):
119+
return [self.base_field.get_db_prep_save(i, connection) for i in value]
120+
return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
121+
return value
122+
123+
def deconstruct(self):
124+
name, path, args, kwargs = super().deconstruct()
125+
if path == "django_mongodb_backend.fields.array.ArrayField":
126+
path = "django_mongodb_backend.fields.ArrayField"
127+
kwargs.update(
128+
{
129+
"base_field": self.base_field.clone(),
130+
"size": self.size,
131+
}
132+
)
133+
return name, path, args, kwargs
134+
135+
def to_python(self, value):
136+
if isinstance(value, str):
137+
# Assume value is being deserialized.
138+
vals = json.loads(value)
139+
value = [self.base_field.to_python(val) for val in vals]
140+
return value
141+
142+
def _from_db_value(self, value, expression, connection):
143+
if value is None:
144+
return value
145+
return [self.base_field.from_db_value(item, expression, connection) for item in value]
146+
147+
def value_to_string(self, obj):
148+
values = []
149+
vals = self.value_from_object(obj)
150+
base_field = self.base_field
151+
152+
for val in vals:
153+
if val is None:
154+
values.append(None)
155+
else:
156+
obj = AttributeSetter(base_field.attname, val)
157+
values.append(base_field.value_to_string(obj))
158+
return json.dumps(values)
159+
160+
def get_transform(self, name):
161+
transform = super().get_transform(name)
162+
if transform:
163+
return transform
164+
if "_" not in name:
165+
try:
166+
index = int(name)
167+
except ValueError:
168+
pass
169+
else:
170+
return IndexTransformFactory(index, self.base_field)
171+
try:
172+
start, end = name.split("_")
173+
start = int(start)
174+
end = int(end)
175+
except ValueError:
176+
pass
177+
else:
178+
return SliceTransformFactory(start, end)
179+
180+
def validate(self, value, model_instance):
181+
super().validate(value, model_instance)
182+
for index, part in enumerate(value):
183+
try:
184+
self.base_field.validate(part, model_instance)
185+
except exceptions.ValidationError as error:
186+
raise prefix_validation_error(
187+
error,
188+
prefix=self.error_messages["item_invalid"],
189+
code="item_invalid",
190+
params={"nth": index + 1},
191+
) from None
192+
if isinstance(self.base_field, ArrayField) and len({len(i) for i in value}) > 1:
193+
raise exceptions.ValidationError(
194+
self.error_messages["nested_array_mismatch"],
195+
code="nested_array_mismatch",
196+
)
197+
198+
def run_validators(self, value):
199+
super().run_validators(value)
200+
for index, part in enumerate(value):
201+
try:
202+
self.base_field.run_validators(part)
203+
except exceptions.ValidationError as error:
204+
raise prefix_validation_error(
205+
error,
206+
prefix=self.error_messages["item_invalid"],
207+
code="item_invalid",
208+
params={"nth": index + 1},
209+
) from None
210+
211+
def formfield(self, **kwargs):
212+
return super().formfield(
213+
**{
214+
"form_class": SimpleArrayField,
215+
"base_field": self.base_field.formfield(),
216+
"max_length": self.size,
217+
**kwargs,
218+
}
219+
)
220+
221+
222+
class Array(Func):
223+
def as_mql(self, compiler, connection):
224+
return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]
225+
226+
227+
class ArrayRHSMixin:
228+
def __init__(self, lhs, rhs):
229+
if isinstance(rhs, tuple | list):
230+
expressions = []
231+
for value in rhs:
232+
if not hasattr(value, "resolve_expression"):
233+
field = lhs.output_field
234+
value = Value(field.base_field.get_prep_value(value))
235+
expressions.append(value)
236+
rhs = Array(*expressions)
237+
super().__init__(lhs, rhs)
238+
239+
240+
@ArrayField.register_lookup
241+
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
242+
lookup_name = "contains"
243+
244+
def as_mql(self, compiler, connection):
245+
lhs_mql = process_lhs(self, compiler, connection)
246+
value = process_rhs(self, compiler, connection)
247+
return {"$and": [{"$ne": [lhs_mql, None]}, {"$setIsSubset": [value, lhs_mql]}]}
248+
249+
250+
@ArrayField.register_lookup
251+
class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
252+
lookup_name = "contained_by"
253+
254+
def as_mql(self, compiler, connection):
255+
lhs_mql = process_lhs(self, compiler, connection)
256+
value = process_rhs(self, compiler, connection)
257+
return {
258+
"$and": [
259+
{"$ne": [lhs_mql, None]},
260+
{"$ne": [value, None]},
261+
{"$setIsSubset": [lhs_mql, value]},
262+
]
263+
}
264+
265+
266+
@ArrayField.register_lookup
267+
class ArrayExact(ArrayRHSMixin, Exact):
268+
pass
269+
270+
271+
@ArrayField.register_lookup
272+
class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
273+
lookup_name = "overlap"
274+
275+
def as_mql(self, compiler, connection):
276+
lhs_mql = process_lhs(self, compiler, connection)
277+
value = process_rhs(self, compiler, connection)
278+
return {
279+
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
280+
}
281+
282+
283+
@ArrayField.register_lookup
284+
class ArrayLenTransform(Transform):
285+
lookup_name = "len"
286+
output_field = IntegerField()
287+
288+
def as_mql(self, compiler, connection):
289+
lhs_mql = process_lhs(self, compiler, connection)
290+
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
291+
292+
293+
@ArrayField.register_lookup
294+
class ArrayInLookup(In):
295+
def get_prep_lookup(self):
296+
values = super().get_prep_lookup()
297+
if hasattr(values, "resolve_expression"):
298+
return values
299+
# process_rhs() expects hashable values, so convert lists to tuples.
300+
prepared_values = []
301+
for value in values:
302+
if hasattr(value, "resolve_expression"):
303+
prepared_values.append(value)
304+
else:
305+
prepared_values.append(tuple(value))
306+
return prepared_values
307+
308+
309+
class IndexTransform(Transform):
310+
def __init__(self, index, base_field, *args, **kwargs):
311+
super().__init__(*args, **kwargs)
312+
self.index = index
313+
self.base_field = base_field
314+
315+
def as_mql(self, compiler, connection):
316+
lhs_mql = process_lhs(self, compiler, connection)
317+
return {"$arrayElemAt": [lhs_mql, self.index]}
318+
319+
@property
320+
def output_field(self):
321+
return self.base_field
322+
323+
324+
class IndexTransformFactory:
325+
def __init__(self, index, base_field):
326+
self.index = index
327+
self.base_field = base_field
328+
329+
def __call__(self, *args, **kwargs):
330+
return IndexTransform(self.index, self.base_field, *args, **kwargs)
331+
332+
333+
class SliceTransform(Transform):
334+
def __init__(self, start, end, *args, **kwargs):
335+
super().__init__(*args, **kwargs)
336+
self.start = start
337+
self.end = end
338+
339+
def as_mql(self, compiler, connection):
340+
lhs_mql = process_lhs(self, compiler, connection)
341+
return {"$slice": [lhs_mql, self.start, self.end]}
342+
343+
344+
class SliceTransformFactory:
345+
def __init__(self, start, end):
346+
self.start = start
347+
self.end = end
348+
349+
def __call__(self, *args, **kwargs):
350+
return SliceTransform(self.start, self.end, *args, **kwargs)

0 commit comments

Comments
 (0)