Skip to content

Commit 8da8ab4

Browse files
authored
Fix/673/from queryset then custom qs method (#680)
* Fix `MyModel.objects.filter(...).my_method()` * Fix regression: `MyModel.objects.filter(...).my_method()` no longer worked when using from_queryset This also fixes the self-type of the copied-over methods of the manager generated by from_queryset. Previously it was not parameterized by the model class, but used Any. The handling of unbound types is not tested here as I have not been able to find a way to create a test case for it. It has been manually tested against an internal codebase. * Remove unneeded defer.
1 parent 08a662e commit 8da8ab4

File tree

4 files changed

+129
-33
lines changed

4 files changed

+129
-33
lines changed

mypy_django_plugin/lib/helpers.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
MemberExpr,
1919
MypyFile,
2020
NameExpr,
21-
PlaceholderNode,
2221
StrExpr,
2322
SymbolNode,
2423
SymbolTable,
@@ -33,12 +32,13 @@
3332
DynamicClassDefContext,
3433
FunctionContext,
3534
MethodContext,
35+
SemanticAnalyzerPluginInterface,
3636
)
3737
from mypy.plugins.common import add_method
3838
from mypy.semanal import SemanticAnalyzer
3939
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
4040
from mypy.types import Type as MypyType
41-
from mypy.types import TypedDictType, TypeOfAny, UnionType
41+
from mypy.types import TypedDictType, TypeOfAny, UnboundType, UnionType
4242

4343
from mypy_django_plugin.lib import fullnames
4444
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
@@ -355,8 +355,26 @@ def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument],
355355
return prepared_arguments, return_type
356356

357357

358+
def bind_or_analyze_type(t: MypyType, api: SemanticAnalyzer, module_name: Optional[str] = None) -> Optional[MypyType]:
359+
"""Analyze a type. If an unbound type, try to look it up in the given module name.
360+
361+
That should hopefully give a bound type."""
362+
if isinstance(t, UnboundType) and module_name is not None:
363+
node = api.lookup_fully_qualified_or_none(module_name + "." + t.name)
364+
if node is None:
365+
return None
366+
return node.type
367+
else:
368+
return api.anal_type(t)
369+
370+
358371
def copy_method_to_another_class(
359-
ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef
372+
ctx: ClassDefContext,
373+
self_type: Instance,
374+
new_method_name: str,
375+
method_node: FuncDef,
376+
return_type: Optional[MypyType] = None,
377+
original_module_name: Optional[str] = None,
360378
) -> None:
361379
semanal_api = get_semanal_api(ctx)
362380
if method_node.type is None:
@@ -374,23 +392,20 @@ def copy_method_to_another_class(
374392
semanal_api.defer()
375393
return
376394

377-
arguments = []
378-
bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True)
379-
380-
assert bound_return_type is not None
381-
382-
if isinstance(bound_return_type, PlaceholderNode):
395+
if return_type is None:
396+
return_type = bind_or_analyze_type(method_type.ret_type, semanal_api, original_module_name)
397+
if return_type is None:
383398
return
384-
385399
try:
386400
original_arguments = method_node.arguments[1:]
387401
except AttributeError:
388402
original_arguments = []
389403

404+
arguments = []
390405
for arg_name, arg_type, original_argument in zip(
391406
method_type.arg_names[1:], method_type.arg_types[1:], original_arguments
392407
):
393-
bound_arg_type = semanal_api.anal_type(arg_type)
408+
bound_arg_type = bind_or_analyze_type(arg_type, semanal_api, original_module_name)
394409
if bound_arg_type is None:
395410
return
396411

@@ -406,4 +421,10 @@ def copy_method_to_another_class(
406421
argument.set_line(original_argument)
407422
arguments.append(argument)
408423

409-
add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type)
424+
add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
425+
426+
427+
def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
428+
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
429+
if sym is not None and isinstance(sym.node, TypeInfo):
430+
get_django_metadata(sym.node)["manager_bases"][fullname] = 1

mypy_django_plugin/main.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,8 @@ def transform_form_class(ctx: ClassDefContext) -> None:
5353
forms.make_meta_nested_class_inherit_from_any(ctx)
5454

5555

