1use std::io;
7
8use derive_more::{Display, Error, From};
9use http::{header, Method, StatusCode};
10
11use crate::{body::BoxBody, header::HeaderValue, RequestHead, Response, ResponseBuilder};
12
13mod codec;
14mod dispatcher;
15mod frame;
16mod mask;
17mod proto;
18
19pub use self::{
20 codec::{Codec, Frame, Item, Message},
21 dispatcher::Dispatcher,
22 frame::Parser,
23 proto::{hash_key, CloseCode, CloseReason, OpCode},
24};
25
26#[derive(Debug, Display, Error, From)]
28pub enum ProtocolError {
29 #[display("received an unmasked frame from client")]
31 UnmaskedFrame,
32
33 #[display("received a masked frame from server")]
35 MaskedFrame,
36
37 #[display("invalid opcode ({})", _0)]
39 InvalidOpcode(#[error(not(source))] u8),
40
41 #[display("invalid control frame length ({})", _0)]
43 InvalidLength(#[error(not(source))] usize),
44
45 #[display("bad opcode")]
47 BadOpCode,
48
49 #[display("payload reached size limit")]
51 Overflow,
52
53 #[display("continuation has not started")]
55 ContinuationNotStarted,
56
57 #[display("received new continuation but it has already started")]
59 ContinuationStarted,
60
61 #[display("unknown continuation fragment: {}", _0)]
63 ContinuationFragment(#[error(not(source))] OpCode),
64
65 #[display("I/O error: {}", _0)]
67 Io(io::Error),
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, Error)]
72pub enum HandshakeError {
73 #[display("method not allowed")]
75 GetMethodRequired,
76
77 #[display("WebSocket upgrade is expected")]
79 NoWebsocketUpgrade,
80
81 #[display("connection upgrade is expected")]
83 NoConnectionUpgrade,
84
85 #[display("WebSocket version header is required")]
87 NoVersionHeader,
88
89 #[display("unsupported WebSocket version")]
91 UnsupportedVersion,
92
93 #[display("unknown WebSocket key")]
95 BadWebsocketKey,
96}
97
98impl From<HandshakeError> for Response<BoxBody> {
99 fn from(err: HandshakeError) -> Self {
100 match err {
101 HandshakeError::GetMethodRequired => {
102 let mut res = Response::new(StatusCode::METHOD_NOT_ALLOWED);
103 #[allow(clippy::declare_interior_mutable_const)]
104 const HV_GET: HeaderValue = HeaderValue::from_static("GET");
105 res.headers_mut().insert(header::ALLOW, HV_GET);
106 res
107 }
108
109 HandshakeError::NoWebsocketUpgrade => {
110 let mut res = Response::bad_request();
111 res.head_mut().reason = Some("No WebSocket Upgrade header found");
112 res
113 }
114
115 HandshakeError::NoConnectionUpgrade => {
116 let mut res = Response::bad_request();
117 res.head_mut().reason = Some("No Connection upgrade");
118 res
119 }
120
121 HandshakeError::NoVersionHeader => {
122 let mut res = Response::bad_request();
123 res.head_mut().reason = Some("WebSocket version header is required");
124 res
125 }
126
127 HandshakeError::UnsupportedVersion => {
128 let mut res = Response::bad_request();
129 res.head_mut().reason = Some("Unsupported WebSocket version");
130 res
131 }
132
133 HandshakeError::BadWebsocketKey => {
134 let mut res = Response::bad_request();
135 res.head_mut().reason = Some("Handshake error");
136 res
137 }
138 }
139 }
140}
141
142impl From<&HandshakeError> for Response<BoxBody> {
143 fn from(err: &HandshakeError) -> Self {
144 (*err).into()
145 }
146}
147
148pub fn handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError> {
150 verify_handshake(req)?;
151 Ok(handshake_response(req))
152}
153
154pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
156 if req.method != Method::GET {
158 return Err(HandshakeError::GetMethodRequired);
159 }
160
161 let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) {
163 if let Ok(s) = hdr.to_str() {
164 s.to_ascii_lowercase().contains("websocket")
165 } else {
166 false
167 }
168 } else {
169 false
170 };
171 if !has_hdr {
172 return Err(HandshakeError::NoWebsocketUpgrade);
173 }
174
175 if !req.upgrade() {
177 return Err(HandshakeError::NoConnectionUpgrade);
178 }
179
180 if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
182 return Err(HandshakeError::NoVersionHeader);
183 }
184 let supported_ver = {
185 if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
186 hdr == "13" || hdr == "8" || hdr == "7"
187 } else {
188 false
189 }
190 };
191 if !supported_ver {
192 return Err(HandshakeError::UnsupportedVersion);
193 }
194
195 if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
197 return Err(HandshakeError::BadWebsocketKey);
198 }
199 Ok(())
200}
201
202pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
206 let key = {
207 let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
208 proto::hash_key(key.as_ref())
209 };
210
211 Response::build(StatusCode::SWITCHING_PROTOCOLS)
212 .upgrade("websocket")
213 .insert_header((
214 header::SEC_WEBSOCKET_ACCEPT,
215 HeaderValue::from_bytes(&key).unwrap(),
217 ))
218 .take()
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::{header, test::TestRequest};
225
226 #[test]
227 fn test_handshake() {
228 let req = TestRequest::default().method(Method::POST).finish();
229 assert_eq!(
230 HandshakeError::GetMethodRequired,
231 verify_handshake(req.head()).unwrap_err(),
232 );
233
234 let req = TestRequest::default().finish();
235 assert_eq!(
236 HandshakeError::NoWebsocketUpgrade,
237 verify_handshake(req.head()).unwrap_err(),
238 );
239
240 let req = TestRequest::default()
241 .insert_header((header::UPGRADE, header::HeaderValue::from_static("test")))
242 .finish();
243 assert_eq!(
244 HandshakeError::NoWebsocketUpgrade,
245 verify_handshake(req.head()).unwrap_err(),
246 );
247
248 let req = TestRequest::default()
249 .insert_header((
250 header::UPGRADE,
251 header::HeaderValue::from_static("websocket"),
252 ))
253 .finish();
254 assert_eq!(
255 HandshakeError::NoConnectionUpgrade,
256 verify_handshake(req.head()).unwrap_err(),
257 );
258
259 let req = TestRequest::default()
260 .insert_header((
261 header::UPGRADE,
262 header::HeaderValue::from_static("websocket"),
263 ))
264 .insert_header((
265 header::CONNECTION,
266 header::HeaderValue::from_static("upgrade"),
267 ))
268 .finish();
269 assert_eq!(
270 HandshakeError::NoVersionHeader,
271 verify_handshake(req.head()).unwrap_err(),
272 );
273
274 let req = TestRequest::default()
275 .insert_header((
276 header::UPGRADE,
277 header::HeaderValue::from_static("websocket"),
278 ))
279 .insert_header((
280 header::CONNECTION,
281 header::HeaderValue::from_static("upgrade"),
282 ))
283 .insert_header((
284 header::SEC_WEBSOCKET_VERSION,
285 header::HeaderValue::from_static("5"),
286 ))
287 .finish();
288 assert_eq!(
289 HandshakeError::UnsupportedVersion,
290 verify_handshake(req.head()).unwrap_err(),
291 );
292
293 let req = TestRequest::default()
294 .insert_header((
295 header::UPGRADE,
296 header::HeaderValue::from_static("websocket"),
297 ))
298 .insert_header((
299 header::CONNECTION,
300 header::HeaderValue::from_static("upgrade"),
301 ))
302 .insert_header((
303 header::SEC_WEBSOCKET_VERSION,
304 header::HeaderValue::from_static("13"),
305 ))
306 .finish();
307 assert_eq!(
308 HandshakeError::BadWebsocketKey,
309 verify_handshake(req.head()).unwrap_err(),
310 );
311
312 let req = TestRequest::default()
313 .insert_header((
314 header::UPGRADE,
315 header::HeaderValue::from_static("websocket"),
316 ))
317 .insert_header((
318 header::CONNECTION,
319 header::HeaderValue::from_static("upgrade"),
320 ))
321 .insert_header((
322 header::SEC_WEBSOCKET_VERSION,
323 header::HeaderValue::from_static("13"),
324 ))
325 .insert_header((
326 header::SEC_WEBSOCKET_KEY,
327 header::HeaderValue::from_static("13"),
328 ))
329 .finish();
330 assert_eq!(
331 StatusCode::SWITCHING_PROTOCOLS,
332 handshake_response(req.head()).finish().status()
333 );
334 }
335
336 #[test]
337 fn test_ws_error_http_response() {
338 let resp: Response<BoxBody> = HandshakeError::GetMethodRequired.into();
339 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
340 let resp: Response<BoxBody> = HandshakeError::NoWebsocketUpgrade.into();
341 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
342 let resp: Response<BoxBody> = HandshakeError::NoConnectionUpgrade.into();
343 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
344 let resp: Response<BoxBody> = HandshakeError::NoVersionHeader.into();
345 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
346 let resp: Response<BoxBody> = HandshakeError::UnsupportedVersion.into();
347 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
348 let resp: Response<BoxBody> = HandshakeError::BadWebsocketKey.into();
349 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
350 }
351}