1use hyper_util::client::legacy::connect::dns::Name as HyperName;
2use tower_service::Service;
3
4use std::collections::HashMap;
5use std::future::Future;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use crate::error::BoxError;
13
14pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Send>;
16
17pub type Resolving = Pin<Box<dyn Future<Output = Result<Addrs, BoxError>> + Send>>;
19
20pub trait Resolve: Send + Sync {
22 fn resolve(&self, name: Name) -> Resolving;
34}
35
36#[derive(Debug)]
38pub struct Name(pub(super) HyperName);
39
40pub trait IntoResolve {
44 #[doc(hidden)]
45 fn into_resolve(self) -> Arc<dyn Resolve>;
46}
47
48impl Name {
49 pub fn as_str(&self) -> &str {
51 self.0.as_str()
52 }
53}
54
55impl FromStr for Name {
56 type Err = sealed::InvalidNameError;
57
58 fn from_str(host: &str) -> Result<Self, Self::Err> {
59 HyperName::from_str(host)
60 .map(Name)
61 .map_err(|_| sealed::InvalidNameError { _ext: () })
62 }
63}
64
65#[derive(Clone)]
66pub(crate) struct DynResolver {
67 resolver: Arc<dyn Resolve>,
68}
69
70impl DynResolver {
71 pub(crate) fn new(resolver: Arc<dyn Resolve>) -> Self {
72 Self { resolver }
73 }
74
75 #[cfg(feature = "socks")]
76 pub(crate) fn gai() -> Self {
77 Self::new(Arc::new(super::gai::GaiResolver::new()))
78 }
79
80 #[cfg(feature = "socks")]
85 pub(crate) async fn http_resolve(
86 &self,
87 target: &http::Uri,
88 ) -> Result<impl Iterator<Item = std::net::SocketAddr>, BoxError> {
89 let host = target.host().ok_or("missing host")?;
90 let port = target
91 .port_u16()
92 .unwrap_or_else(|| match target.scheme_str() {
93 Some("https") => 443,
94 Some("socks4") | Some("socks4a") | Some("socks5") | Some("socks5h") => 1080,
95 _ => 80,
96 });
97
98 let explicit_port = target.port().is_some();
99
100 let addrs = self.resolver.resolve(host.parse()?).await?;
101
102 Ok(addrs.map(move |mut addr| {
103 if explicit_port || addr.port() == 0 {
104 addr.set_port(port);
105 }
106 addr
107 }))
108 }
109}
110
111impl Service<HyperName> for DynResolver {
112 type Response = Addrs;
113 type Error = BoxError;
114 type Future = Resolving;
115
116 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117 Poll::Ready(Ok(()))
118 }
119
120 fn call(&mut self, name: HyperName) -> Self::Future {
121 self.resolver.resolve(Name(name))
122 }
123}
124
125pub(crate) struct DnsResolverWithOverrides {
126 dns_resolver: Arc<dyn Resolve>,
127 overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
128}
129
130impl DnsResolverWithOverrides {
131 pub(crate) fn new(
132 dns_resolver: Arc<dyn Resolve>,
133 overrides: HashMap<String, Vec<SocketAddr>>,
134 ) -> Self {
135 DnsResolverWithOverrides {
136 dns_resolver,
137 overrides: Arc::new(overrides),
138 }
139 }
140}
141
142impl Resolve for DnsResolverWithOverrides {
143 fn resolve(&self, name: Name) -> Resolving {
144 match self.overrides.get(name.as_str()) {
145 Some(dest) => {
146 let addrs: Addrs = Box::new(dest.clone().into_iter());
147 Box::pin(std::future::ready(Ok(addrs)))
148 }
149 None => self.dns_resolver.resolve(name),
150 }
151 }
152}
153
154impl IntoResolve for Arc<dyn Resolve> {
155 fn into_resolve(self) -> Arc<dyn Resolve> {
156 self
157 }
158}
159
160impl<R> IntoResolve for Arc<R>
161where
162 R: Resolve + 'static,
163{
164 fn into_resolve(self) -> Arc<dyn Resolve> {
165 self
166 }
167}
168
169impl<R> IntoResolve for R
170where
171 R: Resolve + 'static,
172{
173 fn into_resolve(self) -> Arc<dyn Resolve> {
174 Arc::new(self)
175 }
176}
177
178mod sealed {
179 use std::fmt;
180
181 #[derive(Debug)]
182 pub struct InvalidNameError {
183 pub(super) _ext: (),
184 }
185
186 impl fmt::Display for InvalidNameError {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 f.write_str("invalid DNS name")
189 }
190 }
191
192 impl std::error::Error for InvalidNameError {}
193}