Skip to content

Commit 84a085c

Browse files
authored
Add better support for ManyToManyField's through model (#1719)
1 parent af38823 commit 84a085c

File tree

13 files changed

+678
-64
lines changed

13 files changed

+678
-64
lines changed

django-stubs/db/models/fields/related.pyi

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable, Iterable, Sequence
2-
from typing import Any, Literal, TypeVar, overload
2+
from typing import Any, Generic, Literal, TypeVar, overload
33
from uuid import UUID
44

55
from django.core import validators # due to weird mypy.stubtest error
@@ -11,14 +11,14 @@ from django.db.models.fields.related_descriptors import ForwardManyToOneDescript
1111
from django.db.models.fields.related_descriptors import ( # noqa: F401
1212
ForwardOneToOneDescriptor as ForwardOneToOneDescriptor,
1313
)
14+
from django.db.models.fields.related_descriptors import ManyRelatedManager
1415
from django.db.models.fields.related_descriptors import ManyToManyDescriptor as ManyToManyDescriptor
1516
from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor as ReverseManyToOneDescriptor
1617
from django.db.models.fields.related_descriptors import ReverseOneToOneDescriptor as ReverseOneToOneDescriptor
1718
from django.db.models.fields.reverse_related import ForeignObjectRel as ForeignObjectRel # noqa: F401
1819
from django.db.models.fields.reverse_related import ManyToManyRel as ManyToManyRel
1920
from django.db.models.fields.reverse_related import ManyToOneRel as ManyToOneRel
2021
from django.db.models.fields.reverse_related import OneToOneRel as OneToOneRel
21-
from django.db.models.manager import RelatedManager
2222
from django.db.models.query_utils import FilteredRelation, PathInfo, Q
2323
from django.utils.functional import _StrOrPromise
2424
from typing_extensions import Self
@@ -27,6 +27,7 @@ RECURSIVE_RELATIONSHIP_CONSTANT: Literal["self"]
2727

2828
def resolve_relation(scope_model: type[Model], relation: str | type[Model]) -> str | type[Model]: ...
2929

30+
_M = TypeVar("_M", bound=Model)
3031
# __set__ value type
3132
_ST = TypeVar("_ST")
3233
# __get__ return type
@@ -204,10 +205,9 @@ class OneToOneField(ForeignKey[_ST, _GT]):
204205
@overload
205206
def __get__(self, instance: Any, owner: Any) -> Self: ...
206207

207-
class ManyToManyField(RelatedField[_ST, _GT]):
208-
_pyi_private_set_type: Sequence[Any]
209-
_pyi_private_get_type: RelatedManager[Any]
208+
_To = TypeVar("_To", bound=Model)
210209

210+
class ManyToManyField(RelatedField[Any, Any], Generic[_To, _M]):
211211
description: str
212212
has_null_arg: bool
213213
swappable: bool
@@ -221,12 +221,12 @@ class ManyToManyField(RelatedField[_ST, _GT]):
221221
rel_class: type[ManyToManyRel]
222222
def __init__(
223223
self,
224-
to: type[Model] | str,
224+
to: type[_To] | str,
225225
related_name: str | None = ...,
226226
related_query_name: str | None = ...,
227227
limit_choices_to: _AllLimitChoicesTo | None = ...,
228228
symmetrical: bool | None = ...,
229-
through: str | type[Model] | None = ...,
229+
through: type[_M] | str | None = ...,
230230
through_fields: tuple[str, str] | None = ...,
231231
db_constraint: bool = ...,
232232
db_table: str | None = ...,
@@ -255,10 +255,10 @@ class ManyToManyField(RelatedField[_ST, _GT]):
255255
) -> None: ...
256256
# class access
257257
@overload
258-
def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[Self]: ...
258+
def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[_M]: ...
259259
# Model instance access
260260
@overload
261-
def __get__(self, instance: Model, owner: Any) -> _GT: ...
261+
def __get__(self, instance: Model, owner: Any) -> ManyRelatedManager[_To]: ...
262262
# non-Model instances
263263
@overload
264264
def __get__(self, instance: Any, owner: Any) -> Self: ...

