1use std::fmt;
8use std::{error::Error as StdError, sync::Arc};
9
10use crate::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, REFERER, WWW_AUTHENTICATE};
11use http::{HeaderMap, HeaderValue};
12use hyper::StatusCode;
13
14use crate::{async_impl, Url};
15use tower_http::follow_redirect::policy::{
16 Action as TowerAction, Attempt as TowerAttempt, Policy as TowerPolicy,
17};
18
19pub struct Policy {
29 inner: PolicyKind,
30}
31
32#[derive(Debug)]
35pub struct Attempt<'a> {
36 status: StatusCode,
37 next: &'a Url,
38 previous: &'a [Url],
39}
40
41#[derive(Debug)]
43pub struct Action {
44 inner: ActionKind,
45}
46
47impl Policy {
48 pub fn limited(max: usize) -> Self {
52 Self {
53 inner: PolicyKind::Limit(max),
54 }
55 }
56
57 pub fn none() -> Self {
59 Self {
60 inner: PolicyKind::None,
61 }
62 }
63
64 pub fn custom<T>(policy: T) -> Self
103 where
104 T: Fn(Attempt) -> Action + Send + Sync + 'static,
105 {
106 Self {
107 inner: PolicyKind::Custom(Box::new(policy)),
108 }
109 }
110
111 pub fn redirect(&self, attempt: Attempt) -> Action {
132 match self.inner {
133 PolicyKind::Custom(ref custom) => custom(attempt),
134 PolicyKind::Limit(max) => {
135 if attempt.previous.len() > max {
137 attempt.error(TooManyRedirects)
138 } else {
139 attempt.follow()
140 }
141 }
142 PolicyKind::None => attempt.stop(),
143 }
144 }
145
146 pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind {
147 self.redirect(Attempt {
148 status,
149 next,
150 previous,
151 })
152 .inner
153 }
154
155 pub(crate) fn is_default(&self) -> bool {
156 matches!(self.inner, PolicyKind::Limit(10))
157 }
158}
159
160impl Default for Policy {
161 fn default() -> Policy {
162 Policy::limited(10)
164 }
165}
166
167impl<'a> Attempt<'a> {
168 pub fn status(&self) -> StatusCode {
170 self.status
171 }
172
173 pub fn url(&self) -> &Url {
175 self.next
176 }
177
178 pub fn previous(&self) -> &[Url] {
180 self.previous
181 }
182 pub fn follow(self) -> Action {
184 Action {
185 inner: ActionKind::Follow,
186 }
187 }
188
189 pub fn stop(self) -> Action {
193 Action {
194 inner: ActionKind::Stop,
195 }
196 }
197
198 pub fn error<E: Into<Box<dyn StdError + Send + Sync>>>(self, error: E) -> Action {
202 Action {
203 inner: ActionKind::Error(error.into()),
204 }
205 }
206}
207
208enum PolicyKind {
209 Custom(Box<dyn Fn(Attempt) -> Action + Send + Sync + 'static>),
210 Limit(usize),
211 None,
212}
213
214impl fmt::Debug for Policy {
215 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216 f.debug_tuple("Policy").field(&self.inner).finish()
217 }
218}
219
220impl fmt::Debug for PolicyKind {
221 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
222 match *self {
223 PolicyKind::Custom(..) => f.pad("Custom"),
224 PolicyKind::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
225 PolicyKind::None => f.pad("None"),
226 }
227 }
228}
229
230#[derive(Debug)]
233pub(crate) enum ActionKind {
234 Follow,
235 Stop,
236 Error(Box<dyn StdError + Send + Sync>),
237}
238
239pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &[Url]) {
240 if let Some(previous) = previous.last() {
241 let cross_host = next.host_str() != previous.host_str()
242 || next.port_or_known_default() != previous.port_or_known_default()
243 || next.scheme() != previous.scheme();
244 if cross_host {
245 headers.remove(AUTHORIZATION);
246 headers.remove(COOKIE);
247 headers.remove("cookie2");
248 headers.remove(PROXY_AUTHORIZATION);
249 headers.remove(WWW_AUTHENTICATE);
250 }
251 }
252}
253
254#[derive(Debug)]
255struct TooManyRedirects;
256
257impl fmt::Display for TooManyRedirects {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 f.write_str("too many redirects")
260 }
261}
262
263impl StdError for TooManyRedirects {}
264
265#[derive(Clone)]
266pub(crate) struct TowerRedirectPolicy {
267 policy: Arc<Policy>,
268 referer: bool,
269 urls: Vec<Url>,
270 https_only: bool,
271}
272
273impl TowerRedirectPolicy {
274 pub(crate) fn new(policy: Policy) -> Self {
275 Self {
276 policy: Arc::new(policy),
277 referer: false,
278 urls: Vec::new(),
279 https_only: false,
280 }
281 }
282
283 pub(crate) fn with_referer(&mut self, referer: bool) -> &mut Self {
284 self.referer = referer;
285 self
286 }
287
288 pub(crate) fn with_https_only(&mut self, https_only: bool) -> &mut Self {
289 self.https_only = https_only;
290 self
291 }
292}
293
294fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {
295 if next.scheme() == "http" && previous.scheme() == "https" {
296 return None;
297 }
298
299 let mut referer = previous.clone();
300 let _ = referer.set_username("");
301 let _ = referer.set_password(None);
302 referer.set_fragment(None);
303 referer.as_str().parse().ok()
304}
305
306impl TowerPolicy<async_impl::body::Body, crate::Error> for TowerRedirectPolicy {
307 fn redirect(&mut self, attempt: &TowerAttempt<'_>) -> Result<TowerAction, crate::Error> {
308 let previous_url =
309 Url::parse(&attempt.previous().to_string()).expect("Previous URL must be valid");
310
311 let next_url = match Url::parse(&attempt.location().to_string()) {
312 Ok(url) => url,
313 Err(e) => return Err(crate::error::builder(e)),
314 };
315
316 self.urls.push(previous_url.clone());
317
318 match self.policy.check(attempt.status(), &next_url, &self.urls) {
319 ActionKind::Follow => {
320 if next_url.scheme() != "http" && next_url.scheme() != "https" {
321 return Err(crate::error::url_bad_scheme(next_url));
322 }
323
324 if self.https_only && next_url.scheme() != "https" {
325 return Err(crate::error::redirect(
326 crate::error::url_bad_scheme(next_url.clone()),
327 next_url,
328 ));
329 }
330 Ok(TowerAction::Follow)
331 }
332 ActionKind::Stop => Ok(TowerAction::Stop),
333 ActionKind::Error(e) => Err(crate::error::redirect(e, previous_url)),
334 }
335 }
336
337 fn on_request(&mut self, req: &mut http::Request<async_impl::body::Body>) {
338 if let Ok(next_url) = Url::parse(&req.uri().to_string()) {
339 remove_sensitive_headers(req.headers_mut(), &next_url, &self.urls);
340 if self.referer {
341 if let Some(previous_url) = self.urls.last() {
342 if let Some(v) = make_referer(&next_url, previous_url) {
343 req.headers_mut().insert(REFERER, v);
344 }
345 }
346 }
347 };
348 }
349
350 fn clone_body(&self, body: &async_impl::body::Body) -> Option<async_impl::body::Body> {
352 body.try_clone()
353 }
354}
355
356#[test]
357fn test_redirect_policy_limit() {
358 let policy = Policy::default();
359 let next = Url::parse("http://x.y/z").unwrap();
360 let mut previous = (0..=9)
361 .map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap())
362 .collect::<Vec<_>>();
363
364 match policy.check(StatusCode::FOUND, &next, &previous) {
365 ActionKind::Follow => (),
366 other => panic!("unexpected {other:?}"),
367 }
368
369 previous.push(Url::parse("http://a.b.d/e/33").unwrap());
370
371 match policy.check(StatusCode::FOUND, &next, &previous) {
372 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
373 other => panic!("unexpected {other:?}"),
374 }
375}
376
377#[test]
378fn test_redirect_policy_limit_to_0() {
379 let policy = Policy::limited(0);
380 let next = Url::parse("http://x.y/z").unwrap();
381 let previous = vec![Url::parse("http://a.b/c").unwrap()];
382
383 match policy.check(StatusCode::FOUND, &next, &previous) {
384 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
385 other => panic!("unexpected {other:?}"),
386 }
387}
388
389#[test]
390fn test_redirect_policy_custom() {
391 let policy = Policy::custom(|attempt| {
392 if attempt.url().host_str() == Some("foo") {
393 attempt.stop()
394 } else {
395 attempt.follow()
396 }
397 });
398
399 let next = Url::parse("http://bar/baz").unwrap();
400 match policy.check(StatusCode::FOUND, &next, &[]) {
401 ActionKind::Follow => (),
402 other => panic!("unexpected {other:?}"),
403 }
404
405 let next = Url::parse("http://foo/baz").unwrap();
406 match policy.check(StatusCode::FOUND, &next, &[]) {
407 ActionKind::Stop => (),
408 other => panic!("unexpected {other:?}"),
409 }
410}
411
412#[test]
413fn test_remove_sensitive_headers() {
414 use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
415
416 let mut headers = HeaderMap::new();
417 headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
418 headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
419 headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
420
421 let next = Url::parse("http://initial-domain.com/path").unwrap();
422 let mut prev = vec![Url::parse("http://initial-domain.com/new_path").unwrap()];
423 let mut filtered_headers = headers.clone();
424
425 remove_sensitive_headers(&mut headers, &next, &prev);
426 assert_eq!(headers, filtered_headers);
427
428 prev.push(Url::parse("http://new-domain.com/path").unwrap());
429 filtered_headers.remove(AUTHORIZATION);
430 filtered_headers.remove(COOKIE);
431
432 remove_sensitive_headers(&mut headers, &next, &prev);
433 assert_eq!(headers, filtered_headers);
434}
435
436#[test]
437fn test_remove_sensitive_headers_on_scheme_downgrade_same_host_port() {
438 use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
439
440 let mut headers = HeaderMap::new();
441 headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
442 headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
443 headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
444
445 let next = Url::parse("http://initial-domain.com:8443/path").unwrap();
446 let prev = vec![Url::parse("https://initial-domain.com:8443/new_path").unwrap()];
447
448 let mut filtered_headers = headers.clone();
449 filtered_headers.remove(AUTHORIZATION);
450 filtered_headers.remove(COOKIE);
451
452 remove_sensitive_headers(&mut headers, &next, &prev);
453 assert_eq!(headers, filtered_headers);
454}