rocket/
trip_wire.rs

1use std::fmt;
2use std::{ops::Deref, pin::Pin, future::Future};
3use std::task::{Context, Poll};
4use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
5
6use tokio::sync::Notify;
7
8#[doc(hidden)]
9pub struct State {
10    tripped: AtomicBool,
11    notify: Notify,
12}
13
14#[must_use = "`TripWire` does nothing unless polled or `trip()`ed"]
15pub struct TripWire {
16    state: Arc<State>,
17    // `Notified` is `!Unpin`. Even if we could name it, we'd need to pin it.
18    event: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
19}
20
21impl Deref for TripWire {
22    type Target = State;
23
24    fn deref(&self) -> &Self::Target {
25        &self.state
26    }
27}
28
29impl Clone for TripWire {
30    fn clone(&self) -> Self {
31        TripWire {
32            state: self.state.clone(),
33            event: None
34        }
35    }
36}
37
38impl fmt::Debug for TripWire {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        f.debug_struct("TripWire")
41            .field("tripped", &self.tripped)
42            .finish()
43    }
44}
45
46impl Future for TripWire {
47    type Output = ();
48
49    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
50        if self.tripped.load(Ordering::Acquire) {
51            self.event = None;
52            return Poll::Ready(());
53        }
54
55        if self.event.is_none() {
56            let state = self.state.clone();
57            self.event = Some(Box::pin(async move {
58                let notified = state.notify.notified();
59                notified.await
60            }));
61        }
62
63        if let Some(ref mut event) = self.event {
64            if event.as_mut().poll(cx).is_ready() {
65                // We need to call `trip()` to avoid a race condition where:
66                //   1) many trip wires have seen !self.tripped but have not
67                //      polled for `self.event` yet, so are not subscribed
68                //   2) trip() is called, adding a permit to `event`
69                //   3) some trip wires poll `event` for the first time
70                //   4) one of those wins, returns `Ready()`
71                //   5) the rest return pending
72                //
73                // Without this `self.trip()` those will never be awoken. With
74                // the call to self.trip(), those that made it to poll() in 3)
75                // will be awoken by `notify_waiters()`. For those the didn't,
76                // one will be awoken by `notify_one()`, which will in-turn call
77                // self.trip(), awaking more until there are no more to awake.
78                self.trip();
79                self.event = None;
80                return Poll::Ready(());
81            }
82        }
83
84        Poll::Pending
85    }
86}
87
88impl TripWire {
89    pub fn new() -> Self {
90        TripWire {
91            state: Arc::new(State {
92                tripped: AtomicBool::new(false),
93                notify: Notify::new()
94            }),
95            event: None,
96        }
97    }
98
99    pub fn trip(&self) {
100        self.tripped.store(true, Ordering::Release);
101        self.notify.notify_waiters();
102        self.notify.notify_one();
103    }
104
105    #[inline(always)]
106    pub fn tripped(&self) -> bool {
107        self.tripped.load(Ordering::Acquire)
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::TripWire;
114
115    #[test]
116    fn ensure_is_send_sync_clone_unpin() {
117        fn is_send_sync_clone_unpin<T: Send + Sync + Clone + Unpin>() {}
118        is_send_sync_clone_unpin::<TripWire>();
119    }
120
121    #[tokio::test]
122    async fn simple_trip() {
123        let wire = TripWire::new();
124        wire.trip();
125        wire.await;
126    }
127
128    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
129    async fn no_trip() {
130        use tokio::time::{sleep, Duration};
131        use futures::stream::{FuturesUnordered as Set, StreamExt};
132        use futures::future::{BoxFuture, FutureExt};
133
134        let wire = TripWire::new();
135        let mut futs: Set<BoxFuture<'static, bool>> = Set::new();
136        for _ in 0..10 {
137            futs.push(Box::pin(wire.clone().map(|_| false)));
138        }
139
140        let sleep = sleep(Duration::from_secs(1));
141        futs.push(Box::pin(sleep.map(|_| true)));
142        assert!(futs.next().await.unwrap());
143    }
144
145    #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
146    async fn general_trip() {
147        let wire = TripWire::new();
148        let mut tasks = vec![];
149        for _ in 0..1000 {
150            tasks.push(tokio::spawn(wire.clone()));
151            tokio::task::yield_now().await;
152        }
153
154        wire.trip();
155        for task in tasks {
156            task.await.unwrap();
157        }
158    }
159
160    #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
161    async fn single_stage_trip() {
162        let mut tasks = vec![];
163        for i in 0..1000 {
164            // Trip once every 100. 50 will be left "untripped", but should be.
165            if i % 2 == 0 {
166                let wire = TripWire::new();
167                tasks.push(tokio::spawn(wire.clone()));
168                tasks.push(tokio::spawn(async move { wire.trip() }));
169            } else {
170                let wire = TripWire::new();
171                let wire2 = wire.clone();
172                tasks.push(tokio::spawn(async move { wire.trip() }));
173                tasks.push(tokio::spawn(wire2));
174            }
175        }
176
177        for task in tasks {
178            task.await.unwrap();
179        }
180    }
181
182    #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
183    async fn staged_trip() {
184        let wire = TripWire::new();
185        let mut tasks = vec![];
186        for i in 0..1050 {
187            let wire = wire.clone();
188            // Trip once every 100. 50 will be left "untripped", but should be.
189            let task = if i % 100 == 0 {
190                tokio::spawn(async move { wire.trip() })
191            } else {
192                tokio::spawn(wire)
193            };
194
195            if i % 20 == 0 {
196                tokio::task::yield_now().await;
197            }
198
199            tasks.push(task);
200        }
201
202        for task in tasks {
203            task.await.unwrap();
204        }
205    }
206}