1use std::{fmt, io};
2
3use bitflags::bitflags;
4use bytes::BytesMut;
5use http::{Method, Version};
6use tokio_util::codec::{Decoder, Encoder};
7
8use super::{
9 decoder::{self, PayloadDecoder, PayloadItem, PayloadType},
10 encoder, Message, MessageType,
11};
12use crate::{body::BodySize, error::ParseError, ConnectionType, Request, Response, ServiceConfig};
13
14bitflags! {
15 #[derive(Debug, Clone, Copy)]
16 struct Flags: u8 {
17 const HEAD = 0b0000_0001;
18 const KEEP_ALIVE_ENABLED = 0b0000_0010;
19 const STREAM = 0b0000_0100;
20 }
21}
22
23pub struct Codec {
25 config: ServiceConfig,
26 decoder: decoder::MessageDecoder<Request>,
27 payload: Option<PayloadDecoder>,
28 version: Version,
29 conn_type: ConnectionType,
30
31 flags: Flags,
33 encoder: encoder::MessageEncoder<Response<()>>,
34}
35
36impl Default for Codec {
37 fn default() -> Self {
38 Codec::new(ServiceConfig::default())
39 }
40}
41
42impl fmt::Debug for Codec {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.debug_struct("h1::Codec")
45 .field("flags", &self.flags)
46 .finish_non_exhaustive()
47 }
48}
49
50impl Codec {
51 pub fn new(config: ServiceConfig) -> Self {
55 let flags = if config.keep_alive().enabled() {
56 Flags::KEEP_ALIVE_ENABLED
57 } else {
58 Flags::empty()
59 };
60
61 Codec {
62 config,
63 flags,
64 decoder: decoder::MessageDecoder::default(),
65 payload: None,
66 version: Version::HTTP_11,
67 conn_type: ConnectionType::Close,
68 encoder: encoder::MessageEncoder::default(),
69 }
70 }
71
72 #[inline]
74 pub fn upgrade(&self) -> bool {
75 self.conn_type == ConnectionType::Upgrade
76 }
77
78 #[inline]
80 pub fn keep_alive(&self) -> bool {
81 self.conn_type == ConnectionType::KeepAlive
82 }
83
84 #[inline]
86 pub fn keep_alive_enabled(&self) -> bool {
87 self.flags.contains(Flags::KEEP_ALIVE_ENABLED)
88 }
89
90 #[inline]
92 pub fn message_type(&self) -> MessageType {
93 if self.flags.contains(Flags::STREAM) {
94 MessageType::Stream
95 } else if self.payload.is_none() {
96 MessageType::None
97 } else {
98 MessageType::Payload
99 }
100 }
101
102 #[inline]
103 pub fn config(&self) -> &ServiceConfig {
104 &self.config
105 }
106}
107
108impl Decoder for Codec {
109 type Item = Message<Request>;
110 type Error = ParseError;
111
112 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
113 if let Some(ref mut payload) = self.payload {
114 Ok(match payload.decode(src)? {
115 Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
116 Some(PayloadItem::Eof) => {
117 self.payload.take();
118 Some(Message::Chunk(None))
119 }
120 None => None,
121 })
122 } else if let Some((req, payload)) = self.decoder.decode(src)? {
123 let head = req.head();
124 self.flags.set(Flags::HEAD, head.method == Method::HEAD);
125 self.version = head.version;
126 self.conn_type = head.connection_type();
127
128 if self.conn_type == ConnectionType::KeepAlive
129 && !self.flags.contains(Flags::KEEP_ALIVE_ENABLED)
130 {
131 self.conn_type = ConnectionType::Close
132 }
133
134 match payload {
135 PayloadType::None => self.payload = None,
136 PayloadType::Payload(pl) => self.payload = Some(pl),
137 PayloadType::Stream(pl) => {
138 self.payload = Some(pl);
139 self.flags.insert(Flags::STREAM);
140 }
141 }
142 Ok(Some(Message::Item(req)))
143 } else {
144 Ok(None)
145 }
146 }
147}
148
149impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
150 type Error = io::Error;
151
152 fn encode(
153 &mut self,
154 item: Message<(Response<()>, BodySize)>,
155 dst: &mut BytesMut,
156 ) -> Result<(), Self::Error> {
157 match item {
158 Message::Item((mut res, length)) => {
159 res.head_mut().version = self.version;
161
162 self.conn_type = if let Some(ct) = res.head().conn_type() {
164 if ct == ConnectionType::KeepAlive {
165 self.conn_type
166 } else {
167 ct
168 }
169 } else {
170 self.conn_type
171 };
172
173 self.encoder.encode(
175 dst,
176 &mut res,
177 self.flags.contains(Flags::HEAD),
178 self.flags.contains(Flags::STREAM),
179 self.version,
180 length,
181 self.conn_type,
182 &self.config,
183 )?;
184 }
185
186 Message::Chunk(Some(bytes)) => {
187 self.encoder.encode_chunk(bytes.as_ref(), dst)?;
188 }
189
190 Message::Chunk(None) => {
191 self.encoder.encode_eof(dst)?;
192 }
193 }
194
195 Ok(())
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::HttpMessage as _;
203
204 #[actix_rt::test]
205 async fn test_http_request_chunked_payload_and_next_message() {
206 let mut codec = Codec::default();
207
208 let mut buf = BytesMut::from(
209 "GET /test HTTP/1.1\r\n\
210 transfer-encoding: chunked\r\n\r\n",
211 );
212 let item = codec.decode(&mut buf).unwrap().unwrap();
213 let req = item.message();
214
215 assert_eq!(req.method(), Method::GET);
216 assert!(req.chunked().unwrap());
217
218 buf.extend(
219 b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\
220 POST /test2 HTTP/1.1\r\n\
221 transfer-encoding: chunked\r\n\r\n"
222 .iter(),
223 );
224
225 let msg = codec.decode(&mut buf).unwrap().unwrap();
226 assert_eq!(msg.chunk().as_ref(), b"data");
227
228 let msg = codec.decode(&mut buf).unwrap().unwrap();
229 assert_eq!(msg.chunk().as_ref(), b"line");
230
231 let msg = codec.decode(&mut buf).unwrap().unwrap();
232 assert!(msg.eof());
233
234 let item = codec.decode(&mut buf).unwrap().unwrap();
236 let req = item.message();
237 assert_eq!(*req.method(), Method::POST);
238 assert!(req.chunked().unwrap());
239 }
240}