warp/filters/
ws.rs

1//! Websockets Filters
2
3use std::borrow::Cow;
4use std::fmt;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use super::header;
10use crate::filter::{filter_fn_one, Filter, One};
11use crate::reject::Rejection;
12use crate::reply::{Reply, Response};
13use futures_util::{future, ready, FutureExt, Sink, Stream, TryFutureExt};
14use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
15use hyper::upgrade::OnUpgrade;
16use tokio_tungstenite::{
17    tungstenite::protocol::{self, WebSocketConfig},
18    WebSocketStream,
19};
20
21/// Creates a Websocket Filter.
22///
23/// The yielded `Ws` is used to finish the websocket upgrade.
24///
25/// # Note
26///
27/// This filter combines multiple filters internally, so you don't need them:
28///
29/// - Method must be `GET`
30/// - Header `connection` must be `upgrade`
31/// - Header `upgrade` must be `websocket`
32/// - Header `sec-websocket-version` must be `13`
33/// - Header `sec-websocket-key` must be set.
34///
35/// If the filters are met, yields a `Ws`. Calling `Ws::on_upgrade` will
36/// return a reply with:
37///
38/// - Status of `101 Switching Protocols`
39/// - Header `connection: upgrade`
40/// - Header `upgrade: websocket`
41/// - Header `sec-websocket-accept` with the hash value of the received key.
42pub fn ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy {
43    let connection_has_upgrade = header::header2()
44        .and_then(|conn: ::headers::Connection| {
45            if conn.contains("upgrade") {
46                future::ok(())
47            } else {
48                future::err(crate::reject::known(MissingConnectionUpgrade))
49            }
50        })
51        .untuple_one();
52
53    crate::get()
54        .and(connection_has_upgrade)
55        .and(header::exact_ignore_case("upgrade", "websocket"))
56        .and(header::exact("sec-websocket-version", "13"))
57        //.and(header::exact2(Upgrade::websocket()))
58        //.and(header::exact2(SecWebsocketVersion::V13))
59        .and(header::header2::<SecWebsocketKey>())
60        .and(on_upgrade())
61        .map(
62            move |key: SecWebsocketKey, on_upgrade: Option<OnUpgrade>| Ws {
63                config: None,
64                key,
65                on_upgrade,
66            },
67        )
68}
69
70/// Extracted by the [`ws`] filter, and used to finish an upgrade.
71pub struct Ws {
72    config: Option<WebSocketConfig>,
73    key: SecWebsocketKey,
74    on_upgrade: Option<OnUpgrade>,
75}
76
77impl Ws {
78    /// Finish the upgrade, passing a function to handle the `WebSocket`.
79    ///
80    /// The passed function must return a `Future`.
81    pub fn on_upgrade<F, U>(self, func: F) -> impl Reply
82    where
83        F: FnOnce(WebSocket) -> U + Send + 'static,
84        U: Future<Output = ()> + Send + 'static,
85    {
86        WsReply {
87            ws: self,
88            on_upgrade: func,
89        }
90    }
91
92    // config
93
94    /// Does nothing.
95    ///
96    /// # Deprecated
97    ///
98    /// Use `max_write_buffer_size()` instead.
99    #[deprecated = "use max_write_buffer_size instead"]
100    pub fn max_send_queue(self, _max: usize) -> Self {
101        self
102    }
103
104    /// The max size of the write buffer, in bytes.
105    pub fn max_write_buffer_size(mut self, max: usize) -> Self {
106        self.config
107            .get_or_insert_with(WebSocketConfig::default)
108            .max_write_buffer_size = max;
109        self
110    }
111
112    /// Set the maximum message size (defaults to 64 megabytes)
113    pub fn max_message_size(mut self, max: usize) -> Self {
114        self.config
115            .get_or_insert_with(WebSocketConfig::default)
116            .max_message_size = Some(max);
117        self
118    }
119
120    /// Set the maximum frame size (defaults to 16 megabytes)
121    pub fn max_frame_size(mut self, max: usize) -> Self {
122        self.config
123            .get_or_insert_with(WebSocketConfig::default)
124            .max_frame_size = Some(max);
125        self
126    }
127}
128
129impl fmt::Debug for Ws {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        f.debug_struct("Ws").finish()
132    }
133}
134
135#[allow(missing_debug_implementations)]
136struct WsReply<F> {
137    ws: Ws,
138    on_upgrade: F,
139}
140
141impl<F, U> Reply for WsReply<F>
142where
143    F: FnOnce(WebSocket) -> U + Send + 'static,
144    U: Future<Output = ()> + Send + 'static,
145{
146    fn into_response(self) -> Response {
147        if let Some(on_upgrade) = self.ws.on_upgrade {
148            let on_upgrade_cb = self.on_upgrade;
149            let config = self.ws.config;
150            let fut = on_upgrade
151                .and_then(move |upgraded| {
152                    tracing::trace!("websocket upgrade complete");
153                    WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
154                })
155                .and_then(move |socket| on_upgrade_cb(socket).map(Ok))
156                .map(|result| {
157                    if let Err(err) = result {
158                        tracing::debug!("ws upgrade error: {}", err);
159                    }
160                });
161            ::tokio::task::spawn(fut);
162        } else {
163            tracing::debug!("ws couldn't be upgraded since no upgrade state was present");
164        }
165
166        let mut res = http::Response::default();
167
168        *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
169
170        res.headers_mut().typed_insert(Connection::upgrade());
171        res.headers_mut().typed_insert(Upgrade::websocket());
172        res.headers_mut()
173            .typed_insert(SecWebsocketAccept::from(self.ws.key));
174
175        res
176    }
177}
178
179// Extracts OnUpgrade state from the route.
180fn on_upgrade() -> impl Filter<Extract = (Option<OnUpgrade>,), Error = Rejection> + Copy {
181    filter_fn_one(|route| future::ready(Ok(route.extensions_mut().remove::<OnUpgrade>())))
182}
183
184/// A websocket `Stream` and `Sink`, provided to `ws` filters.
185///
186/// Ping messages sent from the client will be handled internally by replying with a Pong message.
187/// Close messages need to be handled explicitly: usually by closing the `Sink` end of the
188/// `WebSocket`.
189///
190/// **Note!**
191/// Due to rust futures nature, pings won't be handled until read part of `WebSocket` is polled
192
193pub struct WebSocket {
194    inner: WebSocketStream<hyper::upgrade::Upgraded>,
195}
196
197impl WebSocket {
198    pub(crate) async fn from_raw_socket(
199        upgraded: hyper::upgrade::Upgraded,
200        role: protocol::Role,
201        config: Option<protocol::WebSocketConfig>,
202    ) -> Self {
203        WebSocketStream::from_raw_socket(upgraded, role, config)
204            .map(|inner| WebSocket { inner })
205            .await
206    }
207
208    /// Gracefully close this websocket.
209    pub async fn close(mut self) -> Result<(), crate::Error> {
210        future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await
211    }
212}
213
214impl Stream for WebSocket {
215    type Item = Result<Message, crate::Error>;
216
217    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
218        match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
219            Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
220            Some(Err(e)) => {
221                tracing::debug!("websocket poll error: {}", e);
222                Poll::Ready(Some(Err(crate::Error::new(e))))
223            }
224            None => {
225                tracing::trace!("websocket closed");
226                Poll::Ready(None)
227            }
228        }
229    }
230}
231
232impl Sink<Message> for WebSocket {
233    type Error = crate::Error;
234
235    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
236        match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
237            Ok(()) => Poll::Ready(Ok(())),
238            Err(e) => Poll::Ready(Err(crate::Error::new(e))),
239        }
240    }
241
242    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
243        match Pin::new(&mut self.inner).start_send(item.inner) {
244            Ok(()) => Ok(()),
245            Err(e) => {
246                tracing::debug!("websocket start_send error: {}", e);
247                Err(crate::Error::new(e))
248            }
249        }
250    }
251
252    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
253        match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
254            Ok(()) => Poll::Ready(Ok(())),
255            Err(e) => Poll::Ready(Err(crate::Error::new(e))),
256        }
257    }
258
259    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
260        match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
261            Ok(()) => Poll::Ready(Ok(())),
262            Err(err) => {
263                tracing::debug!("websocket close error: {}", err);
264                Poll::Ready(Err(crate::Error::new(err)))
265            }
266        }
267    }
268}
269
270impl fmt::Debug for WebSocket {
271    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272        f.debug_struct("WebSocket").finish()
273    }
274}
275
276/// A WebSocket message.
277///
278/// This will likely become a `non-exhaustive` enum in the future, once that
279/// language feature has stabilized.
280#[derive(Eq, PartialEq, Clone)]
281pub struct Message {
282    inner: protocol::Message,
283}
284
285impl Message {
286    /// Construct a new Text `Message`.
287    pub fn text<S: Into<String>>(s: S) -> Message {
288        Message {
289            inner: protocol::Message::text(s),
290        }
291    }
292
293    /// Construct a new Binary `Message`.
294    pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message {
295        Message {
296            inner: protocol::Message::binary(v),
297        }
298    }
299
300    /// Construct a new Ping `Message`.
301    pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
302        Message {
303            inner: protocol::Message::Ping(v.into()),
304        }
305    }
306
307    /// Construct a new Pong `Message`.
308    ///
309    /// Note that one rarely needs to manually construct a Pong message because the underlying tungstenite socket
310    /// automatically responds to the Ping messages it receives. Manual construction might still be useful in some cases
311    /// like in tests or to send unidirectional heartbeats.
312    pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
313        Message {
314            inner: protocol::Message::Pong(v.into()),
315        }
316    }
317
318    /// Construct the default Close `Message`.
319    pub fn close() -> Message {
320        Message {
321            inner: protocol::Message::Close(None),
322        }
323    }
324
325    /// Construct a Close `Message` with a code and reason.
326    pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message {
327        Message {
328            inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
329                code: protocol::frame::coding::CloseCode::from(code.into()),
330                reason: reason.into(),
331            })),
332        }
333    }
334
335    /// Returns true if this message is a Text message.
336    pub fn is_text(&self) -> bool {
337        self.inner.is_text()
338    }
339
340    /// Returns true if this message is a Binary message.
341    pub fn is_binary(&self) -> bool {
342        self.inner.is_binary()
343    }
344
345    /// Returns true if this message a is a Close message.
346    pub fn is_close(&self) -> bool {
347        self.inner.is_close()
348    }
349
350    /// Returns true if this message is a Ping message.
351    pub fn is_ping(&self) -> bool {
352        self.inner.is_ping()
353    }
354
355    /// Returns true if this message is a Pong message.
356    pub fn is_pong(&self) -> bool {
357        self.inner.is_pong()
358    }
359
360    /// Try to get the close frame (close code and reason)
361    pub fn close_frame(&self) -> Option<(u16, &str)> {
362        if let protocol::Message::Close(Some(ref close_frame)) = self.inner {
363            Some((close_frame.code.into(), close_frame.reason.as_ref()))
364        } else {
365            None
366        }
367    }
368
369    /// Try to get a reference to the string text, if this is a Text message.
370    pub fn to_str(&self) -> Result<&str, ()> {
371        match self.inner {
372            protocol::Message::Text(ref s) => Ok(s),
373            _ => Err(()),
374        }
375    }
376
377    /// Return the bytes of this message, if the message can contain data.
378    pub fn as_bytes(&self) -> &[u8] {
379        match self.inner {
380            protocol::Message::Text(ref s) => s.as_bytes(),
381            protocol::Message::Binary(ref v) => v,
382            protocol::Message::Ping(ref v) => v,
383            protocol::Message::Pong(ref v) => v,
384            protocol::Message::Close(_) => &[],
385            protocol::Message::Frame(ref frame) => frame.payload(),
386        }
387    }
388
389    /// Destructure this message into binary data.
390    pub fn into_bytes(self) -> Vec<u8> {
391        self.inner.into_data()
392    }
393}
394
395impl fmt::Debug for Message {
396    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397        fmt::Debug::fmt(&self.inner, f)
398    }
399}
400
401impl From<Message> for Vec<u8> {
402    fn from(m: Message) -> Self {
403        m.into_bytes()
404    }
405}
406
407// ===== Rejections =====
408
409/// Connection header did not include 'upgrade'
410#[derive(Debug)]
411pub struct MissingConnectionUpgrade;
412
413impl fmt::Display for MissingConnectionUpgrade {
414    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415        write!(f, "Connection header did not include 'upgrade'")
416    }
417}
418
419impl ::std::error::Error for MissingConnectionUpgrade {}