1use std::fmt;
2use std::{ops::Deref, pin::Pin, future::Future};
3use std::task::{Context, Poll};
4use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
5
6use tokio::sync::Notify;
7
8#[doc(hidden)]
9pub struct State {
10 tripped: AtomicBool,
11 notify: Notify,
12}
13
14#[must_use = "`TripWire` does nothing unless polled or `trip()`ed"]
15pub struct TripWire {
16 state: Arc<State>,
17 event: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
19}
20
21impl Deref for TripWire {
22 type Target = State;
23
24 fn deref(&self) -> &Self::Target {
25 &self.state
26 }
27}
28
29impl Clone for TripWire {
30 fn clone(&self) -> Self {
31 TripWire {
32 state: self.state.clone(),
33 event: None
34 }
35 }
36}
37
38impl fmt::Debug for TripWire {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 f.debug_struct("TripWire")
41 .field("tripped", &self.tripped)
42 .finish()
43 }
44}
45
46impl Future for TripWire {
47 type Output = ();
48
49 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
50 if self.tripped.load(Ordering::Acquire) {
51 self.event = None;
52 return Poll::Ready(());
53 }
54
55 if self.event.is_none() {
56 let state = self.state.clone();
57 self.event = Some(Box::pin(async move {
58 let notified = state.notify.notified();
59 notified.await
60 }));
61 }
62
63 if let Some(ref mut event) = self.event {
64 if event.as_mut().poll(cx).is_ready() {
65 self.trip();
79 self.event = None;
80 return Poll::Ready(());
81 }
82 }
83
84 Poll::Pending
85 }
86}
87
88impl TripWire {
89 pub fn new() -> Self {
90 TripWire {
91 state: Arc::new(State {
92 tripped: AtomicBool::new(false),
93 notify: Notify::new()
94 }),
95 event: None,
96 }
97 }
98
99 pub fn trip(&self) {
100 self.tripped.store(true, Ordering::Release);
101 self.notify.notify_waiters();
102 self.notify.notify_one();
103 }
104
105 #[inline(always)]
106 pub fn tripped(&self) -> bool {
107 self.tripped.load(Ordering::Acquire)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::TripWire;
114
115 #[test]
116 fn ensure_is_send_sync_clone_unpin() {
117 fn is_send_sync_clone_unpin<T: Send + Sync + Clone + Unpin>() {}
118 is_send_sync_clone_unpin::<TripWire>();
119 }
120
121 #[tokio::test]
122 async fn simple_trip() {
123 let wire = TripWire::new();
124 wire.trip();
125 wire.await;
126 }
127
128 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
129 async fn no_trip() {
130 use tokio::time::{sleep, Duration};
131 use futures::stream::{FuturesUnordered as Set, StreamExt};
132 use futures::future::{BoxFuture, FutureExt};
133
134 let wire = TripWire::new();
135 let mut futs: Set<BoxFuture<'static, bool>> = Set::new();
136 for _ in 0..10 {
137 futs.push(Box::pin(wire.clone().map(|_| false)));
138 }
139
140 let sleep = sleep(Duration::from_secs(1));
141 futs.push(Box::pin(sleep.map(|_| true)));
142 assert!(futs.next().await.unwrap());
143 }
144
145 #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
146 async fn general_trip() {
147 let wire = TripWire::new();
148 let mut tasks = vec![];
149 for _ in 0..1000 {
150 tasks.push(tokio::spawn(wire.clone()));
151 tokio::task::yield_now().await;
152 }
153
154 wire.trip();
155 for task in tasks {
156 task.await.unwrap();
157 }
158 }
159
160 #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
161 async fn single_stage_trip() {
162 let mut tasks = vec![];
163 for i in 0..1000 {
164 if i % 2 == 0 {
166 let wire = TripWire::new();
167 tasks.push(tokio::spawn(wire.clone()));
168 tasks.push(tokio::spawn(async move { wire.trip() }));
169 } else {
170 let wire = TripWire::new();
171 let wire2 = wire.clone();
172 tasks.push(tokio::spawn(async move { wire.trip() }));
173 tasks.push(tokio::spawn(wire2));
174 }
175 }
176
177 for task in tasks {
178 task.await.unwrap();
179 }
180 }
181
182 #[tokio::test(flavor = "multi_thread", worker_threads = 10)]
183 async fn staged_trip() {
184 let wire = TripWire::new();
185 let mut tasks = vec![];
186 for i in 0..1050 {
187 let wire = wire.clone();
188 let task = if i % 100 == 0 {
190 tokio::spawn(async move { wire.trip() })
191 } else {
192 tokio::spawn(wire)
193 };
194
195 if i % 20 == 0 {
196 tokio::task::yield_now().await;
197 }
198
199 tasks.push(task);
200 }
201
202 for task in tasks {
203 task.await.unwrap();
204 }
205 }
206}