Skip to content

Commit 29889c8

Browse files
Michael0x2ailevkivskyi
authored andcommitted
Make overloads support classmethod and staticmethod (#5224)
* Move 'is_class' and 'is_static' into FuncBase This commit moves the `is_class` and `is_static` fields into FuncBase. It also cleans up the list of flags so they don't repeat the 'is_property' entry, which is now present in `FUNCBASE_FLAGS`. The high-level plan is to modify the `is_class` and `is_static` fields in OverloadedFuncDef for use later in mypy. * Make semantic analysis phase record class/static methods with overloads This commit adjusts the semantic analysis phase to detect and record when an overload appears to be a classmethod or staticmethod. * Broaden class/static method checks to catch overloads This commit modifies mypy to use the `is_static` and `is_class` fields of OverloadedFuncDef as appropriate. I found the code snippets to modify by asking PyCharm for all instances of code using those two fields and modified the surrounding code as appropriate. * Add support for overloaded classmethods in attrs/dataclasses Both the attrs and dataclasses plugins manually patch classmethods -- we do the same for overloads. * Respond to code review This commit: 1. Updates astdiff.py and adds a case to one of the fine-grained dependency test files. 2. Adds some helper methods to FunctionLike. 3. Performs a few misc cleanups. * Respond to code review; add tests for self types
1 parent e66d53b commit 29889c8

16 files changed

+677
-30
lines changed

mypy/checker.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,12 +1289,10 @@ def check_override(self, override: FunctionLike, original: FunctionLike,
12891289
# this could be unsafe with reverse operator methods.
12901290
fail = True
12911291

1292-
if isinstance(original, CallableType) and isinstance(override, CallableType):
1293-
if (isinstance(original.definition, FuncItem) and
1294-
isinstance(override.definition, FuncItem)):
1295-
if ((original.definition.is_static or original.definition.is_class) and
1296-
not (override.definition.is_static or override.definition.is_class)):
1297-
fail = True
1292+
if isinstance(original, FunctionLike) and isinstance(override, FunctionLike):
1293+
if ((original.is_classmethod() or original.is_staticmethod()) and
1294+
not (override.is_classmethod() or override.is_staticmethod())):
1295+
fail = True
12981296

12991297
if fail:
13001298
emitted_msg = False
@@ -3911,8 +3909,6 @@ def is_untyped_decorator(typ: Optional[Type]) -> bool:
39113909
def is_static(func: Union[FuncBase, Decorator]) -> bool:
39123910
if isinstance(func, Decorator):
39133911
return is_static(func.func)
3914-
elif isinstance(func, OverloadedFuncDef):
3915-
return any(is_static(item) for item in func.items)
3916-
elif isinstance(func, FuncItem):
3912+
elif isinstance(func, FuncBase):
39173913
return func.is_static
3918-
return False
3914+
assert False, "Unexpected func type: {}".format(type(func))

mypy/checkmember.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ def analyze_class_attribute_access(itype: Instance,
448448
return handle_partial_attribute_type(t, is_lvalue, msg, symnode)
449449
if not is_method and (isinstance(t, TypeVarType) or get_type_vars(t)):
450450
msg.fail(messages.GENERIC_INSTANCE_VAR_CLASS_ACCESS, context)
451-
is_classmethod = is_decorated and cast(Decorator, node.node).func.is_class
451+
is_classmethod = ((is_decorated and cast(Decorator, node.node).func.is_class)
452+
or (isinstance(node.node, FuncBase) and node.node.is_class))
452453
return add_class_tvars(t, itype, is_classmethod, builtin_type, original_type)
453454
elif isinstance(node.node, Var):
454455
not_ready_callback(name, context)

mypy/messages.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
TypeInfo, Context, MypyFile, op_methods, FuncDef, reverse_type_aliases,
2828
ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
2929
ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode,
30-
CallExpr, Expression
30+
CallExpr, Expression, OverloadedFuncDef,
3131
)
3232

3333
# Constants that represent simple type checker error message, i.e. messages
@@ -942,6 +942,12 @@ def incompatible_typevar_value(self,
942942
self.format(typ)),
943943
context)
944944

945+
def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None:
946+
self.fail(
947+
'Overload does not consistently use the "@{}" '.format(decorator)
948+
+ 'decorator on all function signatures.',
949+
context)
950+
945951
def overloaded_signatures_overlap(self, index1: int, index2: int, context: Context) -> None:
946952
self.fail('Overloaded function signatures {} and {} overlap with '
947953
'incompatible return types'.format(index1, index2), context)

mypy/nodes.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -370,13 +370,20 @@ def __str__(self) -> str:
370370
return 'ImportedName(%s)' % self.target_fullname
371371

372372

373+
FUNCBASE_FLAGS = [
374+
'is_property', 'is_class', 'is_static',
375+
]
376+
377+
373378
class FuncBase(Node):
374379
"""Abstract base class for function-like nodes"""
375380

376381
__slots__ = ('type',
377382
'unanalyzed_type',
378383
'info',
379384
'is_property',
385+
'is_class', # Uses "@classmethod"
386+
'is_static', # USes "@staticmethod"
380387
'_fullname',
381388
)
382389

@@ -391,6 +398,8 @@ def __init__(self) -> None:
391398
# TODO: Type should be Optional[TypeInfo]
392399
self.info = cast(TypeInfo, None)
393400
self.is_property = False
401+
self.is_class = False
402+
self.is_static = False
394403
# Name with module prefix
395404
# TODO: Type should be Optional[str]
396405
self._fullname = cast(str, None)
@@ -436,8 +445,8 @@ def serialize(self) -> JsonDict:
436445
'items': [i.serialize() for i in self.items],
437446
'type': None if self.type is None else self.type.serialize(),
438447
'fullname': self._fullname,
439-
'is_property': self.is_property,
440-
'impl': None if self.impl is None else self.impl.serialize()
448+
'impl': None if self.impl is None else self.impl.serialize(),
449+
'flags': get_flags(self, FUNCBASE_FLAGS),
441450
}
442451

