actix_server/
socket.rs

1pub(crate) use std::net::{
2    SocketAddr as StdSocketAddr, TcpListener as StdTcpListener, ToSocketAddrs,
3};
4use std::{fmt, io};
5
6use actix_rt::net::TcpStream;
7pub(crate) use mio::net::TcpListener as MioTcpListener;
8use mio::{event::Source, Interest, Registry, Token};
9#[cfg(unix)]
10pub(crate) use {
11    mio::net::UnixListener as MioUnixListener, std::os::unix::net::UnixListener as StdUnixListener,
12};
13
14use crate::builder::MpTcp;
15
16pub(crate) enum MioListener {
17    Tcp(MioTcpListener),
18    #[cfg(unix)]
19    Uds(MioUnixListener),
20}
21
22impl MioListener {
23    pub(crate) fn local_addr(&self) -> SocketAddr {
24        match *self {
25            MioListener::Tcp(ref lst) => lst
26                .local_addr()
27                .map(SocketAddr::Tcp)
28                .unwrap_or(SocketAddr::Unknown),
29            #[cfg(unix)]
30            MioListener::Uds(ref lst) => lst
31                .local_addr()
32                .map(SocketAddr::Uds)
33                .unwrap_or(SocketAddr::Unknown),
34        }
35    }
36
37    pub(crate) fn accept(&self) -> io::Result<MioStream> {
38        match *self {
39            MioListener::Tcp(ref lst) => lst.accept().map(|(stream, _)| MioStream::Tcp(stream)),
40            #[cfg(unix)]
41            MioListener::Uds(ref lst) => lst.accept().map(|(stream, _)| MioStream::Uds(stream)),
42        }
43    }
44}
45
46impl Source for MioListener {
47    fn register(
48        &mut self,
49        registry: &Registry,
50        token: Token,
51        interests: Interest,
52    ) -> io::Result<()> {
53        match *self {
54            MioListener::Tcp(ref mut lst) => lst.register(registry, token, interests),
55            #[cfg(unix)]
56            MioListener::Uds(ref mut lst) => lst.register(registry, token, interests),
57        }
58    }
59
60    fn reregister(
61        &mut self,
62        registry: &Registry,
63        token: Token,
64        interests: Interest,
65    ) -> io::Result<()> {
66        match *self {
67            MioListener::Tcp(ref mut lst) => lst.reregister(registry, token, interests),
68            #[cfg(unix)]
69            MioListener::Uds(ref mut lst) => lst.reregister(registry, token, interests),
70        }
71    }
72
73    fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
74        match *self {
75            MioListener::Tcp(ref mut lst) => lst.deregister(registry),
76            #[cfg(unix)]
77            MioListener::Uds(ref mut lst) => {
78                let res = lst.deregister(registry);
79
80                // cleanup file path
81                if let Ok(addr) = lst.local_addr() {
82                    if let Some(path) = addr.as_pathname() {
83                        let _ = std::fs::remove_file(path);
84                    }
85                }
86                res
87            }
88        }
89    }
90}
91
92impl From<StdTcpListener> for MioListener {
93    fn from(lst: StdTcpListener) -> Self {
94        MioListener::Tcp(MioTcpListener::from_std(lst))
95    }
96}
97
98#[cfg(unix)]
99impl From<StdUnixListener> for MioListener {
100    fn from(lst: StdUnixListener) -> Self {
101        MioListener::Uds(MioUnixListener::from_std(lst))
102    }
103}
104
105impl fmt::Debug for MioListener {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match *self {
108            MioListener::Tcp(ref lst) => write!(f, "{:?}", lst),
109            #[cfg(unix)]
110            MioListener::Uds(ref lst) => write!(f, "{:?}", lst),
111        }
112    }
113}
114
115impl fmt::Display for MioListener {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        match *self {
118            MioListener::Tcp(ref lst) => write!(f, "{:?}", lst),
119            #[cfg(unix)]
120            MioListener::Uds(ref lst) => write!(f, "{:?}", lst),
121        }
122    }
123}
124
125pub(crate) enum SocketAddr {
126    Unknown,
127    Tcp(StdSocketAddr),
128    #[cfg(unix)]
129    Uds(std::os::unix::net::SocketAddr),
130}
131
132impl fmt::Display for SocketAddr {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        match *self {
135            Self::Unknown => write!(f, "Unknown SocketAddr"),
136            Self::Tcp(ref addr) => write!(f, "{}", addr),
137            #[cfg(unix)]
138            Self::Uds(ref addr) => write!(f, "{:?}", addr),
139        }
140    }
141}
142
143impl fmt::Debug for SocketAddr {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        match *self {
146            Self::Unknown => write!(f, "Unknown SocketAddr"),
147            Self::Tcp(ref addr) => write!(f, "{:?}", addr),
148            #[cfg(unix)]
149            Self::Uds(ref addr) => write!(f, "{:?}", addr),
150        }
151    }
152}
153
154#[derive(Debug)]
155pub enum MioStream {
156    Tcp(mio::net::TcpStream),
157    #[cfg(unix)]
158    Uds(mio::net::UnixStream),
159}
160
161/// Helper trait for converting a Mio stream into a Tokio stream.
162pub trait FromStream: Sized {
163    /// Creates stream from a `mio` stream.
164    fn from_mio(sock: MioStream) -> io::Result<Self>;
165}
166
167#[cfg(windows)]
168mod win_impl {
169    use std::os::windows::io::{FromRawSocket, IntoRawSocket};
170
171    use super::*;
172
173    // TODO: This is a workaround and we need an efficient way to convert between Mio and Tokio stream
174    impl FromStream for TcpStream {
175        fn from_mio(sock: MioStream) -> io::Result<Self> {
176            match sock {
177                MioStream::Tcp(mio) => {
178                    let raw = IntoRawSocket::into_raw_socket(mio);
179                    // SAFETY: This is an in-place conversion from Mio stream to Tokio stream.
180                    TcpStream::from_std(unsafe { FromRawSocket::from_raw_socket(raw) })
181                }
182            }
183        }
184    }
185}
186
187#[cfg(unix)]
188mod unix_impl {
189    use std::os::unix::io::{FromRawFd, IntoRawFd};
190
191    use actix_rt::net::UnixStream;
192
193    use super::*;
194
195    // HACK: This is a workaround and we need an efficient way to convert between Mio and Tokio stream
196    impl FromStream for TcpStream {
197        fn from_mio(sock: MioStream) -> io::Result<Self> {
198            match sock {
199                MioStream::Tcp(mio) => {
200                    let raw = IntoRawFd::into_raw_fd(mio);
201                    // SAFETY: This is an in-place conversion from Mio stream to Tokio stream.
202                    TcpStream::from_std(unsafe { FromRawFd::from_raw_fd(raw) })
203                }
204                MioStream::Uds(_) => {
205                    panic!("Should not happen, bug in server impl");
206                }
207            }
208        }
209    }
210
211    // HACK: This is a workaround and we need an efficient way to convert between Mio and Tokio stream
212    impl FromStream for UnixStream {
213        fn from_mio(sock: MioStream) -> io::Result<Self> {
214            match sock {
215                MioStream::Tcp(_) => panic!("Should not happen, bug in server impl"),
216                MioStream::Uds(mio) => {
217                    let raw = IntoRawFd::into_raw_fd(mio);
218                    // SAFETY: This is an in-place conversion from Mio stream to Tokio stream.
219                    UnixStream::from_std(unsafe { FromRawFd::from_raw_fd(raw) })
220                }
221            }
222        }
223    }
224}
225
226pub(crate) fn create_mio_tcp_listener(
227    addr: StdSocketAddr,
228    backlog: u32,
229    mptcp: &MpTcp,
230) -> io::Result<MioTcpListener> {
231    use socket2::{Domain, Protocol, Socket, Type};
232
233    #[cfg(not(target_os = "linux"))]
234    let protocol = Protocol::TCP;
235    #[cfg(target_os = "linux")]
236    let protocol = if matches!(mptcp, MpTcp::Disabled) {
237        Protocol::TCP
238    } else {
239        Protocol::MPTCP
240    };
241
242    let socket = match Socket::new(Domain::for_address(addr), Type::STREAM, Some(protocol)) {
243        Ok(sock) => sock,
244
245        Err(err) if matches!(mptcp, MpTcp::TcpFallback) => {
246            tracing::warn!("binding socket as MPTCP failed: {err}");
247            tracing::warn!("falling back to TCP");
248            Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?
249        }
250
251        Err(err) => return Err(err),
252    };
253
254    socket.set_reuse_address(true)?;
255    socket.set_nonblocking(true)?;
256    socket.bind(&addr.into())?;
257    socket.listen(backlog as i32)?;
258
259    Ok(MioTcpListener::from_std(StdTcpListener::from(socket)))
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn socket_addr() {
268        let addr = SocketAddr::Tcp("127.0.0.1:8080".parse().unwrap());
269        assert!(format!("{:?}", addr).contains("127.0.0.1:8080"));
270        assert_eq!(format!("{}", addr), "127.0.0.1:8080");
271
272        let addr: StdSocketAddr = "127.0.0.1:0".parse().unwrap();
273        let lst = create_mio_tcp_listener(addr, 128, &MpTcp::Disabled).unwrap();
274        let lst = MioListener::Tcp(lst);
275        assert!(format!("{:?}", lst).contains("TcpListener"));
276        assert!(format!("{}", lst).contains("127.0.0.1"));
277    }
278
279    #[test]
280    #[cfg(unix)]
281    fn uds() {
282        let _ = std::fs::remove_file("/tmp/sock.xxxxx");
283        if let Ok(socket) = MioUnixListener::bind("/tmp/sock.xxxxx") {
284            let addr = socket.local_addr().expect("Couldn't get local address");
285            let a = SocketAddr::Uds(addr);
286            assert!(format!("{:?}", a).contains("/tmp/sock.xxxxx"));
287            assert!(format!("{}", a).contains("/tmp/sock.xxxxx"));
288
289            let lst = MioListener::Uds(socket);
290            assert!(format!("{:?}", lst).contains("/tmp/sock.xxxxx"));
291            assert!(format!("{}", lst).contains("/tmp/sock.xxxxx"));
292        }
293    }
294}