actix_http/h1/
client.rs

1use std::{fmt, io};
2
3use bitflags::bitflags;
4use bytes::{Bytes, BytesMut};
5use http::{Method, Version};
6use tokio_util::codec::{Decoder, Encoder};
7
8use super::{
9    decoder::{self, PayloadDecoder, PayloadItem, PayloadType},
10    encoder, reserve_readbuf, Message, MessageType,
11};
12use crate::{
13    body::BodySize,
14    error::{ParseError, PayloadError},
15    ConnectionType, RequestHeadType, ResponseHead, ServiceConfig,
16};
17
18bitflags! {
19    #[derive(Debug, Clone, Copy)]
20    struct Flags: u8 {
21        const HEAD               = 0b0000_0001;
22        const KEEP_ALIVE_ENABLED = 0b0000_1000;
23        const STREAM             = 0b0001_0000;
24    }
25}
26
27/// HTTP/1 Codec
28pub struct ClientCodec {
29    inner: ClientCodecInner,
30}
31
32/// HTTP/1 Payload Codec
33pub struct ClientPayloadCodec {
34    inner: ClientCodecInner,
35}
36
37struct ClientCodecInner {
38    config: ServiceConfig,
39    decoder: decoder::MessageDecoder<ResponseHead>,
40    payload: Option<PayloadDecoder>,
41    version: Version,
42    conn_type: ConnectionType,
43
44    // encoder part
45    flags: Flags,
46    encoder: encoder::MessageEncoder<RequestHeadType>,
47}
48
49impl Default for ClientCodec {
50    fn default() -> Self {
51        ClientCodec::new(ServiceConfig::default())
52    }
53}
54
55impl fmt::Debug for ClientCodec {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("h1::ClientCodec")
58            .field("flags", &self.inner.flags)
59            .finish_non_exhaustive()
60    }
61}
62
63impl ClientCodec {
64    /// Create HTTP/1 codec.
65    ///
66    /// `keepalive_enabled` how response `connection` header get generated.
67    pub fn new(config: ServiceConfig) -> Self {
68        let flags = if config.keep_alive().enabled() {
69            Flags::KEEP_ALIVE_ENABLED
70        } else {
71            Flags::empty()
72        };
73
74        ClientCodec {
75            inner: ClientCodecInner {
76                config,
77                decoder: decoder::MessageDecoder::default(),
78                payload: None,
79                version: Version::HTTP_11,
80                conn_type: ConnectionType::Close,
81
82                flags,
83                encoder: encoder::MessageEncoder::default(),
84            },
85        }
86    }
87
88    /// Check if request is upgrade
89    pub fn upgrade(&self) -> bool {
90        self.inner.conn_type == ConnectionType::Upgrade
91    }
92
93    /// Check if last response is keep-alive
94    pub fn keep_alive(&self) -> bool {
95        self.inner.conn_type == ConnectionType::KeepAlive
96    }
97
98    /// Check last request's message type
99    pub fn message_type(&self) -> MessageType {
100        if self.inner.flags.contains(Flags::STREAM) {
101            MessageType::Stream
102        } else if self.inner.payload.is_none() {
103            MessageType::None
104        } else {
105            MessageType::Payload
106        }
107    }
108
109    /// Convert message codec to a payload codec
110    pub fn into_payload_codec(self) -> ClientPayloadCodec {
111        ClientPayloadCodec { inner: self.inner }
112    }
113}
114
115impl ClientPayloadCodec {
116    /// Check if last response is keep-alive
117    pub fn keep_alive(&self) -> bool {
118        self.inner.conn_type == ConnectionType::KeepAlive
119    }
120
121    /// Transform payload codec to a message codec
122    pub fn into_message_codec(self) -> ClientCodec {
123        ClientCodec { inner: self.inner }
124    }
125}
126
127impl Decoder for ClientCodec {
128    type Item = ResponseHead;
129    type Error = ParseError;
130
131    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
132        debug_assert!(
133            self.inner.payload.is_none(),
134            "Payload decoder should not be set"
135        );
136
137        if let Some((req, payload)) = self.inner.decoder.decode(src)? {
138            if let Some(conn_type) = req.conn_type() {
139                // do not use peer's keep-alive
140                self.inner.conn_type = if conn_type == ConnectionType::KeepAlive {
141                    self.inner.conn_type
142                } else {
143                    conn_type
144                };
145            }
146
147            if !self.inner.flags.contains(Flags::HEAD) {
148                match payload {
149                    PayloadType::None => self.inner.payload = None,
150                    PayloadType::Payload(pl) => self.inner.payload = Some(pl),
151                    PayloadType::Stream(pl) => {
152                        self.inner.payload = Some(pl);
153                        self.inner.flags.insert(Flags::STREAM);
154                    }
155                }
156            } else {
157                self.inner.payload = None;
158            }
159            reserve_readbuf(src);
160            Ok(Some(req))
161        } else {
162            Ok(None)
163        }
164    }
165}
166
167impl Decoder for ClientPayloadCodec {
168    type Item = Option<Bytes>;
169    type Error = PayloadError;
170
171    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
172        debug_assert!(
173            self.inner.payload.is_some(),
174            "Payload decoder is not specified"
175        );
176
177        Ok(match self.inner.payload.as_mut().unwrap().decode(src)? {
178            Some(PayloadItem::Chunk(chunk)) => {
179                reserve_readbuf(src);
180                Some(Some(chunk))
181            }
182            Some(PayloadItem::Eof) => {
183                self.inner.payload.take();
184                Some(None)
185            }
186            None => None,
187        })
188    }
189}
190
191impl Encoder<Message<(RequestHeadType, BodySize)>> for ClientCodec {
192    type Error = io::Error;
193
194    fn encode(
195        &mut self,
196        item: Message<(RequestHeadType, BodySize)>,
197        dst: &mut BytesMut,
198    ) -> Result<(), Self::Error> {
199        match item {
200            Message::Item((mut head, length)) => {
201                let inner = &mut self.inner;
202                inner.version = head.as_ref().version;
203                inner
204                    .flags
205                    .set(Flags::HEAD, head.as_ref().method == Method::HEAD);
206
207                // connection status
208                inner.conn_type = match head.as_ref().connection_type() {
209                    ConnectionType::KeepAlive => {
210                        if inner.flags.contains(Flags::KEEP_ALIVE_ENABLED) {
211                            ConnectionType::KeepAlive
212                        } else {
213                            ConnectionType::Close
214                        }
215                    }
216                    ConnectionType::Upgrade => ConnectionType::Upgrade,
217                    ConnectionType::Close => ConnectionType::Close,
218                };
219
220                inner.encoder.encode(
221                    dst,
222                    &mut head,
223                    false,
224                    false,
225                    inner.version,
226                    length,
227                    inner.conn_type,
228                    &inner.config,
229                )?;
230            }
231            Message::Chunk(Some(bytes)) => {
232                self.inner.encoder.encode_chunk(bytes.as_ref(), dst)?;
233            }
234            Message::Chunk(None) => {
235                self.inner.encoder.encode_eof(dst)?;
236            }
237        }
238        Ok(())
239    }
240}