warp/filters/
cors.rs

1//! CORS Filters
2
3use std::collections::HashSet;
4use std::convert::TryFrom;
5use std::error::Error as StdError;
6use std::fmt;
7use std::sync::Arc;
8
9use headers::{
10    AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMapExt,
11};
12use http::header::{self, HeaderName, HeaderValue};
13
14use crate::filter::{Filter, WrapSealed};
15use crate::reject::{CombineRejection, Rejection};
16use crate::reply::Reply;
17
18use self::internal::{CorsFilter, IntoOrigin, Seconds};
19
20/// Create a wrapping [`Filter`](crate::Filter) that exposes [CORS][] behavior for a wrapped
21/// filter.
22///
23/// [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
24///
25/// # Example
26///
27/// ```
28/// use warp::Filter;
29///
30/// let cors = warp::cors()
31///     .allow_origin("https://hyper.rs")
32///     .allow_methods(vec!["GET", "POST", "DELETE"]);
33///
34/// let route = warp::any()
35///     .map(warp::reply)
36///     .with(cors);
37/// ```
38/// If you want to allow any route:
39/// ```
40/// use warp::Filter;
41/// let cors = warp::cors()
42///     .allow_any_origin();
43/// ```
44/// You can find more usage examples [here](https://github.com/seanmonstar/warp/blob/7fa54eaecd0fe12687137372791ff22fc7995766/tests/cors.rs).
45pub fn cors() -> Builder {
46    Builder {
47        credentials: false,
48        allowed_headers: HashSet::new(),
49        exposed_headers: HashSet::new(),
50        max_age: None,
51        methods: HashSet::new(),
52        origins: None,
53    }
54}
55
56/// A wrapping [`Filter`](crate::Filter) constructed via `warp::cors()`.
57#[derive(Clone, Debug)]
58pub struct Cors {
59    config: Arc<Configured>,
60}
61
62/// A constructed via `warp::cors()`.
63#[derive(Clone, Debug)]
64pub struct Builder {
65    credentials: bool,
66    allowed_headers: HashSet<HeaderName>,
67    exposed_headers: HashSet<HeaderName>,
68    max_age: Option<u64>,
69    methods: HashSet<http::Method>,
70    origins: Option<HashSet<HeaderValue>>,
71}
72
73impl Builder {
74    /// Sets whether to add the `Access-Control-Allow-Credentials` header.
75    pub fn allow_credentials(mut self, allow: bool) -> Self {
76        self.credentials = allow;
77        self
78    }
79
80    /// Adds a method to the existing list of allowed request methods.
81    ///
82    /// # Panics
83    ///
84    /// Panics if the provided argument is not a valid `http::Method`.
85    pub fn allow_method<M>(mut self, method: M) -> Self
86    where
87        http::Method: TryFrom<M>,
88    {
89        let method = match TryFrom::try_from(method) {
90            Ok(m) => m,
91            Err(_) => panic!("illegal Method"),
92        };
93        self.methods.insert(method);
94        self
95    }
96
97    /// Adds multiple methods to the existing list of allowed request methods.
98    ///
99    /// # Panics
100    ///
101    /// Panics if the provided argument is not a valid `http::Method`.
102    pub fn allow_methods<I>(mut self, methods: I) -> Self
103    where
104        I: IntoIterator,
105        http::Method: TryFrom<I::Item>,
106    {
107        let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
108            Ok(m) => m,
109            Err(_) => panic!("illegal Method"),
110        });
111        self.methods.extend(iter);
112        self
113    }
114
115    /// Adds a header to the list of allowed request headers.
116    ///
117    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g. `content-type`.
118    ///
119    /// # Panics
120    ///
121    /// Panics if the provided argument is not a valid `http::header::HeaderName`.
122    pub fn allow_header<H>(mut self, header: H) -> Self
123    where
124        HeaderName: TryFrom<H>,
125    {
126        let header = match TryFrom::try_from(header) {
127            Ok(m) => m,
128            Err(_) => panic!("illegal Header"),
129        };
130        self.allowed_headers.insert(header);
131        self
132    }
133
134    /// Adds multiple headers to the list of allowed request headers.
135    ///
136    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
137    ///
138    /// # Panics
139    ///
140    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
141    pub fn allow_headers<I>(mut self, headers: I) -> Self
142    where
143        I: IntoIterator,
144        HeaderName: TryFrom<I::Item>,
145    {
146        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
147            Ok(h) => h,
148            Err(_) => panic!("illegal Header"),
149        });
150        self.allowed_headers.extend(iter);
151        self
152    }
153
154    /// Adds a header to the list of exposed headers.
155    ///
156    /// # Panics
157    ///
158    /// Panics if the provided argument is not a valid `http::header::HeaderName`.
159    pub fn expose_header<H>(mut self, header: H) -> Self
160    where
161        HeaderName: TryFrom<H>,
162    {
163        let header = match TryFrom::try_from(header) {
164            Ok(m) => m,
165            Err(_) => panic!("illegal Header"),
166        };
167        self.exposed_headers.insert(header);
168        self
169    }
170
171    /// Adds multiple headers to the list of exposed headers.
172    ///
173    /// # Panics
174    ///
175    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
176    pub fn expose_headers<I>(mut self, headers: I) -> Self
177    where
178        I: IntoIterator,
179        HeaderName: TryFrom<I::Item>,
180    {
181        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
182            Ok(h) => h,
183            Err(_) => panic!("illegal Header"),
184        });
185        self.exposed_headers.extend(iter);
186        self
187    }
188
189    /// Sets that *any* `Origin` header is allowed.
190    ///
191    /// # Warning
192    ///
193    /// This can allow websites you didn't intend to access this resource,
194    /// it is usually better to set an explicit list.
195    pub fn allow_any_origin(mut self) -> Self {
196        self.origins = None;
197        self
198    }
199
200    /// Add an origin to the existing list of allowed `Origin`s.
201    ///
202    /// # Panics
203    ///
204    /// Panics if the provided argument is not a valid `Origin`.
205    pub fn allow_origin(self, origin: impl IntoOrigin) -> Self {
206        self.allow_origins(Some(origin))
207    }
208
209    /// Add multiple origins to the existing list of allowed `Origin`s.
210    ///
211    /// # Panics
212    ///
213    /// Panics if the provided argument is not a valid `Origin`.
214    pub fn allow_origins<I>(mut self, origins: I) -> Self
215    where
216        I: IntoIterator,
217        I::Item: IntoOrigin,
218    {
219        let iter = origins
220            .into_iter()
221            .map(IntoOrigin::into_origin)
222            .map(|origin| {
223                origin
224                    .to_string()
225                    .parse()
226                    .expect("Origin is always a valid HeaderValue")
227            });
228
229        self.origins.get_or_insert_with(HashSet::new).extend(iter);
230
231        self
232    }
233
234    /// Sets the `Access-Control-Max-Age` header.
235    ///
236    /// # Example
237    ///
238    ///
239    /// ```
240    /// use std::time::Duration;
241    /// use warp::Filter;
242    ///
243    /// let cors = warp::cors()
244    ///     .max_age(30) // 30u32 seconds
245    ///     .max_age(Duration::from_secs(30)); // or a Duration
246    /// ```
247    pub fn max_age(mut self, seconds: impl Seconds) -> Self {
248        self.max_age = Some(seconds.seconds());
249        self
250    }
251
252    /// Builds the `Cors` wrapper from the configured settings.
253    ///
254    /// This step isn't *required*, as the `Builder` itself can be passed
255    /// to `Filter::with`. This just allows constructing once, thus not needing
256    /// to pay the cost of "building" every time.
257    pub fn build(self) -> Cors {
258        let expose_headers_header = if self.exposed_headers.is_empty() {
259            None
260        } else {
261            Some(self.exposed_headers.iter().cloned().collect())
262        };
263        let allowed_headers_header = self.allowed_headers.iter().cloned().collect();
264        let methods_header = self.methods.iter().cloned().collect();
265
266        let config = Arc::new(Configured {
267            cors: self,
268            allowed_headers_header,
269            expose_headers_header,
270            methods_header,
271        });
272
273        Cors { config }
274    }
275}
276
277impl<F> WrapSealed<F> for Builder
278where
279    F: Filter + Clone + Send + Sync + 'static,
280    F::Extract: Reply,
281    F::Error: CombineRejection<Rejection>,
282    <F::Error as CombineRejection<Rejection>>::One: CombineRejection<Rejection>,
283{
284    type Wrapped = CorsFilter<F>;
285
286    fn wrap(&self, inner: F) -> Self::Wrapped {
287        let Cors { config } = self.clone().build();
288
289        CorsFilter { config, inner }
290    }
291}
292
293impl<F> WrapSealed<F> for Cors
294where
295    F: Filter + Clone + Send + Sync + 'static,
296    F::Extract: Reply,
297    F::Error: CombineRejection<Rejection>,
298    <F::Error as CombineRejection<Rejection>>::One: CombineRejection<Rejection>,
299{
300    type Wrapped = CorsFilter<F>;
301
302    fn wrap(&self, inner: F) -> Self::Wrapped {
303        let config = self.config.clone();
304
305        CorsFilter { config, inner }
306    }
307}
308
309/// An error used to reject requests that are forbidden by a `cors` filter.
310pub struct CorsForbidden {
311    kind: Forbidden,
312}
313
314#[derive(Debug)]
315enum Forbidden {
316    OriginNotAllowed,
317    MethodNotAllowed,
318    HeaderNotAllowed,
319}
320
321impl fmt::Debug for CorsForbidden {
322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323        f.debug_tuple("CorsForbidden").field(&self.kind).finish()
324    }
325}
326
327impl fmt::Display for CorsForbidden {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        let detail = match self.kind {
330            Forbidden::OriginNotAllowed => "origin not allowed",
331            Forbidden::MethodNotAllowed => "request-method not allowed",
332            Forbidden::HeaderNotAllowed => "header not allowed",
333        };
334        write!(f, "CORS request forbidden: {}", detail)
335    }
336}
337
338impl StdError for CorsForbidden {}
339
340#[derive(Clone, Debug)]
341struct Configured {
342    cors: Builder,
343    allowed_headers_header: AccessControlAllowHeaders,
344    expose_headers_header: Option<AccessControlExposeHeaders>,
345    methods_header: AccessControlAllowMethods,
346}
347
348enum Validated {
349    Preflight(HeaderValue),
350    Simple(HeaderValue),
351    NotCors,
352}
353
354impl Configured {
355    fn check_request(
356        &self,
357        method: &http::Method,
358        headers: &http::HeaderMap,
359    ) -> Result<Validated, Forbidden> {
360        match (headers.get(header::ORIGIN), method) {
361            (Some(origin), &http::Method::OPTIONS) => {
362                // OPTIONS requests are preflight CORS requests...
363
364                if !self.is_origin_allowed(origin) {
365                    return Err(Forbidden::OriginNotAllowed);
366                }
367
368                if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
369                    if !self.is_method_allowed(req_method) {
370                        return Err(Forbidden::MethodNotAllowed);
371                    }
372                } else {
373                    tracing::trace!(
374                        "preflight request missing access-control-request-method header"
375                    );
376                    return Err(Forbidden::MethodNotAllowed);
377                }
378
379                if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
380                    let headers = req_headers
381                        .to_str()
382                        .map_err(|_| Forbidden::HeaderNotAllowed)?;
383                    for header in headers.split(',') {
384                        if !self.is_header_allowed(header.trim()) {
385                            return Err(Forbidden::HeaderNotAllowed);
386                        }
387                    }
388                }
389
390                Ok(Validated::Preflight(origin.clone()))
391            }
392            (Some(origin), _) => {
393                // Any other method, simply check for a valid origin...
394
395                tracing::trace!("origin header: {:?}", origin);
396                if self.is_origin_allowed(origin) {
397                    Ok(Validated::Simple(origin.clone()))
398                } else {
399                    Err(Forbidden::OriginNotAllowed)
400                }
401            }
402            (None, _) => {
403                // No `ORIGIN` header means this isn't CORS!
404                Ok(Validated::NotCors)
405            }
406        }
407    }
408
409    fn is_method_allowed(&self, header: &HeaderValue) -> bool {
410        http::Method::from_bytes(header.as_bytes())
411            .map(|method| self.cors.methods.contains(&method))
412            .unwrap_or(false)
413    }
414
415    fn is_header_allowed(&self, header: &str) -> bool {
416        HeaderName::from_bytes(header.as_bytes())
417            .map(|header| self.cors.allowed_headers.contains(&header))
418            .unwrap_or(false)
419    }
420
421    fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
422        if let Some(ref allowed) = self.cors.origins {
423            allowed.contains(origin)
424        } else {
425            true
426        }
427    }
428
429    fn append_preflight_headers(&self, headers: &mut http::HeaderMap) {
430        self.append_common_headers(headers);
431
432        headers.typed_insert(self.allowed_headers_header.clone());
433        headers.typed_insert(self.methods_header.clone());
434
435        if let Some(max_age) = self.cors.max_age {
436            headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
437        }
438    }
439
440    fn append_common_headers(&self, headers: &mut http::HeaderMap) {
441        if self.cors.credentials {
442            headers.insert(
443                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
444                HeaderValue::from_static("true"),
445            );
446        }
447        if let Some(expose_headers_header) = &self.expose_headers_header {
448            headers.typed_insert(expose_headers_header.clone())
449        }
450    }
451}
452
453mod internal {
454    use std::future::Future;
455    use std::pin::Pin;
456    use std::sync::Arc;
457    use std::task::{Context, Poll};
458
459    use futures_util::{future, ready, TryFuture};
460    use headers::Origin;
461    use http::header;
462    use pin_project::pin_project;
463
464    use super::{Configured, CorsForbidden, Validated};
465    use crate::filter::{Filter, FilterBase, Internal, One};
466    use crate::generic::Either;
467    use crate::reject::{CombineRejection, Rejection};
468    use crate::route;
469
470    #[derive(Clone, Debug)]
471    pub struct CorsFilter<F> {
472        pub(super) config: Arc<Configured>,
473        pub(super) inner: F,
474    }
475
476    impl<F> FilterBase for CorsFilter<F>
477    where
478        F: Filter,
479        F::Extract: Send,
480        F::Future: Future,
481        F::Error: CombineRejection<Rejection>,
482    {
483        type Extract =
484            One<Either<One<Preflight>, One<Either<One<Wrapped<F::Extract>>, F::Extract>>>>;
485        type Error = <F::Error as CombineRejection<Rejection>>::One;
486        type Future = future::Either<
487            future::Ready<Result<Self::Extract, Self::Error>>,
488            WrappedFuture<F::Future>,
489        >;
490
491        fn filter(&self, _: Internal) -> Self::Future {
492            let validated =
493                route::with(|route| self.config.check_request(route.method(), route.headers()));
494
495            match validated {
496                Ok(Validated::Preflight(origin)) => {
497                    let preflight = Preflight {
498                        config: self.config.clone(),
499                        origin,
500                    };
501                    future::Either::Left(future::ok((Either::A((preflight,)),)))
502                }
503                Ok(Validated::Simple(origin)) => future::Either::Right(WrappedFuture {
504                    inner: self.inner.filter(Internal),
505                    wrapped: Some((self.config.clone(), origin)),
506                }),
507                Ok(Validated::NotCors) => future::Either::Right(WrappedFuture {
508                    inner: self.inner.filter(Internal),
509                    wrapped: None,
510                }),
511                Err(err) => {
512                    let rejection = crate::reject::known(CorsForbidden { kind: err });
513                    future::Either::Left(future::err(rejection.into()))
514                }
515            }
516        }
517    }
518
519    #[derive(Debug)]
520    pub struct Preflight {
521        config: Arc<Configured>,
522        origin: header::HeaderValue,
523    }
524
525    impl crate::reply::Reply for Preflight {
526        fn into_response(self) -> crate::reply::Response {
527            let mut res = crate::reply::Response::default();
528            self.config.append_preflight_headers(res.headers_mut());
529            res.headers_mut()
530                .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, self.origin);
531            res
532        }
533    }
534
535    #[derive(Debug)]
536    pub struct Wrapped<R> {
537        config: Arc<Configured>,
538        inner: R,
539        origin: header::HeaderValue,
540    }
541
542    impl<R> crate::reply::Reply for Wrapped<R>
543    where
544        R: crate::reply::Reply,
545    {
546        fn into_response(self) -> crate::reply::Response {
547            let mut res = self.inner.into_response();
548            self.config.append_common_headers(res.headers_mut());
549            res.headers_mut()
550                .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, self.origin);
551            res
552        }
553    }
554
555    #[pin_project]
556    #[derive(Debug)]
557    pub struct WrappedFuture<F> {
558        #[pin]
559        inner: F,
560        wrapped: Option<(Arc<Configured>, header::HeaderValue)>,
561    }
562
563    impl<F> Future for WrappedFuture<F>
564    where
565        F: TryFuture,
566        F::Error: CombineRejection<Rejection>,
567    {
568        type Output = Result<
569            One<Either<One<Preflight>, One<Either<One<Wrapped<F::Ok>>, F::Ok>>>>,
570            <F::Error as CombineRejection<Rejection>>::One,
571        >;
572
573        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
574            let pin = self.project();
575            match ready!(pin.inner.try_poll(cx)) {
576                Ok(inner) => {
577                    let item = if let Some((config, origin)) = pin.wrapped.take() {
578                        (Either::A((Wrapped {
579                            config,
580                            inner,
581                            origin,
582                        },)),)
583                    } else {
584                        (Either::B(inner),)
585                    };
586                    let item = (Either::B(item),);
587                    Poll::Ready(Ok(item))
588                }
589                Err(err) => Poll::Ready(Err(err.into())),
590            }
591        }
592    }
593
594    pub trait Seconds {
595        fn seconds(self) -> u64;
596    }
597
598    impl Seconds for u32 {
599        fn seconds(self) -> u64 {
600            self.into()
601        }
602    }
603
604    impl Seconds for ::std::time::Duration {
605        fn seconds(self) -> u64 {
606            self.as_secs()
607        }
608    }
609
610    pub trait IntoOrigin {
611        fn into_origin(self) -> Origin;
612    }
613
614    impl<'a> IntoOrigin for &'a str {
615        fn into_origin(self) -> Origin {
616            let mut parts = self.splitn(2, "://");
617            let scheme = parts.next().expect("missing scheme");
618            let rest = parts.next().expect("missing scheme");
619
620            Origin::try_from_parts(scheme, rest, None).expect("invalid Origin")
621        }
622    }
623}