443452
@classmethod
@@ -451,7 +460,7 @@ def deserialize(cls, data: JsonDict) -> 'OverloadedFuncDef':
451460
if data.get('type') is not None:
452461
res.type = mypy.types.deserialize_type(data['type'])
453462
res._fullname = data['fullname']
454-
res.is_property = data['is_property']
463+
set_flags(res, data['flags'])
455464
# NOTE: res.info will be set in the fixup phase.
456465
return res
457466

@@ -481,9 +490,9 @@ def set_line(self, target: Union[Context, int], column: Optional[int] = None) ->
481490
self.variable.set_line(self.line, self.column)
482491

483492

484-
FUNCITEM_FLAGS = [
493+
FUNCITEM_FLAGS = FUNCBASE_FLAGS + [
485494
'is_overload', 'is_generator', 'is_coroutine', 'is_async_generator',
486-
'is_awaitable_coroutine', 'is_static', 'is_class',
495+
'is_awaitable_coroutine',
487496
]
488497

489498

@@ -503,8 +512,6 @@ class FuncItem(FuncBase):
503512
'is_coroutine', # Defined using 'async def' syntax?
504513
'is_async_generator', # Is an async def generator?
505514
'is_awaitable_coroutine', # Decorated with '@{typing,asyncio}.coroutine'?
506-
'is_static', # Uses @staticmethod?
507-
'is_class', # Uses @classmethod?
508515
'expanded', # Variants of function with type variables with values expanded
509516
)
510517

@@ -525,8 +532,6 @@ def __init__(self,
525532
self.is_coroutine = False
526533
self.is_async_generator = False
527534
self.is_awaitable_coroutine = False
528-
self.is_static = False
529-
self.is_class = False
530535
self.expanded = [] # type: List[FuncItem]
531536

532537
self.min_args = 0
@@ -547,7 +552,7 @@ def is_dynamic(self) -> bool:
547552

548553

549554
FUNCDEF_FLAGS = FUNCITEM_FLAGS + [
550-
'is_decorated', 'is_conditional', 'is_abstract', 'is_property',
555+
'is_decorated', 'is_conditional', 'is_abstract',
551556
]
552557

553558

@@ -561,7 +566,6 @@ class FuncDef(FuncItem, SymbolNode, Statement):
561566
'is_decorated',
562567
'is_conditional',
563568
'is_abstract',
564-
'is_property',
565569
'original_def',
566570
)
567571

@@ -575,7 +579,6 @@ def __init__(self,
575579
self.is_decorated = False
576580
self.is_conditional = False # Defined conditionally (within block)?
577581
self.is_abstract = False
578-
self.is_property = False
579582
# Original conditional definition
580583
self.original_def = None # type: Union[None, FuncDef, Var, Decorator]
581584

mypy/plugins/attrs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,16 @@ def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute],
470470
func_type = stmt.func.type
471471
if isinstance(func_type, CallableType):
472472
func_type.arg_types[0] = ctx.api.class_type(ctx.cls.info)
473+
if isinstance(stmt, OverloadedFuncDef) and stmt.is_class:
474+
func_type = stmt.type
475+
if isinstance(func_type, Overloaded):
476+
class_type = ctx.api.class_type(ctx.cls.info)
477+
for item in func_type.items():
478+
item.arg_types[0] = class_type
479+
if stmt.impl is not None:
480+
assert isinstance(stmt.impl, Decorator)
481+
if isinstance(stmt.impl.func.type, CallableType):
482+
stmt.impl.func.type.arg_types[0] = class_type
473483

474484

475485
class MethodAdder:

mypy/plugins/dataclasses.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from mypy.nodes import (
55
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
66
Context, Decorator, Expression, FuncDef, JsonDict, NameExpr,
7-
SymbolTableNode, TempNode, TypeInfo, Var,
7+
OverloadedFuncDef, SymbolTableNode, TempNode, TypeInfo, Var,
88
)
99
from mypy.plugin import ClassDefContext
1010
from mypy.plugins.common import _add_method, _get_decorator_bool_argument
1111
from mypy.types import (
12-
CallableType, Instance, NoneTyp, TypeVarDef, TypeVarType,
12+
CallableType, Instance, NoneTyp, Overloaded, TypeVarDef, TypeVarType,
1313
)
1414

