Skip to content

Arrayfield support querying by EMF #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions django_mongodb_backend/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
return {
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
}
return {"$and": [{"$isArray": lhs_mql}, {"$size": {"$setIntersection": [value, lhs_mql]}}]}


@ArrayField.register_lookup
Expand All @@ -338,7 +336,7 @@ class ArrayLenTransform(Transform):

def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}


@ArrayField.register_lookup
Expand Down
7 changes: 5 additions & 2 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,11 @@ def as_mql(self, compiler, connection):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
mql = previous.as_mql(compiler, connection)
transforms = ".".join(key_transforms)
return f"{mql}.{transforms}"
# transform = ".".join(key_transforms)
for key in key_transforms:
mql = {"$getField": {"input": mql, "field": key}}
return mql
# return f"{mql}.{transform}"

@property
def output_field(self):
Expand Down
195 changes: 195 additions & 0 deletions django_mongodb_backend/fields/embedded_model_array.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import difflib

from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.models import Field
from django.db.models.expressions import Col
from django.db.models.lookups import Lookup, Transform

from .. import forms
from ..query_utils import process_lhs, process_rhs
from . import EmbeddedModelField
from .array import ArrayField
from .embedded_model import EMFExact, EMFMixin


class EmbeddedModelArrayField(ArrayField):
Expand Down Expand Up @@ -44,3 +52,190 @@ def formfield(self, **kwargs):
**kwargs,
},
)

def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name, self)


@EmbeddedModelArrayField.register_lookup
class EMFArrayExact(EMFExact):
def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
if isinstance(self.lhs, KeyTransform):
lhs_mql, inner_lhs_mql = lhs_mql
else:
inner_lhs_mql = "$$item"
if isinstance(value, models.Model):
value, emf_data = self.model_to_dict(value)
# Get conditions for any nested EmbeddedModelFields.
conditions = self.get_conditions({inner_lhs_mql: (value, emf_data)})
return {
"$anyElementTrue": {
"$ifNull": [
{
"$map": {
"input": lhs_mql,
"as": "item",
"in": {"$and": conditions},
}
},
[],
]
}
}
return {
"$anyElementTrue": {
"$ifNull": [
{
"$map": {
"input": lhs_mql,
"as": "item",
"in": {"$eq": [inner_lhs_mql, value]},
}
},
[],
]
}
}


@EmbeddedModelArrayField.register_lookup
class ArrayOverlap(EMFMixin, Lookup):
lookup_name = "overlap"
get_db_prep_lookup_value_is_iterable = True

def process_rhs(self, compiler, connection):
values = self.rhs
if self.get_db_prep_lookup_value_is_iterable:
values = [values]
# Compute how to serialize each value based on the query target.
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
# field of the subfield. Otherwise, use the base field of the array itself.
if isinstance(self.lhs, KeyTransform):
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
else:
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]

def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
values = process_rhs(self, compiler, connection)
# Querying a subfield within the array elements (via nested KeyTransform).
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
# `$in` on the subfield.
if isinstance(self.lhs, KeyTransform):
lhs_mql, inner_lhs_mql = lhs_mql
return {
"$anyElementTrue": {
"$ifNull": [
{
"$map": {
"input": lhs_mql,
"as": "item",
"in": {"$in": [inner_lhs_mql, values]},
}
},
[],
]
}
}
conditions = []
inner_lhs_mql = "$$item"
# Querying full embedded documents in the array.
# Builds `$or` conditions and maps them over the array to match any full document.
for value in values:
value, emf_data = self.model_to_dict(value)
# Get conditions for any nested EmbeddedModelFields.
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
return {
"$anyElementTrue": {
"$ifNull": [
{
"$map": {
"input": lhs_mql,
"as": "item",
"in": {"$or": conditions},
}
},
[],
]
}
}


class KeyTransform(Transform):
# it should be different class than EMF keytransform even most of the methods are equal.
def __init__(self, key_name, array_field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.array_field = array_field
self.key_name = key_name
# The iteration items begins from the base_field, a virtual column with
# base field output type is created.
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
column_name = f"$item.{key_name}"
column_target.db_column = column_name
column_target.set_attributes_from_name(column_name)
self._lhs = Col(None, column_target)
self._sub_transform = None

def __call__(self, this, *args, **kwargs):
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
return self

def get_lookup(self, name):
return self.output_field.get_lookup(name)

def _get_missing_field_or_lookup_exception(self, lhs, name):
suggested_lookups = difflib.get_close_matches(name, lhs.get_lookups())
if suggested_lookups:
suggested_lookups = " or ".join(suggested_lookups)
suggestion = f", perhaps you meant {suggested_lookups}?"
else:
suggestion = "."
raise FieldDoesNotExist(
f"Unsupported lookup '{name}' for "
f"{self.array_field.base_field.__class__.__name__} '{self.array_field.base_field.name}'"
f"{suggestion}"
)

def get_transform(self, name):
"""
Validate that `name` is either a field of an embedded model or a
lookup on an embedded model's field.
"""
# Once the sub lhs is a transform, all the filter are applied over it.
transform = (
self._lhs.get_transform(name)
if isinstance(self._lhs, Transform)
else self.array_field.base_field.embedded_model._meta.get_field(
self.key_name
).get_transform(name)
)
if transform:
self._sub_transform = transform
return self
raise self._get_missing_field_or_lookup_exception(
self._lhs if isinstance(self._lhs, Transform) else self.base_field, name
)

def as_mql(self, compiler, connection):
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
lhs_mql = process_lhs(self, compiler, connection)
return lhs_mql, inner_lhs_mql

@property
def output_field(self):
return self.array_field


class KeyTransformFactory:
def __init__(self, key_name, base_field):
self.key_name = key_name
self.base_field = base_field

def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)
34 changes: 34 additions & 0 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,37 @@ class Movie(models.Model):

def __str__(self):
return self.title


class RestorationRecord(EmbeddedModel):
date = models.DateField()
description = models.TextField()
restored_by = models.CharField(max_length=255)


class ArtifactDetail(EmbeddedModel):
"""Details about a specific artifact."""

name = models.CharField(max_length=255)
description = models.CharField(max_length=255)
metadata = models.JSONField()
restorations = EmbeddedModelArrayField(RestorationRecord, null=True)
last_restoration = EmbeddedModelField(RestorationRecord, null=True)


class ExhibitSection(EmbeddedModel):
"""A section within an exhibit, containing multiple artifacts."""

section_number = models.IntegerField()
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)


class MuseumExhibit(models.Model):
"""An exhibit in the museum, composed of multiple sections."""

exhibit_name = models.CharField(max_length=255)
sections = EmbeddedModelArrayField(ExhibitSection, null=True)
main_section = EmbeddedModelField(ExhibitSection, null=True)

def __str__(self):
return self.exhibit_name
Loading
Loading