1use std::{
2 collections::VecDeque,
3 fmt,
4 future::Future,
5 io, mem, net,
6 pin::Pin,
7 rc::Rc,
8 task::{Context, Poll},
9};
10
11use actix_codec::{Framed, FramedParts};
12use actix_rt::time::sleep_until;
13use actix_service::Service;
14use bitflags::bitflags;
15use bytes::{Buf, BytesMut};
16use futures_core::ready;
17use pin_project_lite::pin_project;
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::{Decoder as _, Encoder as _};
20use tracing::{error, trace};
21
22use super::{
23 codec::Codec,
24 decoder::MAX_BUFFER_SIZE,
25 payload::{Payload, PayloadSender, PayloadStatus},
26 timer::TimerState,
27 Message, MessageType,
28};
29use crate::{
30 body::{BodySize, BoxBody, MessageBody},
31 config::ServiceConfig,
32 error::{DispatchError, ParseError, PayloadError},
33 service::HttpFlow,
34 Error, Extensions, OnConnectData, Request, Response, StatusCode,
35};
36
37const LW_BUFFER_SIZE: usize = 1024;
38const HW_BUFFER_SIZE: usize = 1024 * 8;
39const MAX_PIPELINED_MESSAGES: usize = 16;
40
41bitflags! {
42 #[derive(Debug, Clone, Copy)]
43 pub struct Flags: u8 {
44 const STARTED = 0b0000_0001;
46
47 const FINISHED = 0b0000_0010;
49
50 const KEEP_ALIVE = 0b0000_0100;
52
53 const SHUTDOWN = 0b0000_1000;
55
56 const READ_DISCONNECT = 0b0001_0000;
58
59 const WRITE_DISCONNECT = 0b0010_0000;
61 }
62}
63
64#[cfg(not(test))]
70pin_project! {
71 pub struct Dispatcher<T, S, B, X, U>
73 where
74 S: Service<Request>,
75 S::Error: Into<Response<BoxBody>>,
76
77 B: MessageBody,
78
79 X: Service<Request, Response = Request>,
80 X::Error: Into<Response<BoxBody>>,
81
82 U: Service<(Request, Framed<T, Codec>), Response = ()>,
83 U::Error: fmt::Display,
84 {
85 #[pin]
86 inner: DispatcherState<T, S, B, X, U>,
87 }
88}
89
90#[cfg(test)]
91pin_project! {
92 pub struct Dispatcher<T, S, B, X, U>
94 where
95 S: Service<Request>,
96 S::Error: Into<Response<BoxBody>>,
97
98 B: MessageBody,
99
100 X: Service<Request, Response = Request>,
101 X::Error: Into<Response<BoxBody>>,
102
103 U: Service<(Request, Framed<T, Codec>), Response = ()>,
104 U::Error: fmt::Display,
105 {
106 #[pin]
107 pub(super) inner: DispatcherState<T, S, B, X, U>,
108
109 pub(super) poll_count: u64,
111 }
112}
113
114pin_project! {
115 #[project = DispatcherStateProj]
116 pub(super) enum DispatcherState<T, S, B, X, U>
117 where
118 S: Service<Request>,
119 S::Error: Into<Response<BoxBody>>,
120
121 B: MessageBody,
122
123 X: Service<Request, Response = Request>,
124 X::Error: Into<Response<BoxBody>>,
125
126 U: Service<(Request, Framed<T, Codec>), Response = ()>,
127 U::Error: fmt::Display,
128 {
129 Normal { #[pin] inner: InnerDispatcher<T, S, B, X, U> },
130 Upgrade { #[pin] fut: U::Future },
131 }
132}
133
134pin_project! {
135 #[project = InnerDispatcherProj]
136 pub(super) struct InnerDispatcher<T, S, B, X, U>
137 where
138 S: Service<Request>,
139 S::Error: Into<Response<BoxBody>>,
140
141 B: MessageBody,
142
143 X: Service<Request, Response = Request>,
144 X::Error: Into<Response<BoxBody>>,
145
146 U: Service<(Request, Framed<T, Codec>), Response = ()>,
147 U::Error: fmt::Display,
148 {
149 flow: Rc<HttpFlow<S, X, U>>,
150 pub(super) flags: Flags,
151 peer_addr: Option<net::SocketAddr>,
152 conn_data: Option<Rc<Extensions>>,
153 config: ServiceConfig,
154 error: Option<DispatchError>,
155
156 #[pin]
157 pub(super) state: State<S, B, X>,
158 payload: Option<PayloadSender>,
160 messages: VecDeque<DispatcherMessage>,
161
162 head_timer: TimerState,
163 ka_timer: TimerState,
164 shutdown_timer: TimerState,
165
166 pub(super) io: Option<T>,
167 read_buf: BytesMut,
168 write_buf: BytesMut,
169 codec: Codec,
170 }
171}
172
173enum DispatcherMessage {
174 Item(Request),
175 Upgrade(Request),
176 Error(Response<()>),
177}
178
179pin_project! {
180 #[project = StateProj]
181 pub(super) enum State<S, B, X>
182 where
183 S: Service<Request>,
184 X: Service<Request, Response = Request>,
185 B: MessageBody,
186 {
187 None,
188 ExpectCall { #[pin] fut: X::Future },
189 ServiceCall { #[pin] fut: S::Future },
190 SendPayload { #[pin] body: B },
191 SendErrorPayload { #[pin] body: BoxBody },
192 }
193}
194
195impl<S, B, X> State<S, B, X>
196where
197 S: Service<Request>,
198 X: Service<Request, Response = Request>,
199 B: MessageBody,
200{
201 pub(super) fn is_none(&self) -> bool {
202 matches!(self, State::None)
203 }
204}
205
206impl<S, B, X> fmt::Debug for State<S, B, X>
207where
208 S: Service<Request>,
209 X: Service<Request, Response = Request>,
210 B: MessageBody,
211{
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 match self {
214 Self::None => write!(f, "State::None"),
215 Self::ExpectCall { .. } => f.debug_struct("State::ExpectCall").finish_non_exhaustive(),
216 Self::ServiceCall { .. } => {
217 f.debug_struct("State::ServiceCall").finish_non_exhaustive()
218 }
219 Self::SendPayload { .. } => {
220 f.debug_struct("State::SendPayload").finish_non_exhaustive()
221 }
222 Self::SendErrorPayload { .. } => f
223 .debug_struct("State::SendErrorPayload")
224 .finish_non_exhaustive(),
225 }
226 }
227}
228
229#[derive(Debug)]
230enum PollResponse {
231 Upgrade(Request),
232 DoNothing,
233 DrainWriteBuf,
234}
235
236impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
237where
238 T: AsyncRead + AsyncWrite + Unpin,
239
240 S: Service<Request>,
241 S::Error: Into<Response<BoxBody>>,
242 S::Response: Into<Response<B>>,
243
244 B: MessageBody,
245
246 X: Service<Request, Response = Request>,
247 X::Error: Into<Response<BoxBody>>,
248
249 U: Service<(Request, Framed<T, Codec>), Response = ()>,
250 U::Error: fmt::Display,
251{
252 pub(crate) fn new(
254 io: T,
255 flow: Rc<HttpFlow<S, X, U>>,
256 config: ServiceConfig,
257 peer_addr: Option<net::SocketAddr>,
258 conn_data: OnConnectData,
259 ) -> Self {
260 Dispatcher {
261 inner: DispatcherState::Normal {
262 inner: InnerDispatcher {
263 flow,
264 flags: Flags::empty(),
265 peer_addr,
266 conn_data: conn_data.0.map(Rc::new),
267 config: config.clone(),
268 error: None,
269
270 state: State::None,
271 payload: None,
272 messages: VecDeque::new(),
273
274 head_timer: TimerState::new(config.client_request_deadline().is_some()),
275 ka_timer: TimerState::new(config.keep_alive().enabled()),
276 shutdown_timer: TimerState::new(config.client_disconnect_deadline().is_some()),
277
278 io: Some(io),
279 read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
280 write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
281 codec: Codec::new(config),
282 },
283 },
284
285 #[cfg(test)]
286 poll_count: 0,
287 }
288 }
289}
290
291impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
292where
293 T: AsyncRead + AsyncWrite + Unpin,
294
295 S: Service<Request>,
296 S::Error: Into<Response<BoxBody>>,
297 S::Response: Into<Response<B>>,
298
299 B: MessageBody,
300
301 X: Service<Request, Response = Request>,
302 X::Error: Into<Response<BoxBody>>,
303
304 U: Service<(Request, Framed<T, Codec>), Response = ()>,
305 U::Error: fmt::Display,
306{
307 fn can_read(&self, cx: &mut Context<'_>) -> bool {
308 if self.flags.contains(Flags::READ_DISCONNECT) {
309 false
310 } else if let Some(ref info) = self.payload {
311 info.need_read(cx) == PayloadStatus::Read
312 } else {
313 true
314 }
315 }
316
317 fn client_disconnected(self: Pin<&mut Self>) {
318 let this = self.project();
319
320 this.flags
321 .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
322
323 if let Some(mut payload) = this.payload.take() {
324 payload.set_error(PayloadError::Incomplete(None));
325 }
326 }
327
328 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
329 let InnerDispatcherProj { io, write_buf, .. } = self.project();
330 let mut io = Pin::new(io.as_mut().unwrap());
331
332 let len = write_buf.len();
333 let mut written = 0;
334
335 while written < len {
336 match io.as_mut().poll_write(cx, &write_buf[written..])? {
337 Poll::Ready(0) => {
338 error!("write zero; closing");
339 return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, "")));
340 }
341
342 Poll::Ready(n) => written += n,
343
344 Poll::Pending => {
345 write_buf.advance(written);
346 return Poll::Pending;
347 }
348 }
349 }
350
351 write_buf.clear();
353
354 io.poll_flush(cx)
356 }
357
358 fn send_response_inner(
359 self: Pin<&mut Self>,
360 res: Response<()>,
361 body: &impl MessageBody,
362 ) -> Result<BodySize, DispatchError> {
363 let this = self.project();
364
365 let size = body.size();
366
367 this.codec
368 .encode(Message::Item((res, size)), this.write_buf)
369 .map_err(|err| {
370 if let Some(mut payload) = this.payload.take() {
371 payload.set_error(PayloadError::Incomplete(None));
372 }
373
374 DispatchError::Io(err)
375 })?;
376
377 Ok(size)
378 }
379
380 fn send_response(
381 mut self: Pin<&mut Self>,
382 res: Response<()>,
383 body: B,
384 ) -> Result<(), DispatchError> {
385 let size = self.as_mut().send_response_inner(res, &body)?;
386 let mut this = self.project();
387 this.state.set(match size {
388 BodySize::None | BodySize::Sized(0) => {
389 this.flags.insert(Flags::FINISHED);
390 State::None
391 }
392 _ => State::SendPayload { body },
393 });
394
395 Ok(())
396 }
397
398 fn send_error_response(
399 mut self: Pin<&mut Self>,
400 res: Response<()>,
401 body: BoxBody,
402 ) -> Result<(), DispatchError> {
403 let size = self.as_mut().send_response_inner(res, &body)?;
404 let mut this = self.project();
405 this.state.set(match size {
406 BodySize::None | BodySize::Sized(0) => {
407 this.flags.insert(Flags::FINISHED);
408 State::None
409 }
410 _ => State::SendErrorPayload { body },
411 });
412
413 Ok(())
414 }
415
416 fn send_continue(self: Pin<&mut Self>) {
417 self.project()
418 .write_buf
419 .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
420 }
421
422 fn poll_response(
423 mut self: Pin<&mut Self>,
424 cx: &mut Context<'_>,
425 ) -> Result<PollResponse, DispatchError> {
426 'res: loop {
427 let mut this = self.as_mut().project();
428 match this.state.as_mut().project() {
429 StateProj::None => match this.messages.pop_front() {
431 Some(DispatcherMessage::Item(req)) => {
433 if req.head().expect() {
435 let fut = this.flow.expect.call(req);
437 this.state.set(State::ExpectCall { fut });
438 } else {
439 let fut = this.flow.service.call(req);
441 this.state.set(State::ServiceCall { fut });
442 };
443 }
444
445 Some(DispatcherMessage::Error(res)) => {
447 self.as_mut().send_error_response(res, BoxBody::new(()))?;
451 }
452
453 Some(DispatcherMessage::Upgrade(req)) => return Ok(PollResponse::Upgrade(req)),
455
456 None => {
458 this.flags.set(Flags::KEEP_ALIVE, this.codec.keep_alive());
460
461 return Ok(PollResponse::DoNothing);
462 }
463 },
464
465 StateProj::ServiceCall { fut } => {
466 match fut.poll(cx) {
467 Poll::Ready(Ok(res)) => {
469 let (res, body) = res.into().replace_body(());
470 self.as_mut().send_response(res, body)?;
471 }
472
473 Poll::Ready(Err(err)) => {
475 let res: Response<BoxBody> = err.into();
476 let (res, body) = res.replace_body(());
477 self.as_mut().send_error_response(res, body)?;
478 }
479
480 Poll::Pending => {
483 if !self.as_mut().poll_request(cx)? {
486 return Ok(PollResponse::DoNothing);
487 }
488 }
490 }
491 }
492
493 StateProj::SendPayload { mut body } => {
494 while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
497 match body.as_mut().poll_next(cx) {
498 Poll::Ready(Some(Ok(item))) => {
499 this.codec
500 .encode(Message::Chunk(Some(item)), this.write_buf)?;
501 }
502
503 Poll::Ready(None) => {
504 this.codec.encode(Message::Chunk(None), this.write_buf)?;
505
506 this.state.set(State::None);
509 this.flags.insert(Flags::FINISHED);
510
511 continue 'res;
512 }
513
514 Poll::Ready(Some(Err(err))) => {
515 let err = err.into();
516 tracing::error!("Response payload stream error: {err:?}");
517 this.flags.insert(Flags::FINISHED);
518 return Err(DispatchError::Body(err));
519 }
520
521 Poll::Pending => return Ok(PollResponse::DoNothing),
522 }
523 }
524
525 return Ok(PollResponse::DrainWriteBuf);
528 }
529
530 StateProj::SendErrorPayload { mut body } => {
531 while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
536 match body.as_mut().poll_next(cx) {
537 Poll::Ready(Some(Ok(item))) => {
538 this.codec
539 .encode(Message::Chunk(Some(item)), this.write_buf)?;
540 }
541
542 Poll::Ready(None) => {
543 this.codec.encode(Message::Chunk(None), this.write_buf)?;
544
545 this.state.set(State::None);
548 this.flags.insert(Flags::FINISHED);
549
550 continue 'res;
551 }
552
553 Poll::Ready(Some(Err(err))) => {
554 tracing::error!("Response payload stream error: {err:?}");
555 this.flags.insert(Flags::FINISHED);
556 return Err(DispatchError::Body(
557 Error::new_body().with_cause(err).into(),
558 ));
559 }
560
561 Poll::Pending => return Ok(PollResponse::DoNothing),
562 }
563 }
564
565 return Ok(PollResponse::DrainWriteBuf);
568 }
569
570 StateProj::ExpectCall { fut } => {
571 trace!(" calling expect service");
572
573 match fut.poll(cx) {
574 Poll::Ready(Ok(req)) => {
577 this.write_buf
578 .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
579 let fut = this.flow.service.call(req);
580 this.state.set(State::ServiceCall { fut });
581 }
582
583 Poll::Ready(Err(err)) => {
585 let res: Response<BoxBody> = err.into();
586 let (res, body) = res.replace_body(());
587 self.as_mut().send_error_response(res, body)?;
588 }
589
590 Poll::Pending => return Ok(PollResponse::DoNothing),
592 }
593 }
594 }
595 }
596 }
597
598 fn handle_request(
599 mut self: Pin<&mut Self>,
600 req: Request,
601 cx: &mut Context<'_>,
602 ) -> Result<(), DispatchError> {
603 {
605 let mut this = self.as_mut().project();
606
607 if req.head().expect() {
609 let fut = this.flow.expect.call(req);
611 this.state.set(State::ExpectCall { fut });
612 } else {
613 let fut = this.flow.service.call(req);
615 this.state.set(State::ServiceCall { fut });
616 };
617 };
618
619 loop {
621 match self.as_mut().project().state.project() {
622 StateProj::ExpectCall { fut } => {
623 match fut.poll(cx) {
624 Poll::Ready(Ok(req)) => {
626 self.as_mut().send_continue();
627
628 let mut this = self.as_mut().project();
629 let fut = this.flow.service.call(req);
630 this.state.set(State::ServiceCall { fut });
631
632 continue;
633 }
634
635 Poll::Ready(Err(err)) => {
639 let res: Response<BoxBody> = err.into();
640 let (res, body) = res.replace_body(());
641 return self.send_error_response(res, body);
642 }
643
644 Poll::Pending => return Ok(()),
647 }
648 }
649
650 StateProj::ServiceCall { fut } => {
651 return match fut.poll(cx) {
653 Poll::Ready(Ok(res)) => {
657 let (res, body) = res.into().replace_body(());
658 self.as_mut().send_response(res, body)
659 }
660
661 Poll::Pending => Ok(()),
663
664 Poll::Ready(Err(err)) => {
666 let res: Response<BoxBody> = err.into();
667 let (res, body) = res.replace_body(());
668 self.as_mut().send_error_response(res, body)
669 }
670 };
671 }
672
673 _ => {
674 unreachable!("State must be set to ServiceCall or ExceptCall in handle_request")
675 }
676 }
677 }
678 }
679
680 fn poll_request(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
684 let pipeline_queue_full = self.messages.len() >= MAX_PIPELINED_MESSAGES;
685 let can_not_read = !self.can_read(cx);
686
687 if pipeline_queue_full || can_not_read {
689 return Ok(false);
690 }
691
692 let mut this = self.as_mut().project();
693
694 let mut updated = false;
695
696 loop {
698 match this.codec.decode(this.read_buf) {
699 Ok(Some(msg)) => {
700 updated = true;
701
702 match msg {
703 Message::Item(mut req) => {
704 this.head_timer.clear(line!());
706
707 req.head_mut().peer_addr = *this.peer_addr;
708
709 req.conn_data.clone_from(this.conn_data);
710
711 match this.codec.message_type() {
712 MessageType::None => {}
714
715 MessageType::Stream if this.flow.upgrade.is_some() => {
719 this.messages.push_back(DispatcherMessage::Upgrade(req));
720 break;
721 }
722
723 MessageType::Payload | MessageType::Stream => {
725 let (sender, payload) = Payload::create(false);
731 *req.payload() = crate::Payload::H1 { payload };
732 *this.payload = Some(sender);
733 }
734 }
735
736 if this.state.is_none() {
738 self.as_mut().handle_request(req, cx)?;
739 this = self.as_mut().project();
740 } else {
741 this.messages.push_back(DispatcherMessage::Item(req));
742 }
743 }
744
745 Message::Chunk(Some(chunk)) => {
746 if let Some(ref mut payload) = this.payload {
747 payload.feed_data(chunk);
748 } else {
749 error!("Internal server error: unexpected payload chunk");
750 this.flags.insert(Flags::READ_DISCONNECT);
751 this.messages.push_back(DispatcherMessage::Error(
752 Response::internal_server_error().drop_body(),
753 ));
754 *this.error = Some(DispatchError::InternalError);
755 break;
756 }
757 }
758
759 Message::Chunk(None) => {
760 if let Some(mut payload) = this.payload.take() {
761 payload.feed_eof();
762 } else {
763 error!("Internal server error: unexpected eof");
764 this.flags.insert(Flags::READ_DISCONNECT);
765 this.messages.push_back(DispatcherMessage::Error(
766 Response::internal_server_error().drop_body(),
767 ));
768 *this.error = Some(DispatchError::InternalError);
769 break;
770 }
771 }
772 }
773 }
774
775 Ok(None) => break,
778
779 Err(ParseError::Io(err)) => {
780 trace!("I/O error: {}", &err);
781 self.as_mut().client_disconnected();
782 this = self.as_mut().project();
783 *this.error = Some(DispatchError::Io(err));
784 break;
785 }
786
787 Err(ParseError::TooLarge) => {
788 trace!("request head was too big; returning 431 response");
789
790 if let Some(mut payload) = this.payload.take() {
791 payload.set_error(PayloadError::Overflow);
792 }
793
794 this.messages
796 .push_back(DispatcherMessage::Error(Response::with_body(
797 StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
798 (),
799 )));
800
801 this.flags.insert(Flags::READ_DISCONNECT);
802 *this.error = Some(ParseError::TooLarge.into());
803
804 break;
805 }
806
807 Err(err) => {
808 trace!("parse error {}", &err);
809
810 if let Some(mut payload) = this.payload.take() {
811 payload.set_error(PayloadError::EncodingCorrupted);
812 }
813
814 this.messages.push_back(DispatcherMessage::Error(
816 Response::bad_request().drop_body(),
817 ));
818
819 this.flags.insert(Flags::READ_DISCONNECT);
820 *this.error = Some(err.into());
821 break;
822 }
823 }
824 }
825
826 Ok(updated)
827 }
828
829 fn poll_head_timer(
830 mut self: Pin<&mut Self>,
831 cx: &mut Context<'_>,
832 ) -> Result<(), DispatchError> {
833 let this = self.as_mut().project();
834
835 if let TimerState::Active { timer } = this.head_timer {
836 if timer.as_mut().poll(cx).is_ready() {
837 trace!("timed out on slow request; replying with 408 and closing connection");
840
841 let _ = self.as_mut().send_error_response(
842 Response::with_body(StatusCode::REQUEST_TIMEOUT, ()),
843 BoxBody::new(()),
844 );
845
846 self.project().flags.insert(Flags::SHUTDOWN);
847 }
848 };
849
850 Ok(())
851 }
852
853 fn poll_ka_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
854 let this = self.as_mut().project();
855 if let TimerState::Active { timer } = this.ka_timer {
856 debug_assert!(
857 this.flags.contains(Flags::KEEP_ALIVE),
858 "keep-alive flag should be set when timer is active",
859 );
860 debug_assert!(
861 this.state.is_none(),
862 "dispatcher should not be in keep-alive phase if state is not none: {:?}",
863 this.state,
864 );
865
866 if timer.as_mut().poll(cx).is_ready() {
878 trace!("timer timed out; closing connection");
880 this.flags.insert(Flags::SHUTDOWN);
881
882 if let Some(deadline) = this.config.client_disconnect_deadline() {
883 this.shutdown_timer
885 .set_and_init(cx, sleep_until(deadline.into()), line!());
886 } else {
887 this.flags.insert(Flags::WRITE_DISCONNECT);
889 }
890 }
891 }
892
893 Ok(())
894 }
895
896 fn poll_shutdown_timer(
897 mut self: Pin<&mut Self>,
898 cx: &mut Context<'_>,
899 ) -> Result<(), DispatchError> {
900 let this = self.as_mut().project();
901 if let TimerState::Active { timer } = this.shutdown_timer {
902 debug_assert!(
903 this.flags.contains(Flags::SHUTDOWN),
904 "shutdown flag should be set when timer is active",
905 );
906
907 if timer.as_mut().poll(cx).is_ready() {
909 trace!("timed-out during shutdown");
910 return Err(DispatchError::DisconnectTimeout);
911 }
912 }
913
914 Ok(())
915 }
916
917 fn poll_timers(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
919 self.as_mut().poll_head_timer(cx)?;
920 self.as_mut().poll_ka_timer(cx)?;
921 self.as_mut().poll_shutdown_timer(cx)?;
922
923 Ok(())
924 }
925
926 #[inline(always)] fn read_available(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
933 let this = self.project();
934
935 if this.flags.contains(Flags::READ_DISCONNECT) {
936 return Ok(false);
937 };
938
939 let mut io = Pin::new(this.io.as_mut().unwrap());
940
941 let mut read_some = false;
942
943 loop {
944 if this.read_buf.len() >= MAX_BUFFER_SIZE {
946 match this.payload {
965 Some(ref p) if p.need_read(cx) != PayloadStatus::Read => {}
980 _ => cx.waker().wake_by_ref(),
981 }
982
983 return Ok(false);
984 }
985
986 let remaining = this.read_buf.capacity() - this.read_buf.len();
988 if remaining < LW_BUFFER_SIZE {
989 this.read_buf.reserve(HW_BUFFER_SIZE - remaining);
990 }
991
992 match tokio_util::io::poll_read_buf(io.as_mut(), cx, this.read_buf) {
993 Poll::Ready(Ok(n)) => {
994 this.flags.remove(Flags::FINISHED);
995
996 if n == 0 {
997 return Ok(true);
998 }
999
1000 read_some = true;
1001 }
1002
1003 Poll::Pending => {
1004 return Ok(false);
1005 }
1006
1007 Poll::Ready(Err(err)) => {
1008 return match err.kind() {
1009 io::ErrorKind::WouldBlock => Ok(false),
1011
1012 io::ErrorKind::ConnectionReset if read_some => Ok(true),
1014
1015 _ => Err(DispatchError::Io(err)),
1016 };
1017 }
1018 }
1019 }
1020 }
1021
1022 fn upgrade(self: Pin<&mut Self>, req: Request) -> U::Future {
1024 let this = self.project();
1025 let mut parts = FramedParts::with_read_buf(
1026 this.io.take().unwrap(),
1027 mem::take(this.codec),
1028 mem::take(this.read_buf),
1029 );
1030 parts.write_buf = mem::take(this.write_buf);
1031 let framed = Framed::from_parts(parts);
1032 this.flow.upgrade.as_ref().unwrap().call((req, framed))
1033 }
1034}
1035
1036impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
1037where
1038 T: AsyncRead + AsyncWrite + Unpin,
1039
1040 S: Service<Request>,
1041 S::Error: Into<Response<BoxBody>>,
1042 S::Response: Into<Response<B>>,
1043
1044 B: MessageBody,
1045
1046 X: Service<Request, Response = Request>,
1047 X::Error: Into<Response<BoxBody>>,
1048
1049 U: Service<(Request, Framed<T, Codec>), Response = ()>,
1050 U::Error: fmt::Display,
1051{
1052 type Output = Result<(), DispatchError>;
1053
1054 #[inline]
1055 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1056 let this = self.as_mut().project();
1057
1058 #[cfg(test)]
1059 {
1060 *this.poll_count += 1;
1061 }
1062
1063 match this.inner.project() {
1064 DispatcherStateProj::Upgrade { fut: upgrade } => upgrade.poll(cx).map_err(|err| {
1065 error!("Upgrade handler error: {}", err);
1066 DispatchError::Upgrade
1067 }),
1068
1069 DispatcherStateProj::Normal { mut inner } => {
1070 trace!("start flags: {:?}", &inner.flags);
1071
1072 trace_timer_states(
1073 "start",
1074 &inner.head_timer,
1075 &inner.ka_timer,
1076 &inner.shutdown_timer,
1077 );
1078
1079 inner.as_mut().poll_timers(cx)?;
1080
1081 let poll = if inner.flags.contains(Flags::SHUTDOWN) {
1082 if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1083 Poll::Ready(Ok(()))
1084 } else {
1085 ready!(inner.as_mut().poll_flush(cx))?;
1087 Pin::new(inner.as_mut().project().io.as_mut().unwrap())
1088 .poll_shutdown(cx)
1089 .map_err(DispatchError::from)
1090 }
1091 } else {
1092 let should_disconnect = inner.as_mut().read_available(cx)?;
1094
1095 if !inner.read_buf.is_empty() && inner.flags.contains(Flags::KEEP_ALIVE) {
1097 let inner = inner.as_mut().project();
1098 inner.flags.remove(Flags::KEEP_ALIVE);
1099 inner.ka_timer.clear(line!());
1100 }
1101
1102 if !inner.flags.contains(Flags::STARTED) {
1103 inner.as_mut().project().flags.insert(Flags::STARTED);
1104
1105 if let Some(deadline) = inner.config.client_request_deadline() {
1106 inner.as_mut().project().head_timer.set_and_init(
1107 cx,
1108 sleep_until(deadline.into()),
1109 line!(),
1110 );
1111 }
1112 }
1113
1114 inner.as_mut().poll_request(cx)?;
1115
1116 if should_disconnect {
1117 let inner = inner.as_mut().project();
1119 inner.flags.insert(Flags::READ_DISCONNECT);
1120 if let Some(mut payload) = inner.payload.take() {
1121 payload.feed_eof();
1122 }
1123 };
1124
1125 loop {
1126 let drain = match inner.as_mut().poll_response(cx)? {
1129 PollResponse::DrainWriteBuf => true,
1130
1131 PollResponse::DoNothing => {
1132 if inner.flags.contains(Flags::KEEP_ALIVE | Flags::FINISHED) {
1135 if let Some(timer) = inner.config.keep_alive_deadline() {
1136 inner.as_mut().project().ka_timer.set_and_init(
1137 cx,
1138 sleep_until(timer.into()),
1139 line!(),
1140 );
1141 }
1142 }
1143
1144 false
1145 }
1146
1147 PollResponse::Upgrade(req) => {
1149 let upgrade = inner.upgrade(req);
1150 self.as_mut()
1151 .project()
1152 .inner
1153 .set(DispatcherState::Upgrade { fut: upgrade });
1154 return self.poll(cx);
1155 }
1156 };
1157
1158 let flush_was_ready = inner.as_mut().poll_flush(cx)?.is_ready();
1165
1166 if !flush_was_ready || !drain {
1171 break;
1172 }
1173 }
1174
1175 if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1177 trace!("client is gone; disconnecting");
1178 return Poll::Ready(Ok(()));
1179 }
1180
1181 let inner_p = inner.as_mut().project();
1182 let state_is_none = inner_p.state.is_none();
1183
1184 if inner_p.flags.contains(Flags::READ_DISCONNECT) && state_is_none {
1186 trace!("read half closed; start shutdown");
1187 inner_p.flags.insert(Flags::SHUTDOWN);
1188 }
1189
1190 if state_is_none && inner_p.write_buf.is_empty() {
1192 if let Some(err) = inner_p.error.take() {
1193 error!("stream error: {}", &err);
1194 return Poll::Ready(Err(err));
1195 }
1196
1197 if inner_p.flags.contains(Flags::FINISHED)
1199 && !inner_p.flags.contains(Flags::KEEP_ALIVE)
1200 {
1201 inner_p.flags.remove(Flags::FINISHED);
1202 inner_p.flags.insert(Flags::SHUTDOWN);
1203 return self.poll(cx);
1204 }
1205
1206 if inner_p.flags.contains(Flags::SHUTDOWN) {
1208 return self.poll(cx);
1209 }
1210 }
1211
1212 trace_timer_states(
1213 "end",
1214 inner_p.head_timer,
1215 inner_p.ka_timer,
1216 inner_p.shutdown_timer,
1217 );
1218
1219 Poll::Pending
1220 };
1221
1222 trace!("end flags: {:?}", &inner.flags);
1223
1224 poll
1225 }
1226 }
1227 }
1228}
1229
1230#[allow(dead_code)]
1231fn trace_timer_states(
1232 label: &str,
1233 head_timer: &TimerState,
1234 ka_timer: &TimerState,
1235 shutdown_timer: &TimerState,
1236) {
1237 trace!("{} timers:", label);
1238
1239 if head_timer.is_enabled() {
1240 trace!(" head {}", &head_timer);
1241 }
1242
1243 if ka_timer.is_enabled() {
1244 trace!(" keep-alive {}", &ka_timer);
1245 }
1246
1247 if shutdown_timer.is_enabled() {
1248 trace!(" shutdown {}", &shutdown_timer);
1249 }
1250}