1use bitflags::bitflags;
2use bytes::{Bytes, BytesMut};
3use bytestring::ByteString;
4use tokio_util::codec::{Decoder, Encoder};
5use tracing::error;
6
7use super::{
8 frame::Parser,
9 proto::{CloseReason, OpCode},
10 ProtocolError,
11};
12
13#[derive(Debug, PartialEq, Eq)]
15pub enum Message {
16 Text(ByteString),
18
19 Binary(Bytes),
21
22 Continuation(Item),
24
25 Ping(Bytes),
27
28 Pong(Bytes),
30
31 Close(Option<CloseReason>),
33
34 Nop,
36}
37
38#[derive(Debug, PartialEq, Eq)]
40pub enum Frame {
41 Text(Bytes),
43
44 Binary(Bytes),
46
47 Continuation(Item),
49
50 Ping(Bytes),
52
53 Pong(Bytes),
55
56 Close(Option<CloseReason>),
58}
59
60#[derive(Debug, PartialEq, Eq)]
62pub enum Item {
63 FirstText(Bytes),
64 FirstBinary(Bytes),
65 Continue(Bytes),
66 Last(Bytes),
67}
68
69#[derive(Debug, Clone)]
71pub struct Codec {
72 flags: Flags,
73 max_size: usize,
74}
75
76bitflags! {
77 #[derive(Debug, Clone, Copy)]
78 struct Flags: u8 {
79 const SERVER = 0b0000_0001;
80 const CONTINUATION = 0b0000_0010;
81 const W_CONTINUATION = 0b0000_0100;
82 }
83}
84
85impl Codec {
86 pub const fn new() -> Codec {
88 Codec {
89 max_size: 65_536,
90 flags: Flags::SERVER,
91 }
92 }
93
94 #[must_use = "This returns the a new Codec, without modifying the original."]
98 pub fn max_size(mut self, size: usize) -> Self {
99 self.max_size = size;
100 self
101 }
102
103 #[must_use = "This returns the a new Codec, without modifying the original."]
107 pub fn client_mode(mut self) -> Self {
108 self.flags.remove(Flags::SERVER);
109 self
110 }
111}
112
113impl Default for Codec {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl Encoder<Message> for Codec {
120 type Error = ProtocolError;
121
122 fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
123 match item {
124 Message::Text(txt) => Parser::write_message(
125 dst,
126 txt,
127 OpCode::Text,
128 true,
129 !self.flags.contains(Flags::SERVER),
130 ),
131 Message::Binary(bin) => Parser::write_message(
132 dst,
133 bin,
134 OpCode::Binary,
135 true,
136 !self.flags.contains(Flags::SERVER),
137 ),
138 Message::Ping(txt) => Parser::write_message(
139 dst,
140 txt,
141 OpCode::Ping,
142 true,
143 !self.flags.contains(Flags::SERVER),
144 ),
145 Message::Pong(txt) => Parser::write_message(
146 dst,
147 txt,
148 OpCode::Pong,
149 true,
150 !self.flags.contains(Flags::SERVER),
151 ),
152 Message::Close(reason) => {
153 Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER))
154 }
155 Message::Continuation(cont) => match cont {
156 Item::FirstText(data) => {
157 if self.flags.contains(Flags::W_CONTINUATION) {
158 return Err(ProtocolError::ContinuationStarted);
159 } else {
160 self.flags.insert(Flags::W_CONTINUATION);
161 Parser::write_message(
162 dst,
163 &data[..],
164 OpCode::Text,
165 false,
166 !self.flags.contains(Flags::SERVER),
167 )
168 }
169 }
170 Item::FirstBinary(data) => {
171 if self.flags.contains(Flags::W_CONTINUATION) {
172 return Err(ProtocolError::ContinuationStarted);
173 } else {
174 self.flags.insert(Flags::W_CONTINUATION);
175 Parser::write_message(
176 dst,
177 &data[..],
178 OpCode::Binary,
179 false,
180 !self.flags.contains(Flags::SERVER),
181 )
182 }
183 }
184 Item::Continue(data) => {
185 if self.flags.contains(Flags::W_CONTINUATION) {
186 Parser::write_message(
187 dst,
188 &data[..],
189 OpCode::Continue,
190 false,
191 !self.flags.contains(Flags::SERVER),
192 )
193 } else {
194 return Err(ProtocolError::ContinuationNotStarted);
195 }
196 }
197 Item::Last(data) => {
198 if self.flags.contains(Flags::W_CONTINUATION) {
199 self.flags.remove(Flags::W_CONTINUATION);
200 Parser::write_message(
201 dst,
202 &data[..],
203 OpCode::Continue,
204 true,
205 !self.flags.contains(Flags::SERVER),
206 )
207 } else {
208 return Err(ProtocolError::ContinuationNotStarted);
209 }
210 }
211 },
212 Message::Nop => {}
213 }
214 Ok(())
215 }
216}
217
218impl Decoder for Codec {
219 type Item = Frame;
220 type Error = ProtocolError;
221
222 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
223 match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) {
224 Ok(Some((finished, opcode, payload))) => {
225 if !finished {
227 return match opcode {
228 OpCode::Continue => {
229 if self.flags.contains(Flags::CONTINUATION) {
230 Ok(Some(Frame::Continuation(Item::Continue(
231 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
232 ))))
233 } else {
234 Err(ProtocolError::ContinuationNotStarted)
235 }
236 }
237 OpCode::Binary => {
238 if !self.flags.contains(Flags::CONTINUATION) {
239 self.flags.insert(Flags::CONTINUATION);
240 Ok(Some(Frame::Continuation(Item::FirstBinary(
241 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
242 ))))
243 } else {
244 Err(ProtocolError::ContinuationStarted)
245 }
246 }
247 OpCode::Text => {
248 if !self.flags.contains(Flags::CONTINUATION) {
249 self.flags.insert(Flags::CONTINUATION);
250 Ok(Some(Frame::Continuation(Item::FirstText(
251 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
252 ))))
253 } else {
254 Err(ProtocolError::ContinuationStarted)
255 }
256 }
257 _ => {
258 error!("Unfinished fragment {:?}", opcode);
259 Err(ProtocolError::ContinuationFragment(opcode))
260 }
261 };
262 }
263
264 match opcode {
265 OpCode::Continue => {
266 if self.flags.contains(Flags::CONTINUATION) {
267 self.flags.remove(Flags::CONTINUATION);
268 Ok(Some(Frame::Continuation(Item::Last(
269 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
270 ))))
271 } else {
272 Err(ProtocolError::ContinuationNotStarted)
273 }
274 }
275 OpCode::Bad => Err(ProtocolError::BadOpCode),
276 OpCode::Close => {
277 if let Some(ref pl) = payload {
278 let close_reason = Parser::parse_close_payload(pl);
279 Ok(Some(Frame::Close(close_reason)))
280 } else {
281 Ok(Some(Frame::Close(None)))
282 }
283 }
284 OpCode::Ping => Ok(Some(Frame::Ping(
285 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
286 ))),
287 OpCode::Pong => Ok(Some(Frame::Pong(
288 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
289 ))),
290 OpCode::Binary => Ok(Some(Frame::Binary(
291 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
292 ))),
293 OpCode::Text => Ok(Some(Frame::Text(
294 payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
295 ))),
296 }
297 }
298 Ok(None) => Ok(None),
299 Err(err) => Err(err),
300 }
301 }
302}