actix_http/ws/
dispatcher.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use actix_codec::{AsyncRead, AsyncWrite, Framed};
8use actix_service::{IntoService, Service};
9use pin_project_lite::pin_project;
10
11use super::{Codec, Frame, Message};
12
13pin_project! {
14    pub struct Dispatcher<S, T>
15    where
16        S: Service<Frame, Response = Message>,
17        S: 'static,
18        T: AsyncRead,
19        T: AsyncWrite,
20    {
21        #[pin]
22        inner: inner::Dispatcher<S, T, Codec, Message>,
23    }
24}
25
26impl<S, T> Dispatcher<S, T>
27where
28    T: AsyncRead + AsyncWrite,
29    S: Service<Frame, Response = Message>,
30    S::Future: 'static,
31    S::Error: 'static,
32{
33    pub fn new<F: IntoService<S, Frame>>(io: T, service: F) -> Self {
34        Dispatcher {
35            inner: inner::Dispatcher::new(Framed::new(io, Codec::new()), service),
36        }
37    }
38
39    pub fn with<F: IntoService<S, Frame>>(framed: Framed<T, Codec>, service: F) -> Self {
40        Dispatcher {
41            inner: inner::Dispatcher::new(framed, service),
42        }
43    }
44}
45
46impl<S, T> Future for Dispatcher<S, T>
47where
48    T: AsyncRead + AsyncWrite,
49    S: Service<Frame, Response = Message>,
50    S::Future: 'static,
51    S::Error: 'static,
52{
53    type Output = Result<(), inner::DispatcherError<S::Error, Codec, Message>>;
54
55    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
56        self.project().inner.poll(cx)
57    }
58}
59
60/// Framed dispatcher service and related utilities.
61mod inner {
62    // allow dead code since this mod was ripped from actix-utils
63    #![allow(dead_code)]
64
65    use core::{
66        fmt,
67        future::Future,
68        mem,
69        pin::Pin,
70        task::{Context, Poll},
71    };
72
73    use actix_codec::Framed;
74    use actix_service::{IntoService, Service};
75    use futures_core::stream::Stream;
76    use local_channel::mpsc;
77    use pin_project_lite::pin_project;
78    use tokio::io::{AsyncRead, AsyncWrite};
79    use tokio_util::codec::{Decoder, Encoder};
80    use tracing::debug;
81
82    use crate::{body::BoxBody, Response};
83
84    /// Framed transport errors
85    pub enum DispatcherError<E, U, I>
86    where
87        U: Encoder<I> + Decoder,
88    {
89        /// Inner service error.
90        Service(E),
91
92        /// Frame encoding error.
93        Encoder(<U as Encoder<I>>::Error),
94
95        /// Frame decoding error.
96        Decoder(<U as Decoder>::Error),
97    }
98
99    impl<E, U, I> From<E> for DispatcherError<E, U, I>
100    where
101        U: Encoder<I> + Decoder,
102    {
103        fn from(err: E) -> Self {
104            DispatcherError::Service(err)
105        }
106    }
107
108    impl<E, U, I> fmt::Debug for DispatcherError<E, U, I>
109    where
110        E: fmt::Debug,
111        U: Encoder<I> + Decoder,
112        <U as Encoder<I>>::Error: fmt::Debug,
113        <U as Decoder>::Error: fmt::Debug,
114    {
115        fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
116            match *self {
117                DispatcherError::Service(ref err) => {
118                    write!(fmt, "DispatcherError::Service({err:?})")
119                }
120                DispatcherError::Encoder(ref err) => {
121                    write!(fmt, "DispatcherError::Encoder({err:?})")
122                }
123                DispatcherError::Decoder(ref err) => {
124                    write!(fmt, "DispatcherError::Decoder({err:?})")
125                }
126            }
127        }
128    }
129
130    impl<E, U, I> fmt::Display for DispatcherError<E, U, I>
131    where
132        E: fmt::Display,
133        U: Encoder<I> + Decoder,
134        <U as Encoder<I>>::Error: fmt::Debug,
135        <U as Decoder>::Error: fmt::Debug,
136    {
137        fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
138            match *self {
139                DispatcherError::Service(ref err) => write!(fmt, "{err}"),
140                DispatcherError::Encoder(ref err) => write!(fmt, "{err:?}"),
141                DispatcherError::Decoder(ref err) => write!(fmt, "{err:?}"),
142            }
143        }
144    }
145
146    impl<E, U, I> From<DispatcherError<E, U, I>> for Response<BoxBody>
147    where
148        E: fmt::Debug + fmt::Display,
149        U: Encoder<I> + Decoder,
150        <U as Encoder<I>>::Error: fmt::Debug,
151        <U as Decoder>::Error: fmt::Debug,
152    {
153        fn from(err: DispatcherError<E, U, I>) -> Self {
154            Response::internal_server_error().set_body(BoxBody::new(err.to_string()))
155        }
156    }
157
158    /// Message type wrapper for signalling end of message stream.
159    pub enum Message<T> {
160        /// Message item.
161        Item(T),
162
163        /// Signal from service to flush all messages and stop processing.
164        Close,
165    }
166
167    pin_project! {
168        /// A future that reads frames from a [`Framed`] object and passes them to a [`Service`].
169        pub struct Dispatcher<S, T, U, I>
170        where
171            S: Service<<U as Decoder>::Item, Response = I>,
172            S::Error: 'static,
173            S::Future: 'static,
174            T: AsyncRead,
175            T: AsyncWrite,
176            U: Encoder<I>,
177            U: Decoder,
178            I: 'static,
179            <U as Encoder<I>>::Error: fmt::Debug,
180        {
181            service: S,
182            state: State<S, U, I>,
183            #[pin]
184            framed: Framed<T, U>,
185            rx: mpsc::Receiver<Result<Message<I>, S::Error>>,
186            tx: mpsc::Sender<Result<Message<I>, S::Error>>,
187        }
188    }
189
190    enum State<S, U, I>
191    where
192        S: Service<<U as Decoder>::Item>,
193        U: Encoder<I> + Decoder,
194    {
195        Processing,
196        Error(DispatcherError<S::Error, U, I>),
197        FramedError(DispatcherError<S::Error, U, I>),
198        FlushAndStop,
199        Stopping,
200    }
201
202    impl<S, U, I> State<S, U, I>
203    where
204        S: Service<<U as Decoder>::Item>,
205        U: Encoder<I> + Decoder,
206    {
207        fn take_error(&mut self) -> DispatcherError<S::Error, U, I> {
208            match mem::replace(self, State::Processing) {
209                State::Error(err) => err,
210                _ => panic!(),
211            }
212        }
213
214        fn take_framed_error(&mut self) -> DispatcherError<S::Error, U, I> {
215            match mem::replace(self, State::Processing) {
216                State::FramedError(err) => err,
217                _ => panic!(),
218            }
219        }
220    }
221
222    impl<S, T, U, I> Dispatcher<S, T, U, I>
223    where
224        S: Service<<U as Decoder>::Item, Response = I>,
225        S::Error: 'static,
226        S::Future: 'static,
227        T: AsyncRead + AsyncWrite,
228        U: Decoder + Encoder<I>,
229        I: 'static,
230        <U as Decoder>::Error: fmt::Debug,
231        <U as Encoder<I>>::Error: fmt::Debug,
232    {
233        /// Create new `Dispatcher`.
234        pub fn new<F>(framed: Framed<T, U>, service: F) -> Self
235        where
236            F: IntoService<S, <U as Decoder>::Item>,
237        {
238            let (tx, rx) = mpsc::channel();
239            Dispatcher {
240                framed,
241                rx,
242                tx,
243                service: service.into_service(),
244                state: State::Processing,
245            }
246        }
247
248        /// Construct new `Dispatcher` instance with customer `mpsc::Receiver`
249        pub fn with_rx<F>(
250            framed: Framed<T, U>,
251            service: F,
252            rx: mpsc::Receiver<Result<Message<I>, S::Error>>,
253        ) -> Self
254        where
255            F: IntoService<S, <U as Decoder>::Item>,
256        {
257            let tx = rx.sender();
258            Dispatcher {
259                framed,
260                rx,
261                tx,
262                service: service.into_service(),
263                state: State::Processing,
264            }
265        }
266
267        /// Get sender handle.
268        pub fn tx(&self) -> mpsc::Sender<Result<Message<I>, S::Error>> {
269            self.tx.clone()
270        }
271
272        /// Get reference to a service wrapped by `Dispatcher` instance.
273        pub fn service(&self) -> &S {
274            &self.service
275        }
276
277        /// Get mutable reference to a service wrapped by `Dispatcher` instance.
278        pub fn service_mut(&mut self) -> &mut S {
279            &mut self.service
280        }
281
282        /// Get reference to a framed instance wrapped by `Dispatcher` instance.
283        pub fn framed(&self) -> &Framed<T, U> {
284            &self.framed
285        }
286
287        /// Get mutable reference to a framed instance wrapped by `Dispatcher` instance.
288        pub fn framed_mut(&mut self) -> &mut Framed<T, U> {
289            &mut self.framed
290        }
291
292        /// Read from framed object.
293        fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool
294        where
295            S: Service<<U as Decoder>::Item, Response = I>,
296            S::Error: 'static,
297            S::Future: 'static,
298            T: AsyncRead + AsyncWrite,
299            U: Decoder + Encoder<I>,
300            I: 'static,
301            <U as Encoder<I>>::Error: fmt::Debug,
302        {
303            loop {
304                let this = self.as_mut().project();
305                match this.service.poll_ready(cx) {
306                    Poll::Ready(Ok(_)) => {
307                        let item = match this.framed.next_item(cx) {
308                            Poll::Ready(Some(Ok(el))) => el,
309                            Poll::Ready(Some(Err(err))) => {
310                                *this.state = State::FramedError(DispatcherError::Decoder(err));
311                                return true;
312                            }
313                            Poll::Pending => return false,
314                            Poll::Ready(None) => {
315                                *this.state = State::Stopping;
316                                return true;
317                            }
318                        };
319
320                        let tx = this.tx.clone();
321                        let fut = this.service.call(item);
322                        actix_rt::spawn(async move {
323                            let item = fut.await;
324                            let _ = tx.send(item.map(Message::Item));
325                        });
326                    }
327                    Poll::Pending => return false,
328                    Poll::Ready(Err(err)) => {
329                        *this.state = State::Error(DispatcherError::Service(err));
330                        return true;
331                    }
332                }
333            }
334        }
335
336        /// Write to framed object.
337        fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool
338        where
339            S: Service<<U as Decoder>::Item, Response = I>,
340            S::Error: 'static,
341            S::Future: 'static,
342            T: AsyncRead + AsyncWrite,
343            U: Decoder + Encoder<I>,
344            I: 'static,
345            <U as Encoder<I>>::Error: fmt::Debug,
346        {
347            loop {
348                let mut this = self.as_mut().project();
349                while !this.framed.is_write_buf_full() {
350                    match Pin::new(&mut this.rx).poll_next(cx) {
351                        Poll::Ready(Some(Ok(Message::Item(msg)))) => {
352                            if let Err(err) = this.framed.as_mut().write(msg) {
353                                *this.state = State::FramedError(DispatcherError::Encoder(err));
354                                return true;
355                            }
356                        }
357                        Poll::Ready(Some(Ok(Message::Close))) => {
358                            *this.state = State::FlushAndStop;
359                            return true;
360                        }
361                        Poll::Ready(Some(Err(err))) => {
362                            *this.state = State::Error(DispatcherError::Service(err));
363                            return true;
364                        }
365                        Poll::Ready(None) | Poll::Pending => break,
366                    }
367                }
368
369                if !this.framed.is_write_buf_empty() {
370                    match this.framed.flush(cx) {
371                        Poll::Pending => break,
372                        Poll::Ready(Ok(_)) => {}
373                        Poll::Ready(Err(err)) => {
374                            debug!("Error sending data: {:?}", err);
375                            *this.state = State::FramedError(DispatcherError::Encoder(err));
376                            return true;
377                        }
378                    }
379                } else {
380                    break;
381                }
382            }
383
384            false
385        }
386    }
387
388    impl<S, T, U, I> Future for Dispatcher<S, T, U, I>
389    where
390        S: Service<<U as Decoder>::Item, Response = I>,
391        S::Error: 'static,
392        S::Future: 'static,
393        T: AsyncRead + AsyncWrite,
394        U: Decoder + Encoder<I>,
395        I: 'static,
396        <U as Encoder<I>>::Error: fmt::Debug,
397        <U as Decoder>::Error: fmt::Debug,
398    {
399        type Output = Result<(), DispatcherError<S::Error, U, I>>;
400
401        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
402            loop {
403                let this = self.as_mut().project();
404
405                return match this.state {
406                    State::Processing => {
407                        if self.as_mut().poll_read(cx) || self.as_mut().poll_write(cx) {
408                            continue;
409                        } else {
410                            Poll::Pending
411                        }
412                    }
413                    State::Error(_) => {
414                        // flush write buffer
415                        if !this.framed.is_write_buf_empty() && this.framed.flush(cx).is_pending() {
416                            return Poll::Pending;
417                        }
418                        Poll::Ready(Err(this.state.take_error()))
419                    }
420                    State::FlushAndStop => {
421                        if !this.framed.is_write_buf_empty() {
422                            this.framed.flush(cx).map(|res| {
423                                if let Err(err) = res {
424                                    debug!("Error sending data: {:?}", err);
425                                }
426
427                                Ok(())
428                            })
429                        } else {
430                            Poll::Ready(Ok(()))
431                        }
432                    }
433                    State::FramedError(_) => Poll::Ready(Err(this.state.take_framed_error())),
434                    State::Stopping => Poll::Ready(Ok(())),
435                };
436            }
437        }
438    }
439}