-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmocks.py
254 lines (225 loc) · 9.07 KB
/
mocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Literal
import pychoir
import pytest
import annotated_logger
if TYPE_CHECKING: # pragma: no cover
from unittest.mock import MagicMock
class AssertLogged:
"""Stores the data from a call to `assert_logged` and checks if there is a match."""
def __init__(
self,
level: str | pychoir.core.Matcher,
message: str | pychoir.core.Matcher,
present: dict[str, str],
absent: set[str] | Literal["ALL"],
*,
count: int | pychoir.core.Matcher,
) -> None:
"""Store the arguments that were passed to `assert_logged` and set defaults."""
self.level = level
self.message = message
self.present = present
self.absent = absent
self.count = count
self.found = 0
self.failed_matches: dict[str, int] = {}
def check(self, mock: AnnotatedLogMock) -> None:
"""Loop through calls in passed mock and check for matches."""
for record in mock.records:
differences = self._check_record_matches(record)
if len(differences) == 0:
self.found = self.found + 1
diff_str = str(differences)
if diff_str in self.failed_matches:
self.failed_matches[diff_str] += 1
else:
self.failed_matches[diff_str] = 1
fail_message = self.build_message()
if len(fail_message) > 0:
pytest.fail("\n".join(fail_message))
def _failed_sort_key(self, failed_tuple: tuple[str, int]) -> str:
failed, count = failed_tuple
message_match = failed.count("Desired message")
count_diff = 0 # pragma: no mutate
if isinstance(self.count, int):
count_diff = abs(count - self.count)
number = (
failed.count("Desired")
+ failed.count("Missing key")
+ failed.count("Unwanted key")
)
length = len(failed)
# This will order by if the message matched then how the count differs
# then number of incorrect bits and finally the length
return f"{message_match}-{count_diff:04d}-{number:04d}-{length:04d}" # pragma: no mutate # noqa: E501
def build_message(self) -> list[str]:
"""Create failure message."""
if self.count == 0 and self.found == 0:
return []
if self.found == 0:
fail_message = [
f"No matching log record found. There were {sum(self.failed_matches.values())} log messages.", # noqa: E501
]
fail_message.append("Desired:")
if isinstance(self.count, int):
fail_message.append(f"Count: {self.count}")
fail_message.append(f"Message: '{self.message}'")
fail_message.append(f"Level: '{self.level}'")
# only put in these if they were specified
fail_message.append(f"Present: '{self.present}'")
fail_message.append(f"Absent: '{self.absent}'")
fail_message.append("")
if len(self.failed_matches) == 0:
return fail_message
fail_message.append(
"Below is a list of the values for the selected extras for those failed matches.", # noqa: E501
)
for match, count in sorted(
self.failed_matches.items(), key=self._failed_sort_key
):
msg = match
if self.count and self.count != count:
msg = (
match[:-1]
+ f', "Desired {self.count} call{"" if self.count == 1 else "s"}, actual {count} call{"" if count == 1 else "s"}"' # noqa: E501
+ match[-1:]
)
fail_message.append(msg)
return fail_message
if self.count != self.found:
return [f"Found {self.found} matching messages, {self.count} were desired"]
return []
def _check_record_matches(
self,
record: logging.LogRecord,
) -> list[str]:
differences = []
if "levelname" in record.__dict__:
level = record.levelname
elif "level" in record.__dict__:
level = record.level # pyright: ignore[reportAttributeAccessIssue]
# If you have removed levelname and levelno and didn't add level... good luck
else:
level = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARNING",
logging.ERROR: "ERROR",
}[record.levelno]
actual = {
"level": level,
"msg": record.msg,
# The extras are already added as attributes, so this is the easiest way
# to get them. There are more things in here, but that should be fine
"extra": record.__dict__,
}
if self.level != actual["level"]:
differences.append(
f"Desired level: {self.level}, actual level: {actual['level']}",
)
# TODO @<crimsonknave>: Do a better string diff here # noqa: FIX002, TD003
if self.message != actual["msg"]:
differences.append(
f"Desired message: '{self.message}', actual message: '{actual['msg']}'",
)
actual_keys = set(actual["extra"].keys())
desired_keys = set(self.present.keys())
missing = desired_keys - actual_keys
unwanted = set()
if self.absent == AnnotatedLogMock.ALL:
unwanted = actual_keys - AnnotatedLogMock.DEFAULT_LOG_KEYS
elif isinstance(self.absent, set):
unwanted = actual_keys & self.absent
shared = desired_keys & actual_keys
differences.extend([f"Missing key: `{key}`" for key in sorted(missing)])
differences.extend([f"Unwanted key: `{key}`" for key in sorted(unwanted)])
differences.extend(
[
f"Extra `{key}` value is incorrect. Desired `{self.present[key]}` ({self.present[key].__class__}) , actual `{actual['extra'][key]}` ({actual['extra'][key].__class__})" # noqa: E501
for key in sorted(shared)
if self.present[key] != actual["extra"][key]
]
)
return differences
class AnnotatedLogMock:
"""Mock that captures logs and provides extra assertion logic."""
ALL = "ALL"
DEFAULT_LOG_KEYS = frozenset(
[
"action",
"annotated",
"args",
"created",
"exc_info",
"exc_text",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"message",
"module",
"msecs",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
]
)
def __init__(self, handler: logging.Handler) -> None:
"""Store the handler and initialize the messages and records lists."""
self.messages = []
self.records = []
self.handler = handler
def __getattr__(self, name: str) -> Any: # noqa: ANN401
"""Fall back to the real handler object."""
return getattr(self.handler, name)
def handle(self, record: logging.LogRecord) -> bool:
"""Wrap the real handle method, store the formatted message and log record."""
self.messages.append(self.handler.format(record))
self.records.append(record)
return self.handler.handle(record)
def assert_logged(
self,
level: str | pychoir.core.Matcher | None = None,
message: str | pychoir.core.Matcher | None = None,
present: dict[str, Any] | None = None,
absent: str | set[str] | list[str] | None = None,
count: int | pychoir.core.Matcher | None = None,
) -> None:
"""Check if the mock received a log call that matches the arguments."""
if level is None:
level = pychoir.existential.Anything()
elif isinstance(level, str):
level = level.upper()
if message is None:
message = pychoir.existential.Anything()
if present is None:
present = {}
if absent is None:
absent = []
if isinstance(absent, list):
absent = set(absent)
if isinstance(absent, str) and absent != "ALL":
absent = {absent}
if count is None:
count = pychoir.numeric.IsPositive()
__tracebackhide__ = True # pragma: no mutate
assert_logged = AssertLogged(level, message, present, absent, count=count)
assert_logged.check(self)
@pytest.fixture()
def annotated_logger_mock(mocker: MagicMock, annotated_logger_object: AnnotatedLogger) -> AnnotatedLogMock:
"""Fixture for a mock of the annotated logger."""
import pdb;pdb.set_trace()
return mocker.patch(
"annotated_logger.handler",
new_callable=AnnotatedLogMock,
handler=annotated_logger.handler,
)