warp/
test.rs

1//! Test utilities to test your filters.
2//!
3//! [`Filter`](../trait.Filter.html)s can be easily tested without starting up an HTTP
4//! server, by making use of the [`RequestBuilder`](./struct.RequestBuilder.html) in this
5//! module.
6//!
7//! # Testing Filters
8//!
9//! It's easy to test filters, especially if smaller filters are used to build
10//! up your full set. Consider these example filters:
11//!
12//! ```
13//! use warp::Filter;
14//!
15//! fn sum() -> impl Filter<Extract = (u32,), Error = warp::Rejection> + Copy {
16//!     warp::path::param()
17//!         .and(warp::path::param())
18//!         .map(|x: u32, y: u32| {
19//!             x + y
20//!         })
21//! }
22//!
23//! fn math() -> impl Filter<Extract = (String,), Error = warp::Rejection> + Copy {
24//!     warp::post()
25//!         .and(sum())
26//!         .map(|z: u32| {
27//!             format!("Sum = {}", z)
28//!         })
29//! }
30//! ```
31//!
32//! We can test some requests against the `sum` filter like this:
33//!
34//! ```
35//! # use warp::Filter;
36//! #[tokio::test]
37//! async fn test_sum() {
38//! #    let sum = || warp::any().map(|| 3);
39//!     let filter = sum();
40//!
41//!     // Execute `sum` and get the `Extract` back.
42//!     let value = warp::test::request()
43//!         .path("/1/2")
44//!         .filter(&filter)
45//!         .await
46//!         .unwrap();
47//!     assert_eq!(value, 3);
48//!
49//!     // Or simply test if a request matches (doesn't reject).
50//!     assert!(
51//!         warp::test::request()
52//!             .path("/1/-5")
53//!             .matches(&filter)
54//!             .await
55//!     );
56//! }
57//! ```
58//!
59//! If the filter returns something that implements `Reply`, and thus can be
60//! turned into a response sent back to the client, we can test what exact
61//! response is returned. The `math` filter uses the `sum` filter, but returns
62//! a `String` that can be turned into a response.
63//!
64//! ```
65//! # use warp::Filter;
66//! #[test]
67//! fn test_math() {
68//! #    let math = || warp::any().map(warp::reply);
69//!     let filter = math();
70//!
71//!     let res = warp::test::request()
72//!         .path("/1/2")
73//!         .reply(&filter);
74//!     assert_eq!(res.status(), 405, "GET is not allowed");
75//!
76//!     let res = warp::test::request()
77//!         .method("POST")
78//!         .path("/1/2")
79//!         .reply(&filter);
80//!     assert_eq!(res.status(), 200);
81//!     assert_eq!(res.body(), "Sum is 3");
82//! }
83//! ```
84use std::convert::TryFrom;
85use std::error::Error as StdError;
86use std::fmt;
87use std::future::Future;
88use std::net::SocketAddr;
89#[cfg(feature = "websocket")]
90use std::pin::Pin;
91#[cfg(feature = "websocket")]
92use std::task::Context;
93#[cfg(feature = "websocket")]
94use std::task::{self, Poll};
95
96use bytes::Bytes;
97#[cfg(feature = "websocket")]
98use futures_channel::mpsc;
99#[cfg(feature = "websocket")]
100use futures_util::StreamExt;
101use futures_util::{future, FutureExt, TryFutureExt};
102use http::{
103    header::{HeaderName, HeaderValue},
104    Response,
105};
106use serde::Serialize;
107#[cfg(feature = "websocket")]
108use tokio::sync::oneshot;
109
110use crate::filter::Filter;
111#[cfg(feature = "websocket")]
112use crate::filters::ws::Message;
113use crate::reject::IsReject;
114use crate::reply::Reply;
115use crate::route::{self, Route};
116use crate::Request;
117#[cfg(feature = "websocket")]
118use crate::{Sink, Stream};
119
120use self::inner::OneOrTuple;
121
122/// Starts a new test `RequestBuilder`.
123pub fn request() -> RequestBuilder {
124    RequestBuilder {
125        remote_addr: None,
126        req: Request::default(),
127    }
128}
129
130/// Starts a new test `WsBuilder`.
131#[cfg(feature = "websocket")]
132pub fn ws() -> WsBuilder {
133    WsBuilder { req: request() }
134}
135
136/// A request builder for testing filters.
137///
138/// See [module documentation](crate::test) for an overview.
139#[must_use = "RequestBuilder does nothing on its own"]
140#[derive(Debug)]
141pub struct RequestBuilder {
142    remote_addr: Option<SocketAddr>,
143    req: Request,
144}
145
146/// A Websocket builder for testing filters.
147///
148/// See [module documentation](crate::test) for an overview.
149#[cfg(feature = "websocket")]
150#[must_use = "WsBuilder does nothing on its own"]
151#[derive(Debug)]
152pub struct WsBuilder {
153    req: RequestBuilder,
154}
155
156/// A test client for Websocket filters.
157#[cfg(feature = "websocket")]
158pub struct WsClient {
159    tx: mpsc::UnboundedSender<crate::ws::Message>,
160    rx: mpsc::UnboundedReceiver<Result<crate::ws::Message, crate::error::Error>>,
161}
162
163/// An error from Websocket filter tests.
164#[derive(Debug)]
165pub struct WsError {
166    cause: Box<dyn StdError + Send + Sync>,
167}
168
169impl RequestBuilder {
170    /// Sets the method of this builder.
171    ///
172    /// The default if not set is `GET`.
173    ///
174    /// # Example
175    ///
176    /// ```
177    /// let req = warp::test::request()
178    ///     .method("POST");
179    /// ```
180    ///
181    /// # Panic
182    ///
183    /// This panics if the passed string is not able to be parsed as a valid
184    /// `Method`.
185    pub fn method(mut self, method: &str) -> Self {
186        *self.req.method_mut() = method.parse().expect("valid method");
187        self
188    }
189
190    /// Sets the request path of this builder.
191    ///
192    /// The default is not set is `/`.
193    ///
194    /// # Example
195    ///
196    /// ```
197    /// let req = warp::test::request()
198    ///     .path("/todos/33");
199    /// ```
200    ///
201    /// # Panic
202    ///
203    /// This panics if the passed string is not able to be parsed as a valid
204    /// `Uri`.
205    pub fn path(mut self, p: &str) -> Self {
206        let uri = p.parse().expect("test request path invalid");
207        *self.req.uri_mut() = uri;
208        self
209    }
210
211    /// Set a header for this request.
212    ///
213    /// # Example
214    ///
215    /// ```
216    /// let req = warp::test::request()
217    ///     .header("accept", "application/json");
218    /// ```
219    ///
220    /// # Panic
221    ///
222    /// This panics if the passed strings are not able to be parsed as a valid
223    /// `HeaderName` and `HeaderValue`.
224    pub fn header<K, V>(mut self, key: K, value: V) -> Self
225    where
226        HeaderName: TryFrom<K>,
227        HeaderValue: TryFrom<V>,
228    {
229        let name: HeaderName = TryFrom::try_from(key)
230            .map_err(|_| ())
231            .expect("invalid header name");
232        let value = TryFrom::try_from(value)
233            .map_err(|_| ())
234            .expect("invalid header value");
235        self.req.headers_mut().insert(name, value);
236        self
237    }
238
239    /// Set the remote address of this request
240    ///
241    /// Default is no remote address.
242    ///
243    /// # Example
244    /// ```
245    /// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
246    ///
247    /// let req = warp::test::request()
248    ///     .remote_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
249    /// ```
250    pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
251        self.remote_addr = Some(addr);
252        self
253    }
254
255    /// Add a type to the request's `http::Extensions`.
256    pub fn extension<T>(mut self, ext: T) -> Self
257    where
258        T: Send + Sync + 'static,
259    {
260        self.req.extensions_mut().insert(ext);
261        self
262    }
263
264    /// Set the bytes of this request body.
265    ///
266    /// Default is an empty body.
267    ///
268    /// # Example
269    ///
270    /// ```
271    /// let req = warp::test::request()
272    ///     .body("foo=bar&baz=quux");
273    /// ```
274    pub fn body(mut self, body: impl AsRef<[u8]>) -> Self {
275        let body = body.as_ref().to_vec();
276        let len = body.len();
277        *self.req.body_mut() = body.into();
278        self.header("content-length", len.to_string())
279    }
280
281    /// Set the bytes of this request body by serializing a value into JSON.
282    ///
283    /// # Example
284    ///
285    /// ```
286    /// let req = warp::test::request()
287    ///     .json(&true);
288    /// ```
289    pub fn json(mut self, val: &impl Serialize) -> Self {
290        let vec = serde_json::to_vec(val).expect("json() must serialize to JSON");
291        let len = vec.len();
292        *self.req.body_mut() = vec.into();
293        self.header("content-length", len.to_string())
294            .header("content-type", "application/json")
295    }
296
297    /// Tries to apply the `Filter` on this request.
298    ///
299    /// # Example
300    ///
301    /// ```no_run
302    /// async {
303    ///     let param = warp::path::param::<u32>();
304    ///
305    ///     let ex = warp::test::request()
306    ///         .path("/41")
307    ///         .filter(&param)
308    ///         .await
309    ///         .unwrap();
310    ///
311    ///     assert_eq!(ex, 41);
312    ///
313    ///     assert!(
314    ///         warp::test::request()
315    ///             .path("/foo")
316    ///             .filter(&param)
317    ///             .await
318    ///             .is_err()
319    ///     );
320    ///};
321    /// ```
322    pub async fn filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error>
323    where
324        F: Filter,
325        F::Future: Send + 'static,
326        F::Extract: OneOrTuple + Send + 'static,
327        F::Error: Send + 'static,
328    {
329        self.apply_filter(f).await.map(|ex| ex.one_or_tuple())
330    }
331
332    /// Returns whether the `Filter` matches this request, or rejects it.
333    ///
334    /// # Example
335    ///
336    /// ```no_run
337    /// async {
338    ///     let get = warp::get();
339    ///     let post = warp::post();
340    ///
341    ///     assert!(
342    ///         warp::test::request()
343    ///             .method("GET")
344    ///             .matches(&get)
345    ///             .await
346    ///     );
347    ///
348    ///     assert!(
349    ///         !warp::test::request()
350    ///             .method("GET")
351    ///             .matches(&post)
352    ///             .await
353    ///     );
354    ///};
355    /// ```
356    pub async fn matches<F>(self, f: &F) -> bool
357    where
358        F: Filter,
359        F::Future: Send + 'static,
360        F::Extract: Send + 'static,
361        F::Error: Send + 'static,
362    {
363        self.apply_filter(f).await.is_ok()
364    }
365
366    /// Returns `Response` provided by applying the `Filter`.
367    ///
368    /// This requires that the supplied `Filter` return a [`Reply`].
369    pub async fn reply<F>(self, f: &F) -> Response<Bytes>
370    where
371        F: Filter + 'static,
372        F::Extract: Reply + Send,
373        F::Error: IsReject + Send,
374    {
375        // TODO: de-duplicate this and apply_filter()
376        assert!(!route::is_set(), "nested test filter calls");
377
378        let route = Route::new(self.req, self.remote_addr);
379        let mut fut = Box::pin(
380            route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
381                let res = match result {
382                    Ok(rep) => rep.into_response(),
383                    Err(rej) => {
384                        tracing::debug!("rejected: {:?}", rej);
385                        rej.into_response()
386                    }
387                };
388                let (parts, body) = res.into_parts();
389                hyper::body::to_bytes(body).map_ok(|chunk| Response::from_parts(parts, chunk))
390            }),
391        );
392
393        let fut = future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)));
394
395        fut.await.expect("reply shouldn't fail")
396    }
397
398    fn apply_filter<F>(self, f: &F) -> impl Future<Output = Result<F::Extract, F::Error>>
399    where
400        F: Filter,
401        F::Future: Send + 'static,
402        F::Extract: Send + 'static,
403        F::Error: Send + 'static,
404    {
405        assert!(!route::is_set(), "nested test filter calls");
406
407        let route = Route::new(self.req, self.remote_addr);
408        let mut fut = Box::pin(route::set(&route, move || {
409            f.filter(crate::filter::Internal)
410        }));
411        future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)))
412    }
413}
414
415#[cfg(feature = "websocket")]
416impl WsBuilder {
417    /// Sets the request path of this builder.
418    ///
419    /// The default is not set is `/`.
420    ///
421    /// # Example
422    ///
423    /// ```
424    /// let req = warp::test::ws()
425    ///     .path("/chat");
426    /// ```
427    ///
428    /// # Panic
429    ///
430    /// This panics if the passed string is not able to be parsed as a valid
431    /// `Uri`.
432    pub fn path(self, p: &str) -> Self {
433        WsBuilder {
434            req: self.req.path(p),
435        }
436    }
437
438    /// Set a header for this request.
439    ///
440    /// # Example
441    ///
442    /// ```
443    /// let req = warp::test::ws()
444    ///     .header("foo", "bar");
445    /// ```
446    ///
447    /// # Panic
448    ///
449    /// This panics if the passed strings are not able to be parsed as a valid
450    /// `HeaderName` and `HeaderValue`.
451    pub fn header<K, V>(self, key: K, value: V) -> Self
452    where
453        HeaderName: TryFrom<K>,
454        HeaderValue: TryFrom<V>,
455    {
456        WsBuilder {
457            req: self.req.header(key, value),
458        }
459    }
460
461    /// Execute this Websocket request against the provided filter.
462    ///
463    /// If the handshake succeeds, returns a `WsClient`.
464    ///
465    /// # Example
466    ///
467    /// ```no_run
468    /// use futures_util::future;
469    /// use warp::Filter;
470    /// #[tokio::main]
471    /// # async fn main() {
472    ///
473    /// // Some route that accepts websockets (but drops them immediately).
474    /// let route = warp::ws()
475    ///     .map(|ws: warp::ws::Ws| {
476    ///         ws.on_upgrade(|_| future::ready(()))
477    ///     });
478    ///
479    /// let client = warp::test::ws()
480    ///     .handshake(route)
481    ///     .await
482    ///     .expect("handshake");
483    /// # }
484    /// ```
485    pub async fn handshake<F>(self, f: F) -> Result<WsClient, WsError>
486    where
487        F: Filter + Clone + Send + Sync + 'static,
488        F::Extract: Reply + Send,
489        F::Error: IsReject + Send,
490    {
491        let (upgraded_tx, upgraded_rx) = oneshot::channel();
492        let (wr_tx, wr_rx) = mpsc::unbounded();
493        let (rd_tx, rd_rx) = mpsc::unbounded();
494
495        tokio::spawn(async move {
496            use tokio_tungstenite::tungstenite::protocol;
497
498            let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0));
499
500            let mut req = self
501                .req
502                .header("connection", "upgrade")
503                .header("upgrade", "websocket")
504                .header("sec-websocket-version", "13")
505                .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
506                .req;
507
508            let query_string = match req.uri().query() {
509                Some(q) => format!("?{}", q),
510                None => String::from(""),
511            };
512
513            let uri = format!("http://{}{}{}", addr, req.uri().path(), query_string)
514                .parse()
515                .expect("addr + path is valid URI");
516
517            *req.uri_mut() = uri;
518
519            // let mut rt = current_thread::Runtime::new().unwrap();
520            tokio::spawn(srv);
521
522            let upgrade = ::hyper::Client::builder()
523                .build(AddrConnect(addr))
524                .request(req)
525                .and_then(hyper::upgrade::on);
526
527            let upgraded = match upgrade.await {
528                Ok(up) => {
529                    let _ = upgraded_tx.send(Ok(()));
530                    up
531                }
532                Err(err) => {
533                    let _ = upgraded_tx.send(Err(err));
534                    return;
535                }
536            };
537            let ws = crate::ws::WebSocket::from_raw_socket(
538                upgraded,
539                protocol::Role::Client,
540                Default::default(),
541            )
542            .await;
543
544            let (tx, rx) = ws.split();
545            let write = wr_rx.map(Ok).forward(tx).map(|_| ());
546
547            let read = rx
548                .take_while(|result| match result {
549                    Err(_) => future::ready(false),
550                    Ok(m) => future::ready(!m.is_close()),
551                })
552                .for_each(move |item| {
553                    rd_tx.unbounded_send(item).expect("ws receive error");
554                    future::ready(())
555                });
556
557            future::join(write, read).await;
558        });
559
560        match upgraded_rx.await {
561            Ok(Ok(())) => Ok(WsClient {
562                tx: wr_tx,
563                rx: rd_rx,
564            }),
565            Ok(Err(err)) => Err(WsError::new(err)),
566            Err(_canceled) => panic!("websocket handshake thread panicked"),
567        }
568    }
569}
570
571#[cfg(feature = "websocket")]
572impl WsClient {
573    /// Send a "text" websocket message to the server.
574    pub async fn send_text(&mut self, text: impl Into<String>) {
575        self.send(crate::ws::Message::text(text)).await;
576    }
577
578    /// Send a websocket message to the server.
579    pub async fn send(&mut self, msg: crate::ws::Message) {
580        self.tx.unbounded_send(msg).unwrap();
581    }
582
583    /// Receive a websocket message from the server.
584    pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
585        self.rx
586            .next()
587            .await
588            .map(|result| result.map_err(WsError::new))
589            .unwrap_or_else(|| {
590                // websocket is closed
591                Err(WsError::new("closed"))
592            })
593    }
594
595    /// Assert the server has closed the connection.
596    pub async fn recv_closed(&mut self) -> Result<(), WsError> {
597        self.rx
598            .next()
599            .await
600            .map(|result| match result {
601                Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
602                Err(err) => Err(WsError::new(err)),
603            })
604            .unwrap_or_else(|| {
605                // closed successfully
606                Ok(())
607            })
608    }
609
610    fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> {
611        let this = Pin::into_inner(self);
612        Pin::new(&mut this.tx)
613    }
614}
615
616#[cfg(feature = "websocket")]
617impl fmt::Debug for WsClient {
618    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
619        f.debug_struct("WsClient").finish()
620    }
621}
622
623#[cfg(feature = "websocket")]
624impl Sink<crate::ws::Message> for WsClient {
625    type Error = WsError;
626
627    fn poll_ready(
628        self: Pin<&mut Self>,
629        context: &mut Context<'_>,
630    ) -> Poll<Result<(), Self::Error>> {
631        self.pinned_tx().poll_ready(context).map_err(WsError::new)
632    }
633
634    fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
635        self.pinned_tx().start_send(message).map_err(WsError::new)
636    }
637
638    fn poll_flush(
639        self: Pin<&mut Self>,
640        context: &mut Context<'_>,
641    ) -> Poll<Result<(), Self::Error>> {
642        self.pinned_tx().poll_flush(context).map_err(WsError::new)
643    }
644
645    fn poll_close(
646        self: Pin<&mut Self>,
647        context: &mut Context<'_>,
648    ) -> Poll<Result<(), Self::Error>> {
649        self.pinned_tx().poll_close(context).map_err(WsError::new)
650    }
651}
652
653#[cfg(feature = "websocket")]
654impl Stream for WsClient {
655    type Item = Result<crate::ws::Message, WsError>;
656
657    fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
658        let this = Pin::into_inner(self);
659        let rx = Pin::new(&mut this.rx);
660        match rx.poll_next(context) {
661            Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))),
662            Poll::Ready(None) => Poll::Ready(None),
663            Poll::Pending => Poll::Pending,
664        }
665    }
666}
667
668// ===== impl WsError =====
669
670#[cfg(feature = "websocket")]
671impl WsError {
672    fn new<E: Into<Box<dyn StdError + Send + Sync>>>(cause: E) -> Self {
673        WsError {
674            cause: cause.into(),
675        }
676    }
677}
678
679impl fmt::Display for WsError {
680    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681        write!(f, "websocket error: {}", self.cause)
682    }
683}
684
685impl StdError for WsError {
686    fn description(&self) -> &str {
687        "websocket error"
688    }
689}
690
691// ===== impl AddrConnect =====
692
693#[cfg(feature = "websocket")]
694#[derive(Clone)]
695struct AddrConnect(SocketAddr);
696
697#[cfg(feature = "websocket")]
698impl tower_service::Service<::http::Uri> for AddrConnect {
699    type Response = ::tokio::net::TcpStream;
700    type Error = ::std::io::Error;
701    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
702
703    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
704        Poll::Ready(Ok(()))
705    }
706
707    fn call(&mut self, _: ::http::Uri) -> Self::Future {
708        Box::pin(tokio::net::TcpStream::connect(self.0))
709    }
710}
711
712mod inner {
713    pub trait OneOrTuple {
714        type Output;
715
716        fn one_or_tuple(self) -> Self::Output;
717    }
718
719    impl OneOrTuple for () {
720        type Output = ();
721        fn one_or_tuple(self) -> Self::Output {}
722    }
723
724    macro_rules! one_or_tuple {
725        ($type1:ident) => {
726            impl<$type1> OneOrTuple for ($type1,) {
727                type Output = $type1;
728                fn one_or_tuple(self) -> Self::Output {
729                    self.0
730                }
731            }
732        };
733        ($type1:ident, $( $type:ident ),*) => {
734            one_or_tuple!($( $type ),*);
735
736            impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) {
737                type Output = Self;
738                fn one_or_tuple(self) -> Self::Output {
739                    self
740                }
741            }
742        }
743    }
744
745    one_or_tuple! {
746        T1,
747        T2,
748        T3,
749        T4,
750        T5,
751        T6,
752        T7,
753        T8,
754        T9,
755        T10,
756        T11,
757        T12,
758        T13,
759        T14,
760        T15,
761        T16
762    }
763}