actix_http/ws/
codec.rs

1use bitflags::bitflags;
2use bytes::{Bytes, BytesMut};
3use bytestring::ByteString;
4use tokio_util::codec::{Decoder, Encoder};
5use tracing::error;
6
7use super::{
8    frame::Parser,
9    proto::{CloseReason, OpCode},
10    ProtocolError,
11};
12
13/// A WebSocket message.
14#[derive(Debug, PartialEq, Eq)]
15pub enum Message {
16    /// Text message.
17    Text(ByteString),
18
19    /// Binary message.
20    Binary(Bytes),
21
22    /// Continuation.
23    Continuation(Item),
24
25    /// Ping message.
26    Ping(Bytes),
27
28    /// Pong message.
29    Pong(Bytes),
30
31    /// Close message with optional reason.
32    Close(Option<CloseReason>),
33
34    /// No-op. Useful for low-level services.
35    Nop,
36}
37
38/// A WebSocket frame.
39#[derive(Debug, PartialEq, Eq)]
40pub enum Frame {
41    /// Text frame. Note that the codec does not validate UTF-8 encoding.
42    Text(Bytes),
43
44    /// Binary frame.
45    Binary(Bytes),
46
47    /// Continuation.
48    Continuation(Item),
49
50    /// Ping message.
51    Ping(Bytes),
52
53    /// Pong message.
54    Pong(Bytes),
55
56    /// Close message with optional reason.
57    Close(Option<CloseReason>),
58}
59
60/// A WebSocket continuation item.
61#[derive(Debug, PartialEq, Eq)]
62pub enum Item {
63    FirstText(Bytes),
64    FirstBinary(Bytes),
65    Continue(Bytes),
66    Last(Bytes),
67}
68
69/// WebSocket protocol codec.
70#[derive(Debug, Clone)]
71pub struct Codec {
72    flags: Flags,
73    max_size: usize,
74}
75
76bitflags! {
77    #[derive(Debug, Clone, Copy)]
78    struct Flags: u8 {
79        const SERVER         = 0b0000_0001;
80        const CONTINUATION   = 0b0000_0010;
81        const W_CONTINUATION = 0b0000_0100;
82    }
83}
84
85impl Codec {
86    /// Create new WebSocket frames decoder.
87    pub const fn new() -> Codec {
88        Codec {
89            max_size: 65_536,
90            flags: Flags::SERVER,
91        }
92    }
93
94    /// Set max frame size.
95    ///
96    /// By default max size is set to 64KiB.
97    #[must_use = "This returns the a new Codec, without modifying the original."]
98    pub fn max_size(mut self, size: usize) -> Self {
99        self.max_size = size;
100        self
101    }
102
103    /// Set decoder to client mode.
104    ///
105    /// By default decoder works in server mode.
106    #[must_use = "This returns the a new Codec, without modifying the original."]
107    pub fn client_mode(mut self) -> Self {
108        self.flags.remove(Flags::SERVER);
109        self
110    }
111}
112
113impl Default for Codec {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl Encoder<Message> for Codec {
120    type Error = ProtocolError;
121
122    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
123        match item {
124            Message::Text(txt) => Parser::write_message(
125                dst,
126                txt,
127                OpCode::Text,
128                true,
129                !self.flags.contains(Flags::SERVER),
130            ),
131            Message::Binary(bin) => Parser::write_message(
132                dst,
133                bin,
134                OpCode::Binary,
135                true,
136                !self.flags.contains(Flags::SERVER),
137            ),
138            Message::Ping(txt) => Parser::write_message(
139                dst,
140                txt,
141                OpCode::Ping,
142                true,
143                !self.flags.contains(Flags::SERVER),
144            ),
145            Message::Pong(txt) => Parser::write_message(
146                dst,
147                txt,
148                OpCode::Pong,
149                true,
150                !self.flags.contains(Flags::SERVER),
151            ),
152            Message::Close(reason) => {
153                Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER))
154            }
155            Message::Continuation(cont) => match cont {
156                Item::FirstText(data) => {
157                    if self.flags.contains(Flags::W_CONTINUATION) {
158                        return Err(ProtocolError::ContinuationStarted);
159                    } else {
160                        self.flags.insert(Flags::W_CONTINUATION);
161                        Parser::write_message(
162                            dst,
163                            &data[..],
164                            OpCode::Text,
165                            false,
166                            !self.flags.contains(Flags::SERVER),
167                        )
168                    }
169                }
170                Item::FirstBinary(data) => {
171                    if self.flags.contains(Flags::W_CONTINUATION) {
172                        return Err(ProtocolError::ContinuationStarted);
173                    } else {
174                        self.flags.insert(Flags::W_CONTINUATION);
175                        Parser::write_message(
176                            dst,
177                            &data[..],
178                            OpCode::Binary,
179                            false,
180                            !self.flags.contains(Flags::SERVER),
181                        )
182                    }
183                }
184                Item::Continue(data) => {
185                    if self.flags.contains(Flags::W_CONTINUATION) {
186                        Parser::write_message(
187                            dst,
188                            &data[..],
189                            OpCode::Continue,
190                            false,
191                            !self.flags.contains(Flags::SERVER),
192                        )
193                    } else {
194                        return Err(ProtocolError::ContinuationNotStarted);
195                    }
196                }
197                Item::Last(data) => {
198                    if self.flags.contains(Flags::W_CONTINUATION) {
199                        self.flags.remove(Flags::W_CONTINUATION);
200                        Parser::write_message(
201                            dst,
202                            &data[..],
203                            OpCode::Continue,
204                            true,
205                            !self.flags.contains(Flags::SERVER),
206                        )
207                    } else {
208                        return Err(ProtocolError::ContinuationNotStarted);
209                    }
210                }
211            },
212            Message::Nop => {}
213        }
214        Ok(())
215    }
216}
217
218impl Decoder for Codec {
219    type Item = Frame;
220    type Error = ProtocolError;
221
222    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
223        match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) {
224            Ok(Some((finished, opcode, payload))) => {
225                // continuation is not supported
226                if !finished {
227                    return match opcode {
228                        OpCode::Continue => {
229                            if self.flags.contains(Flags::CONTINUATION) {
230                                Ok(Some(Frame::Continuation(Item::Continue(
231                                    payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
232                                ))))
233                            } else {
234                                Err(ProtocolError::ContinuationNotStarted)
235                            }
236                        }
237                        OpCode::Binary => {
238                            if !self.flags.contains(Flags::CONTINUATION) {
239                                self.flags.insert(Flags::CONTINUATION);
240                                Ok(Some(Frame::Continuation(Item::FirstBinary(
241                                    payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
242                                ))))
243                            } else {
244                                Err(ProtocolError::ContinuationStarted)
245                            }
246                        }
247                        OpCode::Text => {
248                            if !self.flags.contains(Flags::CONTINUATION) {
249                                self.flags.insert(Flags::CONTINUATION);
250                                Ok(Some(Frame::Continuation(Item::FirstText(
251                                    payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
252                                ))))
253                            } else {
254                                Err(ProtocolError::ContinuationStarted)
255                            }
256                        }
257                        _ => {
258                            error!("Unfinished fragment {:?}", opcode);
259                            Err(ProtocolError::ContinuationFragment(opcode))
260                        }
261                    };
262                }
263
264                match opcode {
265                    OpCode::Continue => {
266                        if self.flags.contains(Flags::CONTINUATION) {
267                            self.flags.remove(Flags::CONTINUATION);
268                            Ok(Some(Frame::Continuation(Item::Last(
269                                payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
270                            ))))
271                        } else {
272                            Err(ProtocolError::ContinuationNotStarted)
273                        }
274                    }
275                    OpCode::Bad => Err(ProtocolError::BadOpCode),
276                    OpCode::Close => {
277                        if let Some(ref pl) = payload {
278                            let close_reason = Parser::parse_close_payload(pl);
279                            Ok(Some(Frame::Close(close_reason)))
280                        } else {
281                            Ok(Some(Frame::Close(None)))
282                        }
283                    }
284                    OpCode::Ping => Ok(Some(Frame::Ping(
285                        payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
286                    ))),
287                    OpCode::Pong => Ok(Some(Frame::Pong(
288                        payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
289                    ))),
290                    OpCode::Binary => Ok(Some(Frame::Binary(
291                        payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
292                    ))),
293                    OpCode::Text => Ok(Some(Frame::Text(
294                        payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
295                    ))),
296                }
297            }
298            Ok(None) => Ok(None),
299            Err(err) => Err(err),
300        }
301    }
302}