1use std::any::Any;
60use std::convert::Infallible;
61use std::error::Error as StdError;
62use std::fmt;
63
64use http::{
65 header::{HeaderValue, CONTENT_TYPE},
66 StatusCode,
67};
68use hyper::Body;
69
70pub(crate) use self::sealed::{CombineRejection, IsReject};
71
72#[inline]
74pub fn reject() -> Rejection {
75 not_found()
76}
77
78#[inline]
80pub fn not_found() -> Rejection {
81 Rejection {
82 reason: Reason::NotFound,
83 }
84}
85
86#[inline]
88pub(crate) fn invalid_query() -> Rejection {
89 known(InvalidQuery { _p: () })
90}
91
92#[inline]
94pub(crate) fn missing_header(name: &'static str) -> Rejection {
95 known(MissingHeader { name })
96}
97
98#[inline]
100pub(crate) fn invalid_header(name: &'static str) -> Rejection {
101 known(InvalidHeader { name })
102}
103
104#[inline]
106pub(crate) fn missing_cookie(name: &'static str) -> Rejection {
107 known(MissingCookie { name })
108}
109
110#[inline]
112pub(crate) fn method_not_allowed() -> Rejection {
113 known(MethodNotAllowed { _p: () })
114}
115
116#[inline]
118pub(crate) fn length_required() -> Rejection {
119 known(LengthRequired { _p: () })
120}
121
122#[inline]
124pub(crate) fn payload_too_large() -> Rejection {
125 known(PayloadTooLarge { _p: () })
126}
127
128#[inline]
133pub(crate) fn unsupported_media_type() -> Rejection {
134 known(UnsupportedMediaType { _p: () })
135}
136
137pub fn custom<T: Reject>(err: T) -> Rejection {
144 Rejection::custom(Box::new(err))
145}
146
147fn __reject_custom_compilefail() {}
155
156pub trait Reject: fmt::Debug + Sized + Send + Sync + 'static {}
177
178trait Cause: fmt::Debug + Send + Sync + 'static {
179 fn as_any(&self) -> &dyn Any;
180}
181
182impl<T> Cause for T
183where
184 T: fmt::Debug + Send + Sync + 'static,
185{
186 fn as_any(&self) -> &dyn Any {
187 self
188 }
189}
190
191impl dyn Cause {
192 fn downcast_ref<T: Any>(&self) -> Option<&T> {
193 self.as_any().downcast_ref::<T>()
194 }
195}
196
197pub(crate) fn known<T: Into<Known>>(err: T) -> Rejection {
198 Rejection::known(err.into())
199}
200
201pub struct Rejection {
205 reason: Reason,
206}
207
208enum Reason {
209 NotFound,
210 Other(Box<Rejections>),
211}
212
213enum Rejections {
214 Known(Known),
215 Custom(Box<dyn Cause>),
216 Combined(Box<Rejections>, Box<Rejections>),
217}
218
219macro_rules! enum_known {
220 ($($(#[$attr:meta])* $var:ident($ty:path),)+) => (
221 pub(crate) enum Known {
222 $(
223 $(#[$attr])*
224 $var($ty),
225 )+
226 }
227
228 impl Known {
229 fn inner_as_any(&self) -> &dyn Any {
230 match *self {
231 $(
232 $(#[$attr])*
233 Known::$var(ref t) => t,
234 )+
235 }
236 }
237 }
238
239 impl fmt::Debug for Known {
240 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241 match *self {
242 $(
243 $(#[$attr])*
244 Known::$var(ref t) => t.fmt(f),
245 )+
246 }
247 }
248 }
249
250 impl fmt::Display for Known {
251 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252 match *self {
253 $(
254 $(#[$attr])*
255 Known::$var(ref t) => t.fmt(f),
256 )+
257 }
258 }
259 }
260
261 $(
262 #[doc(hidden)]
263 $(#[$attr])*
264 impl From<$ty> for Known {
265 fn from(ty: $ty) -> Known {
266 Known::$var(ty)
267 }
268 }
269 )+
270 );
271}
272
273enum_known! {
274 MethodNotAllowed(MethodNotAllowed),
275 InvalidHeader(InvalidHeader),
276 MissingHeader(MissingHeader),
277 MissingCookie(MissingCookie),
278 InvalidQuery(InvalidQuery),
279 LengthRequired(LengthRequired),
280 PayloadTooLarge(PayloadTooLarge),
281 UnsupportedMediaType(UnsupportedMediaType),
282 FileOpenError(crate::fs::FileOpenError),
283 FilePermissionError(crate::fs::FilePermissionError),
284 BodyReadError(crate::body::BodyReadError),
285 BodyDeserializeError(crate::body::BodyDeserializeError),
286 CorsForbidden(crate::cors::CorsForbidden),
287 #[cfg(feature = "websocket")]
288 MissingConnectionUpgrade(crate::ws::MissingConnectionUpgrade),
289 MissingExtension(crate::ext::MissingExtension),
290 BodyConsumedMultipleTimes(crate::body::BodyConsumedMultipleTimes),
291}
292
293impl Rejection {
294 fn known(known: Known) -> Self {
295 Rejection {
296 reason: Reason::Other(Box::new(Rejections::Known(known))),
297 }
298 }
299
300 fn custom(other: Box<dyn Cause>) -> Self {
301 Rejection {
302 reason: Reason::Other(Box::new(Rejections::Custom(other))),
303 }
304 }
305
306 pub fn find<T: 'static>(&self) -> Option<&T> {
326 if let Reason::Other(ref rejections) = self.reason {
327 return rejections.find();
328 }
329 None
330 }
331
332 pub fn is_not_found(&self) -> bool {
342 matches!(self.reason, Reason::NotFound)
343 }
344}
345
346impl<T: Reject> From<T> for Rejection {
347 #[inline]
348 fn from(err: T) -> Rejection {
349 custom(err)
350 }
351}
352
353impl From<Infallible> for Rejection {
354 #[inline]
355 fn from(infallible: Infallible) -> Rejection {
356 match infallible {}
357 }
358}
359
360impl IsReject for Infallible {
361 fn status(&self) -> StatusCode {
362 match *self {}
363 }
364
365 fn into_response(&self) -> crate::reply::Response {
366 match *self {}
367 }
368}
369
370impl IsReject for Rejection {
371 fn status(&self) -> StatusCode {
372 match self.reason {
373 Reason::NotFound => StatusCode::NOT_FOUND,
374 Reason::Other(ref other) => other.status(),
375 }
376 }
377
378 fn into_response(&self) -> crate::reply::Response {
379 match self.reason {
380 Reason::NotFound => {
381 let mut res = http::Response::default();
382 *res.status_mut() = StatusCode::NOT_FOUND;
383 res
384 }
385 Reason::Other(ref other) => other.into_response(),
386 }
387 }
388}
389
390impl fmt::Debug for Rejection {
391 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392 f.debug_tuple("Rejection").field(&self.reason).finish()
393 }
394}
395
396impl fmt::Debug for Reason {
397 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398 match *self {
399 Reason::NotFound => f.write_str("NotFound"),
400 Reason::Other(ref other) => match **other {
401 Rejections::Known(ref e) => fmt::Debug::fmt(e, f),
402 Rejections::Custom(ref e) => fmt::Debug::fmt(e, f),
403 Rejections::Combined(ref a, ref b) => {
404 let mut list = f.debug_list();
405 a.debug_list(&mut list);
406 b.debug_list(&mut list);
407 list.finish()
408 }
409 },
410 }
411 }
412}
413
414impl Rejections {
417 fn status(&self) -> StatusCode {
418 match *self {
419 Rejections::Known(ref k) => match *k {
420 Known::MethodNotAllowed(_) => StatusCode::METHOD_NOT_ALLOWED,
421 Known::InvalidHeader(_)
422 | Known::MissingHeader(_)
423 | Known::MissingCookie(_)
424 | Known::InvalidQuery(_)
425 | Known::BodyReadError(_)
426 | Known::BodyDeserializeError(_) => StatusCode::BAD_REQUEST,
427 #[cfg(feature = "websocket")]
428 Known::MissingConnectionUpgrade(_) => StatusCode::BAD_REQUEST,
429 Known::LengthRequired(_) => StatusCode::LENGTH_REQUIRED,
430 Known::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE,
431 Known::UnsupportedMediaType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE,
432 Known::FilePermissionError(_) | Known::CorsForbidden(_) => StatusCode::FORBIDDEN,
433 Known::FileOpenError(_)
434 | Known::MissingExtension(_)
435 | Known::BodyConsumedMultipleTimes(_) => StatusCode::INTERNAL_SERVER_ERROR,
436 },
437 Rejections::Custom(..) => StatusCode::INTERNAL_SERVER_ERROR,
438 Rejections::Combined(..) => self.preferred().status(),
439 }
440 }
441
442 fn into_response(&self) -> crate::reply::Response {
443 match *self {
444 Rejections::Known(ref e) => {
445 let mut res = http::Response::new(Body::from(e.to_string()));
446 *res.status_mut() = self.status();
447 res.headers_mut().insert(
448 CONTENT_TYPE,
449 HeaderValue::from_static("text/plain; charset=utf-8"),
450 );
451 res
452 }
453 Rejections::Custom(ref e) => {
454 tracing::error!(
455 "unhandled custom rejection, returning 500 response: {:?}",
456 e
457 );
458 let body = format!("Unhandled rejection: {:?}", e);
459 let mut res = http::Response::new(Body::from(body));
460 *res.status_mut() = self.status();
461 res.headers_mut().insert(
462 CONTENT_TYPE,
463 HeaderValue::from_static("text/plain; charset=utf-8"),
464 );
465 res
466 }
467 Rejections::Combined(..) => self.preferred().into_response(),
468 }
469 }
470
471 fn find<T: 'static>(&self) -> Option<&T> {
472 match *self {
473 Rejections::Known(ref e) => e.inner_as_any().downcast_ref(),
474 Rejections::Custom(ref e) => e.downcast_ref(),
475 Rejections::Combined(ref a, ref b) => a.find().or_else(|| b.find()),
476 }
477 }
478
479 fn debug_list(&self, f: &mut fmt::DebugList<'_, '_>) {
480 match *self {
481 Rejections::Known(ref e) => {
482 f.entry(e);
483 }
484 Rejections::Custom(ref e) => {
485 f.entry(e);
486 }
487 Rejections::Combined(ref a, ref b) => {
488 a.debug_list(f);
489 b.debug_list(f);
490 }
491 }
492 }
493
494 fn preferred(&self) -> &Rejections {
495 match self {
496 Rejections::Known(_) | Rejections::Custom(_) => self,
497 Rejections::Combined(a, b) => {
498 let a = a.preferred();
499 let b = b.preferred();
500 match (a.status(), b.status()) {
508 (_, StatusCode::NOT_FOUND) => a,
509 (StatusCode::NOT_FOUND, _) => b,
510 (_, StatusCode::METHOD_NOT_ALLOWED) => a,
511 (StatusCode::METHOD_NOT_ALLOWED, _) => b,
512 (sa, sb) if sa < sb => b,
513 _ => a,
514 }
515 }
516 }
517 }
518}
519
520unit_error! {
521 pub InvalidQuery: "Invalid query string"
523}
524
525unit_error! {
526 pub MethodNotAllowed: "HTTP method not allowed"
528}
529
530unit_error! {
531 pub LengthRequired: "A content-length header is required"
533}
534
535unit_error! {
536 pub PayloadTooLarge: "The request payload is too large"
538}
539
540unit_error! {
541 pub UnsupportedMediaType: "The request's content-type is not supported"
543}
544
545#[derive(Debug)]
547pub struct MissingHeader {
548 name: &'static str,
549}
550
551impl MissingHeader {
552 pub fn name(&self) -> &str {
554 self.name
555 }
556}
557
558impl fmt::Display for MissingHeader {
559 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
560 write!(f, "Missing request header {:?}", self.name)
561 }
562}
563
564impl StdError for MissingHeader {}
565
566#[derive(Debug)]
568pub struct InvalidHeader {
569 name: &'static str,
570}
571
572impl InvalidHeader {
573 pub fn name(&self) -> &str {
575 self.name
576 }
577}
578
579impl fmt::Display for InvalidHeader {
580 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
581 write!(f, "Invalid request header {:?}", self.name)
582 }
583}
584
585impl StdError for InvalidHeader {}
586
587#[derive(Debug)]
589pub struct MissingCookie {
590 name: &'static str,
591}
592
593impl MissingCookie {
594 pub fn name(&self) -> &str {
596 self.name
597 }
598}
599
600impl fmt::Display for MissingCookie {
601 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
602 write!(f, "Missing request cookie {:?}", self.name)
603 }
604}
605
606impl StdError for MissingCookie {}
607
608mod sealed {
609 use super::{Reason, Rejection, Rejections};
610 use http::StatusCode;
611 use std::convert::Infallible;
612 use std::fmt;
613
614 pub trait IsReject: fmt::Debug + Send + Sync {
617 fn status(&self) -> StatusCode;
618 fn into_response(&self) -> crate::reply::Response;
619 }
620
621 fn _assert_object_safe() {
622 fn _assert(_: &dyn IsReject) {}
623 }
624
625 pub trait CombineRejection<E>: Send + Sized {
635 type One: IsReject + From<Self> + From<E> + Into<Rejection>;
649
650 type Combined: IsReject;
653
654 fn combine(self, other: E) -> Self::Combined;
655 }
656
657 impl CombineRejection<Rejection> for Rejection {
658 type One = Rejection;
659 type Combined = Rejection;
660
661 fn combine(self, other: Rejection) -> Self::Combined {
662 let reason = match (self.reason, other.reason) {
663 (Reason::Other(left), Reason::Other(right)) => {
664 Reason::Other(Box::new(Rejections::Combined(left, right)))
665 }
666 (Reason::Other(other), Reason::NotFound)
667 | (Reason::NotFound, Reason::Other(other)) => {
668 Reason::Other(other)
670 }
671 (Reason::NotFound, Reason::NotFound) => Reason::NotFound,
672 };
673
674 Rejection { reason }
675 }
676 }
677
678 impl CombineRejection<Infallible> for Rejection {
679 type One = Rejection;
680 type Combined = Infallible;
681
682 fn combine(self, other: Infallible) -> Self::Combined {
683 match other {}
684 }
685 }
686
687 impl CombineRejection<Rejection> for Infallible {
688 type One = Rejection;
689 type Combined = Infallible;
690
691 fn combine(self, _: Rejection) -> Self::Combined {
692 match self {}
693 }
694 }
695
696 impl CombineRejection<Infallible> for Infallible {
697 type One = Infallible;
698 type Combined = Infallible;
699
700 fn combine(self, _: Infallible) -> Self::Combined {
701 match self {}
702 }
703 }
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709
710 #[derive(Debug, PartialEq)]
711 struct Left;
712
713 #[derive(Debug, PartialEq)]
714 struct Right;
715
716 impl Reject for Left {}
717 impl Reject for Right {}
718
719 #[test]
720 fn rejection_status() {
721 assert_eq!(not_found().status(), StatusCode::NOT_FOUND);
722 assert_eq!(
723 method_not_allowed().status(),
724 StatusCode::METHOD_NOT_ALLOWED
725 );
726 assert_eq!(length_required().status(), StatusCode::LENGTH_REQUIRED);
727 assert_eq!(payload_too_large().status(), StatusCode::PAYLOAD_TOO_LARGE);
728 assert_eq!(
729 unsupported_media_type().status(),
730 StatusCode::UNSUPPORTED_MEDIA_TYPE
731 );
732 assert_eq!(custom(Left).status(), StatusCode::INTERNAL_SERVER_ERROR);
733 }
734
735 #[tokio::test]
736 async fn combine_rejection_causes_with_some_left_and_none_right() {
737 let left = custom(Left);
738 let right = not_found();
739 let reject = left.combine(right);
740 let resp = reject.into_response();
741
742 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
743 assert_eq!(
744 response_body_string(resp).await,
745 "Unhandled rejection: Left"
746 )
747 }
748
749 #[tokio::test]
750 async fn combine_rejection_causes_with_none_left_and_some_right() {
751 let left = not_found();
752 let right = custom(Right);
753 let reject = left.combine(right);
754 let resp = reject.into_response();
755
756 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
757 assert_eq!(
758 response_body_string(resp).await,
759 "Unhandled rejection: Right"
760 )
761 }
762
763 #[tokio::test]
764 async fn unhandled_customs() {
765 let reject = not_found().combine(custom(Right));
766
767 let resp = reject.into_response();
768 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
769 assert_eq!(
770 response_body_string(resp).await,
771 "Unhandled rejection: Right"
772 );
773
774 let reject = custom(Left).combine(custom(Right));
777
778 let resp = reject.into_response();
779 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
780 assert_eq!(
781 response_body_string(resp).await,
782 "Unhandled rejection: Left"
783 );
784
785 let reject = not_found()
787 .combine(not_found())
788 .combine(not_found())
789 .combine(custom(Right))
790 .combine(not_found());
791
792 let resp = reject.into_response();
793 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
794 assert_eq!(
795 response_body_string(resp).await,
796 "Unhandled rejection: Right"
797 );
798 }
799
800 async fn response_body_string(resp: crate::reply::Response) -> String {
801 let (_, body) = resp.into_parts();
802 let body_bytes = hyper::body::to_bytes(body).await.expect("failed concat");
803 String::from_utf8_lossy(&body_bytes).to_string()
804 }
805
806 #[test]
807 fn find_cause() {
808 let rej = custom(Left);
809
810 assert_eq!(rej.find::<Left>(), Some(&Left));
811
812 let rej = rej.combine(method_not_allowed());
813
814 assert_eq!(rej.find::<Left>(), Some(&Left));
815 assert!(rej.find::<MethodNotAllowed>().is_some(), "MethodNotAllowed");
816 }
817
818 #[test]
819 fn size_of_rejection() {
820 assert_eq!(
821 ::std::mem::size_of::<Rejection>(),
822 ::std::mem::size_of::<usize>(),
823 );
824 }
825
826 #[derive(Debug)]
827 struct X(#[allow(unused)] u32);
828 impl Reject for X {}
829
830 fn combine_n<F, R>(n: u32, new_reject: F) -> Rejection
831 where
832 F: Fn(u32) -> R,
833 R: Reject,
834 {
835 let mut rej = not_found();
836
837 for i in 0..n {
838 rej = rej.combine(custom(new_reject(i)));
839 }
840
841 rej
842 }
843
844 #[test]
845 fn test_debug() {
846 let rej = combine_n(3, X);
847
848 let s = format!("{:?}", rej);
849 assert_eq!(s, "Rejection([X(0), X(1), X(2)])");
850 }
851
852 #[test]
853 fn convert_big_rejections_into_response() {
854 let mut rejections = Rejections::Custom(Box::new(std::io::Error::from_raw_os_error(100)));
855 for _ in 0..50 {
856 rejections = Rejections::Combined(
857 Box::new(Rejections::Known(Known::MethodNotAllowed(
858 MethodNotAllowed { _p: () },
859 ))),
860 Box::new(rejections),
861 );
862 }
863 let reason = Reason::Other(Box::new(rejections));
864 let rejection = Rejection { reason };
865 assert_eq!(
866 StatusCode::INTERNAL_SERVER_ERROR,
867 rejection.into_response().status()
868 );
869 }
870}