Skip to content

Commit d041ff2

Browse files
committed
fix(DataFrame): to_dict("index") and typevar
1 parent f340905 commit d041ff2

File tree

2 files changed

+99
-27
lines changed

2 files changed

+99
-27
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ from typing import (
1919
Generic,
2020
Literal,
2121
NoReturn,
22+
TypeVar,
2223
final,
2324
overload,
2425
)
@@ -165,6 +166,8 @@ from pandas._typing import (
165166
from pandas.io.formats.style import Styler
166167
from pandas.plotting import PlotAccessor
167168

169+
_T_MUTABLE_MAPPING = TypeVar("_T_MUTABLE_MAPPING", bound=MutableMapping, covariant=True)
170+
168171
class _iLocIndexerFrame(_iLocIndexer, Generic[_T]):
169172
@overload
170173
def __getitem__(self, idx: tuple[int, int]) -> Scalar: ...
@@ -396,9 +399,9 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
396399
self,
397400
orient: Literal["records"],
398401
*,
399-
into: MutableMapping | type[MutableMapping],
402+
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
400403
index: Literal[True] = ...,
401-
) -> list[MutableMapping[Hashable, Any]]: ...
404+
) -> list[_T_MUTABLE_MAPPING]: ...
402405
@overload
403406
def to_dict(
404407
self,
@@ -410,39 +413,55 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
410413
@overload
411414
def to_dict(
412415
self,
413-
orient: Literal["dict", "list", "series", "index"],
416+
orient: Literal["index"],
417+
*,
418+
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
419+
index: Literal[True] = ...,
420+
) -> MutableMapping[Hashable, _T_MUTABLE_MAPPING]: ...
421+
@overload
422+
def to_dict(
423+
self,
424+
orient: Literal["index"],
425+
*,
426+
into: type[dict] = ...,
427+
index: Literal[True] = ...,
428+
) -> dict[Hashable, dict[Hashable, Any]]: ...
429+
@overload
430+
def to_dict(
431+
self,
432+
orient: Literal["dict", "list", "series"],
414433
*,
415-
into: MutableMapping | type[MutableMapping],
434+
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
416435
index: Literal[True] = ...,
417-
) -> MutableMapping[Hashable, Any]: ...
436+
) -> _T_MUTABLE_MAPPING: ...
418437
@overload
419438
def to_dict(
420439
self,
421440
orient: Literal["split", "tight"],
422441
*,
423-
into: MutableMapping | type[MutableMapping],
442+
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
424443
index: bool = ...,
425-
) -> MutableMapping[Hashable, Any]: ...
444+
) -> _T_MUTABLE_MAPPING: ...
426445
@overload
427446
def to_dict(
428447
self,
429-
orient: Literal["dict", "list", "series", "index"] = ...,
448+
orient: Literal["dict", "list", "series"] = ...,
430449
*,
431-
into: MutableMapping | type[MutableMapping],
450+
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
432451
index: Literal[True] = ...,
433-
) -> MutableMapping[Hashable, Any]: ...
452+
) -> _T_MUTABLE_MAPPING: ...
434453
@overload
435454
def to_dict(
436455
self,
437456
orient: Literal["split", "tight"] = ...,
438457
*,
439-
into: MutableMapping | type[MutableMapping],
458+
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
440459
index: bool = ...,
441-
) -> MutableMapping[Hashable, Any]: ...
460+
) -> _T_MUTABLE_MAPPING: ...
442461
@overload
443462
def to_dict(
444463
self,
445-
orient: Literal["dict", "list", "series", "index"] = ...,
464+
orient: Literal["dict", "list", "series"] = ...,
446465
*,
447466
into: type[dict] = ...,
448467
index: Literal[True] = ...,

tests/test_frame.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3638,33 +3638,83 @@ def test_to_records() -> None:
36383638
)
36393639

36403640

