actix_http/h1/
dispatcher.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io, mem, net,
6    pin::Pin,
7    rc::Rc,
8    task::{Context, Poll},
9};
10
11use actix_codec::{Framed, FramedParts};
12use actix_rt::time::sleep_until;
13use actix_service::Service;
14use bitflags::bitflags;
15use bytes::{Buf, BytesMut};
16use futures_core::ready;
17use pin_project_lite::pin_project;
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::{Decoder as _, Encoder as _};
20use tracing::{error, trace};
21
22use super::{
23    codec::Codec,
24    decoder::MAX_BUFFER_SIZE,
25    payload::{Payload, PayloadSender, PayloadStatus},
26    timer::TimerState,
27    Message, MessageType,
28};
29use crate::{
30    body::{BodySize, BoxBody, MessageBody},
31    config::ServiceConfig,
32    error::{DispatchError, ParseError, PayloadError},
33    service::HttpFlow,
34    Error, Extensions, OnConnectData, Request, Response, StatusCode,
35};
36
37const LW_BUFFER_SIZE: usize = 1024;
38const HW_BUFFER_SIZE: usize = 1024 * 8;
39const MAX_PIPELINED_MESSAGES: usize = 16;
40
41bitflags! {
42    #[derive(Debug, Clone, Copy)]
43    pub struct Flags: u8 {
44        /// Set when stream is read for first time.
45        const STARTED          = 0b0000_0001;
46
47        /// Set when full request-response cycle has occurred.
48        const FINISHED         = 0b0000_0010;
49
50        /// Set if connection is in keep-alive (inactive) state.
51        const KEEP_ALIVE       = 0b0000_0100;
52
53        /// Set if in shutdown procedure.
54        const SHUTDOWN         = 0b0000_1000;
55
56        /// Set if read-half is disconnected.
57        const READ_DISCONNECT  = 0b0001_0000;
58
59        /// Set if write-half is disconnected.
60        const WRITE_DISCONNECT = 0b0010_0000;
61    }
62}
63
64// there's 2 versions of Dispatcher state because of:
65// https://github.com/taiki-e/pin-project-lite/issues/3
66//
67// tl;dr: pin-project-lite doesn't play well with other attribute macros
68
69#[cfg(not(test))]
70pin_project! {
71    /// Dispatcher for HTTP/1.1 protocol
72    pub struct Dispatcher<T, S, B, X, U>
73    where
74        S: Service<Request>,
75        S::Error: Into<Response<BoxBody>>,
76
77        B: MessageBody,
78
79        X: Service<Request, Response = Request>,
80        X::Error: Into<Response<BoxBody>>,
81
82        U: Service<(Request, Framed<T, Codec>), Response = ()>,
83        U::Error: fmt::Display,
84    {
85        #[pin]
86        inner: DispatcherState<T, S, B, X, U>,
87    }
88}
89
90#[cfg(test)]
91pin_project! {
92    /// Dispatcher for HTTP/1.1 protocol
93    pub struct Dispatcher<T, S, B, X, U>
94    where
95        S: Service<Request>,
96        S::Error: Into<Response<BoxBody>>,
97
98        B: MessageBody,
99
100        X: Service<Request, Response = Request>,
101        X::Error: Into<Response<BoxBody>>,
102
103        U: Service<(Request, Framed<T, Codec>), Response = ()>,
104        U::Error: fmt::Display,
105    {
106        #[pin]
107        pub(super) inner: DispatcherState<T, S, B, X, U>,
108
109        // used in tests
110        pub(super) poll_count: u64,
111    }
112}
113
114pin_project! {
115    #[project = DispatcherStateProj]
116    pub(super) enum DispatcherState<T, S, B, X, U>
117    where
118        S: Service<Request>,
119        S::Error: Into<Response<BoxBody>>,
120
121        B: MessageBody,
122
123        X: Service<Request, Response = Request>,
124        X::Error: Into<Response<BoxBody>>,
125
126        U: Service<(Request, Framed<T, Codec>), Response = ()>,
127        U::Error: fmt::Display,
128    {
129        Normal { #[pin] inner: InnerDispatcher<T, S, B, X, U> },
130        Upgrade { #[pin] fut: U::Future },
131    }
132}
133
134pin_project! {
135    #[project = InnerDispatcherProj]
136    pub(super) struct InnerDispatcher<T, S, B, X, U>
137    where
138        S: Service<Request>,
139        S::Error: Into<Response<BoxBody>>,
140
141        B: MessageBody,
142
143        X: Service<Request, Response = Request>,
144        X::Error: Into<Response<BoxBody>>,
145
146        U: Service<(Request, Framed<T, Codec>), Response = ()>,
147        U::Error: fmt::Display,
148    {
149        flow: Rc<HttpFlow<S, X, U>>,
150        pub(super) flags: Flags,
151        peer_addr: Option<net::SocketAddr>,
152        conn_data: Option<Rc<Extensions>>,
153        config: ServiceConfig,
154        error: Option<DispatchError>,
155
156        #[pin]
157        pub(super) state: State<S, B, X>,
158        // when Some(_) dispatcher is in state of receiving request payload
159        payload: Option<PayloadSender>,
160        messages: VecDeque<DispatcherMessage>,
161
162        head_timer: TimerState,
163        ka_timer: TimerState,
164        shutdown_timer: TimerState,
165
166        pub(super) io: Option<T>,
167        read_buf: BytesMut,
168        write_buf: BytesMut,
169        codec: Codec,
170    }
171}
172
173enum DispatcherMessage {
174    Item(Request),
175    Upgrade(Request),
176    Error(Response<()>),
177}
178
179pin_project! {
180    #[project = StateProj]
181    pub(super) enum State<S, B, X>
182    where
183        S: Service<Request>,
184        X: Service<Request, Response = Request>,
185        B: MessageBody,
186    {
187        None,
188        ExpectCall { #[pin] fut: X::Future },
189        ServiceCall { #[pin] fut: S::Future },
190        SendPayload { #[pin] body: B },
191        SendErrorPayload { #[pin] body: BoxBody },
192    }
193}
194
195impl<S, B, X> State<S, B, X>
196where
197    S: Service<Request>,
198    X: Service<Request, Response = Request>,
199    B: MessageBody,
200{
201    pub(super) fn is_none(&self) -> bool {
202        matches!(self, State::None)
203    }
204}
205
206impl<S, B, X> fmt::Debug for State<S, B, X>
207where
208    S: Service<Request>,
209    X: Service<Request, Response = Request>,
210    B: MessageBody,
211{
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        match self {
214            Self::None => write!(f, "State::None"),
215            Self::ExpectCall { .. } => f.debug_struct("State::ExpectCall").finish_non_exhaustive(),
216            Self::ServiceCall { .. } => {
217                f.debug_struct("State::ServiceCall").finish_non_exhaustive()
218            }
219            Self::SendPayload { .. } => {
220                f.debug_struct("State::SendPayload").finish_non_exhaustive()
221            }
222            Self::SendErrorPayload { .. } => f
223                .debug_struct("State::SendErrorPayload")
224                .finish_non_exhaustive(),
225        }
226    }
227}
228
229#[derive(Debug)]
230enum PollResponse {
231    Upgrade(Request),
232    DoNothing,
233    DrainWriteBuf,
234}
235
236impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
237where
238    T: AsyncRead + AsyncWrite + Unpin,
239
240    S: Service<Request>,
241    S::Error: Into<Response<BoxBody>>,
242    S::Response: Into<Response<B>>,
243
244    B: MessageBody,
245
246    X: Service<Request, Response = Request>,
247    X::Error: Into<Response<BoxBody>>,
248
249    U: Service<(Request, Framed<T, Codec>), Response = ()>,
250    U::Error: fmt::Display,
251{
252    /// Create HTTP/1 dispatcher.
253    pub(crate) fn new(
254        io: T,
255        flow: Rc<HttpFlow<S, X, U>>,
256        config: ServiceConfig,
257        peer_addr: Option<net::SocketAddr>,
258        conn_data: OnConnectData,
259    ) -> Self {
260        Dispatcher {
261            inner: DispatcherState::Normal {
262                inner: InnerDispatcher {
263                    flow,
264                    flags: Flags::empty(),
265                    peer_addr,
266                    conn_data: conn_data.0.map(Rc::new),
267                    config: config.clone(),
268                    error: None,
269
270                    state: State::None,
271                    payload: None,
272                    messages: VecDeque::new(),
273
274                    head_timer: TimerState::new(config.client_request_deadline().is_some()),
275                    ka_timer: TimerState::new(config.keep_alive().enabled()),
276                    shutdown_timer: TimerState::new(config.client_disconnect_deadline().is_some()),
277
278                    io: Some(io),
279                    read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
280                    write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
281                    codec: Codec::new(config),
282                },
283            },
284
285            #[cfg(test)]
286            poll_count: 0,
287        }
288    }
289}
290
291impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
292where
293    T: AsyncRead + AsyncWrite + Unpin,
294
295    S: Service<Request>,
296    S::Error: Into<Response<BoxBody>>,
297    S::Response: Into<Response<B>>,
298
299    B: MessageBody,
300
301    X: Service<Request, Response = Request>,
302    X::Error: Into<Response<BoxBody>>,
303
304    U: Service<(Request, Framed<T, Codec>), Response = ()>,
305    U::Error: fmt::Display,
306{
307    fn can_read(&self, cx: &mut Context<'_>) -> bool {
308        if self.flags.contains(Flags::READ_DISCONNECT) {
309            false
310        } else if let Some(ref info) = self.payload {
311            info.need_read(cx) == PayloadStatus::Read
312        } else {
313            true
314        }
315    }
316
317    fn client_disconnected(self: Pin<&mut Self>) {
318        let this = self.project();
319
320        this.flags
321            .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
322
323        if let Some(mut payload) = this.payload.take() {
324            payload.set_error(PayloadError::Incomplete(None));
325        }
326    }
327
328    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
329        let InnerDispatcherProj { io, write_buf, .. } = self.project();
330        let mut io = Pin::new(io.as_mut().unwrap());
331
332        let len = write_buf.len();
333        let mut written = 0;
334
335        while written < len {
336            match io.as_mut().poll_write(cx, &write_buf[written..])? {
337                Poll::Ready(0) => {
338                    error!("write zero; closing");
339                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, "")));
340                }
341
342                Poll::Ready(n) => written += n,
343
344                Poll::Pending => {
345                    write_buf.advance(written);
346                    return Poll::Pending;
347                }
348            }
349        }
350
351        // everything has written to I/O; clear buffer
352        write_buf.clear();
353
354        // flush the I/O and check if get blocked
355        io.poll_flush(cx)
356    }
357
358    fn send_response_inner(
359        self: Pin<&mut Self>,
360        res: Response<()>,
361        body: &impl MessageBody,
362    ) -> Result<BodySize, DispatchError> {
363        let this = self.project();
364
365        let size = body.size();
366
367        this.codec
368            .encode(Message::Item((res, size)), this.write_buf)
369            .map_err(|err| {
370                if let Some(mut payload) = this.payload.take() {
371                    payload.set_error(PayloadError::Incomplete(None));
372                }
373
374                DispatchError::Io(err)
375            })?;
376
377        Ok(size)
378    }
379
380    fn send_response(
381        mut self: Pin<&mut Self>,
382        res: Response<()>,
383        body: B,
384    ) -> Result<(), DispatchError> {
385        let size = self.as_mut().send_response_inner(res, &body)?;
386        let mut this = self.project();
387        this.state.set(match size {
388            BodySize::None | BodySize::Sized(0) => {
389                this.flags.insert(Flags::FINISHED);
390                State::None
391            }
392            _ => State::SendPayload { body },
393        });
394
395        Ok(())
396    }
397
398    fn send_error_response(
399        mut self: Pin<&mut Self>,
400        res: Response<()>,
401        body: BoxBody,
402    ) -> Result<(), DispatchError> {
403        let size = self.as_mut().send_response_inner(res, &body)?;
404        let mut this = self.project();
405        this.state.set(match size {
406            BodySize::None | BodySize::Sized(0) => {
407                this.flags.insert(Flags::FINISHED);
408                State::None
409            }
410            _ => State::SendErrorPayload { body },
411        });
412
413        Ok(())
414    }
415
416    fn send_continue(self: Pin<&mut Self>) {
417        self.project()
418            .write_buf
419            .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
420    }
421
422    fn poll_response(
423        mut self: Pin<&mut Self>,
424        cx: &mut Context<'_>,
425    ) -> Result<PollResponse, DispatchError> {
426        'res: loop {
427            let mut this = self.as_mut().project();
428            match this.state.as_mut().project() {
429                // no future is in InnerDispatcher state; pop next message
430                StateProj::None => match this.messages.pop_front() {
431                    // handle request message
432                    Some(DispatcherMessage::Item(req)) => {
433                        // Handle `EXPECT: 100-Continue` header
434                        if req.head().expect() {
435                            // set InnerDispatcher state and continue loop to poll it
436                            let fut = this.flow.expect.call(req);
437                            this.state.set(State::ExpectCall { fut });
438                        } else {
439                            // set InnerDispatcher state and continue loop to poll it
440                            let fut = this.flow.service.call(req);
441                            this.state.set(State::ServiceCall { fut });
442                        };
443                    }
444
445                    // handle error message
446                    Some(DispatcherMessage::Error(res)) => {
447                        // send_response would update InnerDispatcher state to SendPayload or None
448                        // (If response body is empty)
449                        // continue loop to poll it
450                        self.as_mut().send_error_response(res, BoxBody::new(()))?;
451                    }
452
453                    // return with upgrade request and poll it exclusively
454                    Some(DispatcherMessage::Upgrade(req)) => return Ok(PollResponse::Upgrade(req)),
455
456                    // all messages are dealt with
457                    None => {
458                        // start keep-alive if last request allowed it
459                        this.flags.set(Flags::KEEP_ALIVE, this.codec.keep_alive());
460
461                        return Ok(PollResponse::DoNothing);
462                    }
463                },
464
465                StateProj::ServiceCall { fut } => {
466                    match fut.poll(cx) {
467                        // service call resolved. send response.
468                        Poll::Ready(Ok(res)) => {
469                            let (res, body) = res.into().replace_body(());
470                            self.as_mut().send_response(res, body)?;
471                        }
472
473                        // send service call error as response
474                        Poll::Ready(Err(err)) => {
475                            let res: Response<BoxBody> = err.into();
476                            let (res, body) = res.replace_body(());
477                            self.as_mut().send_error_response(res, body)?;
478                        }
479
480                        // service call pending and could be waiting for more chunk messages
481                        // (pipeline message limit and/or payload can_read limit)
482                        Poll::Pending => {
483                            // no new message is decoded and no new payload is fed
484                            // nothing to do except waiting for new incoming data from client
485                            if !self.as_mut().poll_request(cx)? {
486                                return Ok(PollResponse::DoNothing);
487                            }
488                            // else loop
489                        }
490                    }
491                }
492
493                StateProj::SendPayload { mut body } => {
494                    // keep populate writer buffer until buffer size limit hit,
495                    // get blocked or finished.
496                    while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
497                        match body.as_mut().poll_next(cx) {
498                            Poll::Ready(Some(Ok(item))) => {
499                                this.codec
500                                    .encode(Message::Chunk(Some(item)), this.write_buf)?;
501                            }
502
503                            Poll::Ready(None) => {
504                                this.codec.encode(Message::Chunk(None), this.write_buf)?;
505
506                                // payload stream finished.
507                                // set state to None and handle next message
508                                this.state.set(State::None);
509                                this.flags.insert(Flags::FINISHED);
510
511                                continue 'res;
512                            }
513
514                            Poll::Ready(Some(Err(err))) => {
515                                let err = err.into();
516                                tracing::error!("Response payload stream error: {err:?}");
517                                this.flags.insert(Flags::FINISHED);
518                                return Err(DispatchError::Body(err));
519                            }
520
521                            Poll::Pending => return Ok(PollResponse::DoNothing),
522                        }
523                    }
524
525                    // buffer is beyond max size
526                    // return and try to write the whole buffer to I/O stream.
527                    return Ok(PollResponse::DrainWriteBuf);
528                }
529
530                StateProj::SendErrorPayload { mut body } => {
531                    // TODO: de-dupe impl with SendPayload
532
533                    // keep populate writer buffer until buffer size limit hit,
534                    // get blocked or finished.
535                    while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
536                        match body.as_mut().poll_next(cx) {
537                            Poll::Ready(Some(Ok(item))) => {
538                                this.codec
539                                    .encode(Message::Chunk(Some(item)), this.write_buf)?;
540                            }
541
542                            Poll::Ready(None) => {
543                                this.codec.encode(Message::Chunk(None), this.write_buf)?;
544
545                                // payload stream finished
546                                // set state to None and handle next message
547                                this.state.set(State::None);
548                                this.flags.insert(Flags::FINISHED);
549
550                                continue 'res;
551                            }
552
553                            Poll::Ready(Some(Err(err))) => {
554                                tracing::error!("Response payload stream error: {err:?}");
555                                this.flags.insert(Flags::FINISHED);
556                                return Err(DispatchError::Body(
557                                    Error::new_body().with_cause(err).into(),
558                                ));
559                            }
560
561                            Poll::Pending => return Ok(PollResponse::DoNothing),
562                        }
563                    }
564
565                    // buffer is beyond max size
566                    // return and try to write the whole buffer to stream
567                    return Ok(PollResponse::DrainWriteBuf);
568                }
569
570                StateProj::ExpectCall { fut } => {
571                    trace!("  calling expect service");
572
573                    match fut.poll(cx) {
574                        // expect resolved. write continue to buffer and set InnerDispatcher state
575                        // to service call.
576                        Poll::Ready(Ok(req)) => {
577                            this.write_buf
578                                .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
579                            let fut = this.flow.service.call(req);
580                            this.state.set(State::ServiceCall { fut });
581                        }
582
583                        // send expect error as response
584                        Poll::Ready(Err(err)) => {
585                            let res: Response<BoxBody> = err.into();
586                            let (res, body) = res.replace_body(());
587                            self.as_mut().send_error_response(res, body)?;
588                        }
589
590                        // expect must be solved before progress can be made.
591                        Poll::Pending => return Ok(PollResponse::DoNothing),
592                    }
593                }
594            }
595        }
596    }
597
598    fn handle_request(
599        mut self: Pin<&mut Self>,
600        req: Request,
601        cx: &mut Context<'_>,
602    ) -> Result<(), DispatchError> {
603        // initialize dispatcher state
604        {
605            let mut this = self.as_mut().project();
606
607            // Handle `EXPECT: 100-Continue` header
608            if req.head().expect() {
609                // set dispatcher state to call expect handler
610                let fut = this.flow.expect.call(req);
611                this.state.set(State::ExpectCall { fut });
612            } else {
613                // set dispatcher state to call service handler
614                let fut = this.flow.service.call(req);
615                this.state.set(State::ServiceCall { fut });
616            };
617        };
618
619        // eagerly poll the future once (or twice if expect is resolved immediately).
620        loop {
621            match self.as_mut().project().state.project() {
622                StateProj::ExpectCall { fut } => {
623                    match fut.poll(cx) {
624                        // expect is resolved; continue loop and poll the service call branch.
625                        Poll::Ready(Ok(req)) => {
626                            self.as_mut().send_continue();
627
628                            let mut this = self.as_mut().project();
629                            let fut = this.flow.service.call(req);
630                            this.state.set(State::ServiceCall { fut });
631
632                            continue;
633                        }
634
635                        // future is error; send response and return a result
636                        // on success to notify the dispatcher a new state is set and the outer loop
637                        // should be continued
638                        Poll::Ready(Err(err)) => {
639                            let res: Response<BoxBody> = err.into();
640                            let (res, body) = res.replace_body(());
641                            return self.send_error_response(res, body);
642                        }
643
644                        // future is pending; return Ok(()) to notify that a new state is
645                        // set and the outer loop should be continue.
646                        Poll::Pending => return Ok(()),
647                    }
648                }
649
650                StateProj::ServiceCall { fut } => {
651                    // return no matter the service call future's result.
652                    return match fut.poll(cx) {
653                        // Future is resolved. Send response and return a result. On success
654                        // to notify the dispatcher a new state is set and the outer loop
655                        // should be continue.
656                        Poll::Ready(Ok(res)) => {
657                            let (res, body) = res.into().replace_body(());
658                            self.as_mut().send_response(res, body)
659                        }
660
661                        // see the comment on ExpectCall state branch's Pending
662                        Poll::Pending => Ok(()),
663
664                        // see the comment on ExpectCall state branch's Ready(Err(_))
665                        Poll::Ready(Err(err)) => {
666                            let res: Response<BoxBody> = err.into();
667                            let (res, body) = res.replace_body(());
668                            self.as_mut().send_error_response(res, body)
669                        }
670                    };
671                }
672
673                _ => {
674                    unreachable!("State must be set to ServiceCall or ExceptCall in handle_request")
675                }
676            }
677        }
678    }
679
680    /// Process one incoming request.
681    ///
682    /// Returns true if any meaningful work was done.
683    fn poll_request(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
684        let pipeline_queue_full = self.messages.len() >= MAX_PIPELINED_MESSAGES;
685        let can_not_read = !self.can_read(cx);
686
687        // limit amount of non-processed requests
688        if pipeline_queue_full || can_not_read {
689            return Ok(false);
690        }
691
692        let mut this = self.as_mut().project();
693
694        let mut updated = false;
695
696        // decode from read buf as many full requests as possible
697        loop {
698            match this.codec.decode(this.read_buf) {
699                Ok(Some(msg)) => {
700                    updated = true;
701
702                    match msg {
703                        Message::Item(mut req) => {
704                            // head timer only applies to first request on connection
705                            this.head_timer.clear(line!());
706
707                            req.head_mut().peer_addr = *this.peer_addr;
708
709                            req.conn_data.clone_from(this.conn_data);
710
711                            match this.codec.message_type() {
712                                // request has no payload
713                                MessageType::None => {}
714
715                                // Request is upgradable. Add upgrade message and break.
716                                // Everything remaining in read buffer will be handed to
717                                // upgraded Request.
718                                MessageType::Stream if this.flow.upgrade.is_some() => {
719                                    this.messages.push_back(DispatcherMessage::Upgrade(req));
720                                    break;
721                                }
722
723                                // request is not upgradable
724                                MessageType::Payload | MessageType::Stream => {
725                                    // PayloadSender and Payload are smart pointers share the
726                                    // same state. PayloadSender is attached to dispatcher and used
727                                    // to sink new chunked request data to state. Payload is
728                                    // attached to Request and passed to Service::call where the
729                                    // state can be collected and consumed.
730                                    let (sender, payload) = Payload::create(false);
731                                    *req.payload() = crate::Payload::H1 { payload };
732                                    *this.payload = Some(sender);
733                                }
734                            }
735
736                            // handle request early when no future in InnerDispatcher state.
737                            if this.state.is_none() {
738                                self.as_mut().handle_request(req, cx)?;
739                                this = self.as_mut().project();
740                            } else {
741                                this.messages.push_back(DispatcherMessage::Item(req));
742                            }
743                        }
744
745                        Message::Chunk(Some(chunk)) => {
746                            if let Some(ref mut payload) = this.payload {
747                                payload.feed_data(chunk);
748                            } else {
749                                error!("Internal server error: unexpected payload chunk");
750                                this.flags.insert(Flags::READ_DISCONNECT);
751                                this.messages.push_back(DispatcherMessage::Error(
752                                    Response::internal_server_error().drop_body(),
753                                ));
754                                *this.error = Some(DispatchError::InternalError);
755                                break;
756                            }
757                        }
758
759                        Message::Chunk(None) => {
760                            if let Some(mut payload) = this.payload.take() {
761                                payload.feed_eof();
762                            } else {
763                                error!("Internal server error: unexpected eof");
764                                this.flags.insert(Flags::READ_DISCONNECT);
765                                this.messages.push_back(DispatcherMessage::Error(
766                                    Response::internal_server_error().drop_body(),
767                                ));
768                                *this.error = Some(DispatchError::InternalError);
769                                break;
770                            }
771                        }
772                    }
773                }
774
775                // decode is partial and buffer is not full yet
776                // break and wait for more read
777                Ok(None) => break,
778
779                Err(ParseError::Io(err)) => {
780                    trace!("I/O error: {}", &err);
781                    self.as_mut().client_disconnected();
782                    this = self.as_mut().project();
783                    *this.error = Some(DispatchError::Io(err));
784                    break;
785                }
786
787                Err(ParseError::TooLarge) => {
788                    trace!("request head was too big; returning 431 response");
789
790                    if let Some(mut payload) = this.payload.take() {
791                        payload.set_error(PayloadError::Overflow);
792                    }
793
794                    // request heads that overflow buffer size return a 431 error
795                    this.messages
796                        .push_back(DispatcherMessage::Error(Response::with_body(
797                            StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
798                            (),
799                        )));
800
801                    this.flags.insert(Flags::READ_DISCONNECT);
802                    *this.error = Some(ParseError::TooLarge.into());
803
804                    break;
805                }
806
807                Err(err) => {
808                    trace!("parse error {}", &err);
809
810                    if let Some(mut payload) = this.payload.take() {
811                        payload.set_error(PayloadError::EncodingCorrupted);
812                    }
813
814                    // malformed requests should be responded with 400
815                    this.messages.push_back(DispatcherMessage::Error(
816                        Response::bad_request().drop_body(),
817                    ));
818
819                    this.flags.insert(Flags::READ_DISCONNECT);
820                    *this.error = Some(err.into());
821                    break;
822                }
823            }
824        }
825
826        Ok(updated)
827    }
828
829    fn poll_head_timer(
830        mut self: Pin<&mut Self>,
831        cx: &mut Context<'_>,
832    ) -> Result<(), DispatchError> {
833        let this = self.as_mut().project();
834
835        if let TimerState::Active { timer } = this.head_timer {
836            if timer.as_mut().poll(cx).is_ready() {
837                // timeout on first request (slow request) return 408
838
839                trace!("timed out on slow request; replying with 408 and closing connection");
840
841                let _ = self.as_mut().send_error_response(
842                    Response::with_body(StatusCode::REQUEST_TIMEOUT, ()),
843                    BoxBody::new(()),
844                );
845
846                self.project().flags.insert(Flags::SHUTDOWN);
847            }
848        };
849
850        Ok(())
851    }
852
853    fn poll_ka_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
854        let this = self.as_mut().project();
855        if let TimerState::Active { timer } = this.ka_timer {
856            debug_assert!(
857                this.flags.contains(Flags::KEEP_ALIVE),
858                "keep-alive flag should be set when timer is active",
859            );
860            debug_assert!(
861                this.state.is_none(),
862                "dispatcher should not be in keep-alive phase if state is not none: {:?}",
863                this.state,
864            );
865
866            // Assert removed by @robjtede on account of issue #2655. There are cases where an I/O
867            // flush can be pending after entering the keep-alive state causing the subsequent flush
868            // wake up to panic here. This appears to be a Linux-only problem. Leaving original code
869            // below for posterity because a simple and reliable test could not be found to trigger
870            // the behavior.
871            // debug_assert!(
872            //     this.write_buf.is_empty(),
873            //     "dispatcher should not be in keep-alive phase if write_buf is not empty",
874            // );
875
876            // keep-alive timer has timed out
877            if timer.as_mut().poll(cx).is_ready() {
878                // no tasks at hand
879                trace!("timer timed out; closing connection");
880                this.flags.insert(Flags::SHUTDOWN);
881
882                if let Some(deadline) = this.config.client_disconnect_deadline() {
883                    // start shutdown timeout if enabled
884                    this.shutdown_timer
885                        .set_and_init(cx, sleep_until(deadline.into()), line!());
886                } else {
887                    // no shutdown timeout, drop socket
888                    this.flags.insert(Flags::WRITE_DISCONNECT);
889                }
890            }
891        }
892
893        Ok(())
894    }
895
896    fn poll_shutdown_timer(
897        mut self: Pin<&mut Self>,
898        cx: &mut Context<'_>,
899    ) -> Result<(), DispatchError> {
900        let this = self.as_mut().project();
901        if let TimerState::Active { timer } = this.shutdown_timer {
902            debug_assert!(
903                this.flags.contains(Flags::SHUTDOWN),
904                "shutdown flag should be set when timer is active",
905            );
906
907            // timed-out during shutdown; drop connection
908            if timer.as_mut().poll(cx).is_ready() {
909                trace!("timed-out during shutdown");
910                return Err(DispatchError::DisconnectTimeout);
911            }
912        }
913
914        Ok(())
915    }
916
917    /// Poll head, keep-alive, and disconnect timer.
918    fn poll_timers(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
919        self.as_mut().poll_head_timer(cx)?;
920        self.as_mut().poll_ka_timer(cx)?;
921        self.as_mut().poll_shutdown_timer(cx)?;
922
923        Ok(())
924    }
925
926    /// Returns true when I/O stream can be disconnected after write to it.
927    ///
928    /// It covers these conditions:
929    /// - `std::io::ErrorKind::ConnectionReset` after partial read;
930    /// - all data read done.
931    #[inline(always)] // TODO: bench this inline
932    fn read_available(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
933        let this = self.project();
934
935        if this.flags.contains(Flags::READ_DISCONNECT) {
936            return Ok(false);
937        };
938
939        let mut io = Pin::new(this.io.as_mut().unwrap());
940
941        let mut read_some = false;
942
943        loop {
944            // Return early when read buf exceed decoder's max buffer size.
945            if this.read_buf.len() >= MAX_BUFFER_SIZE {
946                // At this point it's not known IO stream is still scheduled to be waked up so
947                // force wake up dispatcher just in case.
948                //
949                // Reason:
950                // AsyncRead mostly would only have guarantee wake up when the poll_read
951                // return Poll::Pending.
952                //
953                // Case:
954                // When read_buf is beyond max buffer size the early return could be successfully
955                // be parsed as a new Request. This case would not generate ParseError::TooLarge and
956                // at this point IO stream is not fully read to Pending and would result in
957                // dispatcher stuck until timeout (keep-alive).
958                //
959                // Note:
960                // This is a perf choice to reduce branch on <Request as MessageType>::decode.
961                //
962                // A Request head too large to parse is only checked on `httparse::Status::Partial`.
963
964                match this.payload {
965                    // When dispatcher has a payload the responsibility of wake ups is shifted to
966                    // `h1::payload::Payload` unless the payload is needing a read, in which case it
967                    // might not have access to the waker and could result in the dispatcher
968                    // getting stuck until timeout.
969                    //
970                    // Reason:
971                    // Self wake up when there is payload would waste poll and/or result in
972                    // over read.
973                    //
974                    // Case:
975                    // When payload is (partial) dropped by user there is no need to do
976                    // read anymore. At this case read_buf could always remain beyond
977                    // MAX_BUFFER_SIZE and self wake up would be busy poll dispatcher and
978                    // waste resources.
979                    Some(ref p) if p.need_read(cx) != PayloadStatus::Read => {}
980                    _ => cx.waker().wake_by_ref(),
981                }
982
983                return Ok(false);
984            }
985
986            // grow buffer if necessary.
987            let remaining = this.read_buf.capacity() - this.read_buf.len();
988            if remaining < LW_BUFFER_SIZE {
989                this.read_buf.reserve(HW_BUFFER_SIZE - remaining);
990            }
991
992            match tokio_util::io::poll_read_buf(io.as_mut(), cx, this.read_buf) {
993                Poll::Ready(Ok(n)) => {
994                    this.flags.remove(Flags::FINISHED);
995
996                    if n == 0 {
997                        return Ok(true);
998                    }
999
1000                    read_some = true;
1001                }
1002
1003                Poll::Pending => {
1004                    return Ok(false);
1005                }
1006
1007                Poll::Ready(Err(err)) => {
1008                    return match err.kind() {
1009                        // convert WouldBlock error to the same as Pending return
1010                        io::ErrorKind::WouldBlock => Ok(false),
1011
1012                        // connection reset after partial read
1013                        io::ErrorKind::ConnectionReset if read_some => Ok(true),
1014
1015                        _ => Err(DispatchError::Io(err)),
1016                    };
1017                }
1018            }
1019        }
1020    }
1021
1022    /// call upgrade service with request.
1023    fn upgrade(self: Pin<&mut Self>, req: Request) -> U::Future {
1024        let this = self.project();
1025        let mut parts = FramedParts::with_read_buf(
1026            this.io.take().unwrap(),
1027            mem::take(this.codec),
1028            mem::take(this.read_buf),
1029        );
1030        parts.write_buf = mem::take(this.write_buf);
1031        let framed = Framed::from_parts(parts);
1032        this.flow.upgrade.as_ref().unwrap().call((req, framed))
1033    }
1034}
1035
1036impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
1037where
1038    T: AsyncRead + AsyncWrite + Unpin,
1039
1040    S: Service<Request>,
1041    S::Error: Into<Response<BoxBody>>,
1042    S::Response: Into<Response<B>>,
1043
1044    B: MessageBody,
1045
1046    X: Service<Request, Response = Request>,
1047    X::Error: Into<Response<BoxBody>>,
1048
1049    U: Service<(Request, Framed<T, Codec>), Response = ()>,
1050    U::Error: fmt::Display,
1051{
1052    type Output = Result<(), DispatchError>;
1053
1054    #[inline]
1055    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1056        let this = self.as_mut().project();
1057
1058        #[cfg(test)]
1059        {
1060            *this.poll_count += 1;
1061        }
1062
1063        match this.inner.project() {
1064            DispatcherStateProj::Upgrade { fut: upgrade } => upgrade.poll(cx).map_err(|err| {
1065                error!("Upgrade handler error: {}", err);
1066                DispatchError::Upgrade
1067            }),
1068
1069            DispatcherStateProj::Normal { mut inner } => {
1070                trace!("start flags: {:?}", &inner.flags);
1071
1072                trace_timer_states(
1073                    "start",
1074                    &inner.head_timer,
1075                    &inner.ka_timer,
1076                    &inner.shutdown_timer,
1077                );
1078
1079                inner.as_mut().poll_timers(cx)?;
1080
1081                let poll = if inner.flags.contains(Flags::SHUTDOWN) {
1082                    if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1083                        Poll::Ready(Ok(()))
1084                    } else {
1085                        // flush buffer and wait on blocked
1086                        ready!(inner.as_mut().poll_flush(cx))?;
1087                        Pin::new(inner.as_mut().project().io.as_mut().unwrap())
1088                            .poll_shutdown(cx)
1089                            .map_err(DispatchError::from)
1090                    }
1091                } else {
1092                    // read from I/O stream and fill read buffer
1093                    let should_disconnect = inner.as_mut().read_available(cx)?;
1094
1095                    // after reading something from stream, clear keep-alive timer
1096                    if !inner.read_buf.is_empty() && inner.flags.contains(Flags::KEEP_ALIVE) {
1097                        let inner = inner.as_mut().project();
1098                        inner.flags.remove(Flags::KEEP_ALIVE);
1099                        inner.ka_timer.clear(line!());
1100                    }
1101
1102                    if !inner.flags.contains(Flags::STARTED) {
1103                        inner.as_mut().project().flags.insert(Flags::STARTED);
1104
1105                        if let Some(deadline) = inner.config.client_request_deadline() {
1106                            inner.as_mut().project().head_timer.set_and_init(
1107                                cx,
1108                                sleep_until(deadline.into()),
1109                                line!(),
1110                            );
1111                        }
1112                    }
1113
1114                    inner.as_mut().poll_request(cx)?;
1115
1116                    if should_disconnect {
1117                        // I/O stream should to be closed
1118                        let inner = inner.as_mut().project();
1119                        inner.flags.insert(Flags::READ_DISCONNECT);
1120                        if let Some(mut payload) = inner.payload.take() {
1121                            payload.feed_eof();
1122                        }
1123                    };
1124
1125                    loop {
1126                        // poll response to populate write buffer
1127                        // drain indicates whether write buffer should be emptied before next run
1128                        let drain = match inner.as_mut().poll_response(cx)? {
1129                            PollResponse::DrainWriteBuf => true,
1130
1131                            PollResponse::DoNothing => {
1132                                // KEEP_ALIVE is set in send_response_inner if client allows it
1133                                // FINISHED is set after writing last chunk of response
1134                                if inner.flags.contains(Flags::KEEP_ALIVE | Flags::FINISHED) {
1135                                    if let Some(timer) = inner.config.keep_alive_deadline() {
1136                                        inner.as_mut().project().ka_timer.set_and_init(
1137                                            cx,
1138                                            sleep_until(timer.into()),
1139                                            line!(),
1140                                        );
1141                                    }
1142                                }
1143
1144                                false
1145                            }
1146
1147                            // upgrade request and goes Upgrade variant of DispatcherState.
1148                            PollResponse::Upgrade(req) => {
1149                                let upgrade = inner.upgrade(req);
1150                                self.as_mut()
1151                                    .project()
1152                                    .inner
1153                                    .set(DispatcherState::Upgrade { fut: upgrade });
1154                                return self.poll(cx);
1155                            }
1156                        };
1157
1158                        // we didn't get WouldBlock from write operation, so data get written to
1159                        // kernel completely (macOS) and we have to write again otherwise response
1160                        // can get stuck
1161                        //
1162                        // TODO: want to find a reference for this behavior
1163                        // see introduced commit: 3872d3ba
1164                        let flush_was_ready = inner.as_mut().poll_flush(cx)?.is_ready();
1165
1166                        // this assert seems to always be true but not willing to commit to it until
1167                        // we understand what Nikolay meant when writing the above comment
1168                        // debug_assert!(flush_was_ready);
1169
1170                        if !flush_was_ready || !drain {
1171                            break;
1172                        }
1173                    }
1174
1175                    // client is gone
1176                    if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1177                        trace!("client is gone; disconnecting");
1178                        return Poll::Ready(Ok(()));
1179                    }
1180
1181                    let inner_p = inner.as_mut().project();
1182                    let state_is_none = inner_p.state.is_none();
1183
1184                    // read half is closed; we do not process any responses
1185                    if inner_p.flags.contains(Flags::READ_DISCONNECT) && state_is_none {
1186                        trace!("read half closed; start shutdown");
1187                        inner_p.flags.insert(Flags::SHUTDOWN);
1188                    }
1189
1190                    // keep-alive and stream errors
1191                    if state_is_none && inner_p.write_buf.is_empty() {
1192                        if let Some(err) = inner_p.error.take() {
1193                            error!("stream error: {}", &err);
1194                            return Poll::Ready(Err(err));
1195                        }
1196
1197                        // disconnect if keep-alive is not enabled
1198                        if inner_p.flags.contains(Flags::FINISHED)
1199                            && !inner_p.flags.contains(Flags::KEEP_ALIVE)
1200                        {
1201                            inner_p.flags.remove(Flags::FINISHED);
1202                            inner_p.flags.insert(Flags::SHUTDOWN);
1203                            return self.poll(cx);
1204                        }
1205
1206                        // disconnect if shutdown
1207                        if inner_p.flags.contains(Flags::SHUTDOWN) {
1208                            return self.poll(cx);
1209                        }
1210                    }
1211
1212                    trace_timer_states(
1213                        "end",
1214                        inner_p.head_timer,
1215                        inner_p.ka_timer,
1216                        inner_p.shutdown_timer,
1217                    );
1218
1219                    Poll::Pending
1220                };
1221
1222                trace!("end flags: {:?}", &inner.flags);
1223
1224                poll
1225            }
1226        }
1227    }
1228}
1229
1230#[allow(dead_code)]
1231fn trace_timer_states(
1232    label: &str,
1233    head_timer: &TimerState,
1234    ka_timer: &TimerState,
1235    shutdown_timer: &TimerState,
1236) {
1237    trace!("{} timers:", label);
1238
1239    if head_timer.is_enabled() {
1240        trace!("  head {}", &head_timer);
1241    }
1242
1243    if ka_timer.is_enabled() {
1244        trace!("  keep-alive {}", &ka_timer);
1245    }
1246
1247    if shutdown_timer.is_enabled() {
1248        trace!("  shutdown {}", &shutdown_timer);
1249    }
1250}