Skip to content

Commit ad7a7e3

Browse files
zaniebtomchristie
andauthored
Add check for h2.connection.ConnectionState.CLOSED in AsyncHTTP2Connection.is_available (#679)
* Add check for `h2.connection.ConnectionState.CLOSED` in `AsyncHTTP2Connection.is_available` * Add sync implementation * Add test for closed connection * Regenerate sync tests with `unasync` * Use async with * Add anyio annotation --------- Co-authored-by: Tom Christie <[email protected]>
1 parent 9c42d41 commit ad7a7e3

File tree

4 files changed

+76
-0
lines changed

4 files changed

+76
-0
lines changed

Diff for: httpcore/_async/http2.py

+4
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ def is_available(self) -> bool:
433433
self._state != HTTPConnectionState.CLOSED
434434
and not self._connection_error
435435
and not self._used_all_stream_ids
436+
and not (
437+
self._h2_state.state_machine.state
438+
== h2.connection.ConnectionState.CLOSED
439+
)
436440
)
437441

438442
def has_expired(self) -> bool:

Diff for: httpcore/_sync/http2.py

+4
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ def is_available(self) -> bool:
433433
self._state != HTTPConnectionState.CLOSED
434434
and not self._connection_error
435435
and not self._used_all_stream_ids
436+
and not (
437+
self._h2_state.state_machine.state
438+
== h2.connection.ConnectionState.CLOSED
439+
)
436440
)
437441

438442
def has_expired(self) -> bool:

Diff for: tests/_async/test_http2.py

+34
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,40 @@ async def test_http2_connection():
5252
)
5353

5454

55+
@pytest.mark.anyio
56+
async def test_http2_connection_closed():
57+
origin = Origin(b"https", b"example.com", 443)
58+
stream = AsyncMockStream(
59+
[
60+
hyperframe.frame.SettingsFrame().serialize(),
61+
hyperframe.frame.HeadersFrame(
62+
stream_id=1,
63+
data=hpack.Encoder().encode(
64+
[
65+
(b":status", b"200"),
66+
(b"content-type", b"plain/text"),
67+
]
68+
),
69+
flags=["END_HEADERS"],
70+
).serialize(),
71+
hyperframe.frame.DataFrame(
72+
stream_id=1, data=b"Hello, world!", flags=["END_STREAM"]
73+
).serialize(),
74+
# Connection is closed after the first response
75+
hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(),
76+
]
77+
)
78+
async with AsyncHTTP2Connection(
79+
origin=origin, stream=stream, keepalive_expiry=5.0
80+
) as conn:
81+
await conn.request("GET", "https://example.com/")
82+
83+
with pytest.raises(RemoteProtocolError):
84+
await conn.request("GET", "https://example.com/")
85+
86+
assert not conn.is_available()
87+
88+
5589
@pytest.mark.anyio
5690
async def test_http2_connection_post_request():
5791
origin = Origin(b"https", b"example.com", 443)

Diff for: tests/_sync/test_http2.py

+34
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,40 @@ def test_http2_connection():
5353

5454

5555

56+
def test_http2_connection_closed():
57+
origin = Origin(b"https", b"example.com", 443)
58+
stream = MockStream(
59+
[
60+
hyperframe.frame.SettingsFrame().serialize(),
61+
hyperframe.frame.HeadersFrame(
62+
stream_id=1,
63+
data=hpack.Encoder().encode(
64+
[
65+
(b":status", b"200"),
66+
(b"content-type", b"plain/text"),
67+
]
68+
),
69+
flags=["END_HEADERS"],
70+
).serialize(),
71+
hyperframe.frame.DataFrame(
72+
stream_id=1, data=b"Hello, world!", flags=["END_STREAM"]
73+
).serialize(),
74+
# Connection is closed after the first response
75+
hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(),
76+
]
77+
)
78+
with HTTP2Connection(
79+
origin=origin, stream=stream, keepalive_expiry=5.0
80+
) as conn:
81+
conn.request("GET", "https://example.com/")
82+
83+
with pytest.raises(RemoteProtocolError):
84+
conn.request("GET", "https://example.com/")
85+
86+
assert not conn.is_available()
87+
88+
89+
5690
def test_http2_connection_post_request():
5791
origin = Origin(b"https", b"example.com", 443)
5892
stream = MockStream(

0 commit comments

Comments
 (0)