Skip to main content

reqwest/
redirect.rs

1//! Redirect Handling
2//!
3//! By default, a `Client` will automatically handle HTTP redirects, having a
4//! maximum redirect chain of 10 hops. To customize this behavior, a
5//! `redirect::Policy` can be used with a `ClientBuilder`.
6
7use std::fmt;
8use std::{error::Error as StdError, sync::Arc};
9
10use crate::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, REFERER, WWW_AUTHENTICATE};
11use http::{HeaderMap, HeaderValue};
12use hyper::StatusCode;
13
14use crate::{async_impl, Url};
15use tower_http::follow_redirect::policy::{
16    Action as TowerAction, Attempt as TowerAttempt, Policy as TowerPolicy,
17};
18
19/// A type that controls the policy on how to handle the following of redirects.
20///
21/// The default value will catch redirect loops, and has a maximum of 10
22/// redirects it will follow in a chain before returning an error.
23///
24/// - `limited` can be used have the same as the default behavior, but adjust
25///   the allowed maximum redirect hops in a chain.
26/// - `none` can be used to disable all redirect behavior.
27/// - `custom` can be used to create a customized policy.
28pub struct Policy {
29    inner: PolicyKind,
30}
31
32/// A type that holds information on the next request and previous requests
33/// in redirect chain.
34#[derive(Debug)]
35pub struct Attempt<'a> {
36    status: StatusCode,
37    next: &'a Url,
38    previous: &'a [Url],
39}
40
41/// An action to perform when a redirect status code is found.
42#[derive(Debug)]
43pub struct Action {
44    inner: ActionKind,
45}
46
47impl Policy {
48    /// Create a `Policy` with a maximum number of redirects.
49    ///
50    /// An `Error` will be returned if the max is reached.
51    pub fn limited(max: usize) -> Self {
52        Self {
53            inner: PolicyKind::Limit(max),
54        }
55    }
56
57    /// Create a `Policy` that does not follow any redirect.
58    pub fn none() -> Self {
59        Self {
60            inner: PolicyKind::None,
61        }
62    }
63
64    /// Create a custom `Policy` using the passed function.
65    ///
66    /// # Note
67    ///
68    /// The default `Policy` handles a maximum loop
69    /// chain, but the custom variant does not do that for you automatically.
70    /// The custom policy should have some way of handling those.
71    ///
72    /// Information on the next request and previous requests can be found
73    /// on the [`Attempt`] argument passed to the closure.
74    ///
75    /// Actions can be conveniently created from methods on the
76    /// [`Attempt`].
77    ///
78    /// # Example
79    ///
80    /// ```rust
81    /// # use reqwest::{Error, redirect};
82    /// #
83    /// # fn run() -> Result<(), Error> {
84    /// let custom = redirect::Policy::custom(|attempt| {
85    ///     if attempt.previous().len() > 5 {
86    ///         attempt.error("too many redirects")
87    ///     } else if attempt.url().host_str() == Some("example.domain") {
88    ///         // prevent redirects to 'example.domain'
89    ///         attempt.stop()
90    ///     } else {
91    ///         attempt.follow()
92    ///     }
93    /// });
94    /// let client = reqwest::Client::builder()
95    ///     .redirect(custom)
96    ///     .build()?;
97    /// # Ok(())
98    /// # }
99    /// ```
100    ///
101    /// [`Attempt`]: struct.Attempt.html
102    pub fn custom<T>(policy: T) -> Self
103    where
104        T: Fn(Attempt) -> Action + Send + Sync + 'static,
105    {
106        Self {
107            inner: PolicyKind::Custom(Box::new(policy)),
108        }
109    }
110
111    /// Apply this policy to a given [`Attempt`] to produce a [`Action`].
112    ///
113    /// # Note
114    ///
115    /// This method can be used together with `Policy::custom()`
116    /// to construct one `Policy` that wraps another.
117    ///
118    /// # Example
119    ///
120    /// ```rust
121    /// # use reqwest::{Error, redirect};
122    /// #
123    /// # fn run() -> Result<(), Error> {
124    /// let custom = redirect::Policy::custom(|attempt| {
125    ///     eprintln!("{}, Location: {:?}", attempt.status(), attempt.url());
126    ///     redirect::Policy::default().redirect(attempt)
127    /// });
128    /// # Ok(())
129    /// # }
130    /// ```
131    pub fn redirect(&self, attempt: Attempt) -> Action {
132        match self.inner {
133            PolicyKind::Custom(ref custom) => custom(attempt),
134            PolicyKind::Limit(max) => {
135                // The first URL in the previous is the initial URL and not a redirection. It needs to be excluded.
136                if attempt.previous.len() > max {
137                    attempt.error(TooManyRedirects)
138                } else {
139                    attempt.follow()
140                }
141            }
142            PolicyKind::None => attempt.stop(),
143        }
144    }
145
146    pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind {
147        self.redirect(Attempt {
148            status,
149            next,
150            previous,
151        })
152        .inner
153    }
154
155    pub(crate) fn is_default(&self) -> bool {
156        matches!(self.inner, PolicyKind::Limit(10))
157    }
158}
159
160impl Default for Policy {
161    fn default() -> Policy {
162        // Keep `is_default` in sync
163        Policy::limited(10)
164    }
165}
166
167impl<'a> Attempt<'a> {
168    /// Get the type of redirect.
169    pub fn status(&self) -> StatusCode {
170        self.status
171    }
172
173    /// Get the next URL to redirect to.
174    pub fn url(&self) -> &Url {
175        self.next
176    }
177
178    /// Get the list of previous URLs that have already been requested in this chain.
179    pub fn previous(&self) -> &[Url] {
180        self.previous
181    }
182    /// Returns an action meaning reqwest should follow the next URL.
183    pub fn follow(self) -> Action {
184        Action {
185            inner: ActionKind::Follow,
186        }
187    }
188
189    /// Returns an action meaning reqwest should not follow the next URL.
190    ///
191    /// The 30x response will be returned as the `Ok` result.
192    pub fn stop(self) -> Action {
193        Action {
194            inner: ActionKind::Stop,
195        }
196    }
197
198    /// Returns an action failing the redirect with an error.
199    ///
200    /// The `Error` will be returned for the result of the sent request.
201    pub fn error<E: Into<Box<dyn StdError + Send + Sync>>>(self, error: E) -> Action {
202        Action {
203            inner: ActionKind::Error(error.into()),
204        }
205    }
206}
207
208enum PolicyKind {
209    Custom(Box<dyn Fn(Attempt) -> Action + Send + Sync + 'static>),
210    Limit(usize),
211    None,
212}
213
214impl fmt::Debug for Policy {
215    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216        f.debug_tuple("Policy").field(&self.inner).finish()
217    }
218}
219
220impl fmt::Debug for PolicyKind {
221    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
222        match *self {
223            PolicyKind::Custom(..) => f.pad("Custom"),
224            PolicyKind::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
225            PolicyKind::None => f.pad("None"),
226        }
227    }
228}
229
230// pub(crate)
231
232#[derive(Debug)]
233pub(crate) enum ActionKind {
234    Follow,
235    Stop,
236    Error(Box<dyn StdError + Send + Sync>),
237}
238
239pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &[Url]) {
240    if let Some(previous) = previous.last() {
241        let cross_host = next.host_str() != previous.host_str()
242            || next.port_or_known_default() != previous.port_or_known_default()
243            || next.scheme() != previous.scheme();
244        if cross_host {
245            headers.remove(AUTHORIZATION);
246            headers.remove(COOKIE);
247            headers.remove("cookie2");
248            headers.remove(PROXY_AUTHORIZATION);
249            headers.remove(WWW_AUTHENTICATE);
250        }
251    }
252}
253
254#[derive(Debug)]
255struct TooManyRedirects;
256
257impl fmt::Display for TooManyRedirects {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        f.write_str("too many redirects")
260    }
261}
262
263impl StdError for TooManyRedirects {}
264
265#[derive(Clone)]
266pub(crate) struct TowerRedirectPolicy {
267    policy: Arc<Policy>,
268    referer: bool,
269    urls: Vec<Url>,
270    https_only: bool,
271}
272
273impl TowerRedirectPolicy {
274    pub(crate) fn new(policy: Policy) -> Self {
275        Self {
276            policy: Arc::new(policy),
277            referer: false,
278            urls: Vec::new(),
279            https_only: false,
280        }
281    }
282
283    pub(crate) fn with_referer(&mut self, referer: bool) -> &mut Self {
284        self.referer = referer;
285        self
286    }
287
288    pub(crate) fn with_https_only(&mut self, https_only: bool) -> &mut Self {
289        self.https_only = https_only;
290        self
291    }
292}
293
294fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {
295    if next.scheme() == "http" && previous.scheme() == "https" {
296        return None;
297    }
298
299    let mut referer = previous.clone();
300    let _ = referer.set_username("");
301    let _ = referer.set_password(None);
302    referer.set_fragment(None);
303    referer.as_str().parse().ok()
304}
305
306impl TowerPolicy<async_impl::body::Body, crate::Error> for TowerRedirectPolicy {
307    fn redirect(&mut self, attempt: &TowerAttempt<'_>) -> Result<TowerAction, crate::Error> {
308        let previous_url =
309            Url::parse(&attempt.previous().to_string()).expect("Previous URL must be valid");
310
311        let next_url = match Url::parse(&attempt.location().to_string()) {
312            Ok(url) => url,
313            Err(e) => return Err(crate::error::builder(e)),
314        };
315
316        self.urls.push(previous_url.clone());
317
318        match self.policy.check(attempt.status(), &next_url, &self.urls) {
319            ActionKind::Follow => {
320                if next_url.scheme() != "http" && next_url.scheme() != "https" {
321                    return Err(crate::error::url_bad_scheme(next_url));
322                }
323
324                if self.https_only && next_url.scheme() != "https" {
325                    return Err(crate::error::redirect(
326                        crate::error::url_bad_scheme(next_url.clone()),
327                        next_url,
328                    ));
329                }
330                Ok(TowerAction::Follow)
331            }
332            ActionKind::Stop => Ok(TowerAction::Stop),
333            ActionKind::Error(e) => Err(crate::error::redirect(e, previous_url)),
334        }
335    }
336
337    fn on_request(&mut self, req: &mut http::Request<async_impl::body::Body>) {
338        if let Ok(next_url) = Url::parse(&req.uri().to_string()) {
339            remove_sensitive_headers(req.headers_mut(), &next_url, &self.urls);
340            if self.referer {
341                if let Some(previous_url) = self.urls.last() {
342                    if let Some(v) = make_referer(&next_url, previous_url) {
343                        req.headers_mut().insert(REFERER, v);
344                    }
345                }
346            }
347        };
348    }
349
350    // This must be implemented to make 307 and 308 redirects work
351    fn clone_body(&self, body: &async_impl::body::Body) -> Option<async_impl::body::Body> {
352        body.try_clone()
353    }
354}
355
356#[test]
357fn test_redirect_policy_limit() {
358    let policy = Policy::default();
359    let next = Url::parse("http://x.y/z").unwrap();
360    let mut previous = (0..=9)
361        .map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap())
362        .collect::<Vec<_>>();
363
364    match policy.check(StatusCode::FOUND, &next, &previous) {
365        ActionKind::Follow => (),
366        other => panic!("unexpected {other:?}"),
367    }
368
369    previous.push(Url::parse("http://a.b.d/e/33").unwrap());
370
371    match policy.check(StatusCode::FOUND, &next, &previous) {
372        ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
373        other => panic!("unexpected {other:?}"),
374    }
375}
376
377#[test]
378fn test_redirect_policy_limit_to_0() {
379    let policy = Policy::limited(0);
380    let next = Url::parse("http://x.y/z").unwrap();
381    let previous = vec![Url::parse("http://a.b/c").unwrap()];
382
383    match policy.check(StatusCode::FOUND, &next, &previous) {
384        ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
385        other => panic!("unexpected {other:?}"),
386    }
387}
388
389#[test]
390fn test_redirect_policy_custom() {
391    let policy = Policy::custom(|attempt| {
392        if attempt.url().host_str() == Some("foo") {
393            attempt.stop()
394        } else {
395            attempt.follow()
396        }
397    });
398
399    let next = Url::parse("http://bar/baz").unwrap();
400    match policy.check(StatusCode::FOUND, &next, &[]) {
401        ActionKind::Follow => (),
402        other => panic!("unexpected {other:?}"),
403    }
404
405    let next = Url::parse("http://foo/baz").unwrap();
406    match policy.check(StatusCode::FOUND, &next, &[]) {
407        ActionKind::Stop => (),
408        other => panic!("unexpected {other:?}"),
409    }
410}
411
412#[test]
413fn test_remove_sensitive_headers() {
414    use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
415
416    let mut headers = HeaderMap::new();
417    headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
418    headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
419    headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
420
421    let next = Url::parse("http://initial-domain.com/path").unwrap();
422    let mut prev = vec![Url::parse("http://initial-domain.com/new_path").unwrap()];
423    let mut filtered_headers = headers.clone();
424
425    remove_sensitive_headers(&mut headers, &next, &prev);
426    assert_eq!(headers, filtered_headers);
427
428    prev.push(Url::parse("http://new-domain.com/path").unwrap());
429    filtered_headers.remove(AUTHORIZATION);
430    filtered_headers.remove(COOKIE);
431
432    remove_sensitive_headers(&mut headers, &next, &prev);
433    assert_eq!(headers, filtered_headers);
434}
435
436#[test]
437fn test_remove_sensitive_headers_on_scheme_downgrade_same_host_port() {
438    use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
439
440    let mut headers = HeaderMap::new();
441    headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
442    headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
443    headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
444
445    let next = Url::parse("http://initial-domain.com:8443/path").unwrap();
446    let prev = vec![Url::parse("https://initial-domain.com:8443/new_path").unwrap()];
447
448    let mut filtered_headers = headers.clone();
449    filtered_headers.remove(AUTHORIZATION);
450    filtered_headers.remove(COOKIE);
451
452    remove_sensitive_headers(&mut headers, &next, &prev);
453    assert_eq!(headers, filtered_headers);
454}