Skip to content

Add EmbeddedModelArrayField #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions django_mongodb_backend/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from .auto import ObjectIdAutoField
from .duration import register_duration_field
from .embedded_model import EmbeddedModelField
from .embedded_model_array import EmbeddedModelArrayField
from .json import register_json_field
from .objectid import ObjectIdField

__all__ = [
"register_fields",
"ArrayField",
"EmbeddedModelArrayField",
"EmbeddedModelField",
"ObjectIdAutoField",
"ObjectIdField",
Expand Down
62 changes: 62 additions & 0 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from django.core import checks
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.models import lookups
from django.db.models.expressions import Col
from django.db.models.fields.related import lazy_related_operation
from django.db.models.lookups import Transform

from .. import forms
from ..query_utils import process_lhs, process_rhs


class EmbeddedModelField(models.Field):
Expand Down Expand Up @@ -148,6 +151,65 @@ def formfield(self, **kwargs):
)


@EmbeddedModelField.register_lookup
class EMFExact(lookups.Exact):
def model_to_dict(self, instance):
"""
Return a dict containing the data in a model instance, as well as a
dict containing the data for any embedded model fields.
"""
data = {}
emf_data = {}
for f in instance._meta.concrete_fields:
value = f.value_from_object(instance)
if isinstance(f, EmbeddedModelField):
emf_data[f.name] = self.model_to_dict(value) if value is not None else (None, {})
continue
# Unless explicitly set, primary keys aren't included in embedded
# models.
if f.primary_key and value is None:
continue
data[f.name] = value
return data, emf_data

def get_conditions(self, emf_data, prefix=None):
"""
Recursively transform a dictionary of {"field_name": {<model_to_dict>}}
lookups into MQL. `prefix` tracks the string that must be appended to
nested fields.
"""
conditions = []
for k, v in emf_data.items():
v, emf_data = v
subprefix = f"{prefix}.{k}" if prefix else k
conditions += self.get_conditions(emf_data, subprefix)
if v is not None:
# Match all field of the EmbeddedModelField.
conditions += [{"$eq": [f"{subprefix}.{x}", y]} for x, y in v.items()]
else:
# Match a null EmbeddedModelField.
conditions += [{"$eq": [f"{subprefix}", None]}]
return conditions

def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
if isinstance(self.lhs, Col) or (
isinstance(self.lhs, KeyTransform)
and isinstance(self.lhs.ref_field, EmbeddedModelField)
):
if isinstance(value, models.Model):
value, emf_data = self.model_to_dict(value)
# Get conditions for any nested EmbeddedModelFields.
conditions = self.get_conditions({lhs_mql: (value, emf_data)})
return {"$and": conditions}
raise TypeError(
"An EmbeddedModelField must be queried using a model instance, got %s."
% type(value)
)
return connection.mongo_operators[self.lookup_name](lhs_mql, value)


class KeyTransform(Transform):
def __init__(self, key_name, ref_field, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
54 changes: 54 additions & 0 deletions django_mongodb_backend/fields/embedded_model_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from ..forms import EmbeddedModelArrayFormField
from . import EmbeddedModelField
from .array import ArrayField
from .embedded_model import EMFExact


class EmbeddedModelArrayField(ArrayField):
def __init__(self, model, **kwargs):
super().__init__(EmbeddedModelField(model), **kwargs)

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if path == "django_mongodb_backend.fields.multiple_embedded_model.EmbeddedModelArrayField":
path = "django_mongodb_backend.fields.EmbeddedModelArrayField"
kwargs.update(
{
"model": self.base_field.embedded_model,
"size": self.size,
}
)
del kwargs["base_field"]
return name, path, args, kwargs

def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, list | tuple):
return [self.base_field.get_db_prep_save(i, connection) for i in value]
return value

def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": EmbeddedModelArrayFormField,
"model": self.base_field.embedded_model,
"max_length": self.size,
"prefix": self.name,
**kwargs,
}
)

def get_transform(self, name):
# TODO: ...
return self.base_field.get_transform(name)
# Copied from EmbedddedModelField -- customize?
# transform = super().get_transform(name)
# if transform:
# return transform
# field = self.embedded_model._meta.get_field(name)
# return KeyTransformFactory(name, field)


@EmbeddedModelArrayField.register_lookup
class EMFArrayExact(EMFExact):
# TODO
pass
2 changes: 2 additions & 0 deletions django_mongodb_backend/forms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .fields import (
EmbeddedModelArrayFormField,
EmbeddedModelField,
ObjectIdField,
SimpleArrayField,
Expand All @@ -7,6 +8,7 @@
)

