actix_http/h1/
codec.rs

1use std::{fmt, io};
2
3use bitflags::bitflags;
4use bytes::BytesMut;
5use http::{Method, Version};
6use tokio_util::codec::{Decoder, Encoder};
7
8use super::{
9    decoder::{self, PayloadDecoder, PayloadItem, PayloadType},
10    encoder, Message, MessageType,
11};
12use crate::{body::BodySize, error::ParseError, ConnectionType, Request, Response, ServiceConfig};
13
14bitflags! {
15    #[derive(Debug, Clone, Copy)]
16    struct Flags: u8 {
17        const HEAD               = 0b0000_0001;
18        const KEEP_ALIVE_ENABLED = 0b0000_0010;
19        const STREAM             = 0b0000_0100;
20    }
21}
22
23/// HTTP/1 Codec
24pub struct Codec {
25    config: ServiceConfig,
26    decoder: decoder::MessageDecoder<Request>,
27    payload: Option<PayloadDecoder>,
28    version: Version,
29    conn_type: ConnectionType,
30
31    // encoder part
32    flags: Flags,
33    encoder: encoder::MessageEncoder<Response<()>>,
34}
35
36impl Default for Codec {
37    fn default() -> Self {
38        Codec::new(ServiceConfig::default())
39    }
40}
41
42impl fmt::Debug for Codec {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        f.debug_struct("h1::Codec")
45            .field("flags", &self.flags)
46            .finish_non_exhaustive()
47    }
48}
49
50impl Codec {
51    /// Create HTTP/1 codec.
52    ///
53    /// `keepalive_enabled` how response `connection` header get generated.
54    pub fn new(config: ServiceConfig) -> Self {
55        let flags = if config.keep_alive().enabled() {
56            Flags::KEEP_ALIVE_ENABLED
57        } else {
58            Flags::empty()
59        };
60
61        Codec {
62            config,
63            flags,
64            decoder: decoder::MessageDecoder::default(),
65            payload: None,
66            version: Version::HTTP_11,
67            conn_type: ConnectionType::Close,
68            encoder: encoder::MessageEncoder::default(),
69        }
70    }
71
72    /// Check if request is upgrade.
73    #[inline]
74    pub fn upgrade(&self) -> bool {
75        self.conn_type == ConnectionType::Upgrade
76    }
77
78    /// Check if last response is keep-alive.
79    #[inline]
80    pub fn keep_alive(&self) -> bool {
81        self.conn_type == ConnectionType::KeepAlive
82    }
83
84    /// Check if keep-alive enabled on server level.
85    #[inline]
86    pub fn keep_alive_enabled(&self) -> bool {
87        self.flags.contains(Flags::KEEP_ALIVE_ENABLED)
88    }
89
90    /// Check last request's message type.
91    #[inline]
92    pub fn message_type(&self) -> MessageType {
93        if self.flags.contains(Flags::STREAM) {
94            MessageType::Stream
95        } else if self.payload.is_none() {
96            MessageType::None
97        } else {
98            MessageType::Payload
99        }
100    }
101
102    #[inline]
103    pub fn config(&self) -> &ServiceConfig {
104        &self.config
105    }
106}
107
108impl Decoder for Codec {
109    type Item = Message<Request>;
110    type Error = ParseError;
111
112    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
113        if let Some(ref mut payload) = self.payload {
114            Ok(match payload.decode(src)? {
115                Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
116                Some(PayloadItem::Eof) => {
117                    self.payload.take();
118                    Some(Message::Chunk(None))
119                }
120                None => None,
121            })
122        } else if let Some((req, payload)) = self.decoder.decode(src)? {
123            let head = req.head();
124            self.flags.set(Flags::HEAD, head.method == Method::HEAD);
125            self.version = head.version;
126            self.conn_type = head.connection_type();
127
128            if self.conn_type == ConnectionType::KeepAlive
129                && !self.flags.contains(Flags::KEEP_ALIVE_ENABLED)
130            {
131                self.conn_type = ConnectionType::Close
132            }
133
134            match payload {
135                PayloadType::None => self.payload = None,
136                PayloadType::Payload(pl) => self.payload = Some(pl),
137                PayloadType::Stream(pl) => {
138                    self.payload = Some(pl);
139                    self.flags.insert(Flags::STREAM);
140                }
141            }
142            Ok(Some(Message::Item(req)))
143        } else {
144            Ok(None)
145        }
146    }
147}
148
149impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
150    type Error = io::Error;
151
152    fn encode(
153        &mut self,
154        item: Message<(Response<()>, BodySize)>,
155        dst: &mut BytesMut,
156    ) -> Result<(), Self::Error> {
157        match item {
158            Message::Item((mut res, length)) => {
159                // set response version
160                res.head_mut().version = self.version;
161
162                // connection status
163                self.conn_type = if let Some(ct) = res.head().conn_type() {
164                    if ct == ConnectionType::KeepAlive {
165                        self.conn_type
166                    } else {
167                        ct
168                    }
169                } else {
170                    self.conn_type
171                };
172
173                // encode message
174                self.encoder.encode(
175                    dst,
176                    &mut res,
177                    self.flags.contains(Flags::HEAD),
178                    self.flags.contains(Flags::STREAM),
179                    self.version,
180                    length,
181                    self.conn_type,
182                    &self.config,
183                )?;
184            }
185
186            Message::Chunk(Some(bytes)) => {
187                self.encoder.encode_chunk(bytes.as_ref(), dst)?;
188            }
189
190            Message::Chunk(None) => {
191                self.encoder.encode_eof(dst)?;
192            }
193        }
194
195        Ok(())
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::HttpMessage as _;
203
204    #[actix_rt::test]
205    async fn test_http_request_chunked_payload_and_next_message() {
206        let mut codec = Codec::default();
207
208        let mut buf = BytesMut::from(
209            "GET /test HTTP/1.1\r\n\
210             transfer-encoding: chunked\r\n\r\n",
211        );
212        let item = codec.decode(&mut buf).unwrap().unwrap();
213        let req = item.message();
214
215        assert_eq!(req.method(), Method::GET);
216        assert!(req.chunked().unwrap());
217
218        buf.extend(
219            b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\
220               POST /test2 HTTP/1.1\r\n\
221               transfer-encoding: chunked\r\n\r\n"
222                .iter(),
223        );
224
225        let msg = codec.decode(&mut buf).unwrap().unwrap();
226        assert_eq!(msg.chunk().as_ref(), b"data");
227
228        let msg = codec.decode(&mut buf).unwrap().unwrap();
229        assert_eq!(msg.chunk().as_ref(), b"line");
230
231        let msg = codec.decode(&mut buf).unwrap().unwrap();
232        assert!(msg.eof());
233
234        // decode next message
235        let item = codec.decode(&mut buf).unwrap().unwrap();
236        let req = item.message();
237        assert_eq!(*req.method(), Method::POST);
238        assert!(req.chunked().unwrap());
239    }
240}