warp/filters/
sse.rs

1//! Server-Sent Events (SSE)
2//!
3//! # Example
4//!
5//! ```
6//!
7//! use std::time::Duration;
8//! use std::convert::Infallible;
9//! use warp::{Filter, sse::Event};
10//! use futures_util::{stream::iter, Stream};
11//!
12//! fn sse_events() -> impl Stream<Item = Result<Event, Infallible>> {
13//!     iter(vec![
14//!         Ok(Event::default().data("unnamed event")),
15//!         Ok(
16//!             Event::default().event("chat")
17//!             .data("chat message")
18//!         ),
19//!         Ok(
20//!             Event::default().id(13.to_string())
21//!             .event("chat")
22//!             .data("other chat message\nwith next line")
23//!             .retry(Duration::from_millis(5000))
24//!         )
25//!     ])
26//! }
27//!
28//! let app = warp::path("push-notifications")
29//!     .and(warp::get())
30//!     .map(|| {
31//!         warp::sse::reply(warp::sse::keep_alive().stream(sse_events()))
32//!     });
33//! ```
34//!
35//! Each field already is event which can be sent to client.
36//! The events with multiple fields can be created by combining fields using tuples.
37//!
38//! See also the [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API,
39//! which specifies the expected behavior of Server Sent Events.
40//!
41
42#![allow(rustdoc::invalid_html_tags)]
43
44use serde::Serialize;
45use std::borrow::Cow;
46use std::error::Error as StdError;
47use std::fmt::{self, Write};
48use std::future::Future;
49use std::pin::Pin;
50use std::str::FromStr;
51use std::task::{Context, Poll};
52use std::time::Duration;
53
54use futures_util::{future, Stream, TryStream, TryStreamExt};
55use http::header::{HeaderValue, CACHE_CONTROL, CONTENT_TYPE};
56use hyper::Body;
57use pin_project::pin_project;
58use serde_json::Error;
59use tokio::time::{self, Sleep};
60
61use self::sealed::SseError;
62use super::header;
63use crate::filter::One;
64use crate::reply::Response;
65use crate::{Filter, Rejection, Reply};
66
67// Server-sent event data type
68#[derive(Debug)]
69enum DataType {
70    Text(String),
71    Json(String),
72}
73
74/// Server-sent event
75#[derive(Default, Debug)]
76pub struct Event {
77    id: Option<String>,
78    data: Option<DataType>,
79    event: Option<String>,
80    comment: Option<String>,
81    retry: Option<Duration>,
82}
83
84impl Event {
85    /// Set Server-sent event data
86    /// data field(s) ("data:<content>")
87    pub fn data<T: Into<String>>(mut self, data: T) -> Event {
88        self.data = Some(DataType::Text(data.into()));
89        self
90    }
91
92    /// Set Server-sent event data
93    /// data field(s) ("data:<content>")
94    pub fn json_data<T: Serialize>(mut self, data: T) -> Result<Event, Error> {
95        self.data = Some(DataType::Json(serde_json::to_string(&data)?));
96        Ok(self)
97    }
98
99    /// Set Server-sent event comment
100    /// Comment field (":<comment-text>")
101    pub fn comment<T: Into<String>>(mut self, comment: T) -> Event {
102        self.comment = Some(comment.into());
103        self
104    }
105
106    /// Set Server-sent event event
107    /// Event name field ("event:<event-name>")
108    pub fn event<T: Into<String>>(mut self, event: T) -> Event {
109        self.event = Some(event.into());
110        self
111    }
112
113    /// Set Server-sent event retry
114    /// Retry timeout field ("retry:<timeout>")
115    pub fn retry(mut self, duration: Duration) -> Event {
116        self.retry = Some(duration);
117        self
118    }
119
120    /// Set Server-sent event id
121    /// Identifier field ("id:<identifier>")
122    pub fn id<T: Into<String>>(mut self, id: T) -> Event {
123        self.id = Some(id.into());
124        self
125    }
126}
127
128impl fmt::Display for Event {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        if let Some(ref comment) = &self.comment {
131            ":".fmt(f)?;
132            comment.fmt(f)?;
133            f.write_char('\n')?;
134        }
135
136        if let Some(ref event) = &self.event {
137            "event:".fmt(f)?;
138            event.fmt(f)?;
139            f.write_char('\n')?;
140        }
141
142        match self.data {
143            Some(DataType::Text(ref data)) => {
144                for line in data.split('\n') {
145                    "data:".fmt(f)?;
146                    line.fmt(f)?;
147                    f.write_char('\n')?;
148                }
149            }
150            Some(DataType::Json(ref data)) => {
151                "data:".fmt(f)?;
152                data.fmt(f)?;
153                f.write_char('\n')?;
154            }
155            None => {}
156        }
157
158        if let Some(ref id) = &self.id {
159            "id:".fmt(f)?;
160            id.fmt(f)?;
161            f.write_char('\n')?;
162        }
163
164        if let Some(ref duration) = &self.retry {
165            "retry:".fmt(f)?;
166
167            let secs = duration.as_secs();
168            let millis = duration.subsec_millis();
169
170            if secs > 0 {
171                // format seconds
172                secs.fmt(f)?;
173
174                // pad milliseconds
175                if millis < 10 {
176                    f.write_str("00")?;
177                } else if millis < 100 {
178                    f.write_char('0')?;
179                }
180            }
181
182            // format milliseconds
183            millis.fmt(f)?;
184
185            f.write_char('\n')?;
186        }
187
188        f.write_char('\n')?;
189        Ok(())
190    }
191}
192
193/// Gets the optional last event id from request.
194/// Typically this identifier represented as number or string.
195///
196/// ```
197/// let app = warp::sse::last_event_id::<u32>();
198///
199/// // The identifier is present
200/// async {
201///     assert_eq!(
202///         warp::test::request()
203///            .header("Last-Event-ID", "12")
204///            .filter(&app)
205///            .await
206///            .unwrap(),
207///         Some(12)
208///     );
209///
210///     // The identifier is missing
211///     assert_eq!(
212///        warp::test::request()
213///            .filter(&app)
214///            .await
215///            .unwrap(),
216///         None
217///     );
218///
219///     // The identifier is not a valid
220///     assert!(
221///        warp::test::request()
222///            .header("Last-Event-ID", "abc")
223///            .filter(&app)
224///            .await
225///            .is_err(),
226///     );
227///};
228/// ```
229pub fn last_event_id<T>() -> impl Filter<Extract = One<Option<T>>, Error = Rejection> + Copy
230where
231    T: FromStr + Send + Sync + 'static,
232{
233    header::optional("last-event-id")
234}
235
236/// Server-sent events reply
237///
238/// This function converts stream of server events into a `Reply` with:
239///
240/// - Status of `200 OK`
241/// - Header `content-type: text/event-stream`
242/// - Header `cache-control: no-cache`.
243///
244/// # Example
245///
246/// ```
247///
248/// use std::time::Duration;
249/// use futures_util::Stream;
250/// use futures_util::stream::iter;
251/// use std::convert::Infallible;
252/// use warp::{Filter, sse::Event};
253/// use serde_derive::Serialize;
254///
255/// #[derive(Serialize)]
256/// struct Msg {
257///     from: u32,
258///     text: String,
259/// }
260///
261/// fn event_stream() -> impl Stream<Item = Result<Event, Infallible>> {
262///         iter(vec![
263///             // Unnamed event with data only
264///             Ok(Event::default().data("payload")),
265///             // Named event with ID and retry timeout
266///             Ok(
267///                 Event::default().data("other message\nwith next line")
268///                 .event("chat")
269///                 .id(1.to_string())
270///                 .retry(Duration::from_millis(15000))
271///             ),
272///             // Event with JSON data
273///             Ok(
274///                 Event::default().id(2.to_string())
275///                 .json_data(Msg {
276///                     from: 2,
277///                     text: "hello".into(),
278///                 }).unwrap(),
279///             )
280///         ])
281/// }
282///
283/// async {
284///     let app = warp::path("sse").and(warp::get()).map(|| {
285///        warp::sse::reply(event_stream())
286///     });
287///
288///     let res = warp::test::request()
289///         .method("GET")
290///         .header("Connection", "Keep-Alive")
291///         .path("/sse")
292///         .reply(&app)
293///         .await
294///         .into_body();
295///
296///     assert_eq!(
297///         res,
298///         r#"data:payload
299///
300/// event:chat
301/// data:other message
302/// data:with next line
303/// id:1
304/// retry:15000
305///
306/// data:{"from":2,"text":"hello"}
307/// id:2
308///
309/// "#
310///     );
311/// };
312/// ```
313pub fn reply<S>(event_stream: S) -> impl Reply
314where
315    S: TryStream<Ok = Event> + Send + 'static,
316    S::Error: StdError + Send + Sync + 'static,
317{
318    SseReply { event_stream }
319}
320
321#[allow(missing_debug_implementations)]
322struct SseReply<S> {
323    event_stream: S,
324}
325
326impl<S> Reply for SseReply<S>
327where
328    S: TryStream<Ok = Event> + Send + 'static,
329    S::Error: StdError + Send + Sync + 'static,
330{
331    #[inline]
332    fn into_response(self) -> Response {
333        let body_stream = self
334            .event_stream
335            .map_err(|error| {
336                // FIXME: error logging
337                log::error!("sse stream error: {}", error);
338                SseError
339            })
340            .into_stream()
341            .and_then(|event| future::ready(Ok(event.to_string())));
342
343        let mut res = Response::new(Body::wrap_stream(body_stream));
344        // Set appropriate content type
345        res.headers_mut()
346            .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
347        // Disable response body caching
348        res.headers_mut()
349            .insert(CACHE_CONTROL, HeaderValue::from_static("no-cache"));
350        res
351    }
352}
353
354/// Configure the interval between keep-alive messages, the content
355/// of each message, and the associated stream.
356#[derive(Debug)]
357pub struct KeepAlive {
358    comment_text: Cow<'static, str>,
359    max_interval: Duration,
360}
361
362impl KeepAlive {
363    /// Customize the interval between keep-alive messages.
364    ///
365    /// Default is 15 seconds.
366    pub fn interval(mut self, time: Duration) -> Self {
367        self.max_interval = time;
368        self
369    }
370
371    /// Customize the text of the keep-alive message.
372    ///
373    /// Default is an empty comment.
374    pub fn text(mut self, text: impl Into<Cow<'static, str>>) -> Self {
375        self.comment_text = text.into();
376        self
377    }
378
379    /// Wrap an event stream with keep-alive functionality.
380    ///
381    /// See [`keep_alive`] for more.
382    pub fn stream<S>(
383        self,
384        event_stream: S,
385    ) -> impl TryStream<Ok = Event, Error = impl StdError + Send + Sync + 'static> + Send + 'static
386    where
387        S: TryStream<Ok = Event> + Send + 'static,
388        S::Error: StdError + Send + Sync + 'static,
389    {
390        let alive_timer = time::sleep(self.max_interval);
391        SseKeepAlive {
392            event_stream,
393            comment_text: self.comment_text,
394            max_interval: self.max_interval,
395            alive_timer,
396        }
397    }
398}
399
400#[allow(missing_debug_implementations)]
401#[pin_project]
402struct SseKeepAlive<S> {
403    #[pin]
404    event_stream: S,
405    comment_text: Cow<'static, str>,
406    max_interval: Duration,
407    #[pin]
408    alive_timer: Sleep,
409}
410
411/// Keeps event source connection alive when no events sent over a some time.
412///
413/// Some proxy servers may drop HTTP connection after a some timeout of inactivity.
414/// This function helps to prevent such behavior by sending comment events every
415/// `keep_interval` of inactivity.
416///
417/// By default the comment is `:` (an empty comment) and the time interval between
418/// events is 15 seconds. Both may be customized using the builder pattern
419/// as shown below.
420///
421/// ```
422/// use std::time::Duration;
423/// use std::convert::Infallible;
424/// use futures_util::StreamExt;
425/// use tokio::time::interval;
426/// use tokio_stream::wrappers::IntervalStream;
427/// use warp::{Filter, Stream, sse::Event};
428///
429/// // create server-sent event
430/// fn sse_counter(counter: u64) ->  Result<Event, Infallible> {
431///     Ok(Event::default().data(counter.to_string()))
432/// }
433///
434/// fn main() {
435///     let routes = warp::path("ticks")
436///         .and(warp::get())
437///         .map(|| {
438///             let mut counter: u64 = 0;
439///             let interval = interval(Duration::from_secs(15));
440///             let stream = IntervalStream::new(interval);
441///             let event_stream = stream.map(move |_| {
442///                 counter += 1;
443///                 sse_counter(counter)
444///             });
445///             // reply using server-sent events
446///             let stream = warp::sse::keep_alive()
447///                 .interval(Duration::from_secs(5))
448///                 .text("thump".to_string())
449///                 .stream(event_stream);
450///             warp::sse::reply(stream)
451///         });
452/// }
453/// ```
454///
455/// See [notes](https://www.w3.org/TR/2009/WD-eventsource-20090421/#notes).
456pub fn keep_alive() -> KeepAlive {
457    KeepAlive {
458        comment_text: Cow::Borrowed(""),
459        max_interval: Duration::from_secs(15),
460    }
461}
462
463impl<S> Stream for SseKeepAlive<S>
464where
465    S: TryStream<Ok = Event> + Send + 'static,
466    S::Error: StdError + Send + Sync + 'static,
467{
468    type Item = Result<Event, SseError>;
469
470    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
471        let mut pin = self.project();
472        match pin.event_stream.try_poll_next(cx) {
473            Poll::Pending => match Pin::new(&mut pin.alive_timer).poll(cx) {
474                Poll::Pending => Poll::Pending,
475                Poll::Ready(_) => {
476                    // restart timer
477                    pin.alive_timer
478                        .reset(tokio::time::Instant::now() + *pin.max_interval);
479                    let comment_str = pin.comment_text.clone();
480                    let event = Event::default().comment(comment_str);
481                    Poll::Ready(Some(Ok(event)))
482                }
483            },
484            Poll::Ready(Some(Ok(event))) => {
485                // restart timer
486                pin.alive_timer
487                    .reset(tokio::time::Instant::now() + *pin.max_interval);
488                Poll::Ready(Some(Ok(event)))
489            }
490            Poll::Ready(None) => Poll::Ready(None),
491            Poll::Ready(Some(Err(error))) => {
492                log::error!("sse::keep error: {}", error);
493                Poll::Ready(Some(Err(SseError)))
494            }
495        }
496    }
497}
498
499mod sealed {
500    use super::*;
501
502    /// SSE error type
503    #[derive(Debug)]
504    pub struct SseError;
505
506    impl fmt::Display for SseError {
507        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
508            write!(f, "sse error")
509        }
510    }
511
512    impl StdError for SseError {}
513}