From 1086033e46f09549ae854c42c09630b4ad56f722 Mon Sep 17 00:00:00 2001 From: Me Date: Wed, 1 Jan 2025 11:09:38 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20Fixing=20mypy's=20complaints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- broadcaster/_base.py | 4 +++- broadcaster/backends/redis.py | 18 +++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 1d1de35..8c4b42a 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -5,12 +5,14 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, cast from urllib.parse import urlparse +from pydantic import BaseModel + if TYPE_CHECKING: # pragma: no cover from broadcaster.backends.base import BroadcastBackend class Event: - def __init__(self, channel: str, message: str) -> None: + def __init__(self, channel: str, message: str | BaseModel) -> None: self.channel = channel self.message = message diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index fa3535d..48e09d6 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -116,16 +116,16 @@ async def next_published(self) -> Event: class RedisPydanticStreamBackend(RedisStreamBackend): """Redis Stream backend for broadcasting messages using Pydantic models.""" - def __init__(self: typing.Self, url: str) -> None: + def __init__(self, url: str) -> None: """Create a new Redis Stream backend.""" url = url.replace("redis-pydantic-stream", "redis", 1) self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {} self._ready = asyncio.Event() self._producer = redis.Redis.from_url(url) self._consumer = redis.Redis.from_url(url) - self._module_cache: dict[str, type(BaseModel)] = {} + self._module_cache: dict[str, type[BaseModel]] = {} - def _build_module_cache(self: typing.Self) -> None: + def _build_module_cache(self) -> None: """Build a cache of Pydantic models.""" modules = list(sys.modules.keys()) for module_name in modules: @@ -133,13 +133,17 @@ def _build_module_cache(self: typing.Self) -> None: if inspect.isclass(obj) and issubclass(obj, BaseModel): self._module_cache[obj.__name__] = obj - async def publish(self: typing.Self, channel: str, message: BaseModel) -> None: + async def publish(self, channel: str, message: BaseModel) -> None: """Publish a message to a channel.""" msg_type: str = message.__class__.__name__ + + if msg_type not in self._module_cache: + self._module_cache[msg_type] = message.__class__ + message_json: str = message.model_dump_json() await self._producer.xadd(channel, {"msg_type": msg_type, "message": message_json}) - async def wait_for_messages(self: typing.Self) -> list[StreamMessageType]: + async def wait_for_messages(self) -> list[StreamMessageType]: """Wait for messages to be published.""" await self._ready.wait() self._build_module_cache() @@ -148,7 +152,7 @@ async def wait_for_messages(self: typing.Self) -> list[StreamMessageType]: messages = await self._consumer.xread(self.streams, count=1, block=100) return messages - async def next_published(self: typing.Self) -> Event | None: + async def next_published(self) -> Event: """Get the next published message.""" messages = await self.wait_for_messages() stream, events = messages[0] @@ -160,7 +164,7 @@ async def next_published(self: typing.Self) -> Event | None: if msg_type in self._module_cache: message_obj = self._module_cache[msg_type].model_validate_json(message_data) if not message_obj: - return None + return Event(stream.decode("utf-8"), message_data) return Event( channel=stream.decode("utf-8"), message=message_obj,