Skip to content

Commit fd1cc92

Browse files
committed
Implement support for returning TypedDict for dataclasses.asdict
Relates to python#5152
1 parent 52c0a63 commit fd1cc92

File tree

7 files changed

+461
-17
lines changed

7 files changed

+461
-17
lines changed

docs/source/additional_features.rst

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,22 @@ and :pep:`557`.
7272
Caveats/Known Issues
7373
====================
7474

75-
Some functions in the :py:mod:`dataclasses` module, such as :py:func:`~dataclasses.replace` and :py:func:`~dataclasses.asdict`,
75+
Some functions in the :py:mod:`dataclasses` module, such as :py:func:`~dataclasses.replace`,
7676
have imprecise (too permissive) types. This will be fixed in future releases.
7777

78+
Calls to :py:func:`~dataclasses.asdict` will return a ``TypedDict`` based on the original dataclass
79+
definition, transforming it recursively. There are, however, some limitations:
80+
81+
* Subclasses of ``List``, ``Dict``, and ``Tuple`` appearing within dataclasses are transformed into reparameterized
82+
versions of the respective base class, rather than a transformed version of the original subclass.
83+
84+
* Recursion (e.g. dataclasses which reference each other) is not supported and results in an error.
85+
86+
* ``NamedTuples`` appearing within dataclasses are transformed to ``Any``
87+
88+
* A more precise return type cannot be inferred for calls where ``dict_factory`` is set.
89+
90+
7891
Mypy does not yet recognize aliases of :py:func:`dataclasses.dataclass <dataclasses.dataclass>`, and will
7992
probably never recognize dynamically computed decorators. The following examples
8093
do **not** work:

