reqwest/dns/
resolve.rs

1use hyper_util::client::legacy::connect::dns::Name as HyperName;
2use tower_service::Service;
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use crate::error::BoxError;
13
14/// Alias for an `Iterator` trait object over `SocketAddr`.
15pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Send>;
16
17/// Alias for the `Future` type returned by a DNS resolver.
18pub type Resolving = Pin<Box<dyn Future<Output = Result<Addrs, BoxError>> + Send>>;
19
20/// Trait for customizing DNS resolution in reqwest.
21pub trait Resolve: Send + Sync {
22    /// Performs DNS resolution on a `Name`.
23    /// The return type is a future containing an iterator of `SocketAddr`.
24    ///
25    /// It differs from `tower_service::Service<Name>` in several ways:
26    ///  * It is assumed that `resolve` will always be ready to poll.
27    ///  * It does not need a mutable reference to `self`.
28    ///  * Since trait objects cannot make use of associated types, it requires
29    ///    wrapping the returned `Future` and its contained `Iterator` with `Box`.
30    ///
31    /// Explicitly specified port in the URL will override any port in the resolved `SocketAddr`s.
32    /// Otherwise, port `0` will be replaced by the conventional port for the given scheme (e.g. 80 for http).
33    fn resolve(&self, name: Name) -> Resolving;
34}
35
36/// A name that must be resolved to addresses.
37#[derive(Debug)]
38pub struct Name(pub(super) HyperName);
39
40/// A more general trait implemented for types implementing `Resolve`.
41///
42/// Unnameable, only exported to aid seeing what implements this.
43pub trait IntoResolve {
44    #[doc(hidden)]
45    fn into_resolve(self) -> Arc<dyn Resolve>;
46}
47
48impl Name {
49    /// View the name as a string.
50    pub fn as_str(&self) -> &str {
51        self.0.as_str()
52    }
53}
54
55impl FromStr for Name {
56    type Err = sealed::InvalidNameError;
57
58    fn from_str(host: &str) -> Result<Self, Self::Err> {
59        HyperName::from_str(host)
60            .map(Name)
61            .map_err(|_| sealed::InvalidNameError { _ext: () })
62    }
63}
64
65#[derive(Clone)]
66pub(crate) struct DynResolver {
67    resolver: Arc<dyn Resolve>,
68}
69
70impl DynResolver {
71    pub(crate) fn new(resolver: Arc<dyn Resolve>) -> Self {
72        Self { resolver }
73    }
74
75    #[cfg(feature = "socks")]
76    pub(crate) fn gai() -> Self {
77        Self::new(Arc::new(super::gai::GaiResolver::new()))
78    }
79
80    /// Resolve an HTTP host and port, not just a domain name.
81    ///
82    /// This does the same thing that hyper-util's HttpConnector does, before
83    /// calling out to its underlying DNS resolver.
84    #[cfg(feature = "socks")]
85    pub(crate) async fn http_resolve(
86        &self,
87        target: &http::Uri,
88    ) -> Result<impl Iterator<Item = std::net::SocketAddr>, BoxError> {
89        let host = target.host().ok_or("missing host")?;
90        let port = target
91            .port_u16()
92            .unwrap_or_else(|| match target.scheme_str() {
93                Some("https") => 443,
94                Some("socks4") | Some("socks4a") | Some("socks5") | Some("socks5h") => 1080,
95                _ => 80,
96            });
97
98        let explicit_port = target.port().is_some();
99
100        let addrs = self.resolver.resolve(host.parse()?).await?;
101
102        Ok(addrs.map(move |mut addr| {
103            if explicit_port || addr.port() == 0 {
104                addr.set_port(port);
105            }
106            addr
107        }))
108    }
109}
110
111impl Service<HyperName> for DynResolver {
112    type Response = Addrs;
113    type Error = BoxError;
114    type Future = Resolving;
115
116    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117        Poll::Ready(Ok(()))
118    }
119
120    fn call(&mut self, name: HyperName) -> Self::Future {
121        self.resolver.resolve(Name(name))
122    }
123}
124
125pub(crate) struct DnsResolverWithOverrides {
126    dns_resolver: Arc<dyn Resolve>,
127    overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
128}
129
130impl DnsResolverWithOverrides {
131    pub(crate) fn new(
132        dns_resolver: Arc<dyn Resolve>,
133        overrides: HashMap<String, Vec<SocketAddr>>,
134    ) -> Self {
135        DnsResolverWithOverrides {
136            dns_resolver,
137            overrides: Arc::new(overrides),
138        }
139    }
140}
141
142impl Resolve for DnsResolverWithOverrides {
143    fn resolve(&self, name: Name) -> Resolving {
144        match self.overrides.get(name.as_str()) {
145            Some(dest) => {
146                let addrs: Addrs = Box::new(dest.clone().into_iter());
147                Box::pin(std::future::ready(Ok(addrs)))
148            }
149            None => self.dns_resolver.resolve(name),
150        }
151    }
152}
153
154impl IntoResolve for Arc<dyn Resolve> {
155    fn into_resolve(self) -> Arc<dyn Resolve> {
156        self
157    }
158}
159
160impl<R> IntoResolve for Arc<R>
161where
162    R: Resolve + 'static,
163{
164    fn into_resolve(self) -> Arc<dyn Resolve> {
165        self
166    }
167}
168
169impl<R> IntoResolve for R
170where
171    R: Resolve + 'static,
172{
173    fn into_resolve(self) -> Arc<dyn Resolve> {
174        Arc::new(self)
175    }
176}
177
178mod sealed {
179    use std::fmt;
180
181    #[derive(Debug)]
182    pub struct InvalidNameError {
183        pub(super) _ext: (),
184    }
185
186    impl fmt::Display for InvalidNameError {
187        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188            f.write_str("invalid DNS name")
189        }
190    }
191
192    impl std::error::Error for InvalidNameError {}
193}