1515
# The set of decorators that generate dataclasses.
@@ -95,6 +95,16 @@ def transform(self) -> None:
9595
func_type = stmt.func.type
9696
if isinstance(func_type, CallableType):
9797
func_type.arg_types[0] = self._ctx.api.class_type(self._ctx.cls.info)
98+
if isinstance(stmt, OverloadedFuncDef) and stmt.is_class:
99+
func_type = stmt.type
100+
if isinstance(func_type, Overloaded):
101+
class_type = ctx.api.class_type(ctx.cls.info)
102+
for item in func_type.items():
103+
item.arg_types[0] = class_type
104+
if stmt.impl is not None:
105+
assert isinstance(stmt.impl, Decorator)
106+
if isinstance(stmt.impl.func.type, CallableType):
107+
stmt.impl.func.type.arg_types[0] = class_type
98108

99109
# Add an eq method, but only if the class doesn't already have one.
100110
if decorator_arguments['eq'] and info.get('__eq__') is None:

mypy/semanal.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,37 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
587587
# redefinitions already.
588588
return
589589

590+
# We know this is an overload def -- let's handle classmethod and staticmethod
591+
class_status = []
592+
static_status = []
593+
for item in defn.items:
594+
if isinstance(item, Decorator):
595+
inner = item.func
596+
elif isinstance(item, FuncDef):
597+
inner = item
598+
else:
599+
assert False, "The 'item' variable is an unexpected type: {}".format(type(item))
600+
class_status.append(inner.is_class)
601+
static_status.append(inner.is_static)
602+
603+
if defn.impl is not None:
604+
if isinstance(defn.impl, Decorator):
605+
inner = defn.impl.func
606+
elif isinstance(defn.impl, FuncDef):
607+
inner = defn.impl
608+
else:
609+
assert False, "Unexpected impl type: {}".format(type(defn.impl))
610+
class_status.append(inner.is_class)
611+
static_status.append(inner.is_static)
612+
613+
if len(set(class_status)) != 1:
614+
self.msg.overload_inconsistently_applies_decorator('classmethod', defn)
615+
elif len(set(static_status)) != 1:
616+
self.msg.overload_inconsistently_applies_decorator('staticmethod', defn)
617+
else:
618+
defn.is_class = class_status[0]
619+
defn.is_static = static_status[0]
620+
590621
if self.type and not self.is_func_scope():
591622
self.type.names[defn.name()] = SymbolTableNode(MDEF, defn,
592623
typ=defn.type)

mypy/server/astdiff.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
5454

5555
from mypy.nodes import (
5656
SymbolTable, TypeInfo, Var, SymbolNode, Decorator, TypeVarExpr,
57-
OverloadedFuncDef, FuncItem, MODULE_REF, TYPE_ALIAS, UNBOUND_IMPORTED, TVAR
57+
FuncBase, OverloadedFuncDef, FuncItem, MODULE_REF, TYPE_ALIAS, UNBOUND_IMPORTED, TVAR
5858
)
5959
from mypy.types import (
6060
Type, TypeVisitor, UnboundType, AnyType, NoneTyp, UninhabitedType,
@@ -167,13 +167,13 @@ def snapshot_definition(node: Optional[SymbolNode],
167167
The representation is nested tuples and dicts. Only externally
168168
visible attributes are included.
169169
"""
170-
if isinstance(node, (OverloadedFuncDef, FuncItem)):
170+
if isinstance(node, FuncBase):
171171
# TODO: info
172172
if node.type:
173173
signature = snapshot_type(node.type)
174174
else:
175175
signature = snapshot_untyped_signature(node)
176-
return ('Func', common, node.is_property, signature)
176+
return ('Func', common, node.is_property, node.is_class, node.is_static, signature)
177177
elif isinstance(node, Var):
178178
return ('Var', common, snapshot_optional_type(node.type))
179179
elif isinstance(node, Decorator):

mypy/strconv.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ def visit_overloaded_func_def(self, o: 'mypy.nodes.OverloadedFuncDef') -> str:
146146
a.insert(0, o.type)
147147
if o.impl:
148148
a.insert(0, o.impl)
149+
if o.is_static:
150+
a.insert(-1, 'Static')
151+
if o.is_class:
152+
a.insert(-1, 'Class')
149153
return self.dump(a, o)
150154

151155
def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> str:

mypy/treetransform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDe
154154
new._fullname = node._fullname
155155
new.type = self.optional_type(node.type)
156156
new.info = node.info
157+
new.is_static = node.is_static
158+
new.is_class = node.is_class
159+
new.is_property = node.is_property
157160
if node.impl:
158161
new.impl = cast(OverloadPart, node.impl.accept(self))
159162
return new

0 commit comments

Comments
 (0)