Skip to content

Commit cab6aeb

Browse files
Add a test for missing generics in stubs (#2659)
1 parent e2ef0d8 commit cab6aeb

File tree

4 files changed

+137
-7
lines changed

4 files changed

+137
-7
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
6969
# Must match `shard` definition in the test matrix:
7070
- name: Run pytest tests
71-
run: PYTHONPATH='.' pytest --num-shards=4 --shard-id=${{ matrix.shard }} -n auto tests
71+
run: PYTHONPATH='.' pytest --num-shards=4 --shard-id=${{ matrix.shard }} -n auto tests --durations=0
7272
- name: Run mypy on the test cases
7373
run: mypy --strict tests
7474

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ This happens because these Django classes do not support [`__class_getitem__`](h
164164

165165
You can add extra types to patch with `django_stubs_ext.monkeypatch(extra_classes=[YourDesiredType])`
166166

167+
**If you use generic symbols in `django.contrib.auth.forms`**, you will have to do the monkeypatching
168+
again in your first [`AppConfig.ready`](https://docs.djangoproject.com/en/5.2/ref/applications/#django.apps.AppConfig.ready).
169+
This is currently required because `django.contrib.auth.forms` cannot be imported until django is initialized.
170+
171+
167172
2. You can use strings instead: `'QuerySet[MyModel]'` and `'Manager[MyModel]'`, this way it will work as a type for `mypy` and as a regular `str` in runtime.
168173

169174
### How can I create a HttpRequest that's guaranteed to have an authenticated user?

ext/django_stubs_ext/patch.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import builtins
2+
import logging
23
from collections.abc import Iterable
34
from typing import Any, Generic, TypeVar
45

@@ -8,24 +9,33 @@
89
from django.contrib.messages.views import SuccessMessageMixin
910
from django.contrib.sitemaps import Sitemap
1011
from django.contrib.syndication.views import Feed
12+
from django.core.exceptions import AppRegistryNotReady, ImproperlyConfigured
1113
from django.core.files.utils import FileProxyMixin
1214
from django.core.paginator import Paginator
1315
from django.db.models.expressions import ExpressionWrapper
1416
from django.db.models.fields import Field
1517
from django.db.models.fields.related import ForeignKey
16-
from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor
18+
from django.db.models.fields.related_descriptors import (
19+
ForwardManyToOneDescriptor,
20+
ReverseManyToOneDescriptor,
21+
ReverseOneToOneDescriptor,
22+
)
1723
from django.db.models.lookups import Lookup
1824
from django.db.models.manager import BaseManager
19-
from django.db.models.query import ModelIterable, QuerySet, RawQuerySet
25+
from django.db.models.options import Options
26+
from django.db.models.query import BaseIterable, ModelIterable, QuerySet, RawQuerySet
2027
from django.forms.formsets import BaseFormSet
21-
from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField
22-
from django.utils.connection import BaseConnectionHandler
28+
from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField, ModelFormOptions
29+
from django.utils.connection import BaseConnectionHandler, ConnectionProxy
30+
from django.utils.functional import classproperty
2331
from django.views.generic.detail import SingleObjectMixin
2432
from django.views.generic.edit import DeletionMixin, FormMixin
2533
from django.views.generic.list import MultipleObjectMixin
2634

2735
__all__ = ["monkeypatch"]
2836

37+
logger = logging.getLogger(__name__)
38+
2939
_T = TypeVar("_T")
3040
_VersionSpec = tuple[int, int]
3141

@@ -81,16 +91,41 @@ def __repr__(self) -> str:
8191
# These types do have native `__class_getitem__` method since django 4.1:
8292
MPGeneric(ForeignKey, (4, 1)),
8393
MPGeneric(RawQuerySet),
94+
MPGeneric(classproperty),
95+
MPGeneric(ConnectionProxy),
96+
MPGeneric(ModelFormOptions),
97+
MPGeneric(Options),
98+
MPGeneric(BaseIterable),
99+
MPGeneric(ForwardManyToOneDescriptor),
100+
MPGeneric(ReverseOneToOneDescriptor),
84101
]
85102

86103

104+
def _get_need_generic() -> list[MPGeneric[Any]]:
105+
try:
106+
if VERSION >= (5, 1):
107+
from django.contrib.auth.forms import SetPasswordMixin, SetUnusablePasswordMixin
108+
109+
return [MPGeneric(SetPasswordMixin), MPGeneric(SetUnusablePasswordMixin), *_need_generic]
110+
else:
111+
from django.contrib.auth.forms import AdminPasswordChangeForm, SetPasswordForm
112+
113+
return [MPGeneric(SetPasswordForm), MPGeneric(AdminPasswordChangeForm), *_need_generic]
114+
115+
except (ImproperlyConfigured, AppRegistryNotReady):
116+
# We cannot patch symbols in `django.contrib.auth.forms` if the `monkeypatch()` call
117+
# is in the settings file because django is not initialized yet.
118+
# To solve this, you'll have to call `monkeypatch()` again later, in an `AppConfig.ready` for ex.
119+
# See https://docs.djangoproject.com/en/5.2/ref/applications/#django.apps.AppConfig.ready
120+
return _need_generic
121+
122+
87123
def monkeypatch(extra_classes: Iterable[type] | None = None, include_builtins: bool = True) -> None:
88124
"""Monkey patch django as necessary to work properly with mypy."""
89-
90125
# Add the __class_getitem__ dunder.
91126
suited_for_this_version = filter(
92127
lambda spec: spec.version is None or VERSION[:2] <= spec.version,
93-
_need_generic,
128+
_get_need_generic(),
94129
)
95130
for el in suited_for_this_version:
96131
el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls)

tests/test_generic_consistency.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import ast
2+
import glob
3+
import importlib
4+
import os
5+
from typing import final
6+
from unittest import mock
7+
8+
import django
9+
10+
# The root directory of the django-stubs package
11+
STUBS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "django-stubs"))
12+
13+
14+
@final
15+
class GenericInheritanceVisitor(ast.NodeVisitor):
16+
"""AST visitor to find classes inheriting from `typing.Generic` in stubs."""
17+
18+
def __init__(self) -> None:
19+
self.generic_classes: set[str] = set()
20+
21+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
22+
for base in node.bases:
23+
if (
24+
isinstance(base, ast.Subscript)
25+
and isinstance(base.value, ast.Name)
26+
and base.value.id == "Generic"
27+
and not any(dec.id == "type_check_only" for dec in node.decorator_list if isinstance(dec, ast.Name))
28+
):
29+
self.generic_classes.add(node.name)
30+
break
31+
self.generic_visit(node)
32+
33+
34+
def test_find_classes_inheriting_from_generic() -> None:
35+
"""
36+
This test ensures that the `ext/django_stubs_ext/patch.py` stays up-to-date with the stubs.
37+
It works as follows:
38+
1. Parse the ast of each .pyi file, and collects classes inheriting from Generic.
39+
2. For each Generic in the stubs, import the associated module and capture every class in the MRO
40+
3. Ensure that at least one class in the mro is patched in `ext/django_stubs_ext/patch.py`.
41+
"""
42+
with mock.patch.dict(os.environ, {"DJANGO_SETTINGS_MODULE": "scripts.django_tests_settings"}):
43+
# We need this to be able to do django import
44+
django.setup()
45+
46+
# A dict of class_name -> [subclasses names] for each Generic in the stubs.
47+
all_generic_classes: dict[str, list[str]] = {}
48+
49+
print(f"Searching for classes inheriting from Generic in: {STUBS_ROOT}")
50+
pyi_files = glob.glob("**/*.pyi", root_dir=STUBS_ROOT, recursive=True)
51+
for file_path in pyi_files:
52+
with open(os.path.join(STUBS_ROOT, file_path)) as f:
53+
source = f.read()
54+
55+
tree = ast.parse(source)
56+
generic_visitor = GenericInheritanceVisitor()
57+
generic_visitor.visit(tree)
58+
59+
# For each Generic in the stubs, import the associated module and capture every class in the MRO
60+
if generic_visitor.generic_classes:
61+
module_name = _get_module_from_pyi(file_path)
62+
django_module = importlib.import_module(module_name)
63+
all_generic_classes.update(
64+
{
65+
cls: [subcls.__name__ for subcls in getattr(django_module, cls).mro()[1:-1]]
66+
for cls in generic_visitor.generic_classes
67+
}
68+
)
69+
70+
print(f"Processed {len(pyi_files)} .pyi files.")
71+
print(f"Found {len(all_generic_classes)} unique classes inheriting from Generic in stubs")
72+
73+
# Class patched in `ext/django_stubs_ext/patch.py`
74+
import django_stubs_ext
75+
76+
patched_classes = {mp_generic.cls.__name__ for mp_generic in django_stubs_ext.patch._get_need_generic()}
77+
78+
# Pretty-print missing patch in `ext/django_stubs_ext/patch.py`
79+
errors = []
80+
for cls_name, subcls_names in all_generic_classes.items():
81+
if not any(name in patched_classes for name in [*subcls_names, cls_name]):
82+
bases = f"({', '.join(subcls_names)})" if subcls_names else ""
83+
errors.append(f"{cls_name}{bases} is not patched in `ext/django_stubs_ext/patch.py`")
84+
85+
assert not errors, "\n".join(errors)
86+
87+
88+
def _get_module_from_pyi(pyi_path: str) -> str:
89+
py_module = "django." + pyi_path.replace(".pyi", "").replace("/", ".")
90+
return py_module.removesuffix(".__init__")

0 commit comments

Comments
 (0)