classicube_helpers/
async_manager.rs1use std::{
2 cell::{Cell, RefCell},
3 future::Future,
4 pin::Pin,
5 rc::Rc,
6 sync::Mutex,
7 task::{Context, Poll, Waker},
8 time::Duration,
9};
10
11use async_dispatcher::{Dispatcher, DispatcherHandle, LocalDispatcherHandle};
12use futures::{future::Either, prelude::*};
13use futures_timer::Delay;
14use lazy_static::lazy_static;
15use tokio::task::{JoinError, JoinHandle};
16use tracing::{debug, warn, Instrument};
17
18use crate::{tick::TickEventHandler, WithInner};
19
20thread_local!(
21 static ASYNC_DISPATCHER: RefCell<Option<Dispatcher>> = RefCell::default();
22);
23
24thread_local!(
25 static ASYNC_DISPATCHER_LOCAL_HANDLE: RefCell<Option<LocalDispatcherHandle>> =
26 RefCell::default();
27);
28
29lazy_static! {
30 static ref ASYNC_DISPATCHER_HANDLE: Mutex<Option<DispatcherHandle>> = Mutex::default();
31}
32
33lazy_static! {
34 static ref TOKIO_RUNTIME: Mutex<Option<tokio::runtime::Runtime>> = Mutex::default();
35}
36
37thread_local!(
38 static TICK_HANDLER: RefCell<Option<TickEventHandler>> = RefCell::default();
39);
40
41#[tracing::instrument]
42pub fn initialize() {
43 debug!("async_manager");
44
45 {
46 debug!("async_dispatcher");
47 let async_dispatcher = Dispatcher::new();
48 let async_dispatcher_handle = async_dispatcher.get_handle();
49 ASYNC_DISPATCHER_LOCAL_HANDLE.with(|cell| {
50 *cell.borrow_mut() = Some(async_dispatcher.get_handle_local());
51 });
52 ASYNC_DISPATCHER.with(|cell| {
53 *cell.borrow_mut() = Some(async_dispatcher);
54 });
55
56 *ASYNC_DISPATCHER_HANDLE.lock().unwrap() = Some(async_dispatcher_handle);
57 }
58
59 {
60 debug!("tokio");
61 let rt = tokio::runtime::Builder::new_multi_thread()
62 .enable_all()
63 .build()
64 .unwrap();
65
66 *TOKIO_RUNTIME.lock().unwrap() = Some(rt);
67 }
68
69 {
70 debug!("tick_handler");
71 TICK_HANDLER.with(|cell| {
72 let mut tick_handler = TickEventHandler::new();
73 tick_handler.on(|_task| {
74 step();
75 });
76
77 *cell.borrow_mut() = Some(tick_handler);
78 });
79 }
80}
81
82#[tracing::instrument]
83pub fn shutdown() {
84 debug!("async_manager");
85
86 {
87 let mut option = TOKIO_RUNTIME.lock().unwrap();
88 if option.is_some() {
89 debug!("tokio");
90 if let Some(rt) = option.take() {
91 rt.shutdown_timeout(Duration::from_millis(100));
92 }
93 } else {
94 warn!("tokio already shutdown?");
95 }
96 }
97
98 {
99 if ASYNC_DISPATCHER.with_inner(|_| ()).is_some() {
100 debug!("async_dispatcher");
101
102 ASYNC_DISPATCHER_HANDLE.lock().unwrap().take();
103 ASYNC_DISPATCHER_LOCAL_HANDLE.with(|cell| cell.borrow_mut().take());
104 ASYNC_DISPATCHER.with(|cell| cell.borrow_mut().take());
105 } else {
106 warn!("async_dispatcher already shutdown?");
107 }
108 }
109
110 {
111 if TICK_HANDLER.with_inner(|_| ()).is_some() {
112 debug!("tick_handler");
113
114 TICK_HANDLER.with(|cell| cell.borrow_mut().take());
115 } else {
116 warn!("tick_handler already shutdown?");
117 }
118 }
119}
120
121#[derive(Default, Clone)]
122struct YieldedWaker {
123 waker: Rc<Cell<Option<Waker>>>,
124 woke: Rc<Cell<bool>>,
125}
126impl Future for YieldedWaker {
127 type Output = ();
128
129 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130 if self.woke.get() {
131 Poll::Ready(())
132 } else {
133 self.waker.set(Some(cx.waker().clone()));
134 Poll::Pending
135 }
136 }
137}
138
139thread_local!(
140 static YIELDED_WAKERS: RefCell<Vec<YieldedWaker>> = RefCell::default();
141);
142
143pub fn step() {
144 YIELDED_WAKERS.with(move |cell| {
145 let vec = &mut *cell.borrow_mut();
146 for state in vec.drain(..) {
147 state.woke.set(true);
148 if let Some(waker) = state.waker.take() {
149 waker.wake();
150 }
151 }
152 });
153
154 ASYNC_DISPATCHER
156 .with_inner_mut(|async_dispatcher| {
157 async_dispatcher.run_until_stalled();
158 })
159 .unwrap();
160}
161
162pub fn run() {
164 ASYNC_DISPATCHER
165 .with_inner_mut(|async_dispatcher| {
166 async_dispatcher.run();
167 })
168 .unwrap();
169}
170
171pub async fn sleep(duration: Duration) {
172 Delay::new(duration).await;
173}
174
175pub async fn yield_now() {
176 let waker = YieldedWaker::default();
177 {
178 let waker = waker.clone();
179 YIELDED_WAKERS.with(move |ref_cell| {
180 let vec = &mut *ref_cell.borrow_mut();
181 vec.push(waker);
182 });
183 }
184 waker.await;
185}
186
187pub async fn timeout<T, F>(duration: Duration, f: F) -> Option<T>
188where
189 F: Future<Output = T> + Send,
190{
191 let delay = Delay::new(duration);
192
193 match future::select(delay, f.boxed()).await {
194 Either::Left(((), _f)) => None,
195 Either::Right((r, _delay)) => Some(r),
196 }
197}
198
199pub async fn timeout_local<T, F>(duration: Duration, f: F) -> Option<T>
200where
201 F: Future<Output = T>,
202{
203 let delay = Delay::new(duration);
204
205 match future::select(delay, f.boxed_local()).await {
206 Either::Left(((), _f)) => None,
207 Either::Right((r, _delay)) => Some(r),
208 }
209}
210
211pub fn block_on_local<F>(f: F)
215where
216 F: Future<Output = ()> + 'static,
217{
218 let shared = f.in_current_span().shared();
219
220 {
221 let shared = shared.clone();
222 spawn_local_on_main_thread(async move {
223 shared.await;
224 });
225 }
226
227 loop {
228 match shared.peek() {
229 Some(()) => {
230 return;
231 }
232
233 None => {
234 step();
235 }
236 }
237
238 std::thread::sleep(Duration::from_millis(16));
240 }
241}
242
243pub fn spawn<F>(f: F) -> JoinHandle<F::Output>
244where
245 F: Future + Send + 'static,
246 F::Output: Send + 'static,
247{
248 TOKIO_RUNTIME
249 .with_inner(move |rt| rt.spawn(f.in_current_span()))
250 .unwrap()
251}
252
253pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<Result<R, JoinError>>
254where
255 F: FnOnce() -> R + Send + 'static,
256 R: Send + 'static,
257{
258 let span = tracing::Span::current();
259 spawn(async move {
260 tokio::task::spawn_blocking(move || {
261 let _enter = span.enter();
262 f()
263 })
264 .await
265 })
266}
267
268pub fn spawn_on_main_thread<F>(f: F)
269where
270 F: Future<Output = ()> + 'static + Send,
271{
272 let mut handle = {
273 let mut handle = ASYNC_DISPATCHER_HANDLE.lock().unwrap();
274 handle.as_mut().expect("handle.as_mut()").clone()
275 };
276
277 handle.spawn(f.in_current_span());
278}
279
280pub async fn run_on_main_thread<F, O>(f: F) -> O
281where
282 F: Future<Output = O> + 'static + Send,
283 O: 'static + Send + std::fmt::Debug,
284{
285 let mut handle = {
286 let mut handle = ASYNC_DISPATCHER_HANDLE.lock().unwrap();
287 handle.as_mut().expect("handle.as_mut()").clone()
288 };
289
290 handle.dispatch(f.in_current_span()).await
291}
292
293pub fn spawn_local_on_main_thread<F>(f: F)
294where
295 F: Future<Output = ()> + 'static,
296{
297 let mut handle = ASYNC_DISPATCHER_LOCAL_HANDLE
298 .with_inner(Clone::clone)
299 .expect("ASYNC_DISPATCHER_LOCAL_HANDLE is None");
300
301 handle.spawn(f.in_current_span());
302}
303
304crate::test_noop_fn!(ScheduledTask_Add);
305
306#[test]
307fn test_async_manager() {
308 use tracing::info;
309 use tracing_subscriber::{filter::EnvFilter, prelude::*};
310
311 tracing_subscriber::fmt()
312 .with_env_filter(EnvFilter::from_default_env().add_directive("debug".parse().unwrap()))
313 .with_target(false)
314 .with_thread_ids(false)
315 .with_thread_names(false)
316 .with_ansi(true)
317 .without_time()
318 .finish()
319 .init();
320
321 initialize();
322
323 {
324 #[tracing::instrument]
325 fn test() {
326 let a = tracing::info_span!("A");
327 let span = a.enter();
328 spawn(async move {
329 let b = tracing::info_span!("B");
330 let span = b.enter();
331 run_on_main_thread(async move {
332 let c = tracing::info_span!("C");
333 let span = c.enter();
334 info!("run_on_main_thread with instrument test:A:B:C");
335 drop(span);
336 })
337 .await;
338 drop(span);
339 });
340 drop(span);
341 }
342 test();
343 }
344
345 {
346 #[tracing::instrument]
347 fn test() {
348 let a = tracing::info_span!("A");
349 let span = a.enter();
350 spawn_blocking(|| {
351 let b = tracing::info_span!("B");
352 let span = b.enter();
353 info!("spawn_blocking with instrument test:A:B");
354 drop(span);
355 });
356 drop(span);
357 }
358 test();
359 }
360
361 {
362 #[tracing::instrument]
363 fn test() {
364 let a = tracing::info_span!("A");
365 let span = a.enter();
366 block_on_local(async move {
367 let b = tracing::info_span!("B");
368 let span = b.enter();
369 info!("block_on_local with instrument test:A:B");
370 spawn(async move {
371 let c = tracing::info_span!("C");
372 let span = c.enter();
373 info!("block_on_local spawn with instrument test:A:B:C");
374 drop(span);
375 });
376 drop(span);
377 });
378 drop(span);
379 }
380 test();
381 }
382
383 let stopped = std::sync::Arc::new(Mutex::new(false));
384 {
385 let stopped = stopped.clone();
386 spawn(async move {
387 sleep(Duration::from_secs(1)).await;
388 *stopped.lock().unwrap() = true;
389 });
390 }
391 while !*stopped.lock().unwrap() {
392 step();
393 std::thread::sleep(Duration::from_millis(10));
394 }
395
396 shutdown();
397}