3641-
def test_to_dict() -> None:
3641+
def test_to_dict_simple() -> None:
36423642
check(assert_type(DF.to_dict(), dict[Hashable, Any]), dict)
36433643
check(assert_type(DF.to_dict("split"), dict[Hashable, Any]), dict)
3644+
check(assert_type(DF.to_dict("records"), list[dict[Hashable, Any]]), list)
3645+
3646+
if TYPE_CHECKING_INVALID_USAGE:
3647+
3648+
def test(mapping: Mapping) -> None: # pyright: ignore[reportUnusedFunction]
3649+
DF.to_dict( # type: ignore[call-overload]
3650+
into=mapping # pyright: ignore[reportArgumentType,reportCallIssue]
3651+
)
3652+
3653+
3654+
def test_to_dict_into_defaultdict_any() -> None:
3655+
"""Test DataFrame.to_dict with `into=defaultdict[Any, list]`"""
3656+
3657+
data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]})
3658+
target: defaultdict[Hashable, list[Any]] = defaultdict(list)
36443659

3645-
target: MutableMapping = defaultdict(list)
36463660
check(
3647-
assert_type(DF.to_dict(into=target), MutableMapping[Hashable, Any]), defaultdict
3661+
assert_type(data.to_dict(into=target), defaultdict[Hashable, list[Any]]),
3662+
defaultdict,
36483663
)
3649-
target = defaultdict(list)
36503664
check(
3651-
assert_type(DF.to_dict("tight", into=target), MutableMapping[Hashable, Any]),
3665+
assert_type(
3666+
data.to_dict("index", into=target),
3667+
MutableMapping[Hashable, defaultdict[Hashable, list[Any]]],
3668+
),
3669+
defaultdict,
3670+
)
3671+
check(
3672+
assert_type(
3673+
data.to_dict("tight", into=target), defaultdict[Hashable, list[Any]]
3674+
),
36523675
defaultdict,
36533676
)
3654-
target = defaultdict(list)
3655-
check(assert_type(DF.to_dict("records"), list[dict[Hashable, Any]]), list)
36563677
check(
36573678
assert_type(
3658-
DF.to_dict("records", into=target), list[MutableMapping[Hashable, Any]]
3679+
data.to_dict("records", into=target), list[defaultdict[Hashable, list[Any]]]
36593680
),
36603681
list,
36613682
)
3662-
if TYPE_CHECKING_INVALID_USAGE:
36633683

3664-
def test(mapping: Mapping) -> None: # pyright: ignore[reportUnusedFunction]
3665-
DF.to_dict( # type: ignore[call-overload]
3666-
into=mapping # pyright: ignore[reportArgumentType,reportCallIssue]
3667-
)
3684+
3685+
def test_to_dict_into_defaultdict_typed() -> None:
3686+
"""Test DataFrame.to_dict with `into=defaultdict[tuple[str, str], list[int]]`"""
3687+
3688+
data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]})
3689+
target: defaultdict[tuple[str, str], list[int]] = defaultdict(list)
3690+
target[("str", "rts")].append(1)
3691+
3692+
check(
3693+
assert_type(data.to_dict(into=target), defaultdict[tuple[str, str], list[int]]),
3694+
defaultdict,
3695+
tuple,
3696+
)
3697+
check(
3698+
assert_type(
3699+
data.to_dict("index", into=target),
3700+
MutableMapping[Hashable, defaultdict[tuple[str, str], list[int]]],
3701+
),
3702+
defaultdict,
3703+
)
3704+
check(
3705+
assert_type(
3706+
data.to_dict("tight", into=target), defaultdict[tuple[str, str], list[int]]
3707+
),
3708+
defaultdict,
3709+
)
3710+
check(
3711+
assert_type(
3712+
data.to_dict("records", into=target),
3713+
list[defaultdict[tuple[str, str], list[int]]],
3714+
),
3715+
list,
3716+
defaultdict,
3717+
)
36683718

36693719

36703720
def test_neg() -> None:
@@ -4111,7 +4161,10 @@ def test_to_dict_index() -> None:
41114161
assert_type(df.to_dict(orient="series", index=True), dict[Hashable, Any]), dict
41124162
)
41134163
check(
4114-
assert_type(df.to_dict(orient="index", index=True), dict[Hashable, Any]), dict
4164+
assert_type(
4165+
df.to_dict(orient="index", index=True), dict[Hashable, dict[Hashable, Any]]
4166+
),
4167+
dict,
41154168
)
41164169
check(
41174170
assert_type(df.to_dict(orient="split", index=True), dict[Hashable, Any]), dict

0 commit comments

Comments
 (0)