Skip to content

Commit 595e47d

Browse files
rustypbzweihander
rusty
authored andcommitted
Wait tasks with waitgroup
1 parent 54dbb93 commit 595e47d

File tree

3 files changed

+21
-28
lines changed

3 files changed

+21
-28
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ route-recognizer = "0.2.0"
5151
serde = "1.0.117"
5252
serde_json = "1.0.59"
5353
stopper = "0.2.0"
54+
waitgroup = "0.1.2"
5455

5556
[dev-dependencies]
5657
async-std = { version = "1.6.5", features = ["unstable", "attributes"] }

src/listener/tcp_listener.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ use async_std::prelude::*;
1010
use async_std::{io, task};
1111

1212
use futures_util::future::Either;
13-
use futures_util::stream::FuturesUnordered;
13+
14+
use waitgroup::{WaitGroup, Worker};
1415

1516
/// This represents a tide [Listener](crate::listener::Listener) that
1617
/// wraps an [async_std::net::TcpListener]. It is implemented as an
@@ -25,7 +26,6 @@ pub struct TcpListener<State> {
2526
listener: Option<net::TcpListener>,
2627
server: Option<Server<State>>,
2728
info: Option<ListenInfo>,
28-
join_handles: Vec<task::JoinHandle<()>>,
2929
}
3030

3131
impl<State> TcpListener<State> {
@@ -35,7 +35,6 @@ impl<State> TcpListener<State> {
3535
listener: None,
3636
server: None,
3737
info: None,
38-
join_handles: Vec::new(),
3938
}
4039
}
4140

@@ -45,16 +44,18 @@ impl<State> TcpListener<State> {
4544
listener: Some(tcp_listener.into()),
4645
server: None,
4746
info: None,
48-
join_handles: Vec::new(),
4947
}
5048
}
5149
}
5250

5351
fn handle_tcp<State: Clone + Send + Sync + 'static>(
5452
app: Server<State>,
5553
stream: TcpStream,
56-
) -> task::JoinHandle<()> {
54+
wait_group_worker: Worker,
55+
) {
5756
task::spawn(async move {
57+
let _wait_group_worker = wait_group_worker;
58+
5859
let local_addr = stream.local_addr().ok();
5960
let peer_addr = stream.peer_addr().ok();
6061

@@ -75,7 +76,7 @@ fn handle_tcp<State: Clone + Send + Sync + 'static>(
7576
if let Err(error) = fut.await {
7677
log::error!("async-h1 error", { error: error.to_string() });
7778
}
78-
})
79+
});
7980
}
8081

8182
#[async_trait::async_trait]
@@ -121,6 +122,7 @@ where
121122
} else {
122123
Either::Right(incoming)
123124
};
125+
let wait_group = WaitGroup::new();
124126

125127
while let Some(stream) = incoming.next().await {
126128
match stream {
@@ -133,18 +135,12 @@ where
133135
}
134136

135137
Ok(stream) => {
136-
let handle = handle_tcp(server.clone(), stream);
137-
self.join_handles.push(handle);
138+
handle_tcp(server.clone(), stream, wait_group.worker());
138139
}
139140
};
140141
}
141142

142-
let join_handles = std::mem::take(&mut self.join_handles);
143-
join_handles
144-
.into_iter()
145-
.collect::<FuturesUnordered<task::JoinHandle<()>>>()
146-
.collect::<()>()
147-
.await;
143+
wait_group.wait().await;
148144

149145
Ok(())
150146
}

src/listener/unix_listener.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ use async_std::prelude::*;
1111
use async_std::{io, task};
1212

1313
use futures_util::future::Either;
14-
use futures_util::stream::FuturesUnordered;
14+
15+
use waitgroup::{WaitGroup, Worker};
1516

1617
/// This represents a tide [Listener](crate::listener::Listener) that
1718
/// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an
@@ -26,7 +27,6 @@ pub struct UnixListener<State> {
2627
listener: Option<net::UnixListener>,
2728
server: Option<Server<State>>,
2829
info: Option<ListenInfo>,
29-
join_handles: Vec<task::JoinHandle<()>>,
3030
}
3131

3232
impl<State> UnixListener<State> {
@@ -36,7 +36,6 @@ impl<State> UnixListener<State> {
3636
listener: None,
3737
server: None,
3838
info: None,
39-
join_handles: Vec::new(),
4039
}
4140
}
4241

@@ -46,16 +45,18 @@ impl<State> UnixListener<State> {
4645
listener: Some(unix_listener.into()),
4746
server: None,
4847
info: None,
49-
join_handles: Vec::new(),
5048
}
5149
}
5250
}
5351

5452
fn handle_unix<State: Clone + Send + Sync + 'static>(
5553
app: Server<State>,
5654
stream: UnixStream,
57-
) -> task::JoinHandle<()> {
55+
wait_group_worker: Worker,
56+
) {
5857
task::spawn(async move {
58+
let _wait_group_worker = wait_group_worker;
59+
5960
let local_addr = unix_socket_addr_to_string(stream.local_addr());
6061
let peer_addr = unix_socket_addr_to_string(stream.peer_addr());
6162

@@ -76,7 +77,7 @@ fn handle_unix<State: Clone + Send + Sync + 'static>(
7677
if let Err(error) = fut.await {
7778
log::error!("async-h1 error", { error: error.to_string() });
7879
}
79-
})
80+
});
8081
}
8182

8283
#[async_trait::async_trait]
@@ -119,6 +120,7 @@ where
119120
} else {
120121
Either::Right(incoming)
121122
};
123+
let wait_group = WaitGroup::new();
122124

123125
while let Some(stream) = incoming.next().await {
124126
match stream {
@@ -131,18 +133,12 @@ where
131133
}
132134

133135
Ok(stream) => {
134-
let handle = handle_unix(server.clone(), stream);
135-
self.join_handles.push(handle);
136+
handle_unix(server.clone(), stream, wait_group.worker());
136137
}
137138
};
138139
}
139140

140-
let join_handles = std::mem::take(&mut self.join_handles);
141-
join_handles
142-
.into_iter()
143-
.collect::<FuturesUnordered<task::JoinHandle<()>>>()
144-
.collect::<()>()
145-
.await;
141+
wait_group.wait().await;
146142

147143
Ok(())
148144
}

0 commit comments

Comments
 (0)