Skip to content

Commit

Permalink
feat: support timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
Olegt0rr committed Feb 12, 2024
1 parent 2cd91ea commit 8d6be2b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
31 changes: 27 additions & 4 deletions aiohttp_sse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
reason: Optional[str] = None,
headers: Optional[Mapping[str, str]] = None,
sep: Optional[str] = None,
timeout: Optional[float] = None,
):
super().__init__(status=status, reason=reason)

Expand All @@ -54,6 +55,7 @@ def __init__(
self._ping_interval: float = self.DEFAULT_PING_INTERVAL
self._ping_task: Optional[asyncio.Task[None]] = None
self._sep = sep if sep is not None else self.DEFAULT_SEPARATOR
self._timeout = timeout

def is_connected(self) -> bool:
"""Check connection is prepared and ping task is not done."""
Expand Down Expand Up @@ -130,10 +132,16 @@ async def send(

buffer.write(self._sep)
try:
await self.write(buffer.getvalue().encode("utf-8"))
await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout
self.write(buffer.getvalue().encode("utf-8")),
timeout=self._timeout,
)
except ConnectionResetError:
self.stop_streaming()
raise
except asyncio.TimeoutError:
self.stop_streaming()
raise TimeoutError

async def wait(self) -> None:
"""EventSourceResponse object is used for streaming data to the client,
Expand Down Expand Up @@ -202,8 +210,16 @@ async def _ping(self) -> None:
while True:
await asyncio.sleep(self._ping_interval)
try:
await self.write(message)
except (ConnectionResetError, RuntimeError):
await asyncio.wait_for( # TODO(PY311): Use asyncio.timeout
self.write(message),
timeout=self._timeout,
)
except (
ConnectionResetError,
RuntimeError,
TimeoutError,
asyncio.TimeoutError,
):
# RuntimeError - on writing after EOF
break

Expand Down Expand Up @@ -256,12 +272,19 @@ def sse_response(
headers: Optional[Mapping[str, str]] = None,
sep: Optional[str] = None,
response_cls: Type[EventSourceResponse] = EventSourceResponse,
timeout: Optional[float] = None,
) -> Any:
if not issubclass(response_cls, EventSourceResponse):
raise TypeError(
"response_cls must be subclass of "
"aiohttp_sse.EventSourceResponse, got {}".format(response_cls)
)

sse = response_cls(status=status, reason=reason, headers=headers, sep=sep)
sse = response_cls(
status=status,
reason=reason,
headers=headers,
sep=sep,
timeout=timeout,
)
return _ContextManager(sse._prepare(request))
46 changes: 45 additions & 1 deletion tests/test_sse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import sys
from typing import Awaitable, Callable, List
from typing import Awaitable, Callable, List, Optional

import pytest
from aiohttp import web
Expand Down Expand Up @@ -559,3 +559,47 @@ async def handler(request: web.Request) -> EventSourceResponse:

async with client.get("/") as response:
assert 200 == response.status


@pytest.mark.parametrize("timeout", (None, 0.1))
async def test_with_timeout(
aiohttp_client: ClientFixture,
monkeypatch: pytest.MonkeyPatch,
timeout: Optional[float],
) -> None:
"""Test write timeout.
Relates to this issue:
https://github.com/sysid/sse-starlette/issues/89
"""
timeout_raised = False

async def frozen_write(_data: bytes) -> None:
await asyncio.sleep(42)

async def handler(request: web.Request) -> EventSourceResponse:
sse = EventSourceResponse(timeout=timeout)
sse.ping_interval = 42
await sse.prepare(request)
monkeypatch.setattr(sse, "write", frozen_write)

async with sse:
try:
await sse.send("foo")
except TimeoutError:
nonlocal timeout_raised
timeout_raised = True
raise

return sse

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)
async with client.get("/") as resp:
assert resp.status == 200
await asyncio.sleep(0.5)
assert resp.connection.closed is bool(timeout)

assert timeout_raised is bool(timeout)

0 comments on commit 8d6be2b

Please sign in to comment.