diff --git a/aiohttp_sse/__init__.py b/aiohttp_sse/__init__.py index ea36987..0a90d90 100644 --- a/aiohttp_sse/__init__.py +++ b/aiohttp_sse/__init__.py @@ -39,6 +39,7 @@ def __init__( reason: Optional[str] = None, headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, + timeout: Optional[float] = None, ): super().__init__(status=status, reason=reason) @@ -54,6 +55,7 @@ def __init__( self._ping_interval: float = self.DEFAULT_PING_INTERVAL self._ping_task: Optional[asyncio.Task[None]] = None self._sep = sep if sep is not None else self.DEFAULT_SEPARATOR + self._timeout = timeout def is_connected(self) -> bool: """Check connection is prepared and ping task is not done.""" @@ -130,10 +132,16 @@ async def send( buffer.write(self._sep) try: - await self.write(buffer.getvalue().encode("utf-8")) + await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout + self.write(buffer.getvalue().encode("utf-8")), + timeout=self._timeout, + ) except ConnectionResetError: self.stop_streaming() raise + except asyncio.TimeoutError: + self.stop_streaming() + raise TimeoutError async def wait(self) -> None: """EventSourceResponse object is used for streaming data to the client, @@ -202,8 +210,16 @@ async def _ping(self) -> None: while True: await asyncio.sleep(self._ping_interval) try: - await self.write(message) - except (ConnectionResetError, RuntimeError): + await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout + self.write(message), + timeout=self._timeout, + ) + except ( + ConnectionResetError, + RuntimeError, + TimeoutError, + asyncio.TimeoutError, + ): # RuntimeError - on writing after EOF break @@ -256,6 +272,7 @@ def sse_response( headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, response_cls: Type[EventSourceResponse] = EventSourceResponse, + timeout: Optional[float] = None, ) -> Any: if not issubclass(response_cls, EventSourceResponse): raise TypeError( @@ -263,5 +280,11 @@ def sse_response( "aiohttp_sse.EventSourceResponse, got {}".format(response_cls) ) - sse = response_cls(status=status, reason=reason, headers=headers, sep=sep) + sse = response_cls( + status=status, + reason=reason, + headers=headers, + sep=sep, + timeout=timeout, + ) return _ContextManager(sse._prepare(request)) diff --git a/tests/test_sse.py b/tests/test_sse.py index 466df76..3151c63 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -1,6 +1,6 @@ import asyncio import sys -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable, List, Optional import pytest from aiohttp import web @@ -559,3 +559,47 @@ async def handler(request: web.Request) -> EventSourceResponse: async with client.get("/") as response: assert 200 == response.status + + +@pytest.mark.parametrize("timeout", (None, 0.1)) +async def test_with_timeout( + aiohttp_client: ClientFixture, + monkeypatch: pytest.MonkeyPatch, + timeout: Optional[float], +) -> None: + """Test write timeout. + + Relates to this issue: + https://github.com/sysid/sse-starlette/issues/89 + """ + timeout_raised = False + + async def frozen_write(_data: bytes) -> None: + await asyncio.sleep(42) + + async def handler(request: web.Request) -> EventSourceResponse: + sse = EventSourceResponse(timeout=timeout) + sse.ping_interval = 42 + await sse.prepare(request) + monkeypatch.setattr(sse, "write", frozen_write) + + async with sse: + try: + await sse.send("foo") + except TimeoutError: + nonlocal timeout_raised + timeout_raised = True + raise + + return sse + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + async with client.get("/") as resp: + assert resp.status == 200 + await asyncio.sleep(0.5) + assert resp.connection.closed is bool(timeout) + + assert timeout_raised is bool(timeout)