1use std::{fmt, io};
2
3use bitflags::bitflags;
4use bytes::{Bytes, BytesMut};
5use http::{Method, Version};
6use tokio_util::codec::{Decoder, Encoder};
7
8use super::{
9 decoder::{self, PayloadDecoder, PayloadItem, PayloadType},
10 encoder, reserve_readbuf, Message, MessageType,
11};
12use crate::{
13 body::BodySize,
14 error::{ParseError, PayloadError},
15 ConnectionType, RequestHeadType, ResponseHead, ServiceConfig,
16};
17
18bitflags! {
19 #[derive(Debug, Clone, Copy)]
20 struct Flags: u8 {
21 const HEAD = 0b0000_0001;
22 const KEEP_ALIVE_ENABLED = 0b0000_1000;
23 const STREAM = 0b0001_0000;
24 }
25}
26
27pub struct ClientCodec {
29 inner: ClientCodecInner,
30}
31
32pub struct ClientPayloadCodec {
34 inner: ClientCodecInner,
35}
36
37struct ClientCodecInner {
38 config: ServiceConfig,
39 decoder: decoder::MessageDecoder<ResponseHead>,
40 payload: Option<PayloadDecoder>,
41 version: Version,
42 conn_type: ConnectionType,
43
44 flags: Flags,
46 encoder: encoder::MessageEncoder<RequestHeadType>,
47}
48
49impl Default for ClientCodec {
50 fn default() -> Self {
51 ClientCodec::new(ServiceConfig::default())
52 }
53}
54
55impl fmt::Debug for ClientCodec {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_struct("h1::ClientCodec")
58 .field("flags", &self.inner.flags)
59 .finish_non_exhaustive()
60 }
61}
62
63impl ClientCodec {
64 pub fn new(config: ServiceConfig) -> Self {
68 let flags = if config.keep_alive().enabled() {
69 Flags::KEEP_ALIVE_ENABLED
70 } else {
71 Flags::empty()
72 };
73
74 ClientCodec {
75 inner: ClientCodecInner {
76 config,
77 decoder: decoder::MessageDecoder::default(),
78 payload: None,
79 version: Version::HTTP_11,
80 conn_type: ConnectionType::Close,
81
82 flags,
83 encoder: encoder::MessageEncoder::default(),
84 },
85 }
86 }
87
88 pub fn upgrade(&self) -> bool {
90 self.inner.conn_type == ConnectionType::Upgrade
91 }
92
93 pub fn keep_alive(&self) -> bool {
95 self.inner.conn_type == ConnectionType::KeepAlive
96 }
97
98 pub fn message_type(&self) -> MessageType {
100 if self.inner.flags.contains(Flags::STREAM) {
101 MessageType::Stream
102 } else if self.inner.payload.is_none() {
103 MessageType::None
104 } else {
105 MessageType::Payload
106 }
107 }
108
109 pub fn into_payload_codec(self) -> ClientPayloadCodec {
111 ClientPayloadCodec { inner: self.inner }
112 }
113}
114
115impl ClientPayloadCodec {
116 pub fn keep_alive(&self) -> bool {
118 self.inner.conn_type == ConnectionType::KeepAlive
119 }
120
121 pub fn into_message_codec(self) -> ClientCodec {
123 ClientCodec { inner: self.inner }
124 }
125}
126
127impl Decoder for ClientCodec {
128 type Item = ResponseHead;
129 type Error = ParseError;
130
131 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
132 debug_assert!(
133 self.inner.payload.is_none(),
134 "Payload decoder should not be set"
135 );
136
137 if let Some((req, payload)) = self.inner.decoder.decode(src)? {
138 if let Some(conn_type) = req.conn_type() {
139 self.inner.conn_type = if conn_type == ConnectionType::KeepAlive {
141 self.inner.conn_type
142 } else {
143 conn_type
144 };
145 }
146
147 if !self.inner.flags.contains(Flags::HEAD) {
148 match payload {
149 PayloadType::None => self.inner.payload = None,
150 PayloadType::Payload(pl) => self.inner.payload = Some(pl),
151 PayloadType::Stream(pl) => {
152 self.inner.payload = Some(pl);
153 self.inner.flags.insert(Flags::STREAM);
154 }
155 }
156 } else {
157 self.inner.payload = None;
158 }
159 reserve_readbuf(src);
160 Ok(Some(req))
161 } else {
162 Ok(None)
163 }
164 }
165}
166
167impl Decoder for ClientPayloadCodec {
168 type Item = Option<Bytes>;
169 type Error = PayloadError;
170
171 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
172 debug_assert!(
173 self.inner.payload.is_some(),
174 "Payload decoder is not specified"
175 );
176
177 Ok(match self.inner.payload.as_mut().unwrap().decode(src)? {
178 Some(PayloadItem::Chunk(chunk)) => {
179 reserve_readbuf(src);
180 Some(Some(chunk))
181 }
182 Some(PayloadItem::Eof) => {
183 self.inner.payload.take();
184 Some(None)
185 }
186 None => None,
187 })
188 }
189}
190
191impl Encoder<Message<(RequestHeadType, BodySize)>> for ClientCodec {
192 type Error = io::Error;
193
194 fn encode(
195 &mut self,
196 item: Message<(RequestHeadType, BodySize)>,
197 dst: &mut BytesMut,
198 ) -> Result<(), Self::Error> {
199 match item {
200 Message::Item((mut head, length)) => {
201 let inner = &mut self.inner;
202 inner.version = head.as_ref().version;
203 inner
204 .flags
205 .set(Flags::HEAD, head.as_ref().method == Method::HEAD);
206
207 inner.conn_type = match head.as_ref().connection_type() {
209 ConnectionType::KeepAlive => {
210 if inner.flags.contains(Flags::KEEP_ALIVE_ENABLED) {
211 ConnectionType::KeepAlive
212 } else {
213 ConnectionType::Close
214 }
215 }
216 ConnectionType::Upgrade => ConnectionType::Upgrade,
217 ConnectionType::Close => ConnectionType::Close,
218 };
219
220 inner.encoder.encode(
221 dst,
222 &mut head,
223 false,
224 false,
225 inner.version,
226 length,
227 inner.conn_type,
228 &inner.config,
229 )?;
230 }
231 Message::Chunk(Some(bytes)) => {
232 self.inner.encoder.encode_chunk(bytes.as_ref(), dst)?;
233 }
234 Message::Chunk(None) => {
235 self.inner.encoder.encode_eof(dst)?;
236 }
237 }
238 Ok(())
239 }
240}