1use std::{
2 future::Future,
3 marker::PhantomData,
4 mem, net,
5 pin::Pin,
6 rc::Rc,
7 task::{Context, Poll},
8};
9
10use actix_codec::{AsyncRead, AsyncWrite};
11use actix_rt::net::TcpStream;
12use actix_service::{
13 fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _,
14};
15use actix_utils::future::ready;
16use futures_core::{future::LocalBoxFuture, ready};
17use tracing::{error, trace};
18
19use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout};
20use crate::{
21 body::{BoxBody, MessageBody},
22 config::ServiceConfig,
23 error::DispatchError,
24 service::HttpFlow,
25 ConnectCallback, OnConnectData, Request, Response,
26};
27
28pub struct H2Service<T, S, B> {
30 srv: S,
31 cfg: ServiceConfig,
32 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
33 _phantom: PhantomData<(T, B)>,
34}
35
36impl<T, S, B> H2Service<T, S, B>
37where
38 S: ServiceFactory<Request, Config = ()>,
39 S::Error: Into<Response<BoxBody>> + 'static,
40 S::Response: Into<Response<B>> + 'static,
41 <S::Service as Service<Request>>::Future: 'static,
42
43 B: MessageBody + 'static,
44{
45 pub(crate) fn with_config<F: IntoServiceFactory<S, Request>>(
47 cfg: ServiceConfig,
48 service: F,
49 ) -> Self {
50 H2Service {
51 cfg,
52 on_connect_ext: None,
53 srv: service.into_factory(),
54 _phantom: PhantomData,
55 }
56 }
57
58 pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
60 self.on_connect_ext = f;
61 self
62 }
63}
64
65impl<S, B> H2Service<TcpStream, S, B>
66where
67 S: ServiceFactory<Request, Config = ()>,
68 S::Future: 'static,
69 S::Error: Into<Response<BoxBody>> + 'static,
70 S::Response: Into<Response<B>> + 'static,
71 <S::Service as Service<Request>>::Future: 'static,
72
73 B: MessageBody + 'static,
74{
75 pub fn tcp(
77 self,
78 ) -> impl ServiceFactory<
79 TcpStream,
80 Config = (),
81 Response = (),
82 Error = DispatchError,
83 InitError = S::InitError,
84 > {
85 fn_factory(|| {
86 ready(Ok::<_, S::InitError>(fn_service(|io: TcpStream| {
87 let peer_addr = io.peer_addr().ok();
88 ready(Ok::<_, DispatchError>((io, peer_addr)))
89 })))
90 })
91 .and_then(self)
92 }
93}
94
95#[cfg(feature = "openssl")]
96mod openssl {
97 use actix_service::ServiceFactoryExt as _;
98 use actix_tls::accept::{
99 openssl::{
100 reexports::{Error as SslError, SslAcceptor},
101 Acceptor, TlsStream,
102 },
103 TlsError,
104 };
105
106 use super::*;
107
108 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
109 where
110 S: ServiceFactory<Request, Config = ()>,
111 S::Future: 'static,
112 S::Error: Into<Response<BoxBody>> + 'static,
113 S::Response: Into<Response<B>> + 'static,
114 <S::Service as Service<Request>>::Future: 'static,
115
116 B: MessageBody + 'static,
117 {
118 pub fn openssl(
120 self,
121 acceptor: SslAcceptor,
122 ) -> impl ServiceFactory<
123 TcpStream,
124 Config = (),
125 Response = (),
126 Error = TlsError<SslError, DispatchError>,
127 InitError = S::InitError,
128 > {
129 Acceptor::new(acceptor)
130 .map_init_err(|_| {
131 unreachable!("TLS acceptor service factory does not error on init")
132 })
133 .map_err(TlsError::into_service_error)
134 .map(|io: TlsStream<TcpStream>| {
135 let peer_addr = io.get_ref().peer_addr().ok();
136 (io, peer_addr)
137 })
138 .and_then(self.map_err(TlsError::Service))
139 }
140 }
141}
142
143#[cfg(feature = "rustls-0_20")]
144mod rustls_0_20 {
145 use std::io;
146
147 use actix_service::ServiceFactoryExt as _;
148 use actix_tls::accept::{
149 rustls::{reexports::ServerConfig, Acceptor, TlsStream},
150 TlsError,
151 };
152
153 use super::*;
154
155 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
156 where
157 S: ServiceFactory<Request, Config = ()>,
158 S::Future: 'static,
159 S::Error: Into<Response<BoxBody>> + 'static,
160 S::Response: Into<Response<B>> + 'static,
161 <S::Service as Service<Request>>::Future: 'static,
162
163 B: MessageBody + 'static,
164 {
165 pub fn rustls(
167 self,
168 mut config: ServerConfig,
169 ) -> impl ServiceFactory<
170 TcpStream,
171 Config = (),
172 Response = (),
173 Error = TlsError<io::Error, DispatchError>,
174 InitError = S::InitError,
175 > {
176 let mut protos = vec![b"h2".to_vec()];
177 protos.extend_from_slice(&config.alpn_protocols);
178 config.alpn_protocols = protos;
179
180 Acceptor::new(config)
181 .map_init_err(|_| {
182 unreachable!("TLS acceptor service factory does not error on init")
183 })
184 .map_err(TlsError::into_service_error)
185 .map(|io: TlsStream<TcpStream>| {
186 let peer_addr = io.get_ref().0.peer_addr().ok();
187 (io, peer_addr)
188 })
189 .and_then(self.map_err(TlsError::Service))
190 }
191 }
192}
193
194#[cfg(feature = "rustls-0_21")]
195mod rustls_0_21 {
196 use std::io;
197
198 use actix_service::ServiceFactoryExt as _;
199 use actix_tls::accept::{
200 rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
201 TlsError,
202 };
203
204 use super::*;
205
206 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
207 where
208 S: ServiceFactory<Request, Config = ()>,
209 S::Future: 'static,
210 S::Error: Into<Response<BoxBody>> + 'static,
211 S::Response: Into<Response<B>> + 'static,
212 <S::Service as Service<Request>>::Future: 'static,
213
214 B: MessageBody + 'static,
215 {
216 pub fn rustls_021(
218 self,
219 mut config: ServerConfig,
220 ) -> impl ServiceFactory<
221 TcpStream,
222 Config = (),
223 Response = (),
224 Error = TlsError<io::Error, DispatchError>,
225 InitError = S::InitError,
226 > {
227 let mut protos = vec![b"h2".to_vec()];
228 protos.extend_from_slice(&config.alpn_protocols);
229 config.alpn_protocols = protos;
230
231 Acceptor::new(config)
232 .map_init_err(|_| {
233 unreachable!("TLS acceptor service factory does not error on init")
234 })
235 .map_err(TlsError::into_service_error)
236 .map(|io: TlsStream<TcpStream>| {
237 let peer_addr = io.get_ref().0.peer_addr().ok();
238 (io, peer_addr)
239 })
240 .and_then(self.map_err(TlsError::Service))
241 }
242 }
243}
244
245#[cfg(feature = "rustls-0_22")]
246mod rustls_0_22 {
247 use std::io;
248
249 use actix_service::ServiceFactoryExt as _;
250 use actix_tls::accept::{
251 rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
252 TlsError,
253 };
254
255 use super::*;
256
257 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
258 where
259 S: ServiceFactory<Request, Config = ()>,
260 S::Future: 'static,
261 S::Error: Into<Response<BoxBody>> + 'static,
262 S::Response: Into<Response<B>> + 'static,
263 <S::Service as Service<Request>>::Future: 'static,
264
265 B: MessageBody + 'static,
266 {
267 pub fn rustls_0_22(
269 self,
270 mut config: ServerConfig,
271 ) -> impl ServiceFactory<
272 TcpStream,
273 Config = (),
274 Response = (),
275 Error = TlsError<io::Error, DispatchError>,
276 InitError = S::InitError,
277 > {
278 let mut protos = vec![b"h2".to_vec()];
279 protos.extend_from_slice(&config.alpn_protocols);
280 config.alpn_protocols = protos;
281
282 Acceptor::new(config)
283 .map_init_err(|_| {
284 unreachable!("TLS acceptor service factory does not error on init")
285 })
286 .map_err(TlsError::into_service_error)
287 .map(|io: TlsStream<TcpStream>| {
288 let peer_addr = io.get_ref().0.peer_addr().ok();
289 (io, peer_addr)
290 })
291 .and_then(self.map_err(TlsError::Service))
292 }
293 }
294}
295
296#[cfg(feature = "rustls-0_23")]
297mod rustls_0_23 {
298 use std::io;
299
300 use actix_service::ServiceFactoryExt as _;
301 use actix_tls::accept::{
302 rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
303 TlsError,
304 };
305
306 use super::*;
307
308 impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
309 where
310 S: ServiceFactory<Request, Config = ()>,
311 S::Future: 'static,
312 S::Error: Into<Response<BoxBody>> + 'static,
313 S::Response: Into<Response<B>> + 'static,
314 <S::Service as Service<Request>>::Future: 'static,
315
316 B: MessageBody + 'static,
317 {
318 pub fn rustls_0_23(
320 self,
321 mut config: ServerConfig,
322 ) -> impl ServiceFactory<
323 TcpStream,
324 Config = (),
325 Response = (),
326 Error = TlsError<io::Error, DispatchError>,
327 InitError = S::InitError,
328 > {
329 let mut protos = vec![b"h2".to_vec()];
330 protos.extend_from_slice(&config.alpn_protocols);
331 config.alpn_protocols = protos;
332
333 Acceptor::new(config)
334 .map_init_err(|_| {
335 unreachable!("TLS acceptor service factory does not error on init")
336 })
337 .map_err(TlsError::into_service_error)
338 .map(|io: TlsStream<TcpStream>| {
339 let peer_addr = io.get_ref().0.peer_addr().ok();
340 (io, peer_addr)
341 })
342 .and_then(self.map_err(TlsError::Service))
343 }
344 }
345}
346
347impl<T, S, B> ServiceFactory<(T, Option<net::SocketAddr>)> for H2Service<T, S, B>
348where
349 T: AsyncRead + AsyncWrite + Unpin + 'static,
350
351 S: ServiceFactory<Request, Config = ()>,
352 S::Future: 'static,
353 S::Error: Into<Response<BoxBody>> + 'static,
354 S::Response: Into<Response<B>> + 'static,
355 <S::Service as Service<Request>>::Future: 'static,
356
357 B: MessageBody + 'static,
358{
359 type Response = ();
360 type Error = DispatchError;
361 type Config = ();
362 type Service = H2ServiceHandler<T, S::Service, B>;
363 type InitError = S::InitError;
364 type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
365
366 fn new_service(&self, _: ()) -> Self::Future {
367 let service = self.srv.new_service(());
368 let cfg = self.cfg.clone();
369 let on_connect_ext = self.on_connect_ext.clone();
370
371 Box::pin(async move {
372 let service = service.await?;
373 Ok(H2ServiceHandler::new(cfg, on_connect_ext, service))
374 })
375 }
376}
377
378pub struct H2ServiceHandler<T, S, B>
380where
381 S: Service<Request>,
382{
383 flow: Rc<HttpFlow<S, (), ()>>,
384 cfg: ServiceConfig,
385 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
386 _phantom: PhantomData<B>,
387}
388
389impl<T, S, B> H2ServiceHandler<T, S, B>
390where
391 S: Service<Request>,
392 S::Error: Into<Response<BoxBody>> + 'static,
393 S::Future: 'static,
394 S::Response: Into<Response<B>> + 'static,
395 B: MessageBody + 'static,
396{
397 fn new(
398 cfg: ServiceConfig,
399 on_connect_ext: Option<Rc<ConnectCallback<T>>>,
400 service: S,
401 ) -> H2ServiceHandler<T, S, B> {
402 H2ServiceHandler {
403 flow: HttpFlow::new(service, (), None),
404 cfg,
405 on_connect_ext,
406 _phantom: PhantomData,
407 }
408 }
409}
410
411impl<T, S, B> Service<(T, Option<net::SocketAddr>)> for H2ServiceHandler<T, S, B>
412where
413 T: AsyncRead + AsyncWrite + Unpin,
414 S: Service<Request>,
415 S::Error: Into<Response<BoxBody>> + 'static,
416 S::Future: 'static,
417 S::Response: Into<Response<B>> + 'static,
418 B: MessageBody + 'static,
419{
420 type Response = ();
421 type Error = DispatchError;
422 type Future = H2ServiceHandlerResponse<T, S, B>;
423
424 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
425 self.flow.service.poll_ready(cx).map_err(|err| {
426 let err = err.into();
427 error!("Service readiness error: {:?}", err);
428 DispatchError::Service(err)
429 })
430 }
431
432 fn call(&self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future {
433 let on_connect_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
434
435 H2ServiceHandlerResponse {
436 state: State::Handshake(
437 Some(Rc::clone(&self.flow)),
438 Some(self.cfg.clone()),
439 addr,
440 on_connect_data,
441 handshake_with_timeout(io, &self.cfg),
442 ),
443 }
444 }
445}
446
447enum State<T, S: Service<Request>, B: MessageBody>
448where
449 T: AsyncRead + AsyncWrite + Unpin,
450 S::Future: 'static,
451{
452 Handshake(
453 Option<Rc<HttpFlow<S, (), ()>>>,
454 Option<ServiceConfig>,
455 Option<net::SocketAddr>,
456 OnConnectData,
457 HandshakeWithTimeout<T>,
458 ),
459 Established(Dispatcher<T, S, B, (), ()>),
460}
461
462pub struct H2ServiceHandlerResponse<T, S, B>
463where
464 T: AsyncRead + AsyncWrite + Unpin,
465 S: Service<Request>,
466 S::Error: Into<Response<BoxBody>> + 'static,
467 S::Future: 'static,
468 S::Response: Into<Response<B>> + 'static,
469 B: MessageBody + 'static,
470{
471 state: State<T, S, B>,
472}
473
474impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
475where
476 T: AsyncRead + AsyncWrite + Unpin,
477 S: Service<Request>,
478 S::Error: Into<Response<BoxBody>> + 'static,
479 S::Future: 'static,
480 S::Response: Into<Response<B>> + 'static,
481 B: MessageBody,
482{
483 type Output = Result<(), DispatchError>;
484
485 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486 match self.state {
487 State::Handshake(
488 ref mut srv,
489 ref mut config,
490 ref peer_addr,
491 ref mut conn_data,
492 ref mut handshake,
493 ) => match ready!(Pin::new(handshake).poll(cx)) {
494 Ok((conn, timer)) => {
495 let on_connect_data = mem::take(conn_data);
496
497 self.state = State::Established(Dispatcher::new(
498 conn,
499 srv.take().unwrap(),
500 config.take().unwrap(),
501 *peer_addr,
502 on_connect_data,
503 timer,
504 ));
505
506 self.poll(cx)
507 }
508
509 Err(err) => {
510 trace!("H2 handshake error: {}", err);
511 Poll::Ready(Err(err))
512 }
513 },
514
515 State::Established(ref mut disp) => Pin::new(disp).poll(cx),
516 }
517 }
518}