1use std::borrow::Cow;
4use std::fmt;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use super::header;
10use crate::filter::{filter_fn_one, Filter, One};
11use crate::reject::Rejection;
12use crate::reply::{Reply, Response};
13use futures_util::{future, ready, FutureExt, Sink, Stream, TryFutureExt};
14use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
15use hyper::upgrade::OnUpgrade;
16use tokio_tungstenite::{
17 tungstenite::protocol::{self, WebSocketConfig},
18 WebSocketStream,
19};
20
21pub fn ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy {
43 let connection_has_upgrade = header::header2()
44 .and_then(|conn: ::headers::Connection| {
45 if conn.contains("upgrade") {
46 future::ok(())
47 } else {
48 future::err(crate::reject::known(MissingConnectionUpgrade))
49 }
50 })
51 .untuple_one();
52
53 crate::get()
54 .and(connection_has_upgrade)
55 .and(header::exact_ignore_case("upgrade", "websocket"))
56 .and(header::exact("sec-websocket-version", "13"))
57 .and(header::header2::<SecWebsocketKey>())
60 .and(on_upgrade())
61 .map(
62 move |key: SecWebsocketKey, on_upgrade: Option<OnUpgrade>| Ws {
63 config: None,
64 key,
65 on_upgrade,
66 },
67 )
68}
69
70pub struct Ws {
72 config: Option<WebSocketConfig>,
73 key: SecWebsocketKey,
74 on_upgrade: Option<OnUpgrade>,
75}
76
77impl Ws {
78 pub fn on_upgrade<F, U>(self, func: F) -> impl Reply
82 where
83 F: FnOnce(WebSocket) -> U + Send + 'static,
84 U: Future<Output = ()> + Send + 'static,
85 {
86 WsReply {
87 ws: self,
88 on_upgrade: func,
89 }
90 }
91
92 #[deprecated = "use max_write_buffer_size instead"]
100 pub fn max_send_queue(self, _max: usize) -> Self {
101 self
102 }
103
104 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
106 self.config
107 .get_or_insert_with(WebSocketConfig::default)
108 .max_write_buffer_size = max;
109 self
110 }
111
112 pub fn max_message_size(mut self, max: usize) -> Self {
114 self.config
115 .get_or_insert_with(WebSocketConfig::default)
116 .max_message_size = Some(max);
117 self
118 }
119
120 pub fn max_frame_size(mut self, max: usize) -> Self {
122 self.config
123 .get_or_insert_with(WebSocketConfig::default)
124 .max_frame_size = Some(max);
125 self
126 }
127}
128
129impl fmt::Debug for Ws {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 f.debug_struct("Ws").finish()
132 }
133}
134
135#[allow(missing_debug_implementations)]
136struct WsReply<F> {
137 ws: Ws,
138 on_upgrade: F,
139}
140
141impl<F, U> Reply for WsReply<F>
142where
143 F: FnOnce(WebSocket) -> U + Send + 'static,
144 U: Future<Output = ()> + Send + 'static,
145{
146 fn into_response(self) -> Response {
147 if let Some(on_upgrade) = self.ws.on_upgrade {
148 let on_upgrade_cb = self.on_upgrade;
149 let config = self.ws.config;
150 let fut = on_upgrade
151 .and_then(move |upgraded| {
152 tracing::trace!("websocket upgrade complete");
153 WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
154 })
155 .and_then(move |socket| on_upgrade_cb(socket).map(Ok))
156 .map(|result| {
157 if let Err(err) = result {
158 tracing::debug!("ws upgrade error: {}", err);
159 }
160 });
161 ::tokio::task::spawn(fut);
162 } else {
163 tracing::debug!("ws couldn't be upgraded since no upgrade state was present");
164 }
165
166 let mut res = http::Response::default();
167
168 *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
169
170 res.headers_mut().typed_insert(Connection::upgrade());
171 res.headers_mut().typed_insert(Upgrade::websocket());
172 res.headers_mut()
173 .typed_insert(SecWebsocketAccept::from(self.ws.key));
174
175 res
176 }
177}
178
179fn on_upgrade() -> impl Filter<Extract = (Option<OnUpgrade>,), Error = Rejection> + Copy {
181 filter_fn_one(|route| future::ready(Ok(route.extensions_mut().remove::<OnUpgrade>())))
182}
183
184pub struct WebSocket {
194 inner: WebSocketStream<hyper::upgrade::Upgraded>,
195}
196
197impl WebSocket {
198 pub(crate) async fn from_raw_socket(
199 upgraded: hyper::upgrade::Upgraded,
200 role: protocol::Role,
201 config: Option<protocol::WebSocketConfig>,
202 ) -> Self {
203 WebSocketStream::from_raw_socket(upgraded, role, config)
204 .map(|inner| WebSocket { inner })
205 .await
206 }
207
208 pub async fn close(mut self) -> Result<(), crate::Error> {
210 future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await
211 }
212}
213
214impl Stream for WebSocket {
215 type Item = Result<Message, crate::Error>;
216
217 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
218 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
219 Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
220 Some(Err(e)) => {
221 tracing::debug!("websocket poll error: {}", e);
222 Poll::Ready(Some(Err(crate::Error::new(e))))
223 }
224 None => {
225 tracing::trace!("websocket closed");
226 Poll::Ready(None)
227 }
228 }
229 }
230}
231
232impl Sink<Message> for WebSocket {
233 type Error = crate::Error;
234
235 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
236 match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
237 Ok(()) => Poll::Ready(Ok(())),
238 Err(e) => Poll::Ready(Err(crate::Error::new(e))),
239 }
240 }
241
242 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
243 match Pin::new(&mut self.inner).start_send(item.inner) {
244 Ok(()) => Ok(()),
245 Err(e) => {
246 tracing::debug!("websocket start_send error: {}", e);
247 Err(crate::Error::new(e))
248 }
249 }
250 }
251
252 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
253 match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
254 Ok(()) => Poll::Ready(Ok(())),
255 Err(e) => Poll::Ready(Err(crate::Error::new(e))),
256 }
257 }
258
259 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
260 match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
261 Ok(()) => Poll::Ready(Ok(())),
262 Err(err) => {
263 tracing::debug!("websocket close error: {}", err);
264 Poll::Ready(Err(crate::Error::new(err)))
265 }
266 }
267 }
268}
269
270impl fmt::Debug for WebSocket {
271 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272 f.debug_struct("WebSocket").finish()
273 }
274}
275
276#[derive(Eq, PartialEq, Clone)]
281pub struct Message {
282 inner: protocol::Message,
283}
284
285impl Message {
286 pub fn text<S: Into<String>>(s: S) -> Message {
288 Message {
289 inner: protocol::Message::text(s),
290 }
291 }
292
293 pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message {
295 Message {
296 inner: protocol::Message::binary(v),
297 }
298 }
299
300 pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
302 Message {
303 inner: protocol::Message::Ping(v.into()),
304 }
305 }
306
307 pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
313 Message {
314 inner: protocol::Message::Pong(v.into()),
315 }
316 }
317
318 pub fn close() -> Message {
320 Message {
321 inner: protocol::Message::Close(None),
322 }
323 }
324
325 pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message {
327 Message {
328 inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
329 code: protocol::frame::coding::CloseCode::from(code.into()),
330 reason: reason.into(),
331 })),
332 }
333 }
334
335 pub fn is_text(&self) -> bool {
337 self.inner.is_text()
338 }
339
340 pub fn is_binary(&self) -> bool {
342 self.inner.is_binary()
343 }
344
345 pub fn is_close(&self) -> bool {
347 self.inner.is_close()
348 }
349
350 pub fn is_ping(&self) -> bool {
352 self.inner.is_ping()
353 }
354
355 pub fn is_pong(&self) -> bool {
357 self.inner.is_pong()
358 }
359
360 pub fn close_frame(&self) -> Option<(u16, &str)> {
362 if let protocol::Message::Close(Some(ref close_frame)) = self.inner {
363 Some((close_frame.code.into(), close_frame.reason.as_ref()))
364 } else {
365 None
366 }
367 }
368
369 pub fn to_str(&self) -> Result<&str, ()> {
371 match self.inner {
372 protocol::Message::Text(ref s) => Ok(s),
373 _ => Err(()),
374 }
375 }
376
377 pub fn as_bytes(&self) -> &[u8] {
379 match self.inner {
380 protocol::Message::Text(ref s) => s.as_bytes(),
381 protocol::Message::Binary(ref v) => v,
382 protocol::Message::Ping(ref v) => v,
383 protocol::Message::Pong(ref v) => v,
384 protocol::Message::Close(_) => &[],
385 protocol::Message::Frame(ref frame) => frame.payload(),
386 }
387 }
388
389 pub fn into_bytes(self) -> Vec<u8> {
391 self.inner.into_data()
392 }
393}
394
395impl fmt::Debug for Message {
396 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397 fmt::Debug::fmt(&self.inner, f)
398 }
399}
400
401impl From<Message> for Vec<u8> {
402 fn from(m: Message) -> Self {
403 m.into_bytes()
404 }
405}
406
407#[derive(Debug)]
411pub struct MissingConnectionUpgrade;
412
413impl fmt::Display for MissingConnectionUpgrade {
414 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415 write!(f, "Connection header did not include 'upgrade'")
416 }
417}
418
419impl ::std::error::Error for MissingConnectionUpgrade {}