15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
- use std:: future:: Future ;
18
+ use std:: {
19
+ future:: Future ,
20
+ pin:: Pin ,
21
+ task:: { Context , Poll } ,
22
+ } ;
19
23
20
- use crate :: JoinSet ;
21
- use tokio:: task:: JoinError ;
24
+ use tokio:: task:: { JoinError , JoinHandle } ;
22
25
23
26
/// Helper that provides a simple API to spawn a single task and join it.
24
27
/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
28
+ /// Note that if the task was spawned with `spawn_blocking`, it will only be
29
+ /// aborted if it hasn't started yet.
25
30
///
26
- /// Technically, it's just a wrapper of `JoinSet` (with size=1) .
31
+ /// Technically, it's just a wrapper of a `JoinHandle` overriding drop .
27
32
#[ derive( Debug ) ]
28
33
pub struct SpawnedTask < R > {
29
- inner : JoinSet < R > ,
34
+ inner : JoinHandle < R > ,
30
35
}
31
36
32
37
impl < R : ' static > SpawnedTask < R > {
@@ -36,8 +41,8 @@ impl<R: 'static> SpawnedTask<R> {
36
41
T : Send + ' static ,
37
42
R : Send ,
38
43
{
39
- let mut inner = JoinSet :: new ( ) ;
40
- inner. spawn ( task) ;
44
+ # [ allow ( clippy :: disallowed_methods ) ]
45
+ let inner = tokio :: task :: spawn ( task) ;
41
46
Self { inner }
42
47
}
43
48
@@ -47,22 +52,20 @@ impl<R: 'static> SpawnedTask<R> {
47
52
T : Send + ' static ,
48
53
R : Send ,
49
54
{
50
- let mut inner = JoinSet :: new ( ) ;
51
- inner. spawn_blocking ( task) ;
55
+ # [ allow ( clippy :: disallowed_methods ) ]
56
+ let inner = tokio :: task :: spawn_blocking ( task) ;
52
57
Self { inner }
53
58
}
54
59
55
60
/// Joins the task, returning the result of join (`Result<R, JoinError>`).
56
- pub async fn join ( mut self ) -> Result < R , JoinError > {
57
- self . inner
58
- . join_next ( )
59
- . await
60
- . expect ( "`SpawnedTask` instance always contains exactly 1 task" )
61
+ /// Same as awaiting the spawned task, but left for backwards compatibility.
62
+ pub async fn join ( self ) -> Result < R , JoinError > {
63
+ self . await
61
64
}
62
65
63
66
/// Joins the task and unwinds the panic if it happens.
64
67
pub async fn join_unwind ( self ) -> Result < R , JoinError > {
65
- self . join ( ) . await . map_err ( |e| {
68
+ self . await . map_err ( |e| {
66
69
// `JoinError` can be caused either by panic or cancellation. We have to handle panics:
67
70
if e. is_panic ( ) {
68
71
std:: panic:: resume_unwind ( e. into_panic ( ) ) ;
@@ -77,6 +80,20 @@ impl<R: 'static> SpawnedTask<R> {
77
80
}
78
81
}
79
82
83
+ impl < R > Future for SpawnedTask < R > {
84
+ type Output = Result < R , JoinError > ;
85
+
86
+ fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
87
+ Pin :: new ( & mut self . inner ) . poll ( cx)
88
+ }
89
+ }
90
+
91
+ impl < R > Drop for SpawnedTask < R > {
92
+ fn drop ( & mut self ) {
93
+ self . inner . abort ( ) ;
94
+ }
95
+ }
96
+
80
97
#[ cfg( test) ]
81
98
mod tests {
82
99
use super :: * ;
0 commit comments