56-
def add_new_manager_base(ctx: ClassDefContext) -> None:
57-
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
58-
if sym is not None and isinstance(sym.node, TypeInfo):
59-
helpers.get_django_metadata(sym.node)["manager_bases"][ctx.cls.fullname] = 1
56+
def add_new_manager_base_hook(ctx: ClassDefContext) -> None:
57+
helpers.add_new_manager_base(ctx.api, ctx.cls.fullname)
6058

6159

6260
def extract_django_settings_module(config_file_path: Optional[str]) -> str:
@@ -235,7 +233,12 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
235233
related_model_module = related_model_cls.__module__
236234
if related_model_module != file.fullname:
237235
deps.add(self._new_dependency(related_model_module))
238-
return list(deps) + [self._new_dependency("django_stubs_ext")] # for annotate
236+
return list(deps) + [
237+
# for QuerySet.annotate
238+
self._new_dependency("django_stubs_ext"),
239+
# For BaseManager.from_queryset
240+
self._new_dependency("django.db.models.query"),
241+
]
239242

240243
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]:
241244
if fullname == "django.contrib.auth.get_user_model":
@@ -305,7 +308,7 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte
305308
return partial(transform_model_class, django_context=self.django_context)
306309

307310
if fullname in self._get_current_manager_bases():
308-
return add_new_manager_base
311+
return add_new_manager_base_hook
309312

310313
if fullname in self._get_current_form_bases():
311314
return transform_form_class

mypy_django_plugin/transformers/managers.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from mypy.checker import fill_typevars
12
from mypy.nodes import GDEF, Decorator, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo
23
from mypy.plugin import ClassDefContext, DynamicClassDefContext
3-
from mypy.types import AnyType, Instance, TypeOfAny
4+
from mypy.types import CallableType, Instance, TypeVarType, UnboundType, get_proper_type
45

56
from mypy_django_plugin.lib import fullnames, helpers
67

@@ -29,15 +30,11 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
2930
# But it should be analyzed again, so this isn't a problem.
3031
return
3132

33+
base_manager_instance = fill_typevars(base_manager_info)
34+
assert isinstance(base_manager_instance, Instance)
3235
new_manager_info = semanal_api.basic_new_typeinfo(
33-
ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)]), line=ctx.call.line
36+
ctx.name, basetype_or_fallback=base_manager_instance, line=ctx.call.line
3437
)
35-
new_manager_info.line = ctx.call.line
36-
new_manager_info.defn.line = ctx.call.line
37-
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
38-
39-
current_module = semanal_api.cur_mod_node
40-
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
4138

4239
sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
4340
assert sym is not None
@@ -52,6 +49,15 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
5249
derived_queryset_info = sym.node
5350
assert isinstance(derived_queryset_info, TypeInfo)
5451

52+
new_manager_info.line = ctx.call.line
53+
new_manager_info.type_vars = base_manager_info.type_vars
54+
new_manager_info.defn.type_vars = base_manager_info.defn.type_vars
55+
new_manager_info.defn.line = ctx.call.line
56+
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
57+
58+
current_module = semanal_api.cur_mod_node
59+
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
60+
5561
if len(ctx.call.args) > 1:
5662
expr = ctx.call.args[1]
5763
assert isinstance(expr, StrExpr)
@@ -64,11 +70,19 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
6470
base_manager_info.metadata["from_queryset_managers"] = {}
6571
base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname
6672

