1use std::convert::TryFrom;
85use std::error::Error as StdError;
86use std::fmt;
87use std::future::Future;
88use std::net::SocketAddr;
89#[cfg(feature = "websocket")]
90use std::pin::Pin;
91#[cfg(feature = "websocket")]
92use std::task::Context;
93#[cfg(feature = "websocket")]
94use std::task::{self, Poll};
95
96use bytes::Bytes;
97#[cfg(feature = "websocket")]
98use futures_channel::mpsc;
99#[cfg(feature = "websocket")]
100use futures_util::StreamExt;
101use futures_util::{future, FutureExt, TryFutureExt};
102use http::{
103 header::{HeaderName, HeaderValue},
104 Response,
105};
106use serde::Serialize;
107#[cfg(feature = "websocket")]
108use tokio::sync::oneshot;
109
110use crate::filter::Filter;
111#[cfg(feature = "websocket")]
112use crate::filters::ws::Message;
113use crate::reject::IsReject;
114use crate::reply::Reply;
115use crate::route::{self, Route};
116use crate::Request;
117#[cfg(feature = "websocket")]
118use crate::{Sink, Stream};
119
120use self::inner::OneOrTuple;
121
122pub fn request() -> RequestBuilder {
124 RequestBuilder {
125 remote_addr: None,
126 req: Request::default(),
127 }
128}
129
130#[cfg(feature = "websocket")]
132pub fn ws() -> WsBuilder {
133 WsBuilder { req: request() }
134}
135
136#[must_use = "RequestBuilder does nothing on its own"]
140#[derive(Debug)]
141pub struct RequestBuilder {
142 remote_addr: Option<SocketAddr>,
143 req: Request,
144}
145
146#[cfg(feature = "websocket")]
150#[must_use = "WsBuilder does nothing on its own"]
151#[derive(Debug)]
152pub struct WsBuilder {
153 req: RequestBuilder,
154}
155
156#[cfg(feature = "websocket")]
158pub struct WsClient {
159 tx: mpsc::UnboundedSender<crate::ws::Message>,
160 rx: mpsc::UnboundedReceiver<Result<crate::ws::Message, crate::error::Error>>,
161}
162
163#[derive(Debug)]
165pub struct WsError {
166 cause: Box<dyn StdError + Send + Sync>,
167}
168
169impl RequestBuilder {
170 pub fn method(mut self, method: &str) -> Self {
186 *self.req.method_mut() = method.parse().expect("valid method");
187 self
188 }
189
190 pub fn path(mut self, p: &str) -> Self {
206 let uri = p.parse().expect("test request path invalid");
207 *self.req.uri_mut() = uri;
208 self
209 }
210
211 pub fn header<K, V>(mut self, key: K, value: V) -> Self
225 where
226 HeaderName: TryFrom<K>,
227 HeaderValue: TryFrom<V>,
228 {
229 let name: HeaderName = TryFrom::try_from(key)
230 .map_err(|_| ())
231 .expect("invalid header name");
232 let value = TryFrom::try_from(value)
233 .map_err(|_| ())
234 .expect("invalid header value");
235 self.req.headers_mut().insert(name, value);
236 self
237 }
238
239 pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
251 self.remote_addr = Some(addr);
252 self
253 }
254
255 pub fn extension<T>(mut self, ext: T) -> Self
257 where
258 T: Send + Sync + 'static,
259 {
260 self.req.extensions_mut().insert(ext);
261 self
262 }
263
264 pub fn body(mut self, body: impl AsRef<[u8]>) -> Self {
275 let body = body.as_ref().to_vec();
276 let len = body.len();
277 *self.req.body_mut() = body.into();
278 self.header("content-length", len.to_string())
279 }
280
281 pub fn json(mut self, val: &impl Serialize) -> Self {
290 let vec = serde_json::to_vec(val).expect("json() must serialize to JSON");
291 let len = vec.len();
292 *self.req.body_mut() = vec.into();
293 self.header("content-length", len.to_string())
294 .header("content-type", "application/json")
295 }
296
297 pub async fn filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error>
323 where
324 F: Filter,
325 F::Future: Send + 'static,
326 F::Extract: OneOrTuple + Send + 'static,
327 F::Error: Send + 'static,
328 {
329 self.apply_filter(f).await.map(|ex| ex.one_or_tuple())
330 }
331
332 pub async fn matches<F>(self, f: &F) -> bool
357 where
358 F: Filter,
359 F::Future: Send + 'static,
360 F::Extract: Send + 'static,
361 F::Error: Send + 'static,
362 {
363 self.apply_filter(f).await.is_ok()
364 }
365
366 pub async fn reply<F>(self, f: &F) -> Response<Bytes>
370 where
371 F: Filter + 'static,
372 F::Extract: Reply + Send,
373 F::Error: IsReject + Send,
374 {
375 assert!(!route::is_set(), "nested test filter calls");
377
378 let route = Route::new(self.req, self.remote_addr);
379 let mut fut = Box::pin(
380 route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
381 let res = match result {
382 Ok(rep) => rep.into_response(),
383 Err(rej) => {
384 tracing::debug!("rejected: {:?}", rej);
385 rej.into_response()
386 }
387 };
388 let (parts, body) = res.into_parts();
389 hyper::body::to_bytes(body).map_ok(|chunk| Response::from_parts(parts, chunk))
390 }),
391 );
392
393 let fut = future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)));
394
395 fut.await.expect("reply shouldn't fail")
396 }
397
398 fn apply_filter<F>(self, f: &F) -> impl Future<Output = Result<F::Extract, F::Error>>
399 where
400 F: Filter,
401 F::Future: Send + 'static,
402 F::Extract: Send + 'static,
403 F::Error: Send + 'static,
404 {
405 assert!(!route::is_set(), "nested test filter calls");
406
407 let route = Route::new(self.req, self.remote_addr);
408 let mut fut = Box::pin(route::set(&route, move || {
409 f.filter(crate::filter::Internal)
410 }));
411 future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)))
412 }
413}
414
415#[cfg(feature = "websocket")]
416impl WsBuilder {
417 pub fn path(self, p: &str) -> Self {
433 WsBuilder {
434 req: self.req.path(p),
435 }
436 }
437
438 pub fn header<K, V>(self, key: K, value: V) -> Self
452 where
453 HeaderName: TryFrom<K>,
454 HeaderValue: TryFrom<V>,
455 {
456 WsBuilder {
457 req: self.req.header(key, value),
458 }
459 }
460
461 pub async fn handshake<F>(self, f: F) -> Result<WsClient, WsError>
486 where
487 F: Filter + Clone + Send + Sync + 'static,
488 F::Extract: Reply + Send,
489 F::Error: IsReject + Send,
490 {
491 let (upgraded_tx, upgraded_rx) = oneshot::channel();
492 let (wr_tx, wr_rx) = mpsc::unbounded();
493 let (rd_tx, rd_rx) = mpsc::unbounded();
494
495 tokio::spawn(async move {
496 use tokio_tungstenite::tungstenite::protocol;
497
498 let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0));
499
500 let mut req = self
501 .req
502 .header("connection", "upgrade")
503 .header("upgrade", "websocket")
504 .header("sec-websocket-version", "13")
505 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
506 .req;
507
508 let query_string = match req.uri().query() {
509 Some(q) => format!("?{}", q),
510 None => String::from(""),
511 };
512
513 let uri = format!("http://{}{}{}", addr, req.uri().path(), query_string)
514 .parse()
515 .expect("addr + path is valid URI");
516
517 *req.uri_mut() = uri;
518
519 tokio::spawn(srv);
521
522 let upgrade = ::hyper::Client::builder()
523 .build(AddrConnect(addr))
524 .request(req)
525 .and_then(hyper::upgrade::on);
526
527 let upgraded = match upgrade.await {
528 Ok(up) => {
529 let _ = upgraded_tx.send(Ok(()));
530 up
531 }
532 Err(err) => {
533 let _ = upgraded_tx.send(Err(err));
534 return;
535 }
536 };
537 let ws = crate::ws::WebSocket::from_raw_socket(
538 upgraded,
539 protocol::Role::Client,
540 Default::default(),
541 )
542 .await;
543
544 let (tx, rx) = ws.split();
545 let write = wr_rx.map(Ok).forward(tx).map(|_| ());
546
547 let read = rx
548 .take_while(|result| match result {
549 Err(_) => future::ready(false),
550 Ok(m) => future::ready(!m.is_close()),
551 })
552 .for_each(move |item| {
553 rd_tx.unbounded_send(item).expect("ws receive error");
554 future::ready(())
555 });
556
557 future::join(write, read).await;
558 });
559
560 match upgraded_rx.await {
561 Ok(Ok(())) => Ok(WsClient {
562 tx: wr_tx,
563 rx: rd_rx,
564 }),
565 Ok(Err(err)) => Err(WsError::new(err)),
566 Err(_canceled) => panic!("websocket handshake thread panicked"),
567 }
568 }
569}
570
571#[cfg(feature = "websocket")]
572impl WsClient {
573 pub async fn send_text(&mut self, text: impl Into<String>) {
575 self.send(crate::ws::Message::text(text)).await;
576 }
577
578 pub async fn send(&mut self, msg: crate::ws::Message) {
580 self.tx.unbounded_send(msg).unwrap();
581 }
582
583 pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
585 self.rx
586 .next()
587 .await
588 .map(|result| result.map_err(WsError::new))
589 .unwrap_or_else(|| {
590 Err(WsError::new("closed"))
592 })
593 }
594
595 pub async fn recv_closed(&mut self) -> Result<(), WsError> {
597 self.rx
598 .next()
599 .await
600 .map(|result| match result {
601 Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
602 Err(err) => Err(WsError::new(err)),
603 })
604 .unwrap_or_else(|| {
605 Ok(())
607 })
608 }
609
610 fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> {
611 let this = Pin::into_inner(self);
612 Pin::new(&mut this.tx)
613 }
614}
615
616#[cfg(feature = "websocket")]
617impl fmt::Debug for WsClient {
618 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
619 f.debug_struct("WsClient").finish()
620 }
621}
622
623#[cfg(feature = "websocket")]
624impl Sink<crate::ws::Message> for WsClient {
625 type Error = WsError;
626
627 fn poll_ready(
628 self: Pin<&mut Self>,
629 context: &mut Context<'_>,
630 ) -> Poll<Result<(), Self::Error>> {
631 self.pinned_tx().poll_ready(context).map_err(WsError::new)
632 }
633
634 fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
635 self.pinned_tx().start_send(message).map_err(WsError::new)
636 }
637
638 fn poll_flush(
639 self: Pin<&mut Self>,
640 context: &mut Context<'_>,
641 ) -> Poll<Result<(), Self::Error>> {
642 self.pinned_tx().poll_flush(context).map_err(WsError::new)
643 }
644
645 fn poll_close(
646 self: Pin<&mut Self>,
647 context: &mut Context<'_>,
648 ) -> Poll<Result<(), Self::Error>> {
649 self.pinned_tx().poll_close(context).map_err(WsError::new)
650 }
651}
652
653#[cfg(feature = "websocket")]
654impl Stream for WsClient {
655 type Item = Result<crate::ws::Message, WsError>;
656
657 fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
658 let this = Pin::into_inner(self);
659 let rx = Pin::new(&mut this.rx);
660 match rx.poll_next(context) {
661 Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))),
662 Poll::Ready(None) => Poll::Ready(None),
663 Poll::Pending => Poll::Pending,
664 }
665 }
666}
667
668#[cfg(feature = "websocket")]
671impl WsError {
672 fn new<E: Into<Box<dyn StdError + Send + Sync>>>(cause: E) -> Self {
673 WsError {
674 cause: cause.into(),
675 }
676 }
677}
678
679impl fmt::Display for WsError {
680 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681 write!(f, "websocket error: {}", self.cause)
682 }
683}
684
685impl StdError for WsError {
686 fn description(&self) -> &str {
687 "websocket error"
688 }
689}
690
691#[cfg(feature = "websocket")]
694#[derive(Clone)]
695struct AddrConnect(SocketAddr);
696
697#[cfg(feature = "websocket")]
698impl tower_service::Service<::http::Uri> for AddrConnect {
699 type Response = ::tokio::net::TcpStream;
700 type Error = ::std::io::Error;
701 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
702
703 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
704 Poll::Ready(Ok(()))
705 }
706
707 fn call(&mut self, _: ::http::Uri) -> Self::Future {
708 Box::pin(tokio::net::TcpStream::connect(self.0))
709 }
710}
711
712mod inner {
713 pub trait OneOrTuple {
714 type Output;
715
716 fn one_or_tuple(self) -> Self::Output;
717 }
718
719 impl OneOrTuple for () {
720 type Output = ();
721 fn one_or_tuple(self) -> Self::Output {}
722 }
723
724 macro_rules! one_or_tuple {
725 ($type1:ident) => {
726 impl<$type1> OneOrTuple for ($type1,) {
727 type Output = $type1;
728 fn one_or_tuple(self) -> Self::Output {
729 self.0
730 }
731 }
732 };
733 ($type1:ident, $( $type:ident ),*) => {
734 one_or_tuple!($( $type ),*);
735
736 impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) {
737 type Output = Self;
738 fn one_or_tuple(self) -> Self::Output {
739 self
740 }
741 }
742 }
743 }
744
745 one_or_tuple! {
746 T1,
747 T2,
748 T3,
749 T4,
750 T5,
751 T6,
752 T7,
753 T8,
754 T9,
755 T10,
756 T11,
757 T12,
758 T13,
759 T14,
760 T15,
761 T16
762 }
763}