classicube_helpers/
async_manager.rs

1use std::{
2    cell::{Cell, RefCell},
3    future::Future,
4    pin::Pin,
5    rc::Rc,
6    sync::{LazyLock, Mutex},
7    task::{Context, Poll, Waker},
8    time::Duration,
9};
10
11use async_dispatcher::{Dispatcher, DispatcherHandle, LocalDispatcherHandle};
12use futures::{future::Either, prelude::*};
13use futures_timer::Delay;
14use tokio::task::{JoinError, JoinHandle};
15use tracing::{Instrument, debug, warn};
16
17use crate::{WithInner, tick::TickEventHandler};
18
19thread_local!(
20    static ASYNC_DISPATCHER: RefCell<Option<Dispatcher>> = RefCell::default();
21);
22
23thread_local!(
24    static ASYNC_DISPATCHER_LOCAL_HANDLE: RefCell<Option<LocalDispatcherHandle>> =
25        RefCell::default();
26);
27
28static ASYNC_DISPATCHER_HANDLE: LazyLock<Mutex<Option<DispatcherHandle>>> =
29    LazyLock::new(Mutex::default);
30
31static TOKIO_RUNTIME: LazyLock<Mutex<Option<tokio::runtime::Runtime>>> =
32    LazyLock::new(Mutex::default);
33
34thread_local!(
35    static TICK_HANDLER: RefCell<Option<TickEventHandler>> = RefCell::default();
36);
37
38#[tracing::instrument]
39pub fn initialize() {
40    debug!("async_manager");
41
42    {
43        debug!("async_dispatcher");
44        let async_dispatcher = Dispatcher::new();
45        let async_dispatcher_handle = async_dispatcher.get_handle();
46        ASYNC_DISPATCHER_LOCAL_HANDLE.with(|cell| {
47            *cell.borrow_mut() = Some(async_dispatcher.get_handle_local());
48        });
49        ASYNC_DISPATCHER.with(|cell| {
50            *cell.borrow_mut() = Some(async_dispatcher);
51        });
52
53        *ASYNC_DISPATCHER_HANDLE.lock().unwrap() = Some(async_dispatcher_handle);
54    }
55
56    {
57        debug!("tokio");
58        let rt = tokio::runtime::Builder::new_multi_thread()
59            .enable_all()
60            .build()
61            .unwrap();
62
63        *TOKIO_RUNTIME.lock().unwrap() = Some(rt);
64    }
65
66    {
67        debug!("tick_handler");
68        TICK_HANDLER.with(|cell| {
69            let mut tick_handler = TickEventHandler::new();
70            tick_handler.on(|_task| {
71                step();
72            });
73
74            *cell.borrow_mut() = Some(tick_handler);
75        });
76    }
77}
78
79#[tracing::instrument]
80pub fn shutdown() {
81    debug!("async_manager");
82
83    {
84        let mut option = TOKIO_RUNTIME.lock().unwrap();
85        if option.is_some() {
86            debug!("tokio");
87            if let Some(rt) = option.take() {
88                rt.shutdown_timeout(Duration::from_millis(100));
89            }
90        } else {
91            warn!("tokio already shutdown?");
92        }
93    }
94
95    {
96        if ASYNC_DISPATCHER.with_inner(|_| ()).is_some() {
97            debug!("async_dispatcher");
98
99            ASYNC_DISPATCHER_HANDLE.lock().unwrap().take();
100            ASYNC_DISPATCHER_LOCAL_HANDLE.with(|cell| cell.borrow_mut().take());
101            ASYNC_DISPATCHER.with(|cell| cell.borrow_mut().take());
102        } else {
103            warn!("async_dispatcher already shutdown?");
104        }
105    }
106
107    {
108        if TICK_HANDLER.with_inner(|_| ()).is_some() {
109            debug!("tick_handler");
110
111            TICK_HANDLER.with(|cell| cell.borrow_mut().take());
112        } else {
113            warn!("tick_handler already shutdown?");
114        }
115    }
116}
117
118#[derive(Default, Clone)]
119struct YieldedWaker {
120    waker: Rc<Cell<Option<Waker>>>,
121    woke: Rc<Cell<bool>>,
122}
123impl Future for YieldedWaker {
124    type Output = ();
125
126    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        if self.woke.get() {
128            Poll::Ready(())
129        } else {
130            self.waker.set(Some(cx.waker().clone()));
131            Poll::Pending
132        }
133    }
134}
135
136thread_local!(
137    static YIELDED_WAKERS: RefCell<Vec<YieldedWaker>> = RefCell::default();
138);
139
140pub fn step() {
141    YIELDED_WAKERS.with(move |cell| {
142        let vec = &mut *cell.borrow_mut();
143        for state in vec.drain(..) {
144            state.woke.set(true);
145            if let Some(waker) = state.waker.take() {
146                waker.wake();
147            }
148        }
149    });
150
151    // process futures
152    ASYNC_DISPATCHER
153        .with_inner_mut(|async_dispatcher| {
154            async_dispatcher.run_until_stalled();
155        })
156        .unwrap();
157}
158
159/// Run all tasks in the pool to completion.
160pub fn run() {
161    ASYNC_DISPATCHER
162        .with_inner_mut(|async_dispatcher| {
163            async_dispatcher.run();
164        })
165        .unwrap();
166}
167
168pub async fn sleep(duration: Duration) {
169    Delay::new(duration).await;
170}
171
172pub async fn yield_now() {
173    let waker = YieldedWaker::default();
174    {
175        let waker = waker.clone();
176        YIELDED_WAKERS.with(move |ref_cell| {
177            let vec = &mut *ref_cell.borrow_mut();
178            vec.push(waker);
179        });
180    }
181    waker.await;
182}
183
184pub async fn timeout<T, F>(duration: Duration, f: F) -> Option<T>
185where
186    F: Future<Output = T> + Send,
187{
188    let delay = Delay::new(duration);
189
190    match future::select(delay, f.boxed()).await {
191        Either::Left(((), _f)) => None,
192        Either::Right((r, _delay)) => Some(r),
193    }
194}
195
196pub async fn timeout_local<T, F>(duration: Duration, f: F) -> Option<T>
197where
198    F: Future<Output = T>,
199{
200    let delay = Delay::new(duration);
201
202    match future::select(delay, f.boxed_local()).await {
203        Either::Left(((), _f)) => None,
204        Either::Right((r, _delay)) => Some(r),
205    }
206}
207
208/// Block thread until future is resolved.
209///
210/// This will continue to call the same executor so `cef_step()` will still be called!
211pub fn block_on_local<F>(f: F)
212where
213    F: Future<Output = ()> + 'static,
214{
215    let shared = f.in_current_span().shared();
216
217    {
218        let shared = shared.clone();
219        spawn_local_on_main_thread(async move {
220            shared.await;
221        });
222    }
223
224    loop {
225        match shared.peek() {
226            Some(()) => {
227                return;
228            }
229
230            None => {
231                step();
232            }
233        }
234
235        // don't burn anything
236        std::thread::sleep(Duration::from_millis(16));
237    }
238}
239
240pub fn spawn<F>(f: F) -> JoinHandle<F::Output>
241where
242    F: Future + Send + 'static,
243    F::Output: Send + 'static,
244{
245    TOKIO_RUNTIME
246        .with_inner(move |rt| rt.spawn(f.in_current_span()))
247        .unwrap()
248}
249
250pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<Result<R, JoinError>>
251where
252    F: FnOnce() -> R + Send + 'static,
253    R: Send + 'static,
254{
255    let span = tracing::Span::current();
256    spawn(async move {
257        tokio::task::spawn_blocking(move || {
258            let _enter = span.enter();
259            f()
260        })
261        .await
262    })
263}
264
265pub fn spawn_on_main_thread<F>(f: F)
266where
267    F: Future<Output = ()> + 'static + Send,
268{
269    let mut handle = {
270        let mut handle = ASYNC_DISPATCHER_HANDLE.lock().unwrap();
271        handle.as_mut().expect("handle.as_mut()").clone()
272    };
273
274    handle.spawn(f.in_current_span());
275}
276
277pub async fn run_on_main_thread<F, O>(f: F) -> O
278where
279    F: Future<Output = O> + 'static + Send,
280    O: 'static + Send + std::fmt::Debug,
281{
282    let mut handle = {
283        let mut handle = ASYNC_DISPATCHER_HANDLE.lock().unwrap();
284        handle.as_mut().expect("handle.as_mut()").clone()
285    };
286
287    handle.dispatch(f.in_current_span()).await
288}
289
290pub fn spawn_local_on_main_thread<F>(f: F)
291where
292    F: Future<Output = ()> + 'static,
293{
294    let mut handle = ASYNC_DISPATCHER_LOCAL_HANDLE
295        .with_inner(Clone::clone)
296        .expect("ASYNC_DISPATCHER_LOCAL_HANDLE is None");
297
298    handle.spawn(f.in_current_span());
299}
300
301crate::test_noop_fn!(ScheduledTask_Add);
302
303#[test]
304fn test_async_manager() {
305    use tracing::info;
306    use tracing_subscriber::{filter::EnvFilter, prelude::*};
307
308    tracing_subscriber::fmt()
309        .with_env_filter(EnvFilter::from_default_env().add_directive("debug".parse().unwrap()))
310        .with_target(false)
311        .with_thread_ids(false)
312        .with_thread_names(false)
313        .with_ansi(true)
314        .without_time()
315        .finish()
316        .init();
317
318    initialize();
319
320    {
321        #[tracing::instrument]
322        fn test() {
323            let a = tracing::info_span!("A");
324            let span = a.enter();
325            spawn(async move {
326                let b = tracing::info_span!("B");
327                let span = b.enter();
328                run_on_main_thread(async move {
329                    let c = tracing::info_span!("C");
330                    let span = c.enter();
331                    info!("run_on_main_thread with instrument test:A:B:C");
332                    drop(span);
333                })
334                .await;
335                drop(span);
336            });
337            drop(span);
338        }
339        test();
340    }
341
342    {
343        #[tracing::instrument]
344        fn test() {
345            let a = tracing::info_span!("A");
346            let span = a.enter();
347            spawn_blocking(|| {
348                let b = tracing::info_span!("B");
349                let span = b.enter();
350                info!("spawn_blocking with instrument test:A:B");
351                drop(span);
352            });
353            drop(span);
354        }
355        test();
356    }
357
358    {
359        #[tracing::instrument]
360        fn test() {
361            let a = tracing::info_span!("A");
362            let span = a.enter();
363            block_on_local(async move {
364                let b = tracing::info_span!("B");
365                let span = b.enter();
366                info!("block_on_local with instrument test:A:B");
367                spawn(async move {
368                    let c = tracing::info_span!("C");
369                    let span = c.enter();
370                    info!("block_on_local spawn with instrument test:A:B:C");
371                    drop(span);
372                });
373                drop(span);
374            });
375            drop(span);
376        }
377        test();
378    }
379
380    let stopped = std::sync::Arc::new(Mutex::new(false));
381    {
382        let stopped = stopped.clone();
383        spawn(async move {
384            sleep(Duration::from_secs(1)).await;
385            *stopped.lock().unwrap() = true;
386        });
387    }
388    while !*stopped.lock().unwrap() {
389        step();
390        std::thread::sleep(Duration::from_millis(10));
391    }
392
393    shutdown();
394}