73+
# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
74+
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)
75+
6776
class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api)
68-
self_type = Instance(new_manager_info, [])
77+
self_type = fill_typevars(new_manager_info)
78+
assert isinstance(self_type, Instance)
79+
queryset_method_names = []
80+
6981
# we need to copy all methods in MRO before django.db.models.query.QuerySet
7082
for class_mro_info in derived_queryset_info.mro:
7183
if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
84+
for name, sym in class_mro_info.names.items():
85+
queryset_method_names.append(name)
7286
break
7387
for name, sym in class_mro_info.names.items():
7488
if isinstance(sym.node, FuncDef):
@@ -80,3 +94,59 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
8094
helpers.copy_method_to_another_class(
8195
class_def_context, self_type, new_method_name=name, method_node=func_node
8296
)
97+
98+
# Gather names of all BaseManager methods
99+
manager_method_names = []
100+
for manager_mro_info in new_manager_info.mro:
101+
if manager_mro_info.fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME:
102+
for name, sym in manager_mro_info.names.items():
103+
manager_method_names.append(name)
104+
105+
# Copy/alter all methods in common between BaseManager/QuerySet over to the new manager if their return type is
106+
# the QuerySet's self-type. Alter the return type to be the custom queryset, parameterized by the manager's model
107+
# type variable.
108+
for class_mro_info in derived_queryset_info.mro:
109+
if class_mro_info.fullname != fullnames.QUERYSET_CLASS_FULLNAME:
110+
continue
111+
for name, sym in class_mro_info.names.items():
112+
if name not in manager_method_names:
113+
continue
114+
115+
if isinstance(sym.node, FuncDef):
116+
func_node = sym.node
117+
elif isinstance(sym.node, Decorator):
118+
func_node = sym.node.func
119+
else:
120+
continue
121+
122+
method_type = func_node.type
123+
if not isinstance(method_type, CallableType):
124+
if not semanal_api.final_iteration:
125+
semanal_api.defer()
126+
return None
127+
original_return_type = method_type.ret_type
128+
if original_return_type is None:
129+
continue
130+
131+
# Skip any method that doesn't return _QS
132+
original_return_type = get_proper_type(original_return_type)
133+
if isinstance(original_return_type, UnboundType):
134+
if original_return_type.name != "_QS":
135+
continue
136+
elif isinstance(original_return_type, TypeVarType):
137+
if original_return_type.name != "_QS":
138+
continue
139+
else:
140+
continue
141+
142+
# Return the custom queryset parameterized by the manager's type vars
143+
return_type = Instance(derived_queryset_info, self_type.args)
144+
145+
helpers.copy_method_to_another_class(
146+
class_def_context,
147+
self_type,
148+
new_method_name=name,
149+
method_node=func_node,
150+
return_type=return_type,
151+
original_module_name=class_mro_info.module_name,
152+
)

tests/typecheck/managers/querysets/test_from_queryset.yml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
- case: from_queryset_with_base_manager
22
main: |
33
from myapp.models import MyModel
4-
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
4+
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
55
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
66
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
7+
reveal_type(MyModel.objects.filter(id=1).queryset_method()) # N: Revealed type is "builtins.str"
8+
reveal_type(MyModel.objects.filter(id=1)) # N: Revealed type is "myapp.models.ModelQuerySet[myapp.models.MyModel*]"
79
installed_apps:
810
- myapp
911
files:
@@ -23,7 +25,7 @@
2325
- case: from_queryset_with_manager
2426
main: |
2527
from myapp.models import MyModel
26-
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
28+
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
2729
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
2830
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
2931
installed_apps:
@@ -97,7 +99,7 @@
9799
- case: from_queryset_with_class_inheritance
98100
main: |
99101
from myapp.models import MyModel
100-
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
102+
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
101103
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
102104
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
103105
installed_apps:
@@ -121,7 +123,7 @@
121123
- case: from_queryset_with_manager_in_another_directory_and_imports
122124
main: |
123125
from myapp.models import MyModel
124-
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
126+
reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
125127
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
126128
reveal_type(MyModel().objects.queryset_method) # N: Revealed type is "def (param: Union[builtins.str, None] =) -> Union[builtins.str, None]"
127129
reveal_type(MyModel().objects.queryset_method('str')) # N: Revealed type is "Union[builtins.str, None]"
@@ -151,7 +153,7 @@
151153
disable_cache: true
152154
main: |
153155
from myapp.models import MyModel
154-
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
156+
reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
155157
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
156158
reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is "def (param: Union[builtins.int, builtins.str]) -> <nothing>"
157159
reveal_type(MyModel().objects.base_queryset_method(2)) # N: Revealed type is "<nothing>"
@@ -183,7 +185,7 @@
183185
- case: from_queryset_with_decorated_queryset_methods
184186
main: |
185187
from myapp.models import MyModel
186-
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
188+
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
187189
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
188190
installed_apps:
189191
- myapp

0 commit comments

Comments
 (0)