Skip to content

Commit 4d914b5

Browse files
authored
Merge pull request #1851 from TheBlueMatt/2022-11-fix-broken-futures-----again
Unset the needs-notify bit in a Notifier when a Future is fetched
2 parents 8d8ee55 + 0a1e48f commit 4d914b5

File tree

1 file changed

+116
-56
lines changed

1 file changed

+116
-56
lines changed

lightning/src/util/wakers.rs

+116-56
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
1616
use alloc::sync::Arc;
1717
use core::mem;
18-
use crate::sync::{Condvar, Mutex};
18+
use crate::sync::{Condvar, Mutex, MutexGuard};
1919

2020
use crate::prelude::*;
2121

@@ -33,6 +33,20 @@ pub(crate) struct Notifier {
3333
condvar: Condvar,
3434
}
3535

36+
macro_rules! check_woken {
37+
($guard: expr, $retval: expr) => { {
38+
if $guard.0 {
39+
$guard.0 = false;
40+
if $guard.1.as_ref().map(|l| l.lock().unwrap().complete).unwrap_or(false) {
41+
// If we're about to return as woken, and the future state is marked complete, wipe
42+
// the future state and let the next future wait until we get a new notify.
43+
$guard.1.take();
44+
}
45+
return $retval;
46+
}
47+
} }
48+
}
49+
3650
impl Notifier {
3751
pub(crate) fn new() -> Self {
3852
Self {
@@ -41,45 +55,47 @@ impl Notifier {
4155
}
4256
}
4357

58+
fn propagate_future_state_to_notify_flag(&self) -> MutexGuard<(bool, Option<Arc<Mutex<FutureState>>>)> {
59+
let mut lock = self.notify_pending.lock().unwrap();
60+
if let Some(existing_state) = &lock.1 {
61+
if existing_state.lock().unwrap().callbacks_made {
62+
// If the existing `FutureState` has completed and actually made callbacks,
63+
// consider the notification flag to have been cleared and reset the future state.
64+
lock.1.take();
65+
lock.0 = false;
66+
}
67+
}
68+
lock
69+
}
70+
4471
pub(crate) fn wait(&self) {
4572
loop {
46-
let mut guard = self.notify_pending.lock().unwrap();
47-
if guard.0 {
48-
guard.0 = false;
49-
return;
50-
}
73+
let mut guard = self.propagate_future_state_to_notify_flag();
74+
check_woken!(guard, ());
5175
guard = self.condvar.wait(guard).unwrap();
52-
let result = guard.0;
53-
if result {
54-
guard.0 = false;
55-
return
56-
}
76+
check_woken!(guard, ());
5777
}
5878
}
5979

6080
#[cfg(any(test, feature = "std"))]
6181
pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool {
6282
let current_time = Instant::now();
6383
loop {
64-
let mut guard = self.notify_pending.lock().unwrap();
65-
if guard.0 {
66-
guard.0 = false;
67-
return true;
68-
}
84+
let mut guard = self.propagate_future_state_to_notify_flag();
85+
check_woken!(guard, true);
6986
guard = self.condvar.wait_timeout(guard, max_wait).unwrap().0;
87+
check_woken!(guard, true);
7088
// Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
7189
// desired wait time has actually passed, and if not then restart the loop with a reduced wait
7290
// time. Note that this logic can be highly simplified through the use of
7391
// `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
7492
// 1.42.0.
7593
let elapsed = current_time.elapsed();
76-
let result = guard.0;
77-
if result || elapsed >= max_wait {
78-
guard.0 = false;
79-
return result;
94+
if elapsed >= max_wait {
95+
return false;
8096
}
8197
match max_wait.checked_sub(elapsed) {
82-
None => return result,
98+
None => return false,
8399
Some(_) => continue
84100
}
85101
}
@@ -88,17 +104,8 @@ impl Notifier {
88104
/// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
89105
pub(crate) fn notify(&self) {
90106
let mut lock = self.notify_pending.lock().unwrap();
91-
let mut future_probably_generated_calls = false;
92-
if let Some(future_state) = lock.1.take() {
93-
future_probably_generated_calls |= future_state.lock().unwrap().complete();
94-
future_probably_generated_calls |= Arc::strong_count(&future_state) > 1;
95-
}
96-
if future_probably_generated_calls {
97-
// If a future made some callbacks or has not yet been drop'd (i.e. the state has more
98-
// than the one reference we hold), assume the user was notified and skip setting the
99-
// notification-required flag. This will not cause the `wait` functions above to return
100-
// and avoid any future `Future`s starting in a completed state.
101-
return;
107+
if let Some(future_state) = &lock.1 {
108+
future_state.lock().unwrap().complete();
102109
}
103110
lock.0 = true;
104111
mem::drop(lock);
@@ -107,20 +114,14 @@ impl Notifier {
107114

108115
/// Gets a [`Future`] that will get woken up with any waiters
109116
pub(crate) fn get_future(&self) -> Future {
110-
let mut lock = self.notify_pending.lock().unwrap();
111-
if lock.0 {
112-
Future {
113-
state: Arc::new(Mutex::new(FutureState {
114-
callbacks: Vec::new(),
115-
complete: true,
116-
}))
117-
}
118-
} else if let Some(existing_state) = &lock.1 {
117+
let mut lock = self.propagate_future_state_to_notify_flag();
118+
if let Some(existing_state) = &lock.1 {
119119
Future { state: Arc::clone(&existing_state) }
120120
} else {
121121
let state = Arc::new(Mutex::new(FutureState {
122122
callbacks: Vec::new(),
123-
complete: false,
123+
complete: lock.0,
124+
callbacks_made: false,
124125
}));
125126
lock.1 = Some(Arc::clone(&state));
126127
Future { state }
@@ -151,19 +152,21 @@ impl<F: Fn() + Send> FutureCallback for F {
151152
}
152153

153154
pub(crate) struct FutureState {
154-
callbacks: Vec<Box<dyn FutureCallback>>,
155+
// When we're tracking whether a callback counts as having woken the user's code, we check the
156+
// first bool - set to false if we're just calling a Waker, and true if we're calling an actual
157+
// user-provided function.
158+
callbacks: Vec<(bool, Box<dyn FutureCallback>)>,
155159
complete: bool,
160+
callbacks_made: bool,
156161
}
157162

158163
impl FutureState {
159-
fn complete(&mut self) -> bool {
160-
let mut made_calls = false;
161-
for callback in self.callbacks.drain(..) {
164+
fn complete(&mut self) {
165+
for (counts_as_call, callback) in self.callbacks.drain(..) {
162166
callback.call();
163-
made_calls = true;
167+
self.callbacks_made |= counts_as_call;
164168
}
165169
self.complete = true;
166-
made_calls
167170
}
168171
}
169172

@@ -180,10 +183,11 @@ impl Future {
180183
pub fn register_callback(&self, callback: Box<dyn FutureCallback>) {
181184
let mut state = self.state.lock().unwrap();
182185
if state.complete {
186+
state.callbacks_made = true;
183187
mem::drop(state);
184188
callback.call();
185189
} else {
186-
state.callbacks.push(callback);
190+
state.callbacks.push((true, callback));
187191
}
188192
}
189193

@@ -198,12 +202,10 @@ impl Future {
198202
}
199203
}
200204

201-
mod std_future {
202-
use core::task::Waker;
203-
pub struct StdWaker(pub Waker);
204-
impl super::FutureCallback for StdWaker {
205-
fn call(&self) { self.0.wake_by_ref() }
206-
}
205+
use core::task::Waker;
206+
struct StdWaker(pub Waker);
207+
impl FutureCallback for StdWaker {
208+
fn call(&self) { self.0.wake_by_ref() }
207209
}
208210

209211
/// (C-not exported) as Rust Futures aren't usable in language bindings.
@@ -213,10 +215,11 @@ impl<'a> StdFuture for Future {
213215
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
214216
let mut state = self.state.lock().unwrap();
215217
if state.complete {
218+
state.callbacks_made = true;
216219
Poll::Ready(())
217220
} else {
218221
let waker = cx.waker().clone();
219-
state.callbacks.push(Box::new(std_future::StdWaker(waker)));
222+
state.callbacks.push((false, Box::new(StdWaker(waker))));
220223
Poll::Pending
221224
}
222225
}
@@ -285,6 +288,28 @@ mod tests {
285288
assert!(!callback.load(Ordering::SeqCst));
286289
}
287290

291+
#[test]
292+
fn new_future_wipes_notify_bit() {
293+
// Previously, if we were only using the `Future` interface to learn when a `Notifier` has
294+
// been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is
295+
// fetched after the notify bit has been set.
296+
let notifier = Notifier::new();
297+
notifier.notify();
298+
299+
let callback = Arc::new(AtomicBool::new(false));
300+
let callback_ref = Arc::clone(&callback);
301+
notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst))));
302+
assert!(callback.load(Ordering::SeqCst));
303+
304+
let callback = Arc::new(AtomicBool::new(false));
305+
let callback_ref = Arc::clone(&callback);
306+
notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst))));
307+
assert!(!callback.load(Ordering::SeqCst));
308+
309+
notifier.notify();
310+
assert!(callback.load(Ordering::SeqCst));
311+
}
312+
288313
#[cfg(feature = "std")]
289314
#[test]
290315
fn test_wait_timeout() {
@@ -336,6 +361,7 @@ mod tests {
336361
state: Arc::new(Mutex::new(FutureState {
337362
callbacks: Vec::new(),
338363
complete: false,
364+
callbacks_made: false,
339365
}))
340366
};
341367
let callback = Arc::new(AtomicBool::new(false));
@@ -354,6 +380,7 @@ mod tests {
354380
state: Arc::new(Mutex::new(FutureState {
355381
callbacks: Vec::new(),
356382
complete: false,
383+
callbacks_made: false,
357384
}))
358385
};
359386
future.state.lock().unwrap().complete();
@@ -391,6 +418,7 @@ mod tests {
391418
state: Arc::new(Mutex::new(FutureState {
392419
callbacks: Vec::new(),
393420
complete: false,
421+
callbacks_made: false,
394422
}))
395423
};
396424
let mut second_future = Future { state: Arc::clone(&future.state) };
@@ -409,4 +437,36 @@ mod tests {
409437
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
410438
assert_eq!(Pin::new(&mut second_future).poll(&mut Context::from_waker(&second_waker)), Poll::Ready(()));
411439
}
440+
441+
#[test]
442+
fn test_dropped_future_doesnt_count() {
443+
// Tests that if a Future gets drop'd before it is poll()ed `Ready` it doesn't count as
444+
// having been woken, leaving the notify-required flag set.
445+
let notifier = Notifier::new();
446+
notifier.notify();
447+
448+
// If we get a future and don't touch it we're definitely still notify-required.
449+
notifier.get_future();
450+
assert!(notifier.wait_timeout(Duration::from_millis(1)));
451+
assert!(!notifier.wait_timeout(Duration::from_millis(1)));
452+
453+
// Even if we poll'd once but didn't observe a `Ready`, we should be notify-required.
454+
let mut future = notifier.get_future();
455+
let (woken, waker) = create_waker();
456+
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
457+
458+
notifier.notify();
459+
assert!(woken.load(Ordering::SeqCst));
460+
assert!(notifier.wait_timeout(Duration::from_millis(1)));
461+
462+
// However, once we do poll `Ready` it should wipe the notify-required flag.
463+
let mut future = notifier.get_future();
464+
let (woken, waker) = create_waker();
465+
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
466+
467+
notifier.notify();
468+
assert!(woken.load(Ordering::SeqCst));
469+
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
470+
assert!(!notifier.wait_timeout(Duration::from_millis(1)));
471+
}
412472
}

0 commit comments

Comments
 (0)