async_dispatcher/
lib.rs

1#![warn(clippy::pedantic)]
2#![allow(clippy::missing_panics_doc)]
3
4use std::{
5    fmt::Debug,
6    sync::mpsc::{channel, Receiver, Sender},
7};
8
9use futures::{
10    channel::oneshot,
11    executor::{LocalPool, LocalSpawner},
12    future::{BoxFuture, LocalBoxFuture},
13    prelude::*,
14    task::{LocalSpawnExt, SpawnExt},
15};
16
17type Task = BoxFuture<'static, ()>;
18
19type LocalTask = LocalBoxFuture<'static, ()>;
20
21/// Used to call futures from another thread, getting their output values via `await`.
22pub struct Dispatcher {
23    task_sender: Sender<Task>,
24    task_receiver: Receiver<Task>,
25
26    local_task_sender: Sender<LocalTask>,
27    local_task_receiver: Receiver<LocalTask>,
28
29    future_pool: LocalPool,
30    future_spawner: LocalSpawner,
31}
32
33impl Default for Dispatcher {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl Dispatcher {
40    #[must_use]
41    pub fn new() -> Self {
42        let (task_sender, task_receiver) = channel();
43        let (local_task_sender, local_task_receiver) = channel();
44
45        let future_pool = LocalPool::new();
46        let future_spawner = future_pool.spawner();
47
48        Self {
49            task_sender,
50            task_receiver,
51            local_task_sender,
52            local_task_receiver,
53            future_pool,
54            future_spawner,
55        }
56    }
57
58    fn spawn_queued(&mut self) {
59        for task in self.local_task_receiver.try_iter() {
60            self.future_spawner.spawn_local(task).unwrap();
61        }
62
63        for task in self.task_receiver.try_iter() {
64            self.future_spawner.spawn(task).unwrap();
65        }
66    }
67
68    /// Run all tasks in the pool to completion.
69    ///
70    /// This method will not enqueue new tasks until the next call!
71    pub fn run(&mut self) {
72        self.spawn_queued();
73
74        self.future_pool.run();
75    }
76
77    /// Runs all tasks in the pool and returns if no more progress can be made on any task.
78    ///
79    /// This method will not enqueue new tasks until the next call!
80    pub fn run_until_stalled(&mut self) {
81        self.spawn_queued();
82
83        self.future_pool.run_until_stalled();
84    }
85
86    /// Runs all tasks and returns after completing one future or until no more progress can be made. Returns true if one future was completed, false otherwise.
87    ///
88    /// This method will not enqueue new tasks until the next call!
89    pub fn try_run_one(&mut self) -> bool {
90        self.spawn_queued();
91
92        self.future_pool.try_run_one()
93    }
94
95    /// Get a [`DispatcherHandle`] to control this Dispatcher from another thread.
96    pub fn get_handle(&self) -> DispatcherHandle {
97        DispatcherHandle {
98            task_sender: self.task_sender.clone(),
99        }
100    }
101
102    /// Get a [`LocalDispatcherHandle`] to control this Dispatcher from the same thread.
103    pub fn get_handle_local(&self) -> LocalDispatcherHandle {
104        LocalDispatcherHandle {
105            local_task_sender: self.local_task_sender.clone(),
106        }
107    }
108}
109
110/// Created by calling [`Dispatcher::get_handle`].
111#[derive(Clone)]
112pub struct DispatcherHandle {
113    task_sender: Sender<Task>,
114}
115
116impl DispatcherHandle {
117    /// Enqueue a future to be ran when [`Dispatcher::run`] is called.
118    ///
119    /// If you want an output value from the future, use [`DispatcherHandle::dispatch`]
120    pub fn spawn<F>(&mut self, future: F)
121    where
122        F: Future<Output = ()> + 'static + Send,
123    {
124        self.task_sender.send(future.boxed()).unwrap();
125    }
126
127    /// Enqueue a future to be ran when [`Dispatcher::run`] is called.
128    ///
129    /// This returns a Future that resolves to the input future's output value.
130    pub async fn dispatch<F, O>(&mut self, future: F) -> O
131    where
132        F: Future<Output = O> + 'static + Send,
133        O: 'static + Send + Debug,
134    {
135        let (return_sender, return_receiver) = oneshot::channel();
136
137        self.spawn(
138            async {
139                let ret = future.await;
140                return_sender.send(ret).unwrap();
141            }
142            .boxed(),
143        );
144
145        return_receiver.await.unwrap()
146    }
147}
148
149/// Created by calling [`Dispatcher::get_handle_local`].
150#[derive(Clone)]
151pub struct LocalDispatcherHandle {
152    local_task_sender: Sender<LocalTask>,
153}
154
155impl LocalDispatcherHandle {
156    /// Enqueue a future to be ran when [`Dispatcher::run`] is called.
157    ///
158    /// If you want an output value from the future, use [`LocalDispatcherHandle::dispatch`]
159    pub fn spawn<F>(&mut self, future: F)
160    where
161        F: Future<Output = ()> + 'static,
162    {
163        self.local_task_sender.send(future.boxed_local()).unwrap();
164    }
165}
166
167#[cfg(test)]
168mod tests {
169
170    use std::{thread, time::Duration};
171
172    use futures::executor::block_on;
173
174    use super::*;
175
176    #[test]
177    fn it_works() {
178        println!(
179            "dispatcher running on thread {:?}",
180            thread::current().name()
181        );
182        let mut main_thread_dispatcher = Dispatcher::new();
183
184        let mut handle = main_thread_dispatcher.get_handle();
185        let t = thread::spawn(move || {
186            block_on(async move {
187                println!(
188                    "new thread running on thread {:?}",
189                    thread::current().name()
190                );
191
192                println!("dispatching");
193                handle
194                    .dispatch(async {
195                        println!(
196                            "dispatched task running on thread {:?}",
197                            thread::current().name()
198                        );
199
200                        99
201                    })
202                    .await
203            })
204        });
205
206        // wait for thread to dispatch
207        thread::sleep(Duration::from_secs(1));
208
209        println!("running dispatcher");
210        main_thread_dispatcher.run();
211        println!("dispatcher done running");
212
213        assert!(t.join().unwrap() == 99);
214    }
215
216    #[test]
217    fn it_works_local() {
218        println!(
219            "dispatcher running on thread {:?}",
220            thread::current().name()
221        );
222        let mut main_thread_dispatcher = Dispatcher::new();
223
224        let mut local_handle = main_thread_dispatcher.get_handle_local();
225
226        let mut inner_local_handle = local_handle.clone();
227        println!("spawning");
228        local_handle.spawn(async move {
229            println!(
230                "spawned task running on same thread {:?}",
231                thread::current().name()
232            );
233
234            inner_local_handle.spawn(async {
235                println!("inner running");
236            });
237        });
238
239        println!("running dispatcher");
240        main_thread_dispatcher.run();
241        println!("dispatcher done running");
242
243        println!("running dispatcher again");
244        main_thread_dispatcher.run();
245        println!("dispatcher done running again");
246    }
247
248    #[test]
249    fn test_tokio() {
250        // you can't use tokio timers in these threads though!
251
252        let mut main_thread_dispatcher = Dispatcher::new();
253        let mut handle = main_thread_dispatcher.get_handle();
254
255        let tokio_thread = thread::spawn(move || {
256            let rt = tokio::runtime::Runtime::new().unwrap();
257
258            rt.block_on(async move {
259                println!("dispatching");
260                handle
261                    .dispatch(async {
262                        println!(
263                            "dispatched task running on thread {:?}",
264                            thread::current().name()
265                        );
266
267                        99u8
268                    })
269                    .await
270            })
271        });
272
273        // wait for thread to dispatch
274        thread::sleep(Duration::from_secs(1));
275
276        println!("running dispatcher");
277        main_thread_dispatcher.run();
278        println!("dispatcher done running");
279
280        assert!(tokio_thread.join().unwrap() == 99);
281    }
282}