1use std::collections::HashSet;
4use std::convert::TryFrom;
5use std::error::Error as StdError;
6use std::fmt;
7use std::sync::Arc;
8
9use headers::{
10 AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMapExt,
11};
12use http::header::{self, HeaderName, HeaderValue};
13
14use crate::filter::{Filter, WrapSealed};
15use crate::reject::{CombineRejection, Rejection};
16use crate::reply::Reply;
17
18use self::internal::{CorsFilter, IntoOrigin, Seconds};
19
20pub fn cors() -> Builder {
46 Builder {
47 credentials: false,
48 allowed_headers: HashSet::new(),
49 exposed_headers: HashSet::new(),
50 max_age: None,
51 methods: HashSet::new(),
52 origins: None,
53 }
54}
55
56#[derive(Clone, Debug)]
58pub struct Cors {
59 config: Arc<Configured>,
60}
61
62#[derive(Clone, Debug)]
64pub struct Builder {
65 credentials: bool,
66 allowed_headers: HashSet<HeaderName>,
67 exposed_headers: HashSet<HeaderName>,
68 max_age: Option<u64>,
69 methods: HashSet<http::Method>,
70 origins: Option<HashSet<HeaderValue>>,
71}
72
73impl Builder {
74 pub fn allow_credentials(mut self, allow: bool) -> Self {
76 self.credentials = allow;
77 self
78 }
79
80 pub fn allow_method<M>(mut self, method: M) -> Self
86 where
87 http::Method: TryFrom<M>,
88 {
89 let method = match TryFrom::try_from(method) {
90 Ok(m) => m,
91 Err(_) => panic!("illegal Method"),
92 };
93 self.methods.insert(method);
94 self
95 }
96
97 pub fn allow_methods<I>(mut self, methods: I) -> Self
103 where
104 I: IntoIterator,
105 http::Method: TryFrom<I::Item>,
106 {
107 let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
108 Ok(m) => m,
109 Err(_) => panic!("illegal Method"),
110 });
111 self.methods.extend(iter);
112 self
113 }
114
115 pub fn allow_header<H>(mut self, header: H) -> Self
123 where
124 HeaderName: TryFrom<H>,
125 {
126 let header = match TryFrom::try_from(header) {
127 Ok(m) => m,
128 Err(_) => panic!("illegal Header"),
129 };
130 self.allowed_headers.insert(header);
131 self
132 }
133
134 pub fn allow_headers<I>(mut self, headers: I) -> Self
142 where
143 I: IntoIterator,
144 HeaderName: TryFrom<I::Item>,
145 {
146 let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
147 Ok(h) => h,
148 Err(_) => panic!("illegal Header"),
149 });
150 self.allowed_headers.extend(iter);
151 self
152 }
153
154 pub fn expose_header<H>(mut self, header: H) -> Self
160 where
161 HeaderName: TryFrom<H>,
162 {
163 let header = match TryFrom::try_from(header) {
164 Ok(m) => m,
165 Err(_) => panic!("illegal Header"),
166 };
167 self.exposed_headers.insert(header);
168 self
169 }
170
171 pub fn expose_headers<I>(mut self, headers: I) -> Self
177 where
178 I: IntoIterator,
179 HeaderName: TryFrom<I::Item>,
180 {
181 let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
182 Ok(h) => h,
183 Err(_) => panic!("illegal Header"),
184 });
185 self.exposed_headers.extend(iter);
186 self
187 }
188
189 pub fn allow_any_origin(mut self) -> Self {
196 self.origins = None;
197 self
198 }
199
200 pub fn allow_origin(self, origin: impl IntoOrigin) -> Self {
206 self.allow_origins(Some(origin))
207 }
208
209 pub fn allow_origins<I>(mut self, origins: I) -> Self
215 where
216 I: IntoIterator,
217 I::Item: IntoOrigin,
218 {
219 let iter = origins
220 .into_iter()
221 .map(IntoOrigin::into_origin)
222 .map(|origin| {
223 origin
224 .to_string()
225 .parse()
226 .expect("Origin is always a valid HeaderValue")
227 });
228
229 self.origins.get_or_insert_with(HashSet::new).extend(iter);
230
231 self
232 }
233
234 pub fn max_age(mut self, seconds: impl Seconds) -> Self {
248 self.max_age = Some(seconds.seconds());
249 self
250 }
251
252 pub fn build(self) -> Cors {
258 let expose_headers_header = if self.exposed_headers.is_empty() {
259 None
260 } else {
261 Some(self.exposed_headers.iter().cloned().collect())
262 };
263 let allowed_headers_header = self.allowed_headers.iter().cloned().collect();
264 let methods_header = self.methods.iter().cloned().collect();
265
266 let config = Arc::new(Configured {
267 cors: self,
268 allowed_headers_header,
269 expose_headers_header,
270 methods_header,
271 });
272
273 Cors { config }
274 }
275}
276
277impl<F> WrapSealed<F> for Builder
278where
279 F: Filter + Clone + Send + Sync + 'static,
280 F::Extract: Reply,
281 F::Error: CombineRejection<Rejection>,
282 <F::Error as CombineRejection<Rejection>>::One: CombineRejection<Rejection>,
283{
284 type Wrapped = CorsFilter<F>;
285
286 fn wrap(&self, inner: F) -> Self::Wrapped {
287 let Cors { config } = self.clone().build();
288
289 CorsFilter { config, inner }
290 }
291}
292
293impl<F> WrapSealed<F> for Cors
294where
295 F: Filter + Clone + Send + Sync + 'static,
296 F::Extract: Reply,
297 F::Error: CombineRejection<Rejection>,
298 <F::Error as CombineRejection<Rejection>>::One: CombineRejection<Rejection>,
299{
300 type Wrapped = CorsFilter<F>;
301
302 fn wrap(&self, inner: F) -> Self::Wrapped {
303 let config = self.config.clone();
304
305 CorsFilter { config, inner }
306 }
307}
308
309pub struct CorsForbidden {
311 kind: Forbidden,
312}
313
314#[derive(Debug)]
315enum Forbidden {
316 OriginNotAllowed,
317 MethodNotAllowed,
318 HeaderNotAllowed,
319}
320
321impl fmt::Debug for CorsForbidden {
322 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323 f.debug_tuple("CorsForbidden").field(&self.kind).finish()
324 }
325}
326
327impl fmt::Display for CorsForbidden {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 let detail = match self.kind {
330 Forbidden::OriginNotAllowed => "origin not allowed",
331 Forbidden::MethodNotAllowed => "request-method not allowed",
332 Forbidden::HeaderNotAllowed => "header not allowed",
333 };
334 write!(f, "CORS request forbidden: {}", detail)
335 }
336}
337
338impl StdError for CorsForbidden {}
339
340#[derive(Clone, Debug)]
341struct Configured {
342 cors: Builder,
343 allowed_headers_header: AccessControlAllowHeaders,
344 expose_headers_header: Option<AccessControlExposeHeaders>,
345 methods_header: AccessControlAllowMethods,
346}
347
348enum Validated {
349 Preflight(HeaderValue),
350 Simple(HeaderValue),
351 NotCors,
352}
353
354impl Configured {
355 fn check_request(
356 &self,
357 method: &http::Method,
358 headers: &http::HeaderMap,
359 ) -> Result<Validated, Forbidden> {
360 match (headers.get(header::ORIGIN), method) {
361 (Some(origin), &http::Method::OPTIONS) => {
362 if !self.is_origin_allowed(origin) {
365 return Err(Forbidden::OriginNotAllowed);
366 }
367
368 if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
369 if !self.is_method_allowed(req_method) {
370 return Err(Forbidden::MethodNotAllowed);
371 }
372 } else {
373 tracing::trace!(
374 "preflight request missing access-control-request-method header"
375 );
376 return Err(Forbidden::MethodNotAllowed);
377 }
378
379 if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
380 let headers = req_headers
381 .to_str()
382 .map_err(|_| Forbidden::HeaderNotAllowed)?;
383 for header in headers.split(',') {
384 if !self.is_header_allowed(header.trim()) {
385 return Err(Forbidden::HeaderNotAllowed);
386 }
387 }
388 }
389
390 Ok(Validated::Preflight(origin.clone()))
391 }
392 (Some(origin), _) => {
393 tracing::trace!("origin header: {:?}", origin);
396 if self.is_origin_allowed(origin) {
397 Ok(Validated::Simple(origin.clone()))
398 } else {
399 Err(Forbidden::OriginNotAllowed)
400 }
401 }
402 (None, _) => {
403 Ok(Validated::NotCors)
405 }
406 }
407 }
408
409 fn is_method_allowed(&self, header: &HeaderValue) -> bool {
410 http::Method::from_bytes(header.as_bytes())
411 .map(|method| self.cors.methods.contains(&method))
412 .unwrap_or(false)
413 }
414
415 fn is_header_allowed(&self, header: &str) -> bool {
416 HeaderName::from_bytes(header.as_bytes())
417 .map(|header| self.cors.allowed_headers.contains(&header))
418 .unwrap_or(false)
419 }
420
421 fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
422 if let Some(ref allowed) = self.cors.origins {
423 allowed.contains(origin)
424 } else {
425 true
426 }
427 }
428
429 fn append_preflight_headers(&self, headers: &mut http::HeaderMap) {
430 self.append_common_headers(headers);
431
432 headers.typed_insert(self.allowed_headers_header.clone());
433 headers.typed_insert(self.methods_header.clone());
434
435 if let Some(max_age) = self.cors.max_age {
436 headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
437 }
438 }
439
440 fn append_common_headers(&self, headers: &mut http::HeaderMap) {
441 if self.cors.credentials {
442 headers.insert(
443 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
444 HeaderValue::from_static("true"),
445 );
446 }
447 if let Some(expose_headers_header) = &self.expose_headers_header {
448 headers.typed_insert(expose_headers_header.clone())
449 }
450 }
451}
452
453mod internal {
454 use std::future::Future;
455 use std::pin::Pin;
456 use std::sync::Arc;
457 use std::task::{Context, Poll};
458
459 use futures_util::{future, ready, TryFuture};
460 use headers::Origin;
461 use http::header;
462 use pin_project::pin_project;
463
464 use super::{Configured, CorsForbidden, Validated};
465 use crate::filter::{Filter, FilterBase, Internal, One};
466 use crate::generic::Either;
467 use crate::reject::{CombineRejection, Rejection};
468 use crate::route;
469
470 #[derive(Clone, Debug)]
471 pub struct CorsFilter<F> {
472 pub(super) config: Arc<Configured>,
473 pub(super) inner: F,
474 }
475
476 impl<F> FilterBase for CorsFilter<F>
477 where
478 F: Filter,
479 F::Extract: Send,
480 F::Future: Future,
481 F::Error: CombineRejection<Rejection>,
482 {
483 type Extract =
484 One<Either<One<Preflight>, One<Either<One<Wrapped<F::Extract>>, F::Extract>>>>;
485 type Error = <F::Error as CombineRejection<Rejection>>::One;
486 type Future = future::Either<
487 future::Ready<Result<Self::Extract, Self::Error>>,
488 WrappedFuture<F::Future>,
489 >;
490
491 fn filter(&self, _: Internal) -> Self::Future {
492 let validated =
493 route::with(|route| self.config.check_request(route.method(), route.headers()));
494
495 match validated {
496 Ok(Validated::Preflight(origin)) => {
497 let preflight = Preflight {
498 config: self.config.clone(),
499 origin,
500 };
501 future::Either::Left(future::ok((Either::A((preflight,)),)))
502 }
503 Ok(Validated::Simple(origin)) => future::Either::Right(WrappedFuture {
504 inner: self.inner.filter(Internal),
505 wrapped: Some((self.config.clone(), origin)),
506 }),
507 Ok(Validated::NotCors) => future::Either::Right(WrappedFuture {
508 inner: self.inner.filter(Internal),
509 wrapped: None,
510 }),
511 Err(err) => {
512 let rejection = crate::reject::known(CorsForbidden { kind: err });
513 future::Either::Left(future::err(rejection.into()))
514 }
515 }
516 }
517 }
518
519 #[derive(Debug)]
520 pub struct Preflight {
521 config: Arc<Configured>,
522 origin: header::HeaderValue,
523 }
524
525 impl crate::reply::Reply for Preflight {
526 fn into_response(self) -> crate::reply::Response {
527 let mut res = crate::reply::Response::default();
528 self.config.append_preflight_headers(res.headers_mut());
529 res.headers_mut()
530 .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, self.origin);
531 res
532 }
533 }
534
535 #[derive(Debug)]
536 pub struct Wrapped<R> {
537 config: Arc<Configured>,
538 inner: R,
539 origin: header::HeaderValue,
540 }
541
542 impl<R> crate::reply::Reply for Wrapped<R>
543 where
544 R: crate::reply::Reply,
545 {
546 fn into_response(self) -> crate::reply::Response {
547 let mut res = self.inner.into_response();
548 self.config.append_common_headers(res.headers_mut());
549 res.headers_mut()
550 .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, self.origin);
551 res
552 }
553 }
554
555 #[pin_project]
556 #[derive(Debug)]
557 pub struct WrappedFuture<F> {
558 #[pin]
559 inner: F,
560 wrapped: Option<(Arc<Configured>, header::HeaderValue)>,
561 }
562
563 impl<F> Future for WrappedFuture<F>
564 where
565 F: TryFuture,
566 F::Error: CombineRejection<Rejection>,
567 {
568 type Output = Result<
569 One<Either<One<Preflight>, One<Either<One<Wrapped<F::Ok>>, F::Ok>>>>,
570 <F::Error as CombineRejection<Rejection>>::One,
571 >;
572
573 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
574 let pin = self.project();
575 match ready!(pin.inner.try_poll(cx)) {
576 Ok(inner) => {
577 let item = if let Some((config, origin)) = pin.wrapped.take() {
578 (Either::A((Wrapped {
579 config,
580 inner,
581 origin,
582 },)),)
583 } else {
584 (Either::B(inner),)
585 };
586 let item = (Either::B(item),);
587 Poll::Ready(Ok(item))
588 }
589 Err(err) => Poll::Ready(Err(err.into())),
590 }
591 }
592 }
593
594 pub trait Seconds {
595 fn seconds(self) -> u64;
596 }
597
598 impl Seconds for u32 {
599 fn seconds(self) -> u64 {
600 self.into()
601 }
602 }
603
604 impl Seconds for ::std::time::Duration {
605 fn seconds(self) -> u64 {
606 self.as_secs()
607 }
608 }
609
610 pub trait IntoOrigin {
611 fn into_origin(self) -> Origin;
612 }
613
614 impl<'a> IntoOrigin for &'a str {
615 fn into_origin(self) -> Origin {
616 let mut parts = self.splitn(2, "://");
617 let scheme = parts.next().expect("missing scheme");
618 let rest = parts.next().expect("missing scheme");
619
620 Origin::try_from_parts(scheme, rest, None).expect("invalid Origin")
621 }
622 }
623}