Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 19 additions & 32 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
ConnectionEvent, Duration, Instant, VarInt,
mutex::Mutex,
recv_stream::RecvStream,
runtime::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller},
runtime::{AsyncTimer, Runtime, UdpSender},
send_stream::SendStream,
udp_transmit,
};
Expand All @@ -43,7 +43,7 @@ impl Connecting {
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
socket: Arc<dyn AsyncUdpSocket>,
sender: Pin<Box<dyn UdpSender>>,
runtime: Arc<dyn Runtime>,
) -> Self {
let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel();
Expand All @@ -55,7 +55,7 @@ impl Connecting {
conn_events,
on_handshake_data_send,
on_connected_send,
socket,
sender,
runtime.clone(),
);

Expand Down Expand Up @@ -882,7 +882,7 @@ impl ConnectionRef {
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
on_handshake_data: oneshot::Sender<()>,
on_connected: oneshot::Sender<bool>,
socket: Arc<dyn AsyncUdpSocket>,
sender: Pin<Box<dyn UdpSender>>,
runtime: Arc<dyn Runtime>,
) -> Self {
Self(Arc::new(ConnectionInner {
Expand All @@ -902,8 +902,7 @@ impl ConnectionRef {
stopped: FxHashMap::default(),
error: None,
ref_count: 0,
io_poller: socket.clone().create_io_poller(),
socket,
sender,
runtime,
send_buffer: Vec::new(),
buffered_transmit: None,
Expand Down Expand Up @@ -983,8 +982,7 @@ pub(crate) struct State {
pub(crate) error: Option<ConnectionError>,
/// Number of live handles that can be used to initiate or handle I/O; excludes the driver
ref_count: usize,
socket: Arc<dyn AsyncUdpSocket>,
io_poller: Pin<Box<dyn UdpPoller>>,
sender: Pin<Box<dyn UdpSender>>,
runtime: Arc<dyn Runtime>,
send_buffer: Vec<u8>,
/// We buffer a transmit when the underlying I/O would block
Expand All @@ -997,7 +995,7 @@ impl State {
let mut transmits = 0;

let max_datagrams = self
.socket
.sender
.max_transmit_segments()
.min(MAX_TRANSMIT_SEGMENTS);

Expand All @@ -1024,28 +1022,18 @@ impl State {
}
};

if self.io_poller.as_mut().poll_writable(cx)?.is_pending() {
// Retry after a future wakeup
self.buffered_transmit = Some(t);
return Ok(false);
}

let len = t.size;
let retry = match self
.socket
.try_send(&udp_transmit(&t, &self.send_buffer[..len]))
match self
.sender
.as_mut()
.poll_send(&udp_transmit(&t, &self.send_buffer[..len]), cx)
{
Ok(()) => false,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true,
Err(e) => return Err(e),
};
if retry {
// We thought the socket was writable, but it wasn't. Retry so that either another
// `poll_writable` call determines that the socket is indeed not writable and
// registers us for a wakeup, or the send succeeds if this really was just a
// transient failure.
self.buffered_transmit = Some(t);
continue;
Poll::Pending => {
self.buffered_transmit = Some(t);
return Ok(false);
}
Poll::Ready(Err(e)) => return Err(e),
Poll::Ready(Ok(())) => {}
}

if transmits >= MAX_TRANSMIT_DATAGRAMS {
Expand Down Expand Up @@ -1075,9 +1063,8 @@ impl State {
) -> Result<(), ConnectionError> {
loop {
match self.conn_events.poll_recv(cx) {
Poll::Ready(Some(ConnectionEvent::Rebind(socket))) => {
self.socket = socket;
self.io_poller = self.socket.clone().create_io_poller();
Poll::Ready(Some(ConnectionEvent::Rebind(sender))) => {
self.sender = sender;
self.inner.local_address_changed();
}
Poll::Ready(Some(ConnectionEvent::Proto(event))) => {
Expand Down
101 changes: 70 additions & 31 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@ use std::{
collections::VecDeque,
fmt,
future::Future,
io,
io::IoSliceMut,
io::{self, IoSliceMut},
mem,
net::{SocketAddr, SocketAddrV6},
pin::Pin,
str,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};

#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
use crate::runtime::default_runtime;
use crate::{
Instant,
runtime::{AsyncUdpSocket, Runtime},
runtime::{AsyncUdpSocket, Runtime, UdpSender},
udp_transmit,
};
use bytes::{Bytes, BytesMut};
Expand Down Expand Up @@ -130,7 +129,7 @@ impl Endpoint {
pub fn new_with_abstract_socket(
config: EndpointConfig,
server_config: Option<ServerConfig>,
socket: Arc<dyn AsyncUdpSocket>,
socket: Box<dyn AsyncUdpSocket>,
runtime: Arc<dyn Runtime>,
) -> io::Result<Self> {
let addr = socket.local_addr()?;
Expand Down Expand Up @@ -225,12 +224,12 @@ impl Endpoint {
.inner
.connect(self.runtime.now(), config, addr, server_name)?;

let socket = endpoint.socket.clone();
let sender = endpoint.socket.create_sender();
endpoint.stats.outgoing_handshakes += 1;
Ok(endpoint
.recv_state
.connections
.insert(ch, conn, socket, self.runtime.clone()))
.insert(ch, conn, sender, self.runtime.clone()))
}

/// Switch to a new UDP socket
Expand All @@ -247,7 +246,7 @@ impl Endpoint {
/// connections and connections to servers unreachable from the new address will be lost.
///
/// On error, the old UDP socket is retained.
pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
pub fn rebind_abstract(&self, socket: Box<dyn AsyncUdpSocket>) -> io::Result<()> {
let addr = socket.local_addr()?;
let mut inner = self.inner.state.lock().unwrap();
inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
Expand All @@ -256,7 +255,7 @@ impl Endpoint {
// Update connection socket references
for sender in inner.recv_state.connections.senders.values() {
// Ignoring errors from dropped connections
let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
let _ = sender.send(ConnectionEvent::Rebind(inner.socket.create_sender()));
}
if let Some(driver) = inner.driver.take() {
// Ensure the driver can register for wake-ups from the new socket
Expand Down Expand Up @@ -425,16 +424,16 @@ impl EndpointInner {
{
Ok((handle, conn)) => {
state.stats.accepted_handshakes += 1;
let socket = state.socket.clone();
let sender = state.socket.create_sender();
let runtime = state.runtime.clone();
Ok(state
.recv_state
.connections
.insert(handle, conn, socket, runtime))
.insert(handle, conn, sender, runtime))
}
Err(error) => {
if let Some(transmit) = error.response {
respond(transmit, &response_buffer, &*state.socket);
respond(transmit, &response_buffer, &mut state.sender);
}
Err(error.cause)
}
Expand All @@ -446,14 +445,14 @@ impl EndpointInner {
state.stats.refused_handshakes += 1;
let mut response_buffer = Vec::new();
let transmit = state.inner.refuse(incoming, &mut response_buffer);
respond(transmit, &response_buffer, &*state.socket);
respond(transmit, &response_buffer, &mut state.sender);
}

pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
let mut state = self.state.lock().unwrap();
let mut response_buffer = Vec::new();
let transmit = state.inner.retry(incoming, &mut response_buffer)?;
respond(transmit, &response_buffer, &*state.socket);
respond(transmit, &response_buffer, &mut state.sender);
Ok(())
}

Expand All @@ -466,10 +465,11 @@ impl EndpointInner {

#[derive(Debug)]
pub(crate) struct State {
socket: Arc<dyn AsyncUdpSocket>,
socket: Box<dyn AsyncUdpSocket>,
sender: Pin<Box<dyn UdpSender>>,
/// During an active migration, abandoned_socket receives traffic
/// until the first packet arrives on the new socket.
prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
prev_socket: Option<Box<dyn AsyncUdpSocket>>,
inner: proto::Endpoint,
recv_state: RecvState,
driver: Option<Waker>,
Expand All @@ -492,18 +492,28 @@ impl State {
fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
let get_time = || self.runtime.now();
self.recv_state.recv_limiter.start_cycle(get_time);
if let Some(socket) = &self.prev_socket {
if let Some(socket) = &mut self.prev_socket {
// We don't care about the `PollProgress` from old sockets.
let poll_res =
self.recv_state
.poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
let poll_res = self.recv_state.poll_socket(
cx,
&mut self.inner,
&mut **socket,
&mut self.sender,
&*self.runtime,
now,
);
if poll_res.is_err() {
self.prev_socket = None;
}
};
let poll_res =
self.recv_state
.poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
let poll_res = self.recv_state.poll_socket(
cx,
&mut self.inner,
&mut *self.socket,
&mut self.sender,
&*self.runtime,
now,
);
self.recv_state.recv_limiter.finish_cycle(get_time);
let poll_res = poll_res?;
if poll_res.received_connection_packet {
Expand Down Expand Up @@ -555,7 +565,11 @@ impl Drop for State {
}
}

fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
fn respond(
transmit: proto::Transmit,
response_buffer: &[u8],
sender: &mut Pin<Box<dyn UdpSender>>,
) {
// Send if there's kernel buffer space; otherwise, drop it
//
// As an endpoint-generated packet, we know this is an
Expand All @@ -576,7 +590,29 @@ fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn Async
// to transmit. This is morally equivalent to the packet getting
// lost due to congestion further along the link, which
// similarly relies on peer retries for recovery.
_ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));

// Copied from rust 1.85's std::task::Waker::noop() implementation for backwards compatibility
const NOOP: RawWaker = {
const VTABLE: RawWakerVTable = RawWakerVTable::new(
// Cloning just returns a new no-op raw waker
|_| NOOP,
// `wake` does nothing
|_| {},
// `wake_by_ref` does nothing
|_| {},
// Dropping does nothing as we don't allocate anything
|_| {},
);
RawWaker::new(std::ptr::null(), &VTABLE)
};
// SAFETY: Copied from rust stdlib, the NOOP waker is thread-safe and doesn't violate the RawWakerVTable contract,
// it doesn't access the data pointer at all.
let waker = unsafe { Waker::from_raw(NOOP) };
let mut cx = Context::from_waker(&waker);
_ = sender.as_mut().poll_send(
&udp_transmit(&transmit, &response_buffer[..transmit.size]),
&mut cx,
);
}

#[inline]
Expand All @@ -603,7 +639,7 @@ impl ConnectionSet {
&mut self,
handle: ConnectionHandle,
conn: proto::Connection,
socket: Arc<dyn AsyncUdpSocket>,
sender: Pin<Box<dyn UdpSender>>,
runtime: Arc<dyn Runtime>,
) -> Connecting {
let (send, recv) = mpsc::unbounded_channel();
Expand All @@ -615,7 +651,7 @@ impl ConnectionSet {
.unwrap();
}
self.senders.insert(handle, send);
Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
Connecting::new(handle, conn, self.sender.clone(), recv, sender, runtime)
}

fn is_empty(&self) -> bool {
Expand Down Expand Up @@ -674,20 +710,22 @@ pub(crate) struct EndpointRef(Arc<EndpointInner>);

impl EndpointRef {
pub(crate) fn new(
socket: Arc<dyn AsyncUdpSocket>,
socket: Box<dyn AsyncUdpSocket>,
inner: proto::Endpoint,
ipv6: bool,
runtime: Arc<dyn Runtime>,
) -> Self {
let (sender, events) = mpsc::unbounded_channel();
let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
let sender = socket.create_sender();
Self(Arc::new(EndpointInner {
shared: Shared {
incoming: Notify::new(),
idle: Notify::new(),
},
state: Mutex::new(State {
socket,
sender,
prev_socket: None,
inner,
ipv6,
Expand Down Expand Up @@ -769,7 +807,8 @@ impl RecvState {
&mut self,
cx: &mut Context,
endpoint: &mut proto::Endpoint,
socket: &dyn AsyncUdpSocket,
socket: &mut dyn AsyncUdpSocket,
sender: &mut Pin<Box<dyn UdpSender>>,
runtime: &dyn Runtime,
now: Instant,
) -> Result<PollProgress, io::Error> {
Expand Down Expand Up @@ -809,7 +848,7 @@ impl RecvState {
} else {
let transmit =
endpoint.refuse(incoming, &mut response_buffer);
respond(transmit, &response_buffer, socket);
respond(transmit, &response_buffer, sender);
}
}
Some(DatagramEvent::ConnectionEvent(handle, event)) => {
Expand All @@ -823,7 +862,7 @@ impl RecvState {
.send(ConnectionEvent::Proto(event));
}
Some(DatagramEvent::Response(transmit)) => {
respond(transmit, &response_buffer, socket);
respond(transmit, &response_buffer, sender);
}
None => {}
}
Expand Down
Loading
Loading