mypy/plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class CheckerPluginInterface:
209209
docstrings in checker.py for more details.
210210
"""
211211

212+
modules = None # type: Dict[str, MypyFile]
212213
msg = None # type: MessageBuilder
213214
options = None # type: Options
214215
path = None # type: str

mypy/plugins/common.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
from typing import List, Optional, Union
1+
from collections import OrderedDict
2+
from typing import List, Optional, Union, Set
23

34
from mypy.nodes import (
45
ARG_POS, MDEF, Argument, Block, CallExpr, ClassDef, Expression, SYMBOL_FUNCBASE_TYPES,
56
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict,
67
)
7-
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
8+
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface, CheckerPluginInterface
89
from mypy.semanal import set_callable_name
910
from mypy.types import (
1011
CallableType, Overloaded, Type, TypeVarDef, deserialize_type, get_proper_type,
11-
)
12+
TypedDictType, Instance, TPDICT_FB_NAMES)
1213
from mypy.typevars import fill_typevars
1314
from mypy.util import get_unique_redefinition_name
1415
from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API
@@ -155,8 +156,26 @@ def add_method_to_class(
155156

156157

157158
def deserialize_and_fixup_type(
158-
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
159+
data: Union[str, JsonDict],
160+
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface]
159161
) -> Type:
160162
typ = deserialize_type(data)
161163
typ.accept(TypeFixer(api.modules, allow_missing=False))
162164
return typ
165+
166+
167+
def get_anonymous_typeddict_type(api: CheckerPluginInterface) -> Instance:
168+
for type_fullname in TPDICT_FB_NAMES:
169+
try:
170+
anonymous_typeddict_type = api.named_generic_type(type_fullname, [])
171+
if anonymous_typeddict_type is not None:
172+
return anonymous_typeddict_type
173+
except KeyError:
174+
continue
175+
raise RuntimeError("No TypedDict fallback type found")
176+
177+
178+
def make_anonymous_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, Type]',
179+
required_keys: Set[str]) -> TypedDictType:
180+
return TypedDictType(fields, required_keys=required_keys,
181+
fallback=get_anonymous_typeddict_type(api))

mypy/plugins/dataclasses.py

Lines changed: 116 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
"""Plugin that provides support for dataclasses."""
22

3-
from typing import Dict, List, Set, Tuple, Optional
3+
from collections import OrderedDict
4+
from typing import Dict, List, Set, Tuple, Optional, FrozenSet, Callable, Union
5+
46
from typing_extensions import Final
57

8+
from mypy.maptype import map_instance_to_supertype
69
from mypy.nodes import (
7-
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
8-
Context, Expression, JsonDict, NameExpr, RefExpr,
9-
SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode
10-
)
11-
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
12-
from mypy.plugins.common import (
13-
add_method, _get_decorator_bool_argument, deserialize_and_fixup_type,
10+
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr, Context,
11+
Expression, JsonDict, NameExpr, RefExpr, SymbolTableNode, TempNode,
12+
TypeInfo, Var, TypeVarExpr, PlaceholderNode
1413
)
15-
from mypy.types import Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type
14+
from mypy.plugin import ClassDefContext, FunctionContext, CheckerPluginInterface
15+
from mypy.plugin import SemanticAnalyzerPluginInterface
16+
from mypy.plugins.common import (add_method, _get_decorator_bool_argument,
17+
make_anonymous_typeddict, deserialize_and_fixup_type)
1618
from mypy.server.trigger import make_wildcard_trigger
19+
from mypy.types import (Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type, Type,
20+
TupleType, UnionType, AnyType, TypeOfAny)
1721

1822
# The set of decorators that generate dataclasses.
1923
dataclass_makers = {
@@ -24,6 +28,10 @@
2428
SELF_TVAR_NAME = '_DT' # type: Final
2529

2630

31+
def is_type_dataclass(info: TypeInfo) -> bool:
32+
return 'dataclass' in info.metadata
33+
34+
2735
class DataclassAttribute:
2836
def __init__(
2937
self,
@@ -68,7 +76,8 @@ def serialize(self) -> JsonDict:
6876

6977
@classmethod
7078
def deserialize(
71-
cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface
79+
cls, info: TypeInfo, data: JsonDict,
80+
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface]
7281
) -> 'DataclassAttribute':
7382
data = data.copy()
7483
typ = deserialize_and_fixup_type(data.pop('type'), api)
@@ -297,7 +306,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
297306
# we'll have unmodified attrs laying around.
298307
all_attrs = attrs.copy()
299308
for info in cls.info.mro[1:-1]:
300-
if 'dataclass' not in info.metadata:
309+
if not is_type_dataclass(info):
301310
continue
302311

303312
super_attrs = []
@@ -386,3 +395,99 @@ def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
386395
args[name] = arg
387396
return True, args
388397
return False, {}
398+
399+
400+
def asdict_callback(ctx: FunctionContext) -> Type:
401+
positional_arg_types = ctx.arg_types[0]
402+
403+
if positional_arg_types:
404+
if len(ctx.arg_types) == 2:
405+
# We can't infer a more precise for calls where dict_factory is set.
406+
# At least for now, typeshed stubs for asdict don't allow you to pass in `dict` as
407+
# dict_factory, so we can't special-case that.
408+
return ctx.default_return_type
409+
dataclass_instance = positional_arg_types[0]
410+
dataclass_instance = get_proper_type(dataclass_instance)
411+
if isinstance(dataclass_instance, Instance):
412+
info = dataclass_instance.type
413+
if not is_type_dataclass(info):
414+
ctx.api.fail('asdict() should be called on dataclass instances',
415+
dataclass_instance)
416+
return _asdictify(ctx.api, ctx.context, dataclass_instance)
417+
return ctx.default_return_type
418+
419+
420+
def _transform_type_args(*, typ: Instance, transform: Callable[[Instance], Type]) -> List[Type]:
421+
"""For each type arg used in the Instance, call transform function on it if the arg is an
422+
Instance."""
423+
new_args = []
424+
for arg in typ.args:
425+
proper_arg = get_proper_type(arg)
426+
if isinstance(proper_arg, Instance):
427+
new_args.append(transform(proper_arg))
428+
else:
429+
new_args.append(arg)
430+
return new_args
431+
432+
433+
def _asdictify(api: CheckerPluginInterface, context: Context, typ: Type) -> Type:
434+
"""Convert dataclasses into TypedDicts, recursively looking into built-in containers.
435+
436+
It will look for dataclasses inside of tuples, lists, and dicts and convert them to TypedDicts.
437+
"""
438+
439+
def _asdictify_inner(typ: Type, seen_dataclasses: FrozenSet[str]) -> Type:
440+
typ = get_proper_type(typ)
441+
if isinstance(typ, UnionType):
442+
return UnionType([_asdictify_inner(item, seen_dataclasses) for item in typ.items])
443+
if isinstance(typ, Instance):
444+
info = typ.type
445+
if is_type_dataclass(info):
446+
if info.fullname in seen_dataclasses:
447+
api.fail(
448+
"Recursive types are not supported in call to asdict, so falling back to "
449+
"Dict[str, Any]",
450+
context)
451+
# Note: Would be nicer to fallback to default_return_type, but that is Any
452+
# (due to overloads?)
453+
return api.named_generic_type('builtins.dict',
454+
[api.named_generic_type('builtins.str', []),
455+
AnyType(TypeOfAny.implementation_artifact)])
456+
seen_dataclasses |= {info.fullname}
457+
attrs = info.metadata['dataclass']['attributes']
458+
fields = OrderedDict() # type: OrderedDict[str, Type]
459+
for data in attrs:
460+
attr = DataclassAttribute.deserialize(info, data, api)
461+
sym_node = info.names[attr.name]
462+
attr_type = sym_node.type
463+
assert attr_type is not None
464+
fields[attr.name] = _asdictify_inner(attr_type, seen_dataclasses)
465+
return make_anonymous_typeddict(api, fields=fields,
466+
required_keys=set(fields.keys()))
467+
elif info.has_base('builtins.list'):
468+
supertype_instance = map_instance_to_supertype(typ, api.named_generic_type(
469+
'builtins.list', []).type)
470+
new_args = _transform_type_args(
471+
typ=supertype_instance,
472+
transform=lambda arg: _asdictify_inner(arg, seen_dataclasses))
473+
return api.named_generic_type('builtins.list', new_args)
474+
elif info.has_base('builtins.dict'):
475+
supertype_instance = map_instance_to_supertype(typ, api.named_generic_type(
476+
'builtins.dict', []).type)
477+
new_args = _transform_type_args(
478+
typ=supertype_instance,
479+
transform=lambda arg: _asdictify_inner(arg, seen_dataclasses))
480+
return api.named_generic_type('builtins.dict', new_args)
481+
elif isinstance(typ, TupleType):
482+
if typ.partial_fallback.type.is_named_tuple:
483+
# For namedtuples, return Any. To properly support transforming namedtuples,
484+
# we would have to generate a partial_fallback type for the TupleType and add it
485+
# to the symbol table. It's not currently possibl to do this via the
486+
# CheckerPluginInterface. Ideally it would use the same code as
487+
# NamedTupleAnalyzer.build_namedtuple_typeinfo.
488+
return AnyType(TypeOfAny.implementation_artifact)
489+
return TupleType([_asdictify_inner(item, seen_dataclasses) for item in typ.items],
490+
api.named_generic_type('builtins.tuple', []), implicit=typ.implicit)
491+
return typ
492+
493+
return _asdictify_inner(typ, seen_dataclasses=frozenset())

mypy/plugins/default.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@ class DefaultPlugin(Plugin):
2323
def get_function_hook(self, fullname: str
2424
) -> Optional[Callable[[FunctionContext], Type]]:
2525
from mypy.plugins import ctypes
26+
from mypy.plugins import dataclasses
2627

2728
if fullname == 'contextlib.contextmanager':
2829
return contextmanager_callback
2930
elif fullname == 'builtins.open' and self.python_version[0] == 3:
3031
return open_callback
3132
elif fullname == 'ctypes.Array':
3233
return ctypes.array_constructor_callback
34+
elif fullname == 'dataclasses.asdict':
35+
return dataclasses.asdict_callback
3336
return None
3437

3538
def get_method_signature_hook(self, fullname: str

0 commit comments

Comments
 (0)