Skip to content

fix: prevent reuse of the stream after an error #7014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ use crate::sql::Sql;
use crate::tools::time;

pub(crate) mod dns;
pub(crate) mod error_capturing_stream;
pub(crate) mod http;
pub(crate) mod proxy;
pub(crate) mod session;
pub(crate) mod tls;

use dns::lookup_host_with_cache;
pub(crate) use error_capturing_stream::ErrorCapturingStream;
pub use http::{Response as HttpResponse, read_url, read_url_blob};
use tls::wrap_tls;

Expand Down Expand Up @@ -105,7 +107,7 @@ pub(crate) async fn load_connection_timestamp(
/// to the network, which is important to reduce the latency of interactive protocols such as IMAP.
pub(crate) async fn connect_tcp_inner(
addr: SocketAddr,
) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
) -> Result<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>> {
let tcp_stream = timeout(TIMEOUT, TcpStream::connect(addr))
.await
.context("connection timeout")?
Expand All @@ -118,7 +120,9 @@ pub(crate) async fn connect_tcp_inner(
timeout_stream.set_write_timeout(Some(TIMEOUT));
timeout_stream.set_read_timeout(Some(TIMEOUT));

Ok(Box::pin(timeout_stream))
let error_capturing_stream = ErrorCapturingStream::new(timeout_stream);

Ok(Box::pin(error_capturing_stream))
}

/// Attempts to establish TLS connection
Expand Down Expand Up @@ -235,7 +239,7 @@ pub(crate) async fn connect_tcp(
host: &str,
port: u16,
load_cache: bool,
) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
) -> Result<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>> {
let connection_futures = lookup_host_with_cache(context, host, port, "", load_cache)
.await?
.into_iter()
Expand Down
136 changes: 136 additions & 0 deletions src/net/error_capturing_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use std::io::IoSlice;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};

use pin_project::pin_project;

use crate::net::SessionStream;

/// Stream that remembers the first error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only remembers whether an error took place. The comment should be fixed

/// and keeps returning it afterwards.
///
/// It is needed to avoid accidentally using
/// the stream after read timeout.
#[derive(Debug)]
#[pin_project]
pub(crate) struct ErrorCapturingStream<T: AsyncRead + AsyncWrite + std::fmt::Debug> {
#[pin]
inner: T,

/// If true, the stream has already returned an error once.
///
/// All read and write operations return error in this case.
is_broken: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, separate flags is_{read,write}_broken may be introduced, though this is probably not necessary. Still, this way the reader may finish reading useful responses from the server even if writing breaks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we write some command, and writing breaks recoverably (e.g. times out) we don't want whatever response arrives to be read and misinterpreted as a response to the next command.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't be interpreted as a response to the next command because the next command can't even be written to the socket (because is_write_broken persists). But responses to the previous commands can be read, what is wrong with them? Usually protocols allow commands to be sent in batches and responses can be processed asynchronously

}

impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> ErrorCapturingStream<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
is_broken: false,
}
}

/// Gets a reference to the underlying stream.
pub fn get_ref(&self) -> &T {
&self.inner
}

/// Gets a pinned mutable reference to the underlying stream.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().inner
}
}

impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> AsyncRead for ErrorCapturingStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_read(cx, buf);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}
}

impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> AsyncWrite for ErrorCapturingStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_write(cx, buf);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_flush(cx);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_shutdown(cx);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_write_vectored(cx, bufs);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

impl<T: SessionStream> SessionStream for ErrorCapturingStream<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.inner.set_read_timeout(timeout)
}

fn peer_addr(&self) -> anyhow::Result<SocketAddr> {
self.inner.peer_addr()
}
}
4 changes: 2 additions & 2 deletions src/net/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use url::Url;
use crate::config::Config;
use crate::constants::NON_ALPHANUMERIC_WITHOUT_DOT;
use crate::context::Context;
use crate::net::connect_tcp;
use crate::net::session::SessionStream;
use crate::net::tls::wrap_rustls;
use crate::net::{ErrorCapturingStream, connect_tcp};
use crate::sql::Sql;

/// Default SOCKS5 port according to [RFC 1928](https://tools.ietf.org/html/rfc1928).
Expand Down Expand Up @@ -118,7 +118,7 @@ impl Socks5Config {
target_host: &str,
target_port: u16,
load_dns_cache: bool,
) -> Result<Socks5Stream<Pin<Box<TimeoutStream<TcpStream>>>>> {
) -> Result<Socks5Stream<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>>> {
let tcp_stream = connect_tcp(context, &self.host, self.port, load_dns_cache)
.await
.context("Failed to connect to SOCKS5 proxy")?;
Expand Down
8 changes: 5 additions & 3 deletions src/net/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter};
use tokio::net::TcpStream;
use tokio_io_timeout::TimeoutStream;

use crate::net::ErrorCapturingStream;

pub(crate) trait SessionStream:
AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug
{
Expand Down Expand Up @@ -61,13 +63,13 @@ impl<T: SessionStream> SessionStream for BufWriter<T> {
self.get_ref().peer_addr()
}
}
impl SessionStream for Pin<Box<TimeoutStream<TcpStream>>> {
impl SessionStream for Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.as_mut().set_read_timeout_pinned(timeout);
self.as_mut().get_pin_mut().set_read_timeout_pinned(timeout);
}

fn peer_addr(&self) -> Result<SocketAddr> {
Ok(self.get_ref().peer_addr()?)
Ok(self.get_ref().get_ref().peer_addr()?)
}
}
impl<T: SessionStream> SessionStream for Socks5Stream<T> {
Expand Down
Loading