warp/
server.rs

1#[cfg(feature = "tls")]
2use crate::tls::TlsConfigBuilder;
3use std::convert::Infallible;
4use std::error::Error as StdError;
5use std::future::Future;
6use std::net::SocketAddr;
7#[cfg(feature = "tls")]
8use std::path::Path;
9
10use futures_util::{future, FutureExt, TryFuture, TryStream, TryStreamExt};
11use hyper::server::conn::AddrIncoming;
12use hyper::service::{make_service_fn, service_fn};
13use hyper::Server as HyperServer;
14use tokio::io::{AsyncRead, AsyncWrite};
15use tracing::Instrument;
16
17use crate::filter::Filter;
18use crate::reject::IsReject;
19use crate::reply::Reply;
20use crate::transport::Transport;
21
22/// Create a `Server` with the provided `Filter`.
23pub fn serve<F>(filter: F) -> Server<F>
24where
25    F: Filter + Clone + Send + Sync + 'static,
26    F::Extract: Reply,
27    F::Error: IsReject,
28{
29    Server {
30        pipeline: false,
31        filter,
32    }
33}
34
35/// A Warp Server ready to filter requests.
36#[derive(Debug)]
37pub struct Server<F> {
38    pipeline: bool,
39    filter: F,
40}
41
42/// A Warp Server ready to filter requests over TLS.
43///
44/// *This type requires the `"tls"` feature.*
45#[cfg(feature = "tls")]
46pub struct TlsServer<F> {
47    server: Server<F>,
48    tls: TlsConfigBuilder,
49}
50
51// Getting all various generic bounds to make this a re-usable method is
52// very complicated, so instead this is just a macro.
53macro_rules! into_service {
54    ($into:expr) => {{
55        let inner = crate::service($into);
56        make_service_fn(move |transport| {
57            let inner = inner.clone();
58            let remote_addr = Transport::remote_addr(transport);
59            future::ok::<_, Infallible>(service_fn(move |req| {
60                inner.call_with_addr(req, remote_addr)
61            }))
62        })
63    }};
64}
65
66macro_rules! addr_incoming {
67    ($addr:expr) => {{
68        let mut incoming = AddrIncoming::bind($addr)?;
69        incoming.set_nodelay(true);
70        let addr = incoming.local_addr();
71        (addr, incoming)
72    }};
73}
74
75macro_rules! bind_inner {
76    ($this:ident, $addr:expr) => {{
77        let service = into_service!($this.filter);
78        let (addr, incoming) = addr_incoming!($addr);
79        let srv = HyperServer::builder(incoming)
80            .http1_pipeline_flush($this.pipeline)
81            .serve(service);
82        Ok::<_, hyper::Error>((addr, srv))
83    }};
84
85    (tls: $this:ident, $addr:expr) => {{
86        let service = into_service!($this.server.filter);
87        let (addr, incoming) = addr_incoming!($addr);
88        let tls = $this.tls.build()?;
89        let srv = HyperServer::builder(crate::tls::TlsAcceptor::new(tls, incoming))
90            .http1_pipeline_flush($this.server.pipeline)
91            .serve(service);
92        Ok::<_, Box<dyn std::error::Error + Send + Sync>>((addr, srv))
93    }};
94}
95
96macro_rules! bind {
97    ($this:ident, $addr:expr) => {{
98        let addr = $addr.into();
99        (|addr| bind_inner!($this, addr))(&addr).unwrap_or_else(|e| {
100            panic!("error binding to {}: {}", addr, e);
101        })
102    }};
103
104    (tls: $this:ident, $addr:expr) => {{
105        let addr = $addr.into();
106        (|addr| bind_inner!(tls: $this, addr))(&addr).unwrap_or_else(|e| {
107            panic!("error binding to {}: {}", addr, e);
108        })
109    }};
110}
111
112macro_rules! try_bind {
113    ($this:ident, $addr:expr) => {{
114        (|addr| bind_inner!($this, addr))($addr)
115    }};
116
117    (tls: $this:ident, $addr:expr) => {{
118        (|addr| bind_inner!(tls: $this, addr))($addr)
119    }};
120}
121
122// ===== impl Server =====
123
124impl<F> Server<F>
125where
126    F: Filter + Clone + Send + Sync + 'static,
127    <F::Future as TryFuture>::Ok: Reply,
128    <F::Future as TryFuture>::Error: IsReject,
129{
130    /// Run this `Server` forever on the current thread.
131    ///
132    /// # Panics
133    ///
134    /// Panics if we are unable to bind to the provided address.
135    pub async fn run(self, addr: impl Into<SocketAddr>) {
136        let (addr, fut) = self.bind_ephemeral(addr);
137        let span = tracing::info_span!("Server::run", ?addr);
138        tracing::info!(parent: &span, "listening on http://{}", addr);
139
140        fut.instrument(span).await;
141    }
142
143    /// Run this `Server` forever on the current thread with a specific stream
144    /// of incoming connections.
145    ///
146    /// This can be used for Unix Domain Sockets, or TLS, etc.
147    pub async fn run_incoming<I>(self, incoming: I)
148    where
149        I: TryStream + Send,
150        I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
151        I::Error: Into<Box<dyn StdError + Send + Sync>>,
152    {
153        self.run_incoming2(incoming.map_ok(crate::transport::LiftIo).into_stream())
154            .instrument(tracing::info_span!("Server::run_incoming"))
155            .await;
156    }
157
158    async fn run_incoming2<I>(self, incoming: I)
159    where
160        I: TryStream + Send,
161        I::Ok: Transport + Send + 'static + Unpin,
162        I::Error: Into<Box<dyn StdError + Send + Sync>>,
163    {
164        let fut = self.serve_incoming2(incoming);
165
166        tracing::info!("listening with custom incoming");
167
168        fut.await;
169    }
170
171    /// Bind to a socket address, returning a `Future` that can be
172    /// executed on the current runtime.
173    ///
174    /// # Panics
175    ///
176    /// Panics if we are unable to bind to the provided address.
177    pub fn bind(self, addr: impl Into<SocketAddr> + 'static) -> impl Future<Output = ()> + 'static {
178        let (_, fut) = self.bind_ephemeral(addr);
179        fut
180    }
181
182    /// Bind to a socket address, returning a `Future` that can be
183    /// executed on any runtime.
184    ///
185    /// In case we are unable to bind to the specified address, resolves to an
186    /// error and logs the reason.
187    pub async fn try_bind(self, addr: impl Into<SocketAddr>) {
188        let addr = addr.into();
189        let srv = match try_bind!(self, &addr) {
190            Ok((_, srv)) => srv,
191            Err(err) => {
192                tracing::error!("error binding to {}: {}", addr, err);
193                return;
194            }
195        };
196
197        srv.map(|result| {
198            if let Err(err) = result {
199                tracing::error!("server error: {}", err)
200            }
201        })
202        .await;
203    }
204
205    /// Bind to a possibly ephemeral socket address.
206    ///
207    /// Returns the bound address and a `Future` that can be executed on
208    /// the current runtime.
209    ///
210    /// # Panics
211    ///
212    /// Panics if we are unable to bind to the provided address.
213    pub fn bind_ephemeral(
214        self,
215        addr: impl Into<SocketAddr>,
216    ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
217        let (addr, srv) = bind!(self, addr);
218        let srv = srv.map(|result| {
219            if let Err(err) = result {
220                tracing::error!("server error: {}", err)
221            }
222        });
223
224        (addr, srv)
225    }
226
227    /// Tried to bind a possibly ephemeral socket address.
228    ///
229    /// Returns a `Result` which fails in case we are unable to bind with the
230    /// underlying error.
231    ///
232    /// Returns the bound address and a `Future` that can be executed on
233    /// the current runtime.
234    pub fn try_bind_ephemeral(
235        self,
236        addr: impl Into<SocketAddr>,
237    ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error> {
238        let addr = addr.into();
239        let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?;
240        let srv = srv.map(|result| {
241            if let Err(err) = result {
242                tracing::error!("server error: {}", err)
243            }
244        });
245
246        Ok((addr, srv))
247    }
248
249    /// Create a server with graceful shutdown signal.
250    ///
251    /// When the signal completes, the server will start the graceful shutdown
252    /// process.
253    ///
254    /// Returns the bound address and a `Future` that can be executed on
255    /// the current runtime.
256    ///
257    /// # Example
258    ///
259    /// ```no_run
260    /// use warp::Filter;
261    /// use futures_util::future::TryFutureExt;
262    /// use tokio::sync::oneshot;
263    ///
264    /// # fn main() {
265    /// let routes = warp::any()
266    ///     .map(|| "Hello, World!");
267    ///
268    /// let (tx, rx) = oneshot::channel();
269    ///
270    /// let (addr, server) = warp::serve(routes)
271    ///     .bind_with_graceful_shutdown(([127, 0, 0, 1], 3030), async {
272    ///          rx.await.ok();
273    ///     });
274    ///
275    /// // Spawn the server into a runtime
276    /// tokio::task::spawn(server);
277    ///
278    /// // Later, start the shutdown...
279    /// let _ = tx.send(());
280    /// # }
281    /// ```
282    ///
283    /// # Panics
284    ///
285    /// Panics if we are unable to bind to the provided address.
286    pub fn bind_with_graceful_shutdown(
287        self,
288        addr: impl Into<SocketAddr> + 'static,
289        signal: impl Future<Output = ()> + Send + 'static,
290    ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
291        let (addr, srv) = bind!(self, addr);
292        let fut = srv.with_graceful_shutdown(signal).map(|result| {
293            if let Err(err) = result {
294                tracing::error!("server error: {}", err)
295            }
296        });
297        (addr, fut)
298    }
299
300    /// Create a server with graceful shutdown signal.
301    ///
302    /// When the signal completes, the server will start the graceful shutdown
303    /// process.
304    pub fn try_bind_with_graceful_shutdown(
305        self,
306        addr: impl Into<SocketAddr> + 'static,
307        signal: impl Future<Output = ()> + Send + 'static,
308    ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error> {
309        let addr = addr.into();
310        let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?;
311        let srv = srv.with_graceful_shutdown(signal).map(|result| {
312            if let Err(err) = result {
313                tracing::error!("server error: {}", err)
314            }
315        });
316
317        Ok((addr, srv))
318    }
319
320    /// Setup this `Server` with a specific stream of incoming connections.
321    ///
322    /// This can be used for Unix Domain Sockets, or TLS, etc.
323    ///
324    /// Returns a `Future` that can be executed on the current runtime.
325    pub fn serve_incoming<I>(self, incoming: I) -> impl Future<Output = ()>
326    where
327        I: TryStream + Send,
328        I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
329        I::Error: Into<Box<dyn StdError + Send + Sync>>,
330    {
331        let incoming = incoming.map_ok(crate::transport::LiftIo);
332        self.serve_incoming2(incoming)
333            .instrument(tracing::info_span!("Server::serve_incoming"))
334    }
335
336    /// Setup this `Server` with a specific stream of incoming connections and a
337    /// signal to initiate graceful shutdown.
338    ///
339    /// This can be used for Unix Domain Sockets, or TLS, etc.
340    ///
341    /// When the signal completes, the server will start the graceful shutdown
342    /// process.
343    ///
344    /// Returns a `Future` that can be executed on the current runtime.
345    pub fn serve_incoming_with_graceful_shutdown<I>(
346        self,
347        incoming: I,
348        signal: impl Future<Output = ()> + Send + 'static,
349    ) -> impl Future<Output = ()>
350    where
351        I: TryStream + Send,
352        I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
353        I::Error: Into<Box<dyn StdError + Send + Sync>>,
354    {
355        let incoming = incoming.map_ok(crate::transport::LiftIo);
356        let service = into_service!(self.filter);
357        let pipeline = self.pipeline;
358
359        async move {
360            let srv =
361                HyperServer::builder(hyper::server::accept::from_stream(incoming.into_stream()))
362                    .http1_pipeline_flush(pipeline)
363                    .serve(service)
364                    .with_graceful_shutdown(signal)
365                    .await;
366
367            if let Err(err) = srv {
368                tracing::error!("server error: {}", err);
369            }
370        }
371        .instrument(tracing::info_span!(
372            "Server::serve_incoming_with_graceful_shutdown"
373        ))
374    }
375
376    async fn serve_incoming2<I>(self, incoming: I)
377    where
378        I: TryStream + Send,
379        I::Ok: Transport + Send + 'static + Unpin,
380        I::Error: Into<Box<dyn StdError + Send + Sync>>,
381    {
382        let service = into_service!(self.filter);
383
384        let srv = HyperServer::builder(hyper::server::accept::from_stream(incoming.into_stream()))
385            .http1_pipeline_flush(self.pipeline)
386            .serve(service)
387            .await;
388
389        if let Err(err) = srv {
390            tracing::error!("server error: {}", err);
391        }
392    }
393
394    // Generally shouldn't be used, as it can slow down non-pipelined responses.
395    //
396    // It's only real use is to make silly pipeline benchmarks look better.
397    #[doc(hidden)]
398    pub fn unstable_pipeline(mut self) -> Self {
399        self.pipeline = true;
400        self
401    }
402
403    /// Configure a server to use TLS.
404    ///
405    /// *This function requires the `"tls"` feature.*
406    #[cfg(feature = "tls")]
407    pub fn tls(self) -> TlsServer<F> {
408        TlsServer {
409            server: self,
410            tls: TlsConfigBuilder::new(),
411        }
412    }
413}
414
415// // ===== impl TlsServer =====
416
417#[cfg(feature = "tls")]
418impl<F> TlsServer<F>
419where
420    F: Filter + Clone + Send + Sync + 'static,
421    <F::Future as TryFuture>::Ok: Reply,
422    <F::Future as TryFuture>::Error: IsReject,
423{
424    // TLS config methods
425
426    /// Specify the file path to read the private key.
427    ///
428    /// *This function requires the `"tls"` feature.*
429    pub fn key_path(self, path: impl AsRef<Path>) -> Self {
430        self.with_tls(|tls| tls.key_path(path))
431    }
432
433    /// Specify the file path to read the certificate.
434    ///
435    /// *This function requires the `"tls"` feature.*
436    pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
437        self.with_tls(|tls| tls.cert_path(path))
438    }
439
440    /// Specify the file path to read the trust anchor for optional client authentication.
441    ///
442    /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
443    /// of the `client_auth_` methods, then client authentication is disabled by default.
444    ///
445    /// *This function requires the `"tls"` feature.*
446    pub fn client_auth_optional_path(self, path: impl AsRef<Path>) -> Self {
447        self.with_tls(|tls| tls.client_auth_optional_path(path))
448    }
449
450    /// Specify the file path to read the trust anchor for required client authentication.
451    ///
452    /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
453    /// `client_auth_` methods, then client authentication is disabled by default.
454    ///
455    /// *This function requires the `"tls"` feature.*
456    pub fn client_auth_required_path(self, path: impl AsRef<Path>) -> Self {
457        self.with_tls(|tls| tls.client_auth_required_path(path))
458    }
459
460    /// Specify the in-memory contents of the private key.
461    ///
462    /// *This function requires the `"tls"` feature.*
463    pub fn key(self, key: impl AsRef<[u8]>) -> Self {
464        self.with_tls(|tls| tls.key(key.as_ref()))
465    }
466
467    /// Specify the in-memory contents of the certificate.
468    ///
469    /// *This function requires the `"tls"` feature.*
470    pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
471        self.with_tls(|tls| tls.cert(cert.as_ref()))
472    }
473
474    /// Specify the in-memory contents of the trust anchor for optional client authentication.
475    ///
476    /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
477    /// of the `client_auth_` methods, then client authentication is disabled by default.
478    ///
479    /// *This function requires the `"tls"` feature.*
480    pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self {
481        self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref()))
482    }
483
484    /// Specify the in-memory contents of the trust anchor for required client authentication.
485    ///
486    /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
487    /// `client_auth_` methods, then client authentication is disabled by default.
488    ///
489    /// *This function requires the `"tls"` feature.*
490    pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self {
491        self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref()))
492    }
493
494    /// Specify the DER-encoded OCSP response.
495    ///
496    /// *This function requires the `"tls"` feature.*
497    pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self {
498        self.with_tls(|tls| tls.ocsp_resp(resp.as_ref()))
499    }
500
501    fn with_tls<Func>(self, func: Func) -> Self
502    where
503        Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
504    {
505        let TlsServer { server, tls } = self;
506        let tls = func(tls);
507        TlsServer { server, tls }
508    }
509
510    // Server run methods
511
512    /// Run this `TlsServer` forever on the current thread.
513    ///
514    /// *This function requires the `"tls"` feature.*
515    pub async fn run(self, addr: impl Into<SocketAddr>) {
516        let (addr, fut) = self.bind_ephemeral(addr);
517        let span = tracing::info_span!("TlsServer::run", %addr);
518        tracing::info!(parent: &span, "listening on https://{}", addr);
519
520        fut.instrument(span).await;
521    }
522
523    /// Bind to a socket address, returning a `Future` that can be
524    /// executed on a runtime.
525    ///
526    /// *This function requires the `"tls"` feature.*
527    ///
528    /// # Panics
529    ///
530    /// Panics if we are unable to bind to the provided address.
531    pub async fn bind(self, addr: impl Into<SocketAddr>) {
532        let (_, fut) = self.bind_ephemeral(addr);
533        fut.await;
534    }
535
536    /// Bind to a possibly ephemeral socket address.
537    ///
538    /// Returns the bound address and a `Future` that can be executed on
539    /// the current runtime.
540    ///
541    /// *This function requires the `"tls"` feature.*
542    ///
543    /// # Panics
544    ///
545    /// Panics if we are unable to bind to the provided address.
546    pub fn bind_ephemeral(
547        self,
548        addr: impl Into<SocketAddr>,
549    ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
550        let (addr, srv) = bind!(tls: self, addr);
551        let srv = srv.map(|result| {
552            if let Err(err) = result {
553                tracing::error!("server error: {}", err)
554            }
555        });
556
557        (addr, srv)
558    }
559
560    /// Create a server with graceful shutdown signal.
561    ///
562    /// When the signal completes, the server will start the graceful shutdown
563    /// process.
564    ///
565    /// *This function requires the `"tls"` feature.*
566    ///
567    /// # Panics
568    ///
569    /// Panics if we are unable to bind to the provided address.
570    pub fn bind_with_graceful_shutdown(
571        self,
572        addr: impl Into<SocketAddr> + 'static,
573        signal: impl Future<Output = ()> + Send + 'static,
574    ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
575        let (addr, srv) = bind!(tls: self, addr);
576
577        let fut = srv.with_graceful_shutdown(signal).map(|result| {
578            if let Err(err) = result {
579                tracing::error!("server error: {}", err)
580            }
581        });
582        (addr, fut)
583    }
584
585    /// Create a server with graceful shutdown signal.
586    ///
587    /// When the signal completes, the server will start the graceful shutdown
588    /// process.
589    ///
590    /// *This function requires the `"tls"` feature.*
591    pub fn try_bind_with_graceful_shutdown(
592        self,
593        addr: impl Into<SocketAddr> + 'static,
594        signal: impl Future<Output = ()> + Send + 'static,
595    ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error> {
596        let addr = addr.into();
597        let (addr, srv) = try_bind!(tls: self, &addr).map_err(crate::Error::new)?;
598        let srv = srv.with_graceful_shutdown(signal).map(|result| {
599            if let Err(err) = result {
600                tracing::error!("server error: {}", err)
601            }
602        });
603
604        Ok((addr, srv))
605    }
606}
607
608#[cfg(feature = "tls")]
609impl<F> ::std::fmt::Debug for TlsServer<F>
610where
611    F: ::std::fmt::Debug,
612{
613    fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
614        f.debug_struct("TlsServer")
615            .field("server", &self.server)
616            .finish()
617    }
618}