Skip to content

Commit 139c2c3

Browse files
committed
refactor: use enum for ResumptionTicketState
1 parent b67e570 commit 139c2c3

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

quinn/src/connection.rs

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ use crate::{
2525
udp_transmit,
2626
};
2727
use proto::{
28-
ConnectionError, ConnectionHandle, ConnectionStats, Dir, EndpointEvent, StreamEvent, StreamId,
29-
congestion::Controller,
28+
ConnectionError, ConnectionHandle, ConnectionStats, Dir, EndpointEvent, Side, StreamEvent,
29+
StreamId, congestion::Controller,
3030
};
3131

3232
/// In-progress connection attempt future
@@ -639,23 +639,25 @@ impl Connection {
639639

640640
/// Waits until the connection received TLS resumption tickets.
641641
///
642-
/// Completes immediately if tickets were already received. Otherwise completes
643-
/// once tickets are received.
642+
/// Returns `true` once resumption tickets were received. Resolves immediately
643+
/// if tickets were already received, otherwise it resolves once tickets arrive.
644+
/// If the server does not send any tickets, the returned future will remain pending forever.
644645
///
645-
/// Should only be used on the client. On the server, this will be pending forever.
646-
/// Will also be pending forever if the server does not send any tickets.
647-
pub fn resumption_tickets_received(&self) -> impl Future<Output = ()> + Send + 'static {
646+
/// This should only be used on the client side. On the server side, it will
647+
/// always resolve immediately and return `false`.
648+
pub fn resumption_tickets_received(&self) -> impl Future<Output = bool> + Send + 'static {
648649
let conn = self.0.state.lock("resumption_tickets_received");
649-
let mut notify = if !conn.resumption_tickets_received {
650-
Some(conn.resumption_tickets_received_notify.clone())
651-
} else {
652-
None
650+
let (mut notify, out) = match conn.resumption_tickets.as_ref() {
651+
Some(ResumptionTicketState::Received) => (None, true),
652+
Some(ResumptionTicketState::Pending(notify)) => (Some(notify.clone()), true),
653+
None => (None, false),
653654
};
654655
drop(conn);
655656
async move {
656657
if let Some(notify) = notify.take() {
657658
notify.notified().await;
658659
}
660+
out
659661
}
660662
}
661663
}
@@ -892,6 +894,10 @@ impl ConnectionRef {
892894
socket: Arc<dyn AsyncUdpSocket>,
893895
runtime: Arc<dyn Runtime>,
894896
) -> Self {
897+
let resumption_tickets = match conn.side() {
898+
Side::Client => Some(ResumptionTicketState::Pending(Default::default())),
899+
Side::Server => None,
900+
};
895901
Self(Arc::new(ConnectionInner {
896902
state: Mutex::new(State {
897903
inner: conn,
@@ -914,8 +920,7 @@ impl ConnectionRef {
914920
runtime,
915921
send_buffer: Vec::new(),
916922
buffered_transmit: None,
917-
resumption_tickets_received: false,
918-
resumption_tickets_received_notify: Arc::new(Notify::new()),
923+
resumption_tickets,
919924
}),
920925
shared: Shared::default(),
921926
}))
@@ -998,8 +1003,8 @@ pub(crate) struct State {
9981003
send_buffer: Vec<u8>,
9991004
/// We buffer a transmit when the underlying I/O would block
10001005
buffered_transmit: Option<proto::Transmit>,
1001-
resumption_tickets_received: bool,
1002-
resumption_tickets_received_notify: Arc<Notify>,
1006+
/// Whether we received resumption tickets. None on the server side.
1007+
resumption_tickets: Option<ResumptionTicketState>,
10031008
}
10041009

10051010
impl State {
@@ -1135,8 +1140,12 @@ impl State {
11351140
}
11361141
}
11371142
ResumptionTicketsReceived => {
1138-
self.resumption_tickets_received = true;
1139-
self.resumption_tickets_received_notify.notify_waiters();
1143+
if let Some(ResumptionTicketState::Pending(notify)) =
1144+
self.resumption_tickets.as_mut()
1145+
{
1146+
notify.notify_waiters();
1147+
self.resumption_tickets = Some(ResumptionTicketState::Received);
1148+
}
11401149
}
11411150
ConnectionLost { reason } => {
11421151
self.terminate(reason, shared);
@@ -1308,6 +1317,12 @@ fn wake_all_notify(wakers: &mut FxHashMap<StreamId, Arc<Notify>>) {
13081317
.for_each(|(_, notify)| notify.notify_waiters())
13091318
}
13101319

1320+
#[derive(Debug)]
1321+
enum ResumptionTicketState {
1322+
Received,
1323+
Pending(Arc<Notify>),
1324+
}
1325+
13111326
/// Errors that can arise when sending a datagram
13121327
#[derive(Debug, Error, Clone, Eq, PartialEq)]
13131328
pub enum SendDatagramError {

0 commit comments

Comments
 (0)