|
1 | 1 | import builtins
|
| 2 | +import logging |
2 | 3 | from collections.abc import Iterable
|
3 | 4 | from typing import Any, Generic, TypeVar
|
4 | 5 |
|
|
8 | 9 | from django.contrib.messages.views import SuccessMessageMixin
|
9 | 10 | from django.contrib.sitemaps import Sitemap
|
10 | 11 | from django.contrib.syndication.views import Feed
|
| 12 | +from django.core.exceptions import AppRegistryNotReady, ImproperlyConfigured |
11 | 13 | from django.core.files.utils import FileProxyMixin
|
12 | 14 | from django.core.paginator import Paginator
|
13 | 15 | from django.db.models.expressions import ExpressionWrapper
|
14 | 16 | from django.db.models.fields import Field
|
15 | 17 | from django.db.models.fields.related import ForeignKey
|
16 |
| -from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor |
| 18 | +from django.db.models.fields.related_descriptors import ( |
| 19 | + ForwardManyToOneDescriptor, |
| 20 | + ReverseManyToOneDescriptor, |
| 21 | + ReverseOneToOneDescriptor, |
| 22 | +) |
17 | 23 | from django.db.models.lookups import Lookup
|
18 | 24 | from django.db.models.manager import BaseManager
|
19 |
| -from django.db.models.query import ModelIterable, QuerySet, RawQuerySet |
| 25 | +from django.db.models.options import Options |
| 26 | +from django.db.models.query import BaseIterable, ModelIterable, QuerySet, RawQuerySet |
20 | 27 | from django.forms.formsets import BaseFormSet
|
21 |
| -from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField |
22 |
| -from django.utils.connection import BaseConnectionHandler |
| 28 | +from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField, ModelFormOptions |
| 29 | +from django.utils.connection import BaseConnectionHandler, ConnectionProxy |
| 30 | +from django.utils.functional import classproperty |
23 | 31 | from django.views.generic.detail import SingleObjectMixin
|
24 | 32 | from django.views.generic.edit import DeletionMixin, FormMixin
|
25 | 33 | from django.views.generic.list import MultipleObjectMixin
|
26 | 34 |
|
27 | 35 | __all__ = ["monkeypatch"]
|
28 | 36 |
|
| 37 | +logger = logging.getLogger(__name__) |
| 38 | + |
29 | 39 | _T = TypeVar("_T")
|
30 | 40 | _VersionSpec = tuple[int, int]
|
31 | 41 |
|
@@ -81,16 +91,41 @@ def __repr__(self) -> str:
|
81 | 91 | # These types do have native `__class_getitem__` method since django 4.1:
|
82 | 92 | MPGeneric(ForeignKey, (4, 1)),
|
83 | 93 | MPGeneric(RawQuerySet),
|
| 94 | + MPGeneric(classproperty), |
| 95 | + MPGeneric(ConnectionProxy), |
| 96 | + MPGeneric(ModelFormOptions), |
| 97 | + MPGeneric(Options), |
| 98 | + MPGeneric(BaseIterable), |
| 99 | + MPGeneric(ForwardManyToOneDescriptor), |
| 100 | + MPGeneric(ReverseOneToOneDescriptor), |
84 | 101 | ]
|
85 | 102 |
|
86 | 103 |
|
| 104 | +def _get_need_generic() -> list[MPGeneric[Any]]: |
| 105 | + try: |
| 106 | + if VERSION >= (5, 1): |
| 107 | + from django.contrib.auth.forms import SetPasswordMixin, SetUnusablePasswordMixin |
| 108 | + |
| 109 | + return [MPGeneric(SetPasswordMixin), MPGeneric(SetUnusablePasswordMixin), *_need_generic] |
| 110 | + else: |
| 111 | + from django.contrib.auth.forms import AdminPasswordChangeForm, SetPasswordForm |
| 112 | + |
| 113 | + return [MPGeneric(SetPasswordForm), MPGeneric(AdminPasswordChangeForm), *_need_generic] |
| 114 | + |
| 115 | + except (ImproperlyConfigured, AppRegistryNotReady): |
| 116 | + # We cannot patch symbols in `django.contrib.auth.forms` if the `monkeypatch()` call |
| 117 | + # is in the settings file because django is not initialized yet. |
| 118 | + # To solve this, you'll have to call `monkeypatch()` again later, in an `AppConfig.ready` for ex. |
| 119 | + # See https://docs.djangoproject.com/en/5.2/ref/applications/#django.apps.AppConfig.ready |
| 120 | + return _need_generic |
| 121 | + |
| 122 | + |
87 | 123 | def monkeypatch(extra_classes: Iterable[type] | None = None, include_builtins: bool = True) -> None:
|
88 | 124 | """Monkey patch django as necessary to work properly with mypy."""
|
89 |
| - |
90 | 125 | # Add the __class_getitem__ dunder.
|
91 | 126 | suited_for_this_version = filter(
|
92 | 127 | lambda spec: spec.version is None or VERSION[:2] <= spec.version,
|
93 |
| - _need_generic, |
| 128 | + _get_need_generic(), |
94 | 129 | )
|
95 | 130 | for el in suited_for_this_version:
|
96 | 131 | el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls)
|
|
0 commit comments