classicube_helpers/
async_manager.rs1use std::{
2 cell::{Cell, RefCell},
3 future::Future,
4 pin::Pin,
5 rc::Rc,
6 sync::{LazyLock, Mutex},
7 task::{Context, Poll, Waker},
8 time::Duration,
9};
10
11use async_dispatcher::{Dispatcher, DispatcherHandle, LocalDispatcherHandle};
12use futures::{future::Either, prelude::*};
13use futures_timer::Delay;
14use tokio::task::{JoinError, JoinHandle};
15use tracing::{Instrument, debug, warn};
16
17use crate::{WithInner, tick::TickEventHandler};
18
19thread_local!(
20 static ASYNC_DISPATCHER: RefCell<Option<Dispatcher>> = RefCell::default();
21);
22
23thread_local!(
24 static ASYNC_DISPATCHER_LOCAL_HANDLE: RefCell<Option<LocalDispatcherHandle>> =
25 RefCell::default();
26);
27
28static ASYNC_DISPATCHER_HANDLE: LazyLock<Mutex<Option<DispatcherHandle>>> =
29 LazyLock::new(Mutex::default);
30
31static TOKIO_RUNTIME: LazyLock<Mutex<Option<tokio::runtime::Runtime>>> =
32 LazyLock::new(Mutex::default);
33
34thread_local!(
35 static TICK_HANDLER: RefCell<Option<TickEventHandler>> = RefCell::default();
36);
37
38#[tracing::instrument]
39pub fn initialize() {
40 debug!("async_manager");
41
42 {
43 debug!("async_dispatcher");
44 let async_dispatcher = Dispatcher::new();
45 let async_dispatcher_handle = async_dispatcher.get_handle();
46 ASYNC_DISPATCHER_LOCAL_HANDLE.with(|cell| {
47 *cell.borrow_mut() = Some(async_dispatcher.get_handle_local());
48 });
49 ASYNC_DISPATCHER.with(|cell| {
50 *cell.borrow_mut() = Some(async_dispatcher);
51 });
52
53 *ASYNC_DISPATCHER_HANDLE.lock().unwrap() = Some(async_dispatcher_handle);
54 }
55
56 {
57 debug!("tokio");
58 let rt = tokio::runtime::Builder::new_multi_thread()
59 .enable_all()
60 .build()
61 .unwrap();
62
63 *TOKIO_RUNTIME.lock().unwrap() = Some(rt);
64 }
65
66 {
67 debug!("tick_handler");
68 TICK_HANDLER.with(|cell| {
69 let mut tick_handler = TickEventHandler::new();
70 tick_handler.on(|_task| {
71 step();
72 });
73
74 *cell.borrow_mut() = Some(tick_handler);
75 });
76 }
77}
78
79#[tracing::instrument]
80pub fn shutdown() {
81 debug!("async_manager");
82
83 {
84 let mut option = TOKIO_RUNTIME.lock().unwrap();
85 if option.is_some() {
86 debug!("tokio");
87 if let Some(rt) = option.take() {
88 rt.shutdown_timeout(Duration::from_millis(100));
89 }
90 } else {
91 warn!("tokio already shutdown?");
92 }
93 }
94
95 {
96 if ASYNC_DISPATCHER.with_inner(|_| ()).is_some() {
97 debug!("async_dispatcher");
98
99 ASYNC_DISPATCHER_HANDLE.lock().unwrap().take();
100 ASYNC_DISPATCHER_LOCAL_HANDLE.with(|cell| cell.borrow_mut().take());
101 ASYNC_DISPATCHER.with(|cell| cell.borrow_mut().take());
102 } else {
103 warn!("async_dispatcher already shutdown?");
104 }
105 }
106
107 {
108 if TICK_HANDLER.with_inner(|_| ()).is_some() {
109 debug!("tick_handler");
110
111 TICK_HANDLER.with(|cell| cell.borrow_mut().take());
112 } else {
113 warn!("tick_handler already shutdown?");
114 }
115 }
116}
117
118#[derive(Default, Clone)]
119struct YieldedWaker {
120 waker: Rc<Cell<Option<Waker>>>,
121 woke: Rc<Cell<bool>>,
122}
123impl Future for YieldedWaker {
124 type Output = ();
125
126 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 if self.woke.get() {
128 Poll::Ready(())
129 } else {
130 self.waker.set(Some(cx.waker().clone()));
131 Poll::Pending
132 }
133 }
134}
135
136thread_local!(
137 static YIELDED_WAKERS: RefCell<Vec<YieldedWaker>> = RefCell::default();
138);
139
140pub fn step() {
141 YIELDED_WAKERS.with(move |cell| {
142 let vec = &mut *cell.borrow_mut();
143 for state in vec.drain(..) {
144 state.woke.set(true);
145 if let Some(waker) = state.waker.take() {
146 waker.wake();
147 }
148 }
149 });
150
151 ASYNC_DISPATCHER
153 .with_inner_mut(|async_dispatcher| {
154 async_dispatcher.run_until_stalled();
155 })
156 .unwrap();
157}
158
159pub fn run() {
161 ASYNC_DISPATCHER
162 .with_inner_mut(|async_dispatcher| {
163 async_dispatcher.run();
164 })
165 .unwrap();
166}
167
168pub async fn sleep(duration: Duration) {
169 Delay::new(duration).await;
170}
171
172pub async fn yield_now() {
173 let waker = YieldedWaker::default();
174 {
175 let waker = waker.clone();
176 YIELDED_WAKERS.with(move |ref_cell| {
177 let vec = &mut *ref_cell.borrow_mut();
178 vec.push(waker);
179 });
180 }
181 waker.await;
182}
183
184pub async fn timeout<T, F>(duration: Duration, f: F) -> Option<T>
185where
186 F: Future<Output = T> + Send,
187{
188 let delay = Delay::new(duration);
189
190 match future::select(delay, f.boxed()).await {
191 Either::Left(((), _f)) => None,
192 Either::Right((r, _delay)) => Some(r),
193 }
194}
195
196pub async fn timeout_local<T, F>(duration: Duration, f: F) -> Option<T>
197where
198 F: Future<Output = T>,
199{
200 let delay = Delay::new(duration);
201
202 match future::select(delay, f.boxed_local()).await {
203 Either::Left(((), _f)) => None,
204 Either::Right((r, _delay)) => Some(r),
205 }
206}
207
208pub fn block_on_local<F>(f: F)
212where
213 F: Future<Output = ()> + 'static,
214{
215 let shared = f.in_current_span().shared();
216
217 {
218 let shared = shared.clone();
219 spawn_local_on_main_thread(async move {
220 shared.await;
221 });
222 }
223
224 loop {
225 match shared.peek() {
226 Some(()) => {
227 return;
228 }
229
230 None => {
231 step();
232 }
233 }
234
235 std::thread::sleep(Duration::from_millis(16));
237 }
238}
239
240pub fn spawn<F>(f: F) -> JoinHandle<F::Output>
241where
242 F: Future + Send + 'static,
243 F::Output: Send + 'static,
244{
245 TOKIO_RUNTIME
246 .with_inner(move |rt| rt.spawn(f.in_current_span()))
247 .unwrap()
248}
249
250pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<Result<R, JoinError>>
251where
252 F: FnOnce() -> R + Send + 'static,
253 R: Send + 'static,
254{
255 let span = tracing::Span::current();
256 spawn(async move {
257 tokio::task::spawn_blocking(move || {
258 let _enter = span.enter();
259 f()
260 })
261 .await
262 })
263}
264
265pub fn spawn_on_main_thread<F>(f: F)
266where
267 F: Future<Output = ()> + 'static + Send,
268{
269 let mut handle = {
270 let mut handle = ASYNC_DISPATCHER_HANDLE.lock().unwrap();
271 handle.as_mut().expect("handle.as_mut()").clone()
272 };
273
274 handle.spawn(f.in_current_span());
275}
276
277pub async fn run_on_main_thread<F, O>(f: F) -> O
278where
279 F: Future<Output = O> + 'static + Send,
280 O: 'static + Send + std::fmt::Debug,
281{
282 let mut handle = {
283 let mut handle = ASYNC_DISPATCHER_HANDLE.lock().unwrap();
284 handle.as_mut().expect("handle.as_mut()").clone()
285 };
286
287 handle.dispatch(f.in_current_span()).await
288}
289
290pub fn spawn_local_on_main_thread<F>(f: F)
291where
292 F: Future<Output = ()> + 'static,
293{
294 let mut handle = ASYNC_DISPATCHER_LOCAL_HANDLE
295 .with_inner(Clone::clone)
296 .expect("ASYNC_DISPATCHER_LOCAL_HANDLE is None");
297
298 handle.spawn(f.in_current_span());
299}
300
301crate::test_noop_fn!(ScheduledTask_Add);
302
303#[test]
304fn test_async_manager() {
305 use tracing::info;
306 use tracing_subscriber::{filter::EnvFilter, prelude::*};
307
308 tracing_subscriber::fmt()
309 .with_env_filter(EnvFilter::from_default_env().add_directive("debug".parse().unwrap()))
310 .with_target(false)
311 .with_thread_ids(false)
312 .with_thread_names(false)
313 .with_ansi(true)
314 .without_time()
315 .finish()
316 .init();
317
318 initialize();
319
320 {
321 #[tracing::instrument]
322 fn test() {
323 let a = tracing::info_span!("A");
324 let span = a.enter();
325 spawn(async move {
326 let b = tracing::info_span!("B");
327 let span = b.enter();
328 run_on_main_thread(async move {
329 let c = tracing::info_span!("C");
330 let span = c.enter();
331 info!("run_on_main_thread with instrument test:A:B:C");
332 drop(span);
333 })
334 .await;
335 drop(span);
336 });
337 drop(span);
338 }
339 test();
340 }
341
342 {
343 #[tracing::instrument]
344 fn test() {
345 let a = tracing::info_span!("A");
346 let span = a.enter();
347 spawn_blocking(|| {
348 let b = tracing::info_span!("B");
349 let span = b.enter();
350 info!("spawn_blocking with instrument test:A:B");
351 drop(span);
352 });
353 drop(span);
354 }
355 test();
356 }
357
358 {
359 #[tracing::instrument]
360 fn test() {
361 let a = tracing::info_span!("A");
362 let span = a.enter();
363 block_on_local(async move {
364 let b = tracing::info_span!("B");
365 let span = b.enter();
366 info!("block_on_local with instrument test:A:B");
367 spawn(async move {
368 let c = tracing::info_span!("C");
369 let span = c.enter();
370 info!("block_on_local spawn with instrument test:A:B:C");
371 drop(span);
372 });
373 drop(span);
374 });
375 drop(span);
376 }
377 test();
378 }
379
380 let stopped = std::sync::Arc::new(Mutex::new(false));
381 {
382 let stopped = stopped.clone();
383 spawn(async move {
384 sleep(Duration::from_secs(1)).await;
385 *stopped.lock().unwrap() = true;
386 });
387 }
388 while !*stopped.lock().unwrap() {
389 step();
390 std::thread::sleep(Duration::from_millis(10));
391 }
392
393 shutdown();
394}