django-stubs/db/models/fields/related_descriptors.pyi

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from collections.abc import Callable
2-
from typing import Any, Generic, TypeVar, overload
1+
from collections.abc import Callable, Iterable
2+
from typing import Any, Generic, NoReturn, TypeVar, overload
33

44
from django.core.exceptions import ObjectDoesNotExist
55
from django.db.models.base import Model
66
from django.db.models.fields import Field
7-
from django.db.models.fields.related import ForeignKey, RelatedField
7+
from django.db.models.fields.related import ForeignKey, ManyToManyField, RelatedField
88
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
9-
from django.db.models.manager import RelatedManager
9+
from django.db.models.manager import BaseManager, RelatedManager
1010
from django.db.models.query import QuerySet
1111
from django.db.models.query_utils import DeferredAttribute
12+
from typing_extensions import Self
1213

13-
_T = TypeVar("_T")
14+
_M = TypeVar("_M", bound=Model)
1415
_F = TypeVar("_F", bound=Field)
1516
_From = TypeVar("_From", bound=Model)
1617
_To = TypeVar("_To", bound=Model)
@@ -65,28 +66,63 @@ class ReverseOneToOneDescriptor(Generic[_From, _To]):
6566
def __reduce__(self) -> tuple[Callable[..., Any], tuple[type[_To], str]]: ...
6667

6768
class ReverseManyToOneDescriptor:
69+
"""
70+
In the example::
71+
72+
class Child(Model):
73+
parent = ForeignKey(Parent, related_name='children')
74+
75+
``Parent.children`` is a ``ReverseManyToOneDescriptor`` instance.
76+
"""
77+
6878
rel: ManyToOneRel
6979
field: ForeignKey
7080
def __init__(self, rel: ManyToOneRel) -> None: ...
7181
@property
72-
def related_manager_cls(self) -> type[RelatedManager]: ...
73-
def __get__(self, instance: Model | None, cls: type[Model] | None = ...) -> ReverseManyToOneDescriptor: ...
74-
def __set__(self, instance: Model, value: list[Model]) -> Any: ...
82+
def related_manager_cls(self) -> type[RelatedManager[Any]]: ...
83+
@overload
84+
def __get__(self, instance: None, cls: Any = ...) -> Self: ...
85+
@overload
86+
def __get__(self, instance: Model, cls: Any = ...) -> type[RelatedManager[Any]]: ...
87+
def __set__(self, instance: Any, value: Any) -> NoReturn: ...
88+
89+
def create_reverse_many_to_one_manager(
90+
superclass: type[BaseManager[_M]], rel: ManyToOneRel
91+
) -> type[RelatedManager[_M]]: ...
7592

76-
def create_reverse_many_to_one_manager(superclass: type, rel: Any) -> type[RelatedManager]: ...
93+
class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_M]):
94+
"""
95+
In the example::
96+
97+
class Pizza(Model):
98+
toppings = ManyToManyField(Topping, related_name='pizzas')
99+
100+
``Pizza.toppings`` and ``Topping.pizzas`` are ``ManyToManyDescriptor``
101+
instances.
102+
"""
77103

78-
class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_F]):
79-
field: _F # type: ignore[assignment]
104+
# 'field' here is 'rel.field'
80105
rel: ManyToManyRel # type: ignore[assignment]
106+
field: ManyToManyField[Any, _M] # type: ignore[assignment]
81107
reverse: bool
82108
def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ...
83109
@property
84-
def through(self) -> type[Model]: ...
110+
def through(self) -> type[_M]: ...
85111
@property
86-
def related_manager_cls(self) -> type[Any]: ... # ManyRelatedManager
112+
def related_manager_cls(self) -> type[ManyRelatedManager[Any]]: ... # type: ignore[override]
87113

