1#![warn(clippy::pedantic)]
2#![allow(clippy::missing_panics_doc)]
3
4use std::{
5 fmt::Debug,
6 sync::mpsc::{channel, Receiver, Sender},
7};
8
9use futures::{
10 channel::oneshot,
11 executor::{LocalPool, LocalSpawner},
12 future::{BoxFuture, LocalBoxFuture},
13 prelude::*,
14 task::{LocalSpawnExt, SpawnExt},
15};
16
17type Task = BoxFuture<'static, ()>;
18
19type LocalTask = LocalBoxFuture<'static, ()>;
20
21pub struct Dispatcher {
23 task_sender: Sender<Task>,
24 task_receiver: Receiver<Task>,
25
26 local_task_sender: Sender<LocalTask>,
27 local_task_receiver: Receiver<LocalTask>,
28
29 future_pool: LocalPool,
30 future_spawner: LocalSpawner,
31}
32
33impl Default for Dispatcher {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl Dispatcher {
40 #[must_use]
41 pub fn new() -> Self {
42 let (task_sender, task_receiver) = channel();
43 let (local_task_sender, local_task_receiver) = channel();
44
45 let future_pool = LocalPool::new();
46 let future_spawner = future_pool.spawner();
47
48 Self {
49 task_sender,
50 task_receiver,
51 local_task_sender,
52 local_task_receiver,
53 future_pool,
54 future_spawner,
55 }
56 }
57
58 fn spawn_queued(&mut self) {
59 for task in self.local_task_receiver.try_iter() {
60 self.future_spawner.spawn_local(task).unwrap();
61 }
62
63 for task in self.task_receiver.try_iter() {
64 self.future_spawner.spawn(task).unwrap();
65 }
66 }
67
68 pub fn run(&mut self) {
72 self.spawn_queued();
73
74 self.future_pool.run();
75 }
76
77 pub fn run_until_stalled(&mut self) {
81 self.spawn_queued();
82
83 self.future_pool.run_until_stalled();
84 }
85
86 pub fn try_run_one(&mut self) -> bool {
90 self.spawn_queued();
91
92 self.future_pool.try_run_one()
93 }
94
95 pub fn get_handle(&self) -> DispatcherHandle {
97 DispatcherHandle {
98 task_sender: self.task_sender.clone(),
99 }
100 }
101
102 pub fn get_handle_local(&self) -> LocalDispatcherHandle {
104 LocalDispatcherHandle {
105 local_task_sender: self.local_task_sender.clone(),
106 }
107 }
108}
109
110#[derive(Clone)]
112pub struct DispatcherHandle {
113 task_sender: Sender<Task>,
114}
115
116impl DispatcherHandle {
117 pub fn spawn<F>(&mut self, future: F)
121 where
122 F: Future<Output = ()> + 'static + Send,
123 {
124 self.task_sender.send(future.boxed()).unwrap();
125 }
126
127 pub async fn dispatch<F, O>(&mut self, future: F) -> O
131 where
132 F: Future<Output = O> + 'static + Send,
133 O: 'static + Send + Debug,
134 {
135 let (return_sender, return_receiver) = oneshot::channel();
136
137 self.spawn(
138 async {
139 let ret = future.await;
140 return_sender.send(ret).unwrap();
141 }
142 .boxed(),
143 );
144
145 return_receiver.await.unwrap()
146 }
147}
148
149#[derive(Clone)]
151pub struct LocalDispatcherHandle {
152 local_task_sender: Sender<LocalTask>,
153}
154
155impl LocalDispatcherHandle {
156 pub fn spawn<F>(&mut self, future: F)
160 where
161 F: Future<Output = ()> + 'static,
162 {
163 self.local_task_sender.send(future.boxed_local()).unwrap();
164 }
165}
166
167#[cfg(test)]
168mod tests {
169
170 use std::{thread, time::Duration};
171
172 use futures::executor::block_on;
173
174 use super::*;
175
176 #[test]
177 fn it_works() {
178 println!(
179 "dispatcher running on thread {:?}",
180 thread::current().name()
181 );
182 let mut main_thread_dispatcher = Dispatcher::new();
183
184 let mut handle = main_thread_dispatcher.get_handle();
185 let t = thread::spawn(move || {
186 block_on(async move {
187 println!(
188 "new thread running on thread {:?}",
189 thread::current().name()
190 );
191
192 println!("dispatching");
193 handle
194 .dispatch(async {
195 println!(
196 "dispatched task running on thread {:?}",
197 thread::current().name()
198 );
199
200 99
201 })
202 .await
203 })
204 });
205
206 thread::sleep(Duration::from_secs(1));
208
209 println!("running dispatcher");
210 main_thread_dispatcher.run();
211 println!("dispatcher done running");
212
213 assert!(t.join().unwrap() == 99);
214 }
215
216 #[test]
217 fn it_works_local() {
218 println!(
219 "dispatcher running on thread {:?}",
220 thread::current().name()
221 );
222 let mut main_thread_dispatcher = Dispatcher::new();
223
224 let mut local_handle = main_thread_dispatcher.get_handle_local();
225
226 let mut inner_local_handle = local_handle.clone();
227 println!("spawning");
228 local_handle.spawn(async move {
229 println!(
230 "spawned task running on same thread {:?}",
231 thread::current().name()
232 );
233
234 inner_local_handle.spawn(async {
235 println!("inner running");
236 });
237 });
238
239 println!("running dispatcher");
240 main_thread_dispatcher.run();
241 println!("dispatcher done running");
242
243 println!("running dispatcher again");
244 main_thread_dispatcher.run();
245 println!("dispatcher done running again");
246 }
247
248 #[test]
249 fn test_tokio() {
250 let mut main_thread_dispatcher = Dispatcher::new();
253 let mut handle = main_thread_dispatcher.get_handle();
254
255 let tokio_thread = thread::spawn(move || {
256 let rt = tokio::runtime::Runtime::new().unwrap();
257
258 rt.block_on(async move {
259 println!("dispatching");
260 handle
261 .dispatch(async {
262 println!(
263 "dispatched task running on thread {:?}",
264 thread::current().name()
265 );
266
267 99u8
268 })
269 .await
270 })
271 });
272
273 thread::sleep(Duration::from_secs(1));
275
276 println!("running dispatcher");
277 main_thread_dispatcher.run();
278 println!("dispatcher done running");
279
280 assert!(tokio_thread.join().unwrap() == 99);
281 }
282}