Skip to content

Commit

Permalink
Merge pull request #1565 from njsmith/stop-leaking-sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
pquentin authored May 29, 2020
2 parents e7d0571 + 0376be0 commit 4081a05
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 149 deletions.
2 changes: 1 addition & 1 deletion trio/_core/tests/test_multierror.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def run_script(name, use_ipython=False):
print("subprocess PYTHONPATH:", env.get("PYTHONPATH"))

if use_ipython:
lines = [script_path.open().read(), "exit()"]
lines = [script_path.read_text(), "exit()"]

cmd = [
sys.executable,
Expand Down
1 change: 1 addition & 0 deletions trio/_highlevel_open_tcp_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ async def open_tcp_stream(
open_ssl_over_tcp_stream
"""

# To keep our public API surface smaller, rule out some cases that
# getaddrinfo will accept in some circumstances, but that act weird or
# have non-portable behavior or are just plain not useful.
Expand Down
15 changes: 8 additions & 7 deletions trio/tests/test_highlevel_open_tcp_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,11 @@ async def check_backlog(nominal, required_min, required_max):
async def test_open_tcp_listeners_ipv6_v6only():
# Check IPV6_V6ONLY is working properly
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
_, port, *_ = ipv6_listener.socket.getsockname()
async with ipv6_listener:
_, port, *_ = ipv6_listener.socket.getsockname()

with pytest.raises(OSError):
await open_tcp_stream("127.0.0.1", port)
with pytest.raises(OSError):
await open_tcp_stream("127.0.0.1", port)


async def test_open_tcp_listeners_rebind():
Expand All @@ -127,10 +128,10 @@ async def test_open_tcp_listeners_rebind():

# Plain old rebinding while it's still there should fail, even if we have
# SO_REUSEADDR set
probe = stdlib_socket.socket()
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
with pytest.raises(OSError):
probe.bind(sockaddr1)
with stdlib_socket.socket() as probe:
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
with pytest.raises(OSError):
probe.bind(sockaddr1)

# Now use the first listener to set up some connections in various states,
# and make sure that they don't create any obstacle to rebinding a second
Expand Down
11 changes: 6 additions & 5 deletions trio/tests/test_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ async def accept(self):

async def test_socket_stream_works_when_peer_has_already_closed():
sock_a, sock_b = tsocket.socketpair()
await sock_b.send(b"x")
sock_b.close()
stream = SocketStream(sock_a)
assert await stream.receive_some(1) == b"x"
assert await stream.receive_some(1) == b""
with sock_a, sock_b:
await sock_b.send(b"x")
sock_b.close()
stream = SocketStream(sock_a)
assert await stream.receive_some(1) == b"x"
assert await stream.receive_some(1) == b""
90 changes: 47 additions & 43 deletions trio/tests/test_highlevel_ssl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,53 +43,57 @@ async def getnameinfo(self, *args): # pragma: no cover

# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# noqa is needed because flake8 doesn't understand how pytest fixtures work.
async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx,): # noqa: F811
async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811
async with trio.open_nursery() as nursery:
(listener,) = await nursery.start(
partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1",)
)
sockaddr = listener.transport_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
trio.socket.set_custom_hostname_resolver(hostname_resolver)

# We don't have the right trust set up
# (checks that ssl_context=None is doing some validation)
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()

# We have the trust but not the hostname
# (checks custom ssl_context + hostname checking)
stream = await open_ssl_over_tcp_stream(
"xyzzy.example.org", 80, ssl_context=client_ctx,
)
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()

# This one should work!
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org", 80, ssl_context=client_ctx,
)
assert isinstance(stream, trio.SSLStream)
assert stream.server_hostname == "trio-test-1.example.org"
await stream.send_all(b"x")
assert await stream.receive_some(1) == b"x"
await stream.aclose()

# Check https_compatible settings are being passed through
assert not stream._https_compatible
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org",
80,
ssl_context=client_ctx,
https_compatible=True,
# also, smoke test happy_eyeballs_delay
happy_eyeballs_delay=1,
)
assert stream._https_compatible

# Stop the echo server
nursery.cancel_scope.cancel()
async with listener:
sockaddr = listener.transport_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
trio.socket.set_custom_hostname_resolver(hostname_resolver)

# We don't have the right trust set up
# (checks that ssl_context=None is doing some validation)
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()

# We have the trust but not the hostname
# (checks custom ssl_context + hostname checking)
stream = await open_ssl_over_tcp_stream(
"xyzzy.example.org", 80, ssl_context=client_ctx,
)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()

# This one should work!
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org", 80, ssl_context=client_ctx,
)
async with stream:
assert isinstance(stream, trio.SSLStream)
assert stream.server_hostname == "trio-test-1.example.org"
await stream.send_all(b"x")
assert await stream.receive_some(1) == b"x"

# Check https_compatible settings are being passed through
assert not stream._https_compatible
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org",
80,
ssl_context=client_ctx,
https_compatible=True,
# also, smoke test happy_eyeballs_delay
happy_eyeballs_delay=1,
)
async with stream:
assert stream._https_compatible

# Stop the echo server
nursery.cancel_scope.cancel()


async def test_open_ssl_over_tcp_listeners():
Expand Down
147 changes: 78 additions & 69 deletions trio/tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ async def test_from_stdlib_socket():
class MySocket(stdlib_socket.socket):
pass

mysock = MySocket()
with pytest.raises(TypeError):
tsocket.from_stdlib_socket(mysock)
with MySocket() as mysock:
with pytest.raises(TypeError):
tsocket.from_stdlib_socket(mysock)


async def test_from_fd():
Expand Down Expand Up @@ -292,12 +292,15 @@ async def test_sniff_sockopts():
# check family / type for correctness:
assert tsocket_socket.family == socket.family
assert tsocket_socket.type == socket.type
tsocket_socket.detach()

# fromfd constructor
tsocket_from_fd = tsocket.fromfd(socket.fileno(), AF_INET, SOCK_STREAM)
# check family / type for correctness:
assert tsocket_from_fd.family == socket.family
assert tsocket_from_fd.type == socket.type
tsocket_from_fd.close()

socket.close()


Expand Down Expand Up @@ -482,73 +485,78 @@ class Addresses:
async def test_SocketType_resolve(socket_type, addrs):
v6 = socket_type == tsocket.AF_INET6

# For some reason the stdlib special-cases "" to pass NULL to getaddrinfo
# They also error out on None, but whatever, None is much more consistent,
# so we accept it too.
for null in [None, ""]:
sock = tsocket.socket(family=socket_type)
got = await sock._resolve_local_address((null, 80))
assert got == (addrs.bind_all, 80, *addrs.extra)
got = await sock._resolve_remote_address((null, 80))
assert got == (addrs.localhost, 80, *addrs.extra)

# AI_PASSIVE only affects the wildcard address, so for everything else
# _resolve_local_address and _resolve_remote_address should work the same:
for resolver in ["_resolve_local_address", "_resolve_remote_address"]:

async def res(*args):
return await getattr(sock, resolver)(*args)

assert await res((addrs.arbitrary, "http")) == (
addrs.arbitrary,
80,
*addrs.extra,
)
if v6:
assert await res(("1::2", 80, 1)) == ("1::2", 80, 1, 0)
assert await res(("1::2", 80, 1, 2)) == ("1::2", 80, 1, 2)

# V4 mapped addresses resolved if V6ONLY is False
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False)
assert await res(("1.2.3.4", "http")) == ("::ffff:1.2.3.4", 80, 0, 0,)

# Check the <broadcast> special case, because why not
assert await res(("<broadcast>", 123)) == (addrs.broadcast, 123, *addrs.extra,)

# But not if it's true (at least on systems where getaddrinfo works
# correctly)
if v6 and not gai_without_v4mapped_is_buggy():
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True)
with pytest.raises(tsocket.gaierror) as excinfo:
await res(("1.2.3.4", 80))
# Windows, macOS
expected_errnos = {tsocket.EAI_NONAME}
# Linux
if hasattr(tsocket, "EAI_ADDRFAMILY"):
expected_errnos.add(tsocket.EAI_ADDRFAMILY)
assert excinfo.value.errno in expected_errnos

# A family where we know nothing about the addresses, so should just
# pass them through. This should work on Linux, which is enough to
# smoke test the basic functionality...
try:
netlink_sock = tsocket.socket(
family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM
with tsocket.socket(family=socket_type) as sock:
# For some reason the stdlib special-cases "" to pass NULL to
# getaddrinfo They also error out on None, but whatever, None is much
# more consistent, so we accept it too.
for null in [None, ""]:
got = await sock._resolve_local_address((null, 80))
assert got == (addrs.bind_all, 80, *addrs.extra)
got = await sock._resolve_remote_address((null, 80))
assert got == (addrs.localhost, 80, *addrs.extra)

# AI_PASSIVE only affects the wildcard address, so for everything else
# _resolve_local_address and _resolve_remote_address should work the same:
for resolver in ["_resolve_local_address", "_resolve_remote_address"]:

async def res(*args):
return await getattr(sock, resolver)(*args)

assert await res((addrs.arbitrary, "http")) == (
addrs.arbitrary,
80,
*addrs.extra,
)
except (AttributeError, OSError):
pass
else:
assert await getattr(netlink_sock, resolver)("asdf") == "asdf"

with pytest.raises(ValueError):
await res("1.2.3.4")
with pytest.raises(ValueError):
await res(("1.2.3.4",))
with pytest.raises(ValueError):
if v6:
await res(("1.2.3.4", 80, 0, 0, 0))
assert await res(("1::2", 80, 1)) == ("1::2", 80, 1, 0)
assert await res(("1::2", 80, 1, 2)) == ("1::2", 80, 1, 2)

# V4 mapped addresses resolved if V6ONLY is False
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False)
assert await res(("1.2.3.4", "http")) == ("::ffff:1.2.3.4", 80, 0, 0,)

# Check the <broadcast> special case, because why not
assert await res(("<broadcast>", 123)) == (
addrs.broadcast,
123,
*addrs.extra,
)

# But not if it's true (at least on systems where getaddrinfo works
# correctly)
if v6 and not gai_without_v4mapped_is_buggy():
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True)
with pytest.raises(tsocket.gaierror) as excinfo:
await res(("1.2.3.4", 80))
# Windows, macOS
expected_errnos = {tsocket.EAI_NONAME}
# Linux
if hasattr(tsocket, "EAI_ADDRFAMILY"):
expected_errnos.add(tsocket.EAI_ADDRFAMILY)
assert excinfo.value.errno in expected_errnos

# A family where we know nothing about the addresses, so should just
# pass them through. This should work on Linux, which is enough to
# smoke test the basic functionality...
try:
netlink_sock = tsocket.socket(
family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM
)
except (AttributeError, OSError):
pass
else:
await res(("1.2.3.4", 80, 0, 0))
assert await getattr(netlink_sock, resolver)("asdf") == "asdf"
netlink_sock.close()

with pytest.raises(ValueError):
await res("1.2.3.4")
with pytest.raises(ValueError):
await res(("1.2.3.4",))
with pytest.raises(ValueError):
if v6:
await res(("1.2.3.4", 80, 0, 0, 0))
else:
await res(("1.2.3.4", 80, 0, 0))


async def test_SocketType_unresolved_names():
Expand Down Expand Up @@ -923,8 +931,9 @@ async def check_AF_UNIX(path):
with tsocket.socket(family=tsocket.AF_UNIX) as csock:
await csock.connect(path)
ssock, _ = await lsock.accept()
await csock.send(b"x")
assert await ssock.recv(1) == b"x"
with ssock:
await csock.send(b"x")
assert await ssock.recv(1) == b"x"

# Can't use tmpdir fixture, because we can exceed the maximum AF_UNIX path
# length on macOS.
Expand Down
Loading

0 comments on commit 4081a05

Please sign in to comment.