Skip to content

Commit 6242951

Browse files
committed
Implement Future for SpawnedTask.
It allows polling a SpawnedTask, instead of just joining it. The implementation is changed from `JoinSet` to a `JoinHandle` to simplify the code, as `JoinSet` doesn't provide any additional benefits.
1 parent 784df33 commit 6242951

File tree

1 file changed

+32
-15
lines changed

1 file changed

+32
-15
lines changed

datafusion/common-runtime/src/common.rs

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,23 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::future::Future;
18+
use std::{
19+
future::Future,
20+
pin::Pin,
21+
task::{Context, Poll},
22+
};
1923

20-
use crate::JoinSet;
21-
use tokio::task::JoinError;
24+
use tokio::task::{JoinError, JoinHandle};
2225

2326
/// Helper that provides a simple API to spawn a single task and join it.
2427
/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
28+
/// Note that if the task was spawned with `spawn_blocking`, it will only be
29+
/// aborted if it hasn't started yet.
2530
///
26-
/// Technically, it's just a wrapper of `JoinSet` (with size=1).
31+
/// Technically, it's just a wrapper of a `JoinHandle` overriding drop.
2732
#[derive(Debug)]
2833
pub struct SpawnedTask<R> {
29-
inner: JoinSet<R>,
34+
inner: JoinHandle<R>,
3035
}
3136

3237
impl<R: 'static> SpawnedTask<R> {
@@ -36,8 +41,8 @@ impl<R: 'static> SpawnedTask<R> {
3641
T: Send + 'static,
3742
R: Send,
3843
{
39-
let mut inner = JoinSet::new();
40-
inner.spawn(task);
44+
#[allow(clippy::disallowed_methods)]
45+
let inner = tokio::task::spawn(task);
4146
Self { inner }
4247
}
4348

@@ -47,22 +52,20 @@ impl<R: 'static> SpawnedTask<R> {
4752
T: Send + 'static,
4853
R: Send,
4954
{
50-
let mut inner = JoinSet::new();
51-
inner.spawn_blocking(task);
55+
#[allow(clippy::disallowed_methods)]
56+
let inner = tokio::task::spawn_blocking(task);
5257
Self { inner }
5358
}
5459

5560
/// Joins the task, returning the result of join (`Result<R, JoinError>`).
56-
pub async fn join(mut self) -> Result<R, JoinError> {
57-
self.inner
58-
.join_next()
59-
.await
60-
.expect("`SpawnedTask` instance always contains exactly 1 task")
61+
/// Same as awaiting the spawned task, but left for backwards compatibility.
62+
pub async fn join(self) -> Result<R, JoinError> {
63+
self.await
6164
}
6265

6366
/// Joins the task and unwinds the panic if it happens.
6467
pub async fn join_unwind(self) -> Result<R, JoinError> {
65-
self.join().await.map_err(|e| {
68+
self.await.map_err(|e| {
6669
// `JoinError` can be caused either by panic or cancellation. We have to handle panics:
6770
if e.is_panic() {
6871
std::panic::resume_unwind(e.into_panic());
@@ -77,6 +80,20 @@ impl<R: 'static> SpawnedTask<R> {
7780
}
7881
}
7982

83+
impl<R> Future for SpawnedTask<R> {
84+
type Output = Result<R, JoinError>;
85+
86+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
87+
Pin::new(&mut self.inner).poll(cx)
88+
}
89+
}
90+
91+
impl<R> Drop for SpawnedTask<R> {
92+
fn drop(&mut self) {
93+
self.inner.abort();
94+
}
95+
}
96+
8097
#[cfg(test)]
8198
mod tests {
8299
use super::*;

0 commit comments

Comments
 (0)