88-
# fake
89-
class _ForwardManyToManyManager(Generic[_T]):
90-
def all(self) -> QuerySet: ...
114+
class ManyRelatedManager(BaseManager[_M], Generic[_M]):
115+
related_val: tuple[int, ...]
116+
def add(self, *objs: _M | int, bulk: bool = ...) -> None: ...
117+
async def aadd(self, *objs: _M | int, bulk: bool = ...) -> None: ...
118+
def remove(self, *objs: _M | int, bulk: bool = ...) -> None: ...
119+
async def aremove(self, *objs: _M | int, bulk: bool = ...) -> None: ...
120+
def set(self, objs: QuerySet[_M] | Iterable[_M | int], *, bulk: bool = ..., clear: bool = ...) -> None: ...
121+
async def aset(self, objs: QuerySet[_M] | Iterable[_M | int], *, bulk: bool = ..., clear: bool = ...) -> None: ...
122+
def clear(self) -> None: ...
123+
async def aclear(self) -> None: ...
124+
def __call__(self, *, manager: str) -> ManyRelatedManager[_M]: ...
91125

92-
def create_forward_many_to_many_manager(superclass: type, rel: Any, reverse: Any) -> _ForwardManyToManyManager: ...
126+
def create_forward_many_to_many_manager(
127+
superclass: type[BaseManager[_M]], rel: ManyToManyRel, reverse: bool
128+
) -> type[ManyRelatedManager[_M]]: ...

django-stubs/db/models/fields/reverse_related.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,13 @@ class OneToOneRel(ManyToOneRel):
112112
) -> None: ...
113113

114114
class ManyToManyRel(ForeignObjectRel):
115-
field: ManyToManyField # type: ignore[assignment]
115+
field: ManyToManyField[Any, Any] # type: ignore[assignment]
116116
through: type[Model] | None
117117
through_fields: tuple[str, str] | None
118118
db_constraint: bool
119119
def __init__(
120120
self,
121-
field: ManyToManyField,
121+
field: ManyToManyField[Any, Any],
122122
to: type[Model] | str,
123123
related_name: str | None = ...,
124124
related_query_name: str | None = ...,

mypy_django_plugin/django/context.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,21 @@
33
from collections import defaultdict
44
from contextlib import contextmanager
55
from functools import cached_property
6-
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, Literal, Optional, Sequence, Set, Tuple, Type, Union
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Dict,
10+
Iterable,
11+
Iterator,
12+
Literal,
13+
Mapping,
14+
Optional,
15+
Sequence,
16+
Set,
17+
Tuple,
18+
Type,
19+
Union,
20+
)
721

822
from django.core.exceptions import FieldDoesNotExist, FieldError
923
from django.db import models
@@ -270,6 +284,14 @@ def all_registered_model_classes(self) -> Set[Type[models.Model]]:
270284
def all_registered_model_class_fullnames(self) -> Set[str]:
271285
return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes}
272286

287+
@cached_property
288+
def model_class_fullnames_by_label(self) -> Mapping[str, str]:
289+
return {
290+
klass._meta.label: helpers.get_class_fullname(klass)
291+
for klass in self.all_registered_model_classes
292+
if klass is not models.Model
293+
}
294+
273295
def get_field_nullability(self, field: Union["Field[Any, Any]", ForeignObjectRel], method: Optional[str]) -> bool:
274296
if method in ("values", "values_list"):
275297
return field.null

mypy_django_plugin/lib/fullnames.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
FOREIGN_OBJECT_FULLNAME,
3939
FOREIGN_KEY_FULLNAME,
4040
ONETOONE_FIELD_FULLNAME,
41-
MANYTOMANY_FIELD_FULLNAME,
4241
)
4342
)
4443

mypy_django_plugin/lib/helpers.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from mypy.nodes import (
1111
GDEF,
1212
MDEF,
13+
AssignmentStmt,
1314
Block,
1415
ClassDef,
16+
Context,
1517
Expression,
1618
MemberExpr,
1719
MypyFile,
@@ -33,7 +35,8 @@
3335
SemanticAnalyzerPluginInterface,
3436
)
3537
from mypy.semanal import SemanticAnalyzer
36-
from mypy.types import AnyType, Instance, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
38+
from mypy.semanal_shared import parse_bool
39+
from mypy.types import AnyType, Instance, LiteralType, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
3740
from mypy.types import Type as MypyType
3841
from typing_extensions import TypedDict
3942

@@ -45,12 +48,14 @@
4548

4649

