Skip to content

Commit ad51dde

Browse files
authored
add unittest.case._AssertRaisesBaseContext class (#11158)
1 parent f8e7386 commit ad51dde

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

stdlib/unittest/_log.pyi

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@ import logging
22
import sys
33
from types import TracebackType
44
from typing import ClassVar, Generic, NamedTuple, TypeVar
5-
from unittest.case import TestCase
5+
from unittest.case import TestCase, _BaseTestCaseContext
66

77
_L = TypeVar("_L", None, _LoggingWatcher)
88

99
class _LoggingWatcher(NamedTuple):
1010
records: list[logging.LogRecord]
1111
output: list[str]
1212

13-
class _AssertLogsContext(Generic[_L]):
13+
class _AssertLogsContext(_BaseTestCaseContext, Generic[_L]):
1414
LOGGING_FORMAT: ClassVar[str]
15-
test_case: TestCase
1615
logger_name: str
1716
level: int
1817
msg: None

stdlib/unittest/case.pyi

+20-3
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,26 @@ _P = ParamSpec("_P")
2525
DIFF_OMITTED: str
2626

2727
class _BaseTestCaseContext:
28+
test_case: TestCase
2829
def __init__(self, test_case: TestCase) -> None: ...
2930

31+
class _AssertRaisesBaseContext(_BaseTestCaseContext):
32+
expected: type[BaseException] | tuple[type[BaseException], ...]
33+
expected_regex: Pattern[str] | None
34+
obj_name: str | None
35+
msg: str | None
36+
37+
def __init__(
38+
self,
39+
expected: type[BaseException] | tuple[type[BaseException], ...],
40+
test_case: TestCase,
41+
expected_regex: str | Pattern[str] | None = None,
42+
) -> None: ...
43+
44+
# This returns Self if args is the empty list, and None otherwise.
45+
# but it's not possible to construct an overload which expresses that
46+
def handle(self, name: str, args: list[Any], kwargs: dict[str, Any]) -> Any: ...
47+
3048
if sys.version_info >= (3, 9):
3149
from unittest._log import _AssertLogsContext, _LoggingWatcher
3250
else:
@@ -41,7 +59,6 @@ else:
4159

4260
class _AssertLogsContext(_BaseTestCaseContext, Generic[_L]):
4361
LOGGING_FORMAT: ClassVar[str]
44-
test_case: TestCase
4562
logger_name: str
4663
level: int
4764
msg: None
@@ -310,7 +327,7 @@ class FunctionTestCase(TestCase):
310327
def __hash__(self) -> int: ...
311328
def __eq__(self, other: object) -> bool: ...
312329

313-
class _AssertRaisesContext(Generic[_E]):
330+
class _AssertRaisesContext(_AssertRaisesBaseContext, Generic[_E]):
314331
exception: _E
315332
def __enter__(self) -> Self: ...
316333
def __exit__(
@@ -319,7 +336,7 @@ class _AssertRaisesContext(Generic[_E]):
319336
if sys.version_info >= (3, 9):
320337
def __class_getitem__(cls, item: Any) -> GenericAlias: ...
321338

322-
class _AssertWarnsContext:
339+
class _AssertWarnsContext(_AssertRaisesBaseContext):
323340
warning: WarningMessage
324341
filename: str
325342
lineno: int

0 commit comments

Comments
 (0)