actix_http/encoding/
decoder.rs

1//! Stream decoders.
2
3use std::{
4    future::Future,
5    io::{self, Write as _},
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use actix_rt::task::{spawn_blocking, JoinHandle};
11use bytes::Bytes;
12#[cfg(feature = "compress-gzip")]
13use flate2::write::{GzDecoder, ZlibDecoder};
14use futures_core::{ready, Stream};
15#[cfg(feature = "compress-zstd")]
16use zstd::stream::write::Decoder as ZstdDecoder;
17
18use crate::{
19    encoding::Writer,
20    error::PayloadError,
21    header::{ContentEncoding, HeaderMap, CONTENT_ENCODING},
22};
23
24const MAX_CHUNK_SIZE_DECODE_IN_PLACE: usize = 2049;
25
26pin_project_lite::pin_project! {
27    pub struct Decoder<S> {
28        decoder: Option<ContentDecoder>,
29        #[pin]
30        stream: S,
31        eof: bool,
32        fut: Option<JoinHandle<Result<(Option<Bytes>, ContentDecoder), io::Error>>>,
33    }
34}
35
36impl<S> Decoder<S>
37where
38    S: Stream<Item = Result<Bytes, PayloadError>>,
39{
40    /// Construct a decoder.
41    #[inline]
42    pub fn new(stream: S, encoding: ContentEncoding) -> Decoder<S> {
43        let decoder = match encoding {
44            #[cfg(feature = "compress-brotli")]
45            ContentEncoding::Brotli => Some(ContentDecoder::Brotli(Box::new(
46                brotli::DecompressorWriter::new(Writer::new(), 8_096),
47            ))),
48
49            #[cfg(feature = "compress-gzip")]
50            ContentEncoding::Deflate => Some(ContentDecoder::Deflate(Box::new(ZlibDecoder::new(
51                Writer::new(),
52            )))),
53
54            #[cfg(feature = "compress-gzip")]
55            ContentEncoding::Gzip => Some(ContentDecoder::Gzip(Box::new(GzDecoder::new(
56                Writer::new(),
57            )))),
58
59            #[cfg(feature = "compress-zstd")]
60            ContentEncoding::Zstd => Some(ContentDecoder::Zstd(Box::new(
61                ZstdDecoder::new(Writer::new()).expect(
62                    "Failed to create zstd decoder. This is a bug. \
63                         Please report it to the actix-web repository.",
64                ),
65            ))),
66            _ => None,
67        };
68
69        Decoder {
70            decoder,
71            stream,
72            fut: None,
73            eof: false,
74        }
75    }
76
77    /// Construct decoder based on headers.
78    #[inline]
79    pub fn from_headers(stream: S, headers: &HeaderMap) -> Decoder<S> {
80        // check content-encoding
81        let encoding = headers
82            .get(&CONTENT_ENCODING)
83            .and_then(|val| val.to_str().ok())
84            .and_then(|x| x.parse().ok())
85            .unwrap_or(ContentEncoding::Identity);
86
87        Self::new(stream, encoding)
88    }
89}
90
91impl<S> Stream for Decoder<S>
92where
93    S: Stream<Item = Result<Bytes, PayloadError>>,
94{
95    type Item = Result<Bytes, PayloadError>;
96
97    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
98        let mut this = self.project();
99
100        loop {
101            if let Some(ref mut fut) = this.fut {
102                let (chunk, decoder) = ready!(Pin::new(fut).poll(cx)).map_err(|_| {
103                    PayloadError::Io(io::Error::new(
104                        io::ErrorKind::Other,
105                        "Blocking task was cancelled unexpectedly",
106                    ))
107                })??;
108
109                *this.decoder = Some(decoder);
110                this.fut.take();
111
112                if let Some(chunk) = chunk {
113                    return Poll::Ready(Some(Ok(chunk)));
114                }
115            }
116
117            if *this.eof {
118                return Poll::Ready(None);
119            }
120
121            match ready!(this.stream.as_mut().poll_next(cx)) {
122                Some(Err(err)) => return Poll::Ready(Some(Err(err))),
123
124                Some(Ok(chunk)) => {
125                    if let Some(mut decoder) = this.decoder.take() {
126                        if chunk.len() < MAX_CHUNK_SIZE_DECODE_IN_PLACE {
127                            let chunk = decoder.feed_data(chunk)?;
128                            *this.decoder = Some(decoder);
129
130                            if let Some(chunk) = chunk {
131                                return Poll::Ready(Some(Ok(chunk)));
132                            }
133                        } else {
134                            *this.fut = Some(spawn_blocking(move || {
135                                let chunk = decoder.feed_data(chunk)?;
136                                Ok((chunk, decoder))
137                            }));
138                        }
139
140                        continue;
141                    } else {
142                        return Poll::Ready(Some(Ok(chunk)));
143                    }
144                }
145
146                None => {
147                    *this.eof = true;
148
149                    return if let Some(mut decoder) = this.decoder.take() {
150                        match decoder.feed_eof() {
151                            Ok(Some(res)) => Poll::Ready(Some(Ok(res))),
152                            Ok(None) => Poll::Ready(None),
153                            Err(err) => Poll::Ready(Some(Err(err.into()))),
154                        }
155                    } else {
156                        Poll::Ready(None)
157                    };
158                }
159            }
160        }
161    }
162}
163
164enum ContentDecoder {
165    #[cfg(feature = "compress-gzip")]
166    Deflate(Box<ZlibDecoder<Writer>>),
167
168    #[cfg(feature = "compress-gzip")]
169    Gzip(Box<GzDecoder<Writer>>),
170
171    #[cfg(feature = "compress-brotli")]
172    Brotli(Box<brotli::DecompressorWriter<Writer>>),
173
174    // We need explicit 'static lifetime here because ZstdDecoder need lifetime
175    // argument, and we use `spawn_blocking` in `Decoder::poll_next` that require `FnOnce() -> R + Send + 'static`
176    #[cfg(feature = "compress-zstd")]
177    Zstd(Box<ZstdDecoder<'static, Writer>>),
178}
179
180impl ContentDecoder {
181    fn feed_eof(&mut self) -> io::Result<Option<Bytes>> {
182        match self {
183            #[cfg(feature = "compress-brotli")]
184            ContentDecoder::Brotli(ref mut decoder) => match decoder.flush() {
185                Ok(()) => {
186                    let b = decoder.get_mut().take();
187
188                    if !b.is_empty() {
189                        Ok(Some(b))
190                    } else {
191                        Ok(None)
192                    }
193                }
194                Err(err) => Err(err),
195            },
196
197            #[cfg(feature = "compress-gzip")]
198            ContentDecoder::Gzip(ref mut decoder) => match decoder.try_finish() {
199                Ok(_) => {
200                    let b = decoder.get_mut().take();
201
202                    if !b.is_empty() {
203                        Ok(Some(b))
204                    } else {
205                        Ok(None)
206                    }
207                }
208                Err(err) => Err(err),
209            },
210
211            #[cfg(feature = "compress-gzip")]
212            ContentDecoder::Deflate(ref mut decoder) => match decoder.try_finish() {
213                Ok(_) => {
214                    let b = decoder.get_mut().take();
215                    if !b.is_empty() {
216                        Ok(Some(b))
217                    } else {
218                        Ok(None)
219                    }
220                }
221                Err(err) => Err(err),
222            },
223
224            #[cfg(feature = "compress-zstd")]
225            ContentDecoder::Zstd(ref mut decoder) => match decoder.flush() {
226                Ok(_) => {
227                    let b = decoder.get_mut().take();
228                    if !b.is_empty() {
229                        Ok(Some(b))
230                    } else {
231                        Ok(None)
232                    }
233                }
234                Err(err) => Err(err),
235            },
236        }
237    }
238
239    fn feed_data(&mut self, data: Bytes) -> io::Result<Option<Bytes>> {
240        match self {
241            #[cfg(feature = "compress-brotli")]
242            ContentDecoder::Brotli(ref mut decoder) => match decoder.write_all(&data) {
243                Ok(_) => {
244                    decoder.flush()?;
245                    let b = decoder.get_mut().take();
246
247                    if !b.is_empty() {
248                        Ok(Some(b))
249                    } else {
250                        Ok(None)
251                    }
252                }
253                Err(err) => Err(err),
254            },
255
256            #[cfg(feature = "compress-gzip")]
257            ContentDecoder::Gzip(ref mut decoder) => match decoder.write_all(&data) {
258                Ok(_) => {
259                    decoder.flush()?;
260                    let b = decoder.get_mut().take();
261
262                    if !b.is_empty() {
263                        Ok(Some(b))
264                    } else {
265                        Ok(None)
266                    }
267                }
268                Err(err) => Err(err),
269            },
270
271            #[cfg(feature = "compress-gzip")]
272            ContentDecoder::Deflate(ref mut decoder) => match decoder.write_all(&data) {
273                Ok(_) => {
274                    decoder.flush()?;
275
276                    let b = decoder.get_mut().take();
277                    if !b.is_empty() {
278                        Ok(Some(b))
279                    } else {
280                        Ok(None)
281                    }
282                }
283                Err(err) => Err(err),
284            },
285
286            #[cfg(feature = "compress-zstd")]
287            ContentDecoder::Zstd(ref mut decoder) => match decoder.write_all(&data) {
288                Ok(_) => {
289                    decoder.flush()?;
290
291                    let b = decoder.get_mut().take();
292                    if !b.is_empty() {
293                        Ok(Some(b))
294                    } else {
295                        Ok(None)
296                    }
297                }
298                Err(err) => Err(err),
299            },
300        }
301    }
302}