4750
class DjangoTypeMetadata(TypedDict, total=False):
51+
is_abstract_model: bool
4852
from_queryset_manager: str
4953
reverse_managers: Dict[str, str]
5054
baseform_bases: Dict[str, int]
5155
manager_bases: Dict[str, int]
5256
model_bases: Dict[str, int]
5357
queryset_bases: Dict[str, int]
58+
m2m_throughs: Dict[str, str]
5459

5560

5661
def get_django_metadata(model_info: TypeInfo) -> DjangoTypeMetadata:
@@ -385,3 +390,61 @@ def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) ->
385390
if sym is not None and isinstance(sym.node, TypeInfo):
386391
bases = get_django_metadata_bases(sym.node, "manager_bases")
387392
bases[fullname] = 1
393+
394+
395+
def is_abstract_model(model: TypeInfo) -> bool:
396+
if model.metaclass_type is None or model.metaclass_type.type.fullname != fullnames.MODEL_METACLASS_FULLNAME:
397+
return False
398+
399+
metadata = get_django_metadata(model)
400+
if metadata.get("is_abstract_model") is not None:
401+
return metadata["is_abstract_model"]
402+
403+
meta = model.names.get("Meta")
404+
# Check if 'abstract' is declared in this model's 'class Meta' as
405+
# 'abstract = True' won't be inherited from a parent model.
406+
if meta is not None and isinstance(meta.node, TypeInfo) and "abstract" in meta.node.names:
407+
for stmt in meta.node.defn.defs.body:
408+
if (
409+
# abstract =
410+
isinstance(stmt, AssignmentStmt)
411+
and len(stmt.lvalues) == 1
412+
and isinstance(stmt.lvalues[0], NameExpr)
413+
and stmt.lvalues[0].name == "abstract"
414+
):
415+
# abstract = True (builtins.bool)
416+
rhs_is_true = parse_bool(stmt.rvalue) is True
417+
# abstract: Literal[True]
418+
is_literal_true = isinstance(stmt.type, LiteralType) and stmt.type.value is True
419+
metadata["is_abstract_model"] = rhs_is_true or is_literal_true
420+
return metadata["is_abstract_model"]
421+
422+
metadata["is_abstract_model"] = False
423+
return False
424+
425+
426+
def resolve_lazy_reference(
427+
reference: str, *, api: Union[TypeChecker, SemanticAnalyzer], django_context: "DjangoContext", ctx: Context
428+
) -> Optional[TypeInfo]:
429+
"""
430+
Attempts to resolve a lazy reference(e.g. "<app_label>.<object_name>") to a
431+
'TypeInfo' instance.
432+
"""
433+
if "." not in reference:
434+
# <object_name> -- needs prefix of <app_label>. We can't implicitly solve
435+
# what app label this should be, yet.
436+
return None
437+
438+
# Reference conforms to the structure of a lazy reference: '<app_label>.<object_name>'
439+
fullname = django_context.model_class_fullnames_by_label.get(reference)
440+
if fullname is not None:
441+
model_info = lookup_fully_qualified_typeinfo(api, fullname)
442+
if model_info is not None:
443+
return model_info
444+
elif isinstance(api, SemanticAnalyzer) and not api.final_iteration:
445+
# Getting this far, where Django matched the reference but we still can't
446+
# find it, we want to defer
447+
api.defer()
448+
else:
449+
api.fail("Could not match lazy reference with any model", ctx)
450+
return None

mypy_django_plugin/transformers/fields.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mypy_django_plugin.django.context import DjangoContext
1414
from mypy_django_plugin.exceptions import UnregisteredModelError
1515
from mypy_django_plugin.lib import fullnames, helpers
16+
from mypy_django_plugin.transformers import manytomany
1617

1718
if TYPE_CHECKING:
1819
from django.contrib.contenttypes.fields import GenericForeignKey
@@ -213,6 +214,10 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
213214

214215
assert isinstance(outer_model_info, TypeInfo)
215216

217+
if default_return_type.type.has_base(fullnames.MANYTOMANY_FIELD_FULLNAME):
218+
return manytomany.fill_model_args_for_many_to_many_field(
219+
ctx=ctx, model_info=outer_model_info, default_return_type=default_return_type, django_context=django_context
220+
)
216221
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
217222
return fill_descriptor_types_for_related_field(ctx, django_context)
218223

0 commit comments

Comments
 (0)