actix_http/h2/
service.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    mem, net,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll},
8};
9
10use actix_codec::{AsyncRead, AsyncWrite};
11use actix_rt::net::TcpStream;
12use actix_service::{
13    fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
14};
15use actix_utils::future::ready;
16use futures_core::{future::LocalBoxFuture, ready};
17use tracing::{error, trace};
18
19use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout};
20use crate::{
21    body::{BoxBody, MessageBody},
22    config::ServiceConfig,
23    error::DispatchError,
24    service::HttpFlow,
25    ConnectCallback, OnConnectData, Request, Response,
26};
27
28/// `ServiceFactory` implementation for HTTP/2 transport
29pub struct H2Service<T, S, B> {
30    srv: S,
31    cfg: ServiceConfig,
32    on_connect_ext: Option<Rc<ConnectCallback<T>>>,
33    _phantom: PhantomData<(T, B)>,
34}
35
36impl<T, S, B> H2Service<T, S, B>
37where
38    S: ServiceFactory<Request, Config = ()>,
39    S::Error: Into<Response<BoxBody>> + 'static,
40    S::Response: Into<Response<B>> + 'static,
41    <S::Service as Service<Request>>::Future: 'static,
42
43    B: MessageBody + 'static,
44{
45    /// Create new `H2Service` instance with config.
46    pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
47        cfg: ServiceConfig,
48        service: F,
49    ) -> Self {
50        H2Service {
51            cfg,
52            on_connect_ext: None,
53            srv: service.into_factory(),
54            _phantom: PhantomData,
55        }
56    }
57
58    /// Set on connect callback.
59    pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
60        self.on_connect_ext = f;
61        self
62    }
63}
64
65impl<S, B> H2Service<TcpStream, S, B>
66where
67    S: ServiceFactory<Request, Config = ()>,
68    S::Future: 'static,
69    S::Error: Into<Response<BoxBody>> + 'static,
70    S::Response: Into<Response<B>> + 'static,
71    <S::Service as Service<Request>>::Future: 'static,
72
73    B: MessageBody + 'static,
74{
75    /// Create plain TCP based service
76    pub fn tcp(
77        self,
78    ) -> impl ServiceFactory<
79        TcpStream,
80        Config = (),
81        Response = (),
82        Error = DispatchError,
83        InitError = S::InitError,
84    > {
85        fn_factory(|| {
86            ready(Ok::<_, S::InitError>(fn_service(|io: TcpStream| {
87                let peer_addr = io.peer_addr().ok();
88                ready(Ok::<_, DispatchError>((io, peer_addr)))
89            })))
90        })
91        .and_then(self)
92    }
93}
94
95#[cfg(feature = "openssl")]
96mod openssl {
97    use actix_service::ServiceFactoryExt as _;
98    use actix_tls::accept::{
99        openssl::{
100            reexports::{Error as SslError, SslAcceptor},
101            Acceptor, TlsStream,
102        },
103        TlsError,
104    };
105
106    use super::*;
107
108    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
109    where
110        S: ServiceFactory<Request, Config = ()>,
111        S::Future: 'static,
112        S::Error: Into<Response<BoxBody>> + 'static,
113        S::Response: Into<Response<B>> + 'static,
114        <S::Service as Service<Request>>::Future: 'static,
115
116        B: MessageBody + 'static,
117    {
118        /// Create OpenSSL based service.
119        pub fn openssl(
120            self,
121            acceptor: SslAcceptor,
122        ) -> impl ServiceFactory<
123            TcpStream,
124            Config = (),
125            Response = (),
126            Error = TlsError<SslError, DispatchError>,
127            InitError = S::InitError,
128        > {
129            Acceptor::new(acceptor)
130                .map_init_err(|_| {
131                    unreachable!("TLS acceptor service factory does not error on init")
132                })
133                .map_err(TlsError::into_service_error)
134                .map(|io: TlsStream<TcpStream>| {
135                    let peer_addr = io.get_ref().peer_addr().ok();
136                    (io, peer_addr)
137                })
138                .and_then(self.map_err(TlsError::Service))
139        }
140    }
141}
142
143#[cfg(feature = "rustls-0_20")]
144mod rustls_0_20 {
145    use std::io;
146
147    use actix_service::ServiceFactoryExt as _;
148    use actix_tls::accept::{
149        rustls::{reexports::ServerConfig, Acceptor, TlsStream},
150        TlsError,
151    };
152
153    use super::*;
154
155    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
156    where
157        S: ServiceFactory<Request, Config = ()>,
158        S::Future: 'static,
159        S::Error: Into<Response<BoxBody>> + 'static,
160        S::Response: Into<Response<B>> + 'static,
161        <S::Service as Service<Request>>::Future: 'static,
162
163        B: MessageBody + 'static,
164    {
165        /// Create Rustls v0.20 based service.
166        pub fn rustls(
167            self,
168            mut config: ServerConfig,
169        ) -> impl ServiceFactory<
170            TcpStream,
171            Config = (),
172            Response = (),
173            Error = TlsError<io::Error, DispatchError>,
174            InitError = S::InitError,
175        > {
176            let mut protos = vec![b"h2".to_vec()];
177            protos.extend_from_slice(&config.alpn_protocols);
178            config.alpn_protocols = protos;
179
180            Acceptor::new(config)
181                .map_init_err(|_| {
182                    unreachable!("TLS acceptor service factory does not error on init")
183                })
184                .map_err(TlsError::into_service_error)
185                .map(|io: TlsStream<TcpStream>| {
186                    let peer_addr = io.get_ref().0.peer_addr().ok();
187                    (io, peer_addr)
188                })
189                .and_then(self.map_err(TlsError::Service))
190        }
191    }
192}
193
194#[cfg(feature = "rustls-0_21")]
195mod rustls_0_21 {
196    use std::io;
197
198    use actix_service::ServiceFactoryExt as _;
199    use actix_tls::accept::{
200        rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
201        TlsError,
202    };
203
204    use super::*;
205
206    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
207    where
208        S: ServiceFactory<Request, Config = ()>,
209        S::Future: 'static,
210        S::Error: Into<Response<BoxBody>> + 'static,
211        S::Response: Into<Response<B>> + 'static,
212        <S::Service as Service<Request>>::Future: 'static,
213
214        B: MessageBody + 'static,
215    {
216        /// Create Rustls v0.21 based service.
217        pub fn rustls_021(
218            self,
219            mut config: ServerConfig,
220        ) -> impl ServiceFactory<
221            TcpStream,
222            Config = (),
223            Response = (),
224            Error = TlsError<io::Error, DispatchError>,
225            InitError = S::InitError,
226        > {
227            let mut protos = vec![b"h2".to_vec()];
228            protos.extend_from_slice(&config.alpn_protocols);
229            config.alpn_protocols = protos;
230
231            Acceptor::new(config)
232                .map_init_err(|_| {
233                    unreachable!("TLS acceptor service factory does not error on init")
234                })
235                .map_err(TlsError::into_service_error)
236                .map(|io: TlsStream<TcpStream>| {
237                    let peer_addr = io.get_ref().0.peer_addr().ok();
238                    (io, peer_addr)
239                })
240                .and_then(self.map_err(TlsError::Service))
241        }
242    }
243}
244
245#[cfg(feature = "rustls-0_22")]
246mod rustls_0_22 {
247    use std::io;
248
249    use actix_service::ServiceFactoryExt as _;
250    use actix_tls::accept::{
251        rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
252        TlsError,
253    };
254
255    use super::*;
256
257    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
258    where
259        S: ServiceFactory<Request, Config = ()>,
260        S::Future: 'static,
261        S::Error: Into<Response<BoxBody>> + 'static,
262        S::Response: Into<Response<B>> + 'static,
263        <S::Service as Service<Request>>::Future: 'static,
264
265        B: MessageBody + 'static,
266    {
267        /// Create Rustls v0.22 based service.
268        pub fn rustls_0_22(
269            self,
270            mut config: ServerConfig,
271        ) -> impl ServiceFactory<
272            TcpStream,
273            Config = (),
274            Response = (),
275            Error = TlsError<io::Error, DispatchError>,
276            InitError = S::InitError,
277        > {
278            let mut protos = vec![b"h2".to_vec()];
279            protos.extend_from_slice(&config.alpn_protocols);
280            config.alpn_protocols = protos;
281
282            Acceptor::new(config)
283                .map_init_err(|_| {
284                    unreachable!("TLS acceptor service factory does not error on init")
285                })
286                .map_err(TlsError::into_service_error)
287                .map(|io: TlsStream<TcpStream>| {
288                    let peer_addr = io.get_ref().0.peer_addr().ok();
289                    (io, peer_addr)
290                })
291                .and_then(self.map_err(TlsError::Service))
292        }
293    }
294}
295
296#[cfg(feature = "rustls-0_23")]
297mod rustls_0_23 {
298    use std::io;
299
300    use actix_service::ServiceFactoryExt as _;
301    use actix_tls::accept::{
302        rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
303        TlsError,
304    };
305
306    use super::*;
307
308    impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
309    where
310        S: ServiceFactory<Request, Config = ()>,
311        S::Future: 'static,
312        S::Error: Into<Response<BoxBody>> + 'static,
313        S::Response: Into<Response<B>> + 'static,
314        <S::Service as Service<Request>>::Future: 'static,
315
316        B: MessageBody + 'static,
317    {
318        /// Create Rustls v0.23 based service.
319        pub fn rustls_0_23(
320            self,
321            mut config: ServerConfig,
322        ) -> impl ServiceFactory<
323            TcpStream,
324            Config = (),
325            Response = (),
326            Error = TlsError<io::Error, DispatchError>,
327            InitError = S::InitError,
328        > {
329            let mut protos = vec![b"h2".to_vec()];
330            protos.extend_from_slice(&config.alpn_protocols);
331            config.alpn_protocols = protos;
332
333            Acceptor::new(config)
334                .map_init_err(|_| {
335                    unreachable!("TLS acceptor service factory does not error on init")
336                })
337                .map_err(TlsError::into_service_error)
338                .map(|io: TlsStream<TcpStream>| {
339                    let peer_addr = io.get_ref().0.peer_addr().ok();
340                    (io, peer_addr)
341                })
342                .and_then(self.map_err(TlsError::Service))
343        }
344    }
345}
346
347impl<T, S, B> ServiceFactory<(T, Option<net::SocketAddr>)> for H2Service<T, S, B>
348where
349    T: AsyncRead + AsyncWrite + Unpin + 'static,
350
351    S: ServiceFactory<Request, Config = ()>,
352    S::Future: 'static,
353    S::Error: Into<Response<BoxBody>> + 'static,
354    S::Response: Into<Response<B>> + 'static,
355    <S::Service as Service<Request>>::Future: 'static,
356
357    B: MessageBody + 'static,
358{
359    type Response = ();
360    type Error = DispatchError;
361    type Config = ();
362    type Service = H2ServiceHandler<T, S::Service, B>;
363    type InitError = S::InitError;
364    type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
365
366    fn new_service(&self, _: ()) -> Self::Future {
367        let service = self.srv.new_service(());
368        let cfg = self.cfg.clone();
369        let on_connect_ext = self.on_connect_ext.clone();
370
371        Box::pin(async move {
372            let service = service.await?;
373            Ok(H2ServiceHandler::new(cfg, on_connect_ext, service))
374        })
375    }
376}
377
378/// `Service` implementation for HTTP/2 transport
379pub struct H2ServiceHandler<T, S, B>
380where
381    S: Service<Request>,
382{
383    flow: Rc<HttpFlow<S, (), ()>>,
384    cfg: ServiceConfig,
385    on_connect_ext: Option<Rc<ConnectCallback<T>>>,
386    _phantom: PhantomData<B>,
387}
388
389impl<T, S, B> H2ServiceHandler<T, S, B>
390where
391    S: Service<Request>,
392    S::Error: Into<Response<BoxBody>> + 'static,
393    S::Future: 'static,
394    S::Response: Into<Response<B>> + 'static,
395    B: MessageBody + 'static,
396{
397    fn new(
398        cfg: ServiceConfig,
399        on_connect_ext: Option<Rc<ConnectCallback<T>>>,
400        service: S,
401    ) -> H2ServiceHandler<T, S, B> {
402        H2ServiceHandler {
403            flow: HttpFlow::new(service, (), None),
404            cfg,
405            on_connect_ext,
406            _phantom: PhantomData,
407        }
408    }
409}
410
411impl<T, S, B> Service<(T, Option<net::SocketAddr>)> for H2ServiceHandler<T, S, B>
412where
413    T: AsyncRead + AsyncWrite + Unpin,
414    S: Service<Request>,
415    S::Error: Into<Response<BoxBody>> + 'static,
416    S::Future: 'static,
417    S::Response: Into<Response<B>> + 'static,
418    B: MessageBody + 'static,
419{
420    type Response = ();
421    type Error = DispatchError;
422    type Future = H2ServiceHandlerResponse<T, S, B>;
423
424    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
425        self.flow.service.poll_ready(cx).map_err(|err| {
426            let err = err.into();
427            error!("Service readiness error: {:?}", err);
428            DispatchError::Service(err)
429        })
430    }
431
432    fn call(&self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future {
433        let on_connect_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
434
435        H2ServiceHandlerResponse {
436            state: State::Handshake(
437                Some(Rc::clone(&self.flow)),
438                Some(self.cfg.clone()),
439                addr,
440                on_connect_data,
441                handshake_with_timeout(io, &self.cfg),
442            ),
443        }
444    }
445}
446
447enum State<T, S: Service<Request>, B: MessageBody>
448where
449    T: AsyncRead + AsyncWrite + Unpin,
450    S::Future: 'static,
451{
452    Handshake(
453        Option<Rc<HttpFlow<S, (), ()>>>,
454        Option<ServiceConfig>,
455        Option<net::SocketAddr>,
456        OnConnectData,
457        HandshakeWithTimeout<T>,
458    ),
459    Established(Dispatcher<T, S, B, (), ()>),
460}
461
462pub struct H2ServiceHandlerResponse<T, S, B>
463where
464    T: AsyncRead + AsyncWrite + Unpin,
465    S: Service<Request>,
466    S::Error: Into<Response<BoxBody>> + 'static,
467    S::Future: 'static,
468    S::Response: Into<Response<B>> + 'static,
469    B: MessageBody + 'static,
470{
471    state: State<T, S, B>,
472}
473
474impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
475where
476    T: AsyncRead + AsyncWrite + Unpin,
477    S: Service<Request>,
478    S::Error: Into<Response<BoxBody>> + 'static,
479    S::Future: 'static,
480    S::Response: Into<Response<B>> + 'static,
481    B: MessageBody,
482{
483    type Output = Result<(), DispatchError>;
484
485    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486        match self.state {
487            State::Handshake(
488                ref mut srv,
489                ref mut config,
490                ref peer_addr,
491                ref mut conn_data,
492                ref mut handshake,
493            ) => match ready!(Pin::new(handshake).poll(cx)) {
494                Ok((conn, timer)) => {
495                    let on_connect_data = mem::take(conn_data);
496
497                    self.state = State::Established(Dispatcher::new(
498                        conn,
499                        srv.take().unwrap(),
500                        config.take().unwrap(),
501                        *peer_addr,
502                        on_connect_data,
503                        timer,
504                    ));
505
506                    self.poll(cx)
507                }
508
509                Err(err) => {
510                    trace!("H2 handshake error: {}", err);
511                    Poll::Ready(Err(err))
512                }
513            },
514
515            State::Established(ref mut disp) => Pin::new(disp).poll(cx),
516        }
517    }
518}