__all__ = [
"EmbeddedModelArrayFormField",
"EmbeddedModelField",
"SimpleArrayField",
"SplitArrayField",
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb_backend/forms/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .array import SimpleArrayField, SplitArrayField, SplitArrayWidget
from .embedded_model import EmbeddedModelField
from .embedded_model_array import EmbeddedModelArrayFormField
from .objectid import ObjectIdField

__all__ = [
"EmbeddedModelArrayFormField",
"EmbeddedModelField",
"SimpleArrayField",
"SplitArrayField",
Expand Down
76 changes: 76 additions & 0 deletions django_mongodb_backend/forms/fields/embedded_model_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from django import forms
from django.core.exceptions import ValidationError
from django.forms import formset_factory, model_to_dict
from django.forms.models import modelform_factory
from django.utils.html import format_html, format_html_join


def models_to_dicts(models):
"""
Convert initial data (which is a list of model instances or None) to a
list of dictionary data suitable for a formset.
"""
return [model_to_dict(model) for model in models or []]


class EmbeddedModelArrayFormField(forms.Field):
def __init__(self, model, prefix, max_length=None, *args, **kwargs):
kwargs.pop("base_field")
self.model = model
self.prefix = prefix
self.formset = formset_factory(
form=modelform_factory(model, fields="__all__"),
can_delete=True,
max_num=max_length,
)
kwargs["widget"] = MultipleEmbeddedModelWidget()
super().__init__(*args, **kwargs)

def clean(self, value):
if not value:
# TODO: null or empty list?
return []
formset = self.formset(value, prefix=self.prefix)
if not formset.is_valid():
raise ValidationError(formset.errors)
cleaned_data = []
for data in formset.cleaned_data:
# The fallback to True skips empty forms.
if data.get("DELETE", True):
continue
data.pop("DELETE") # The "delete" checkbox isn't part of model data.
cleaned_data.append(self.model(**data))
return cleaned_data

def has_changed(self, initial, data):
formset = self.formset(data, initial=models_to_dicts(initial), prefix=self.prefix)
return formset.has_changed()

def get_bound_field(self, form, field_name):
return MultipleEmbeddedModelBoundField(form, self, field_name)


class MultipleEmbeddedModelBoundField(forms.BoundField):
def __init__(self, form, field, name):
super().__init__(form, field, name)
self.formset = field.formset(
self.data if form.is_bound else None,
initial=models_to_dicts(self.initial),
prefix=self.html_name,
)

def __str__(self):
body = format_html_join(
"\n", "<tbody>{}</tbody>", ((form.as_table(),) for form in self.formset)
)
return format_html("<table>\n{}\n</table>\n{}", body, self.formset.management_form)


class MultipleEmbeddedModelWidget(forms.Widget):
"""
This widget extracts the data for EmbeddedModelArrayFormField's formset.
It is never rendered.
"""

def value_from_datadict(self, data, files, name):
return {key: data[key] for key in data if key.startswith(f"{name}-")}
7 changes: 7 additions & 0 deletions docs/source/releases/5.2.x.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ Initial release from the state of :ref:`django-mongodb-backend 5.1.0 beta 2
Regarding new features in Django 5.2,
:class:`~django.db.models.CompositePrimaryKey` isn't supported.

New features
------------

*These features won't appear in Django MongoDB Backend 5.1.x.*

- Allowed ``EmbeddedModelField``’s ``exact`` lookup to use a model instance.

Bug fixes
---------

Expand Down
13 changes: 13 additions & 0 deletions docs/source/topics/embedded-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,16 @@ as relational fields. For example, to retrieve all customers who have an
address with the city "New York"::

>>> Customer.objects.filter(address__city="New York")

You can also query using a model instance. Unlike a normal relational lookup
which does the lookup by primary key, since embedded models typically don't
have a primary key set, the query requires that every field match. For example,
this query gives customers with addresses with the city "New York" and all
other fields of the address equal to their default (:attr:`Field.default
<django.db.models.Field.default>`, ``None``, or an empty string).

>>> Customer.objects.filter(address=Address(city="New York"))

.. versionadded:: 5.2.0b0

The ability to query by model instance was added.
52 changes: 51 additions & 1 deletion tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from django.db import models

from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField
from django_mongodb_backend.fields import (
ArrayField,
EmbeddedModelArrayField,
EmbeddedModelField,
ObjectIdField,
)
from django_mongodb_backend.models import EmbeddedModel


Expand Down Expand Up @@ -132,3 +137,48 @@ class Library(models.Model):

def __str__(self):
return self.name


class A(models.Model):
b = EmbeddedModelField("B")


class B(EmbeddedModel):
c = EmbeddedModelField("C")
name = models.CharField(max_length=100)
value = models.IntegerField()


class C(EmbeddedModel):
d = EmbeddedModelField("D")
name = models.CharField(max_length=100)
value = models.IntegerField()


class D(EmbeddedModel):
e = EmbeddedModelField("E")
nullable_e = EmbeddedModelField("E", null=True, blank=True)
name = models.CharField(max_length=100)
value = models.IntegerField()


class E(EmbeddedModel):
name = models.CharField(max_length=100)
value = models.IntegerField()


# ArrayField + EmbeddedModelField
class Review(EmbeddedModel):
title = models.CharField(max_length=255)
rating = models.IntegerField()

def __str__(self):
return self.title


class Movie(models.Model):
title = models.CharField(max_length=255)
reviews = EmbeddedModelArrayField(Review, null=True)

def __str__(self):
return self.title
Loading
Loading