-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathplatform.py
515 lines (433 loc) · 18.3 KB
/
platform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Awaitable, Callable, Collection
from dataclasses import dataclass
import json
import re
import ssl
import time
import typing
from typing import Any, ParamSpec, TypeVar
import httpx
from httpx import AsyncClient
from nonebot.log import logger
from nonebot_plugin_saa import PlatformTarget
from nonebot_bison.plugin_config import plugin_config
from nonebot_bison.post import Post
from nonebot_bison.types import Category, RawPost, SubUnit, Tag, Target
from nonebot_bison.utils import ProcessContext, Site
class CategoryNotSupport(Exception):
"""raise in get_category, when you know the category of the post
but don't want to support it or don't support its parsing yet
"""
class CategoryNotRecognize(Exception):
"""raise in get_category, when you don't know the category of post"""
class RegistryMeta(type):
def __new__(cls, name, bases, namespace, **kwargs):
return super().__new__(cls, name, bases, namespace)
def __init__(cls, name, bases, namespace, **kwargs):
if kwargs.get("base"):
# this is the base class
cls.registry = []
elif not kwargs.get("abstract"):
# this is the subclass
cls.registry.append(cls)
super().__init__(name, bases, namespace, **kwargs)
P = ParamSpec("P")
R = TypeVar("R")
def logger_custom_warning(_msg: str) -> None:
if plugin_config.bison_collapse_network_warning:
_msg = re.sub(r"\s*[\r\n]+\s*", "", _msg)
_max_length = plugin_config.bison_collapse_network_warning_length
logger.warning(_msg if len(_msg) < _max_length else f"{_msg[:_max_length]}...")
return None
else:
logger.warning(_msg)
return None
async def catch_network_error(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R | None:
if plugin_config.bison_show_network_warning:
try:
return await func(*args, **kwargs)
except httpx.RequestError as err:
return logger_custom_warning(f"network connection error: {type(err)}, url: {err.request.url}")
except ssl.SSLError as err:
return logger_custom_warning(f"ssl error: {err}")
except json.JSONDecodeError as err:
return logger_custom_warning(f"json error, parsing: {err.doc}")
except Exception as err:
return logger_custom_warning(f"unmatched exception: {err}")
else:
return await func(*args, **kwargs)
class PlatformMeta(RegistryMeta):
categories: dict[Category, str]
store: dict[Target, Any]
def __init__(cls, name, bases, namespace, **kwargs):
cls.reverse_category = {}
cls.store = {}
if hasattr(cls, "categories") and cls.categories:
for key, val in cls.categories.items():
cls.reverse_category[val] = key
super().__init__(name, bases, namespace, **kwargs)
class PlatformABCMeta(PlatformMeta, ABC): ...
class Platform(metaclass=PlatformABCMeta, base=True):
site: type[Site]
ctx: ProcessContext
is_common: bool
enabled: bool
name: str
has_target: bool
categories: dict[Category, str]
enable_tag: bool
platform_name: str
parse_target_promot: str | None = None
registry: list[type["Platform"]]
reverse_category: dict[str, Category]
use_batch: bool = False
# TODO: 限定可使用的theme名称
default_theme: str = "basic"
@classmethod
@abstractmethod
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: ...
@abstractmethod
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: ...
async def do_fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
return await catch_network_error(self.fetch_new_post, sub_unit) or []
@abstractmethod
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: ...
async def do_batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
return await catch_network_error(self.batch_fetch_new_post, sub_units) or []
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post: ...
async def do_parse(self, raw_post: RawPost) -> Post:
"actually function called"
return await self.parse(raw_post)
def __init__(self, context: ProcessContext):
super().__init__()
self.ctx = context
class ParseTargetException(Exception):
def __init__(self, *args: object, prompt: str | None = None) -> None:
super().__init__(*args)
self.prompt = prompt
"""用户输入提示信息"""
@classmethod
async def parse_target(cls, target_string: str) -> Target:
return Target(target_string)
@abstractmethod
def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None:
"Return Tag list of given RawPost"
@classmethod
def get_stored_data(cls, target: Target) -> Any:
return cls.store.get(target)
@classmethod
def set_stored_data(cls, target: Target, data: Any):
cls.store[target] = data
def tag_separator(self, stored_tags: list[Tag]) -> tuple[list[Tag], list[Tag]]:
"""返回分离好的正反tag元组"""
subscribed_tags = []
banned_tags = []
for tag in stored_tags:
if tag.startswith("~"):
banned_tags.append(tag.lstrip("~"))
else:
subscribed_tags.append(tag)
return subscribed_tags, banned_tags
def is_banned_post(
self,
post_tags: Collection[Tag],
subscribed_tags: list[Tag],
banned_tags: list[Tag],
) -> bool:
"""只要存在任意屏蔽tag则返回真,此行为优先级最高。
存在任意被订阅tag则返回假,此行为优先级次之。
若被订阅tag为空,则返回假。
"""
# 存在任意需要屏蔽的tag则为真
if banned_tags:
for tag in post_tags or []:
if tag in banned_tags:
return True
# 检测屏蔽tag后,再检测订阅tag
# 存在任意需要订阅的tag则为假
if subscribed_tags:
ban_it = True
for tag in post_tags or []:
if tag in subscribed_tags:
ban_it = False
return ban_it
else:
return False
async def filter_user_custom(
self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]
) -> list[RawPost]:
res: list[RawPost] = []
for raw_post in raw_post_list:
if self.categories:
cat = self.get_category(raw_post)
if cats and cat not in cats:
continue
if self.enable_tag and tags:
raw_post_tags = self.get_tags(raw_post)
if isinstance(raw_post_tags, Collection) and self.is_banned_post(
raw_post_tags, *self.tag_separator(tags)
):
continue
res.append(raw_post)
return res
async def dispatch_user_post(
self, new_posts: list[RawPost], sub_unit: SubUnit
) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in sub_unit.user_sub_infos:
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self.do_parse(raw_post))
res.append((user, user_post))
return res
@abstractmethod
def get_category(self, post: RawPost) -> Category | None:
"Return category of given Rawpost"
raise NotImplementedError()
class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress"
def __init__(self, ctx: ProcessContext):
super().__init__(ctx)
self.parse_cache: dict[Any, Post] = {}
@abstractmethod
def get_id(self, post: RawPost) -> Any:
"Get post id of given RawPost"
async def do_parse(self, raw_post: RawPost) -> Post:
post_id = self.get_id(raw_post)
if post_id not in self.parse_cache:
retry_times = 3
while retry_times:
try:
self.parse_cache[post_id] = await self.parse(raw_post)
break
except Exception as err:
retry_times -= 1
if not retry_times:
raise err
return self.parse_cache[post_id]
@abstractmethod
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
@abstractmethod
async def batch_get_sub_list(self, targets: list[Target]) -> list[list[RawPost]]:
"Get post list of the given targets"
raise NotImplementedError()
@abstractmethod
def get_date(self, post: RawPost) -> int | None:
"Get post timestamp and return, return None if can't get the time"
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
res = []
for raw_post in raw_post_list:
# post_id = self.get_id(raw_post)
# if post_id in exists_posts_set:
# continue
if (
(post_time := self.get_date(raw_post))
and time.time() - post_time > 2 * 60 * 60
and plugin_config.bison_init_filter
):
continue
try:
self.get_category(raw_post)
except CategoryNotSupport as e:
logger.info("未支持解析的推文类别:" + repr(e) + ",忽略")
continue
except CategoryNotRecognize as e:
logger.warning("未知推文类别:" + repr(e))
msgs = self.ctx.gen_req_records()
for m in msgs:
logger.warning(m)
continue
except NotImplementedError:
pass
res.append(raw_post)
return res
class NewMessage(MessageProcess, abstract=True):
"Fetch a list of messages, filter the new messages, dispatch it to different users"
@dataclass
class MessageStorage:
inited: bool
exists_posts: set[Any]
async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
filtered_post = await self.filter_common(raw_post_list)
store = self.get_stored_data(target) or self.MessageStorage(False, set())
res = []
if not store.inited and plugin_config.bison_init_filter:
# target not init
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
store.exists_posts.add(post_id)
logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}")
store.inited = True
else:
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
if post_id in store.exists_posts:
continue
res.append(raw_post)
store.exists_posts.add(post_id)
self.set_stored_data(target, store)
return res
async def _handle_new_post(
self,
post_list: list[RawPost],
sub_unit: SubUnit,
) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.filter_common_with_diff(sub_unit.sub_target, post_list)
if not new_posts:
return []
else:
for post in new_posts:
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
sub_unit.sub_target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(new_posts, sub_unit)
self.parse_cache = {}
return res
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(sub_unit.sub_target)
return await self._handle_new_post(post_list, sub_unit)
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
if not self.has_target:
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
posts_set = await self.batch_get_sub_list([x[0] for x in sub_units])
res = []
for sub_unit, posts in zip(sub_units, posts_set):
res.extend(await self._handle_new_post(posts, sub_unit))
return res
class StatusChange(Platform, abstract=True):
"Watch a status, and fire a post when status changes"
class FetchError(RuntimeError):
pass
@abstractmethod
async def get_status(self, target: Target) -> Any: ...
@abstractmethod
async def batch_get_status(self, targets: list[Target]) -> list[Any]: ...
@abstractmethod
def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: ...
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post: ...
async def _handle_status_change(
self, new_status: Any, sub_unit: SubUnit
) -> list[tuple[PlatformTarget, list[Post]]]:
res = []
if old_status := self.get_stored_data(sub_unit.sub_target):
diff = self.compare_status(sub_unit.sub_target, old_status, new_status)
if diff:
logger.info(
"status changes {} {}: {} -> {}".format(
self.platform_name,
sub_unit.sub_target if self.has_target else "-",
old_status,
new_status,
)
)
res = await self.dispatch_user_post(diff, sub_unit)
self.set_stored_data(sub_unit.sub_target, new_status)
return res
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
try:
new_status = await self.get_status(sub_unit.sub_target)
except self.FetchError as err:
logger.warning(f"fetching {self.name}-{sub_unit.sub_target} error: {err}")
raise
return await self._handle_status_change(new_status, sub_unit)
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
if not self.has_target:
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
new_statuses = await self.batch_get_status([x[0] for x in sub_units])
res = []
for sub_unit, new_status in zip(sub_units, new_statuses):
res.extend(await self._handle_status_change(new_status, sub_unit))
return res
class SimplePost(NewMessage, abstract=True):
"Fetch a list of messages, dispatch it to different users"
async def _handle_new_post(
self,
post_list: list[RawPost],
sub_unit: SubUnit,
) -> list[tuple[PlatformTarget, list[Post]]]:
if not post_list:
return []
else:
for post in post_list:
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
sub_unit.sub_target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(post_list, sub_unit)
self.parse_cache = {}
return res
def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
if typing.TYPE_CHECKING:
class NoTargetGroup(Platform, abstract=True):
platform_list: list[type[Platform]]
platform_obj_list: list[Platform]
DUMMY_STR = "_DUMMY"
platform_name = platform_list[0].platform_name
name = DUMMY_STR
categories_keys = set()
categories = {}
site = platform_list[0].site
for platform in platform_list:
if platform.has_target:
raise RuntimeError(f"Platform {platform.name} should have no target")
if name == DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError(f"Platform name for {platform_name} not fit")
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(f"Platform categories for {platform_name} duplicate")
categories_keys |= platform_category_key_set
categories.update(platform.categories)
if platform.site != site:
raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
def __init__(self: "NoTargetGroup", ctx: ProcessContext):
Platform.__init__(self, ctx)
self.platform_obj_list = []
for platform_class in self.platform_list:
self.platform_obj_list.append(platform_class(ctx))
def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(x.name for x in self.platform_list) + "]"
@classmethod
async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target)
async def fetch_new_post(self: "NoTargetGroup", sub_unit: SubUnit):
res = defaultdict(list)
for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(sub_unit)
for user, posts in platform_res:
res[user].extend(posts)
return [[key, val] for key, val in res.items()]
return type(
"NoTargetGroup",
(Platform,),
{
"platform_list": platform_list,
"platform_name": platform_list[0].platform_name,
"name": name,
"categories": categories,
"site": site,
"is_common": platform_list[0].is_common,
"enabled": True,
"has_target": False,
"enable_tag": False,
"__init__": __init__,
"get_target_name": get_target_name,
"fetch_new_post": fetch_new_post,
},
abstract=True,
)