Skip to content
Closed
24 changes: 20 additions & 4 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Set

from mypy.nodes import (
ARG_POS, MDEF, Argument, Block, CallExpr, Expression, SYMBOL_FUNCBASE_TYPES,
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict,
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface, CheckerPluginInterface
from mypy.semanal import set_callable_name
from mypy.types import (
CallableType, Overloaded, Type, TypeVarDef, deserialize_type, get_proper_type,
)
TypedDictType, Instance, TPDICT_FB_NAMES)
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API
Expand Down Expand Up @@ -134,8 +134,24 @@ def add_method(


def deserialize_and_fixup_type(
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
) -> Type:
typ = deserialize_type(data)
typ.accept(TypeFixer(api.modules, allow_missing=False))
return typ


def get_anonymous_typeddict_type(api: CheckerPluginInterface) -> Instance:
for type_fullname in TPDICT_FB_NAMES:
try:
anonymous_typeddict_type = api.named_generic_type(type_fullname, [])
if anonymous_typeddict_type is not None:
return anonymous_typeddict_type
except KeyError:
continue
raise RuntimeError("No TypedDict fallback type found")


def make_anonymous_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, Type]',
required_keys: Set[str]) -> TypedDictType:
return TypedDictType(fields, required_keys=required_keys, fallback=get_anonymous_typeddict_type(api))
98 changes: 93 additions & 5 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
"""Plugin that provides support for dataclasses."""

from typing import Dict, List, Set, Tuple, Optional
from collections import OrderedDict
from typing import Dict, List, Set, Tuple, Optional, FrozenSet, Callable

from typing_extensions import Final

from mypy.maptype import map_instance_to_supertype
from mypy.nodes import (
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
Context, Expression, JsonDict, NameExpr, RefExpr,
SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import ClassDefContext, FunctionContext, CheckerPluginInterface
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import add_method, _get_decorator_bool_argument, make_anonymous_typeddict
from mypy.plugins.common import (
add_method, _get_decorator_bool_argument, deserialize_and_fixup_type,
deserialize_and_fixup_type,
)
from mypy.types import Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type
from mypy.server.trigger import make_wildcard_trigger
from mypy.typeops import tuple_fallback
from mypy.types import Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type, Type, TupleType, UnionType, \
AnyType, TypeOfAny

# The set of decorators that generate dataclasses.
dataclass_makers = {
Expand All @@ -24,6 +31,10 @@
SELF_TVAR_NAME = '_DT' # type: Final


def is_type_dataclass(info: TypeInfo) -> bool:
return 'dataclass' in info.metadata


class DataclassAttribute:
def __init__(
self,
Expand Down Expand Up @@ -297,7 +308,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
for info in cls.info.mro[1:-1]:
if 'dataclass' not in info.metadata:
if not is_type_dataclass(info):
continue

super_attrs = []
Expand Down Expand Up @@ -386,3 +397,80 @@ def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
args[name] = arg
return True, args
return False, {}


def asdict_callback(ctx: FunctionContext) -> Type:
positional_arg_types = ctx.arg_types[0]

if positional_arg_types:
if len(ctx.arg_types) == 2:
# We can't infer a more precise for calls where dict_factory is set.
# At least for now, typeshed stubs for asdict don't allow you to pass in `dict` as dict_factory,
# so we can't special-case that.
return ctx.default_return_type
dataclass_instance = positional_arg_types[0]
if isinstance(dataclass_instance, Instance):
info = dataclass_instance.type
if not is_type_dataclass(info):
ctx.api.fail('asdict() should be called on dataclass instances', dataclass_instance)
return _type_asdict(ctx.api, ctx.context, dataclass_instance)
return ctx.default_return_type


def _transform_type_args(*, typ: Instance, transform: Callable[[Instance], Type]) -> \
List[Type]:
"""For each type arg used in the Instance, call transform function on it if the arg is an Instance."""
return [transform(arg) if isinstance(arg, Instance) else arg for arg in typ.args]


def _type_asdict(api: CheckerPluginInterface, context: Context, typ: Type) -> Type:
"""Convert dataclasses into TypedDicts, recursively looking into built-in containers.

It will look for dataclasses inside of tuples, lists, and dicts and convert them to TypedDicts.
"""

def _type_asdict_inner(typ: Type, seen_dataclasses: FrozenSet[str]) -> Type:
if isinstance(typ, UnionType):
return UnionType([_type_asdict_inner(item, seen_dataclasses) for item in typ.items])
if isinstance(typ, Instance):
info = typ.type
if is_type_dataclass(info):
if info.fullname in seen_dataclasses:
api.fail("Recursive types are not supported in call to asdict, so falling back to Dict[str, Any]",
context)
# Note: Would be nicer to fallback to default_return_type, but that is Any (due to overloads?)
return api.named_generic_type('builtins.dict', [api.named_generic_type('builtins.str', []),
AnyType(TypeOfAny.implementation_artifact)])
seen_dataclasses |= {info.fullname}
attrs = info.metadata['dataclass']['attributes']
fields = OrderedDict() # type: OrderedDict[str, Type]
for data in attrs:
# TODO: DataclassAttribute.deserialize takes SemanticAnalyzerPluginInterface but we have
# CheckerPluginInterface here.
attr = DataclassAttribute.deserialize(info, data, api)
sym_node = info.names[attr.name]
typ = sym_node.type
assert typ is not None
fields[attr.name] = _type_asdict_inner(typ, seen_dataclasses)
return make_anonymous_typeddict(api, fields=fields, required_keys=set(fields.keys()))
elif info.has_base('builtins.list'):
supertype_instance = map_instance_to_supertype(typ, api.named_generic_type('builtins.list', []).type)
new_args = _transform_type_args(
typ=supertype_instance,
transform=lambda arg: _type_asdict_inner(arg, seen_dataclasses)
)
return api.named_generic_type('builtins.list', new_args)
elif info.has_base('builtins.dict'):
supertype_instance = map_instance_to_supertype(typ, api.named_generic_type('builtins.dict', []).type)
new_args = _transform_type_args(
typ=supertype_instance,
transform=lambda arg: _type_asdict_inner(arg, seen_dataclasses)
)
return api.named_generic_type('builtins.dict', new_args)
elif isinstance(typ, TupleType):
# TODO: Support subclasses/namedtuples properly
return TupleType([_type_asdict_inner(item, seen_dataclasses) for item in typ.items],
tuple_fallback(typ), implicit=typ.implicit)
return typ

return _type_asdict_inner(typ, seen_dataclasses=frozenset())
3 changes: 3 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ class DefaultPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
from mypy.plugins import ctypes
from mypy.plugins import dataclasses

if fullname == 'contextlib.contextmanager':
return contextmanager_callback
elif fullname == 'builtins.open' and self.python_version[0] == 3:
return open_callback
elif fullname == 'ctypes.Array':
return ctypes.array_constructor_callback
elif fullname == 'dataclasses.asdict':
return dataclasses.asdict_callback
return None

def get_method_signature_hook(self, fullname: str
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeshed
Submodule typeshed updated 783 files
Loading