classicube_helpers/
async_manager.rs

1use std::{
2    cell::{Cell, RefCell},
3    future::Future,
4    pin::Pin,
5    rc::Rc,
6    sync::Mutex,
7    task::{Context, Poll, Waker},
8    time::Duration,
9};
10
11use async_dispatcher::{Dispatcher, DispatcherHandle, LocalDispatcherHandle};
12use futures::{future::Either, prelude::*};
13use futures_timer::Delay;
14use lazy_static::lazy_static;
15use tokio::task::{JoinError, JoinHandle};
16use tracing::{debug, warn, Instrument};
17
18use crate::{tick::TickEventHandler, WithInner};
19
20thread_local!(
21    static ASYNC_DISPATCHER: RefCell<Option<Dispatcher>> = RefCell::default();
22);
23
24thread_local!(
25    static ASYNC_DISPATCHER_LOCAL_HANDLE: RefCell<Option<LocalDispatcherHandle>> =
26        RefCell::default();
27);
28
29lazy_static! {
30    static ref ASYNC_DISPATCHER_HANDLE: Mutex<Option<DispatcherHandle>> = Mutex::default();
31}
32
33lazy_static! {
34    static ref TOKIO_RUNTIME: Mutex<Option<tokio::runtime::Runtime>> = Mutex::default();
35}
36
37thread_local!(
38    static TICK_HANDLER: RefCell<Option<TickEventHandler>> = RefCell::default();
39);
40
41#[tracing::instrument]
42pub fn initialize() {
43    debug!("async_manager");
44
45    {
46        debug!("async_dispatcher");
47        let async_dispatcher = Dispatcher::new();
48        let async_dispatcher_handle = async_dispatcher.get_handle();
49        ASYNC_DISPATCHER_LOCAL_HANDLE.with(|cell| {
50            *cell.borrow_mut() = Some(async_dispatcher.get_handle_local());
51        });
52        ASYNC_DISPATCHER.with(|cell| {
53            *cell.borrow_mut() = Some(async_dispatcher);
54        });
55
56        *ASYNC_DISPATCHER_HANDLE.lock().unwrap() = Some(async_dispatcher_handle);
57    }
58
59    {
60        debug!("tokio");
61        let rt = tokio::runtime::Builder::new_multi_thread()
62            .enable_all()
63            .build()
64            .unwrap();
65
66        *TOKIO_RUNTIME.lock().unwrap() = Some(rt);
67    }
68
69    {
70        debug!("tick_handler");
71        TICK_HANDLER.with(|cell| {
72            let mut tick_handler = TickEventHandler::new();
73            tick_handler.on(|_task| {
74                step();
75            });
76
77            *cell.borrow_mut() = Some(tick_handler);
78        });
79    }
80}
81
82#[tracing::instrument]
83pub fn shutdown() {
84    debug!("async_manager");
85
86    {
87        let mut option = TOKIO_RUNTIME.lock().unwrap();
88        if option.is_some() {
89            debug!("tokio");
90            if let Some(rt) = option.take() {
91                rt.shutdown_timeout(Duration::from_millis(100));
92            }
93        } else {
94            warn!("tokio already shutdown?");
95        }
96    }
97
98    {
99        if ASYNC_DISPATCHER.with_inner(|_| ()).is_some() {
100            debug!("async_dispatcher");
101
102            ASYNC_DISPATCHER_HANDLE.lock().unwrap().take();
103            ASYNC_DISPATCHER_LOCAL_HANDLE.with(|cell| cell.borrow_mut().take());
104            ASYNC_DISPATCHER.with(|cell| cell.borrow_mut().take());
105        } else {
106            warn!("async_dispatcher already shutdown?");
107        }
108    }
109
110    {
111        if TICK_HANDLER.with_inner(|_| ()).is_some() {
112            debug!("tick_handler");
113
114            TICK_HANDLER.with(|cell| cell.borrow_mut().take());
115        } else {
116            warn!("tick_handler already shutdown?");
117        }
118    }
119}
120
121#[derive(Default, Clone)]
122struct YieldedWaker {
123    waker: Rc<Cell<Option<Waker>>>,
124    woke: Rc<Cell<bool>>,
125}
126impl Future for YieldedWaker {
127    type Output = ();
128
129    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130        if self.woke.get() {
131            Poll::Ready(())
132        } else {
133            self.waker.set(Some(cx.waker().clone()));
134            Poll::Pending
135        }
136    }
137}
138
139thread_local!(
140    static YIELDED_WAKERS: RefCell<Vec<YieldedWaker>> = RefCell::default();
141);
142
143pub fn step() {
144    YIELDED_WAKERS.with(move |cell| {
145        let vec = &mut *cell.borrow_mut();
146        for state in vec.drain(..) {
147            state.woke.set(true);
148            if let Some(waker) = state.waker.take() {
149                waker.wake();
150            }
151        }
152    });
153
154    // process futures
155    ASYNC_DISPATCHER
156        .with_inner_mut(|async_dispatcher| {
157            async_dispatcher.run_until_stalled();
158        })
159        .unwrap();
160}
161
162/// Run all tasks in the pool to completion.
163pub fn run() {
164    ASYNC_DISPATCHER
165        .with_inner_mut(|async_dispatcher| {
166            async_dispatcher.run();
167        })
168        .unwrap();
169}
170
171pub async fn sleep(duration: Duration) {
172    Delay::new(duration).await;
173}
174
175pub async fn yield_now() {
176    let waker = YieldedWaker::default();
177    {
178        let waker = waker.clone();
179        YIELDED_WAKERS.with(move |ref_cell| {
180            let vec = &mut *ref_cell.borrow_mut();
181            vec.push(waker);
182        });
183    }
184    waker.await;
185}
186
187pub async fn timeout<T, F>(duration: Duration, f: F) -> Option<T>
188where
189    F: Future<Output = T> + Send,
190{
191    let delay = Delay::new(duration);
192
193    match future::select(delay, f.boxed()).await {
194        Either::Left(((), _f)) => None,
195        Either::Right((r, _delay)) => Some(r),
196    }
197}
198
199pub async fn timeout_local<T, F>(duration: Duration, f: F) -> Option<T>
200where
201    F: Future<Output = T>,
202{
203    let delay = Delay::new(duration);
204
205    match future::select(delay, f.boxed_local()).await {
206        Either::Left(((), _f)) => None,
207        Either::Right((r, _delay)) => Some(r),
208    }
209}
210
211/// Block thread until future is resolved.
212///
213/// This will continue to call the same executor so `cef_step()` will still be called!
214pub fn block_on_local<F>(f: F)
215where
216    F: Future<Output = ()> + 'static,
217{
218    let shared = f.in_current_span().shared();
219
220    {
221        let shared = shared.clone();
222        spawn_local_on_main_thread(async move {
223            shared.await;
224        });
225    }
226
227    loop {
228        match shared.peek() {
229            Some(()) => {
230                return;
231            }
232
233            None => {
234                step();
235            }
236        }
237
238        // don't burn anything
239        std::thread::sleep(Duration::from_millis(16));
240    }
241}
242
243pub fn spawn<F>(f: F) -> JoinHandle<F::Output>
244where
245    F: Future + Send + 'static,
246    F::Output: Send + 'static,
247{
248    TOKIO_RUNTIME
249        .with_inner(move |rt| rt.spawn(f.in_current_span()))
250        .unwrap()
251}
252
253pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<Result<R, JoinError>>
254where
255    F: FnOnce() -> R + Send + 'static,
256    R: Send + 'static,
257{
258    let span = tracing::Span::current();
259    spawn(async move {
260        tokio::task::spawn_blocking(move || {
261            let _enter = span.enter();
262            f()
263        })
264        .await
265    })
266}
267
268pub fn spawn_on_main_thread<F>(f: F)
269where
270    F: Future<Output = ()> + 'static + Send,
271{
272    let mut handle = {
273        let mut handle = ASYNC_DISPATCHER_HANDLE.lock().unwrap();
274        handle.as_mut().expect("handle.as_mut()").clone()
275    };
276
277    handle.spawn(f.in_current_span());
278}
279
280pub async fn run_on_main_thread<F, O>(f: F) -> O
281where
282    F: Future<Output = O> + 'static + Send,
283    O: 'static + Send + std::fmt::Debug,
284{
285    let mut handle = {
286        let mut handle = ASYNC_DISPATCHER_HANDLE.lock().unwrap();
287        handle.as_mut().expect("handle.as_mut()").clone()
288    };
289
290    handle.dispatch(f.in_current_span()).await
291}
292
293pub fn spawn_local_on_main_thread<F>(f: F)
294where
295    F: Future<Output = ()> + 'static,
296{
297    let mut handle = ASYNC_DISPATCHER_LOCAL_HANDLE
298        .with_inner(Clone::clone)
299        .expect("ASYNC_DISPATCHER_LOCAL_HANDLE is None");
300
301    handle.spawn(f.in_current_span());
302}
303
304crate::test_noop_fn!(ScheduledTask_Add);
305
306#[test]
307fn test_async_manager() {
308    use tracing::info;
309    use tracing_subscriber::{filter::EnvFilter, prelude::*};
310
311    tracing_subscriber::fmt()
312        .with_env_filter(EnvFilter::from_default_env().add_directive("debug".parse().unwrap()))
313        .with_target(false)
314        .with_thread_ids(false)
315        .with_thread_names(false)
316        .with_ansi(true)
317        .without_time()
318        .finish()
319        .init();
320
321    initialize();
322
323    {
324        #[tracing::instrument]
325        fn test() {
326            let a = tracing::info_span!("A");
327            let span = a.enter();
328            spawn(async move {
329                let b = tracing::info_span!("B");
330                let span = b.enter();
331                run_on_main_thread(async move {
332                    let c = tracing::info_span!("C");
333                    let span = c.enter();
334                    info!("run_on_main_thread with instrument test:A:B:C");
335                    drop(span);
336                })
337                .await;
338                drop(span);
339            });
340            drop(span);
341        }
342        test();
343    }
344
345    {
346        #[tracing::instrument]
347        fn test() {
348            let a = tracing::info_span!("A");
349            let span = a.enter();
350            spawn_blocking(|| {
351                let b = tracing::info_span!("B");
352                let span = b.enter();
353                info!("spawn_blocking with instrument test:A:B");
354                drop(span);
355            });
356            drop(span);
357        }
358        test();
359    }
360
361    {
362        #[tracing::instrument]
363        fn test() {
364            let a = tracing::info_span!("A");
365            let span = a.enter();
366            block_on_local(async move {
367                let b = tracing::info_span!("B");
368                let span = b.enter();
369                info!("block_on_local with instrument test:A:B");
370                spawn(async move {
371                    let c = tracing::info_span!("C");
372                    let span = c.enter();
373                    info!("block_on_local spawn with instrument test:A:B:C");
374                    drop(span);
375                });
376                drop(span);
377            });
378            drop(span);
379        }
380        test();
381    }
382
383    let stopped = std::sync::Arc::new(Mutex::new(false));
384    {
385        let stopped = stopped.clone();
386        spawn(async move {
387            sleep(Duration::from_secs(1)).await;
388            *stopped.lock().unwrap() = true;
389        });
390    }
391    while !*stopped.lock().unwrap() {
392        step();
393        std::thread::sleep(Duration::from_millis(10));
394    }
395
396    shutdown();
397}