1use std::task::Poll;
2
3use actix_rt::pin;
4use actix_utils::future::poll_fn;
5use bytes::{Bytes, BytesMut};
6use derive_more::{Display, Error};
7use futures_core::ready;
8
9use super::{BodySize, MessageBody};
10
11pub async fn to_bytes<B: MessageBody>(body: B) -> Result<Bytes, B::Error> {
34 to_bytes_limited(body, usize::MAX)
35 .await
36 .expect("body should never yield more than usize::MAX bytes")
37}
38
39#[derive(Debug, Display, Error)]
41#[display("limit exceeded while collecting body bytes")]
42#[non_exhaustive]
43pub struct BodyLimitExceeded;
44
45pub async fn to_bytes_limited<B: MessageBody>(
72 body: B,
73 limit: usize,
74) -> Result<Result<Bytes, B::Error>, BodyLimitExceeded> {
75 const INITIAL_ALLOC_BYTES: usize = 32 * 1024;
77
78 let cap = match body.size() {
79 BodySize::None | BodySize::Sized(0) => return Ok(Ok(Bytes::new())),
80 BodySize::Sized(size) if size as usize > limit => return Err(BodyLimitExceeded),
81 BodySize::Sized(size) => (size as usize).min(INITIAL_ALLOC_BYTES),
82 BodySize::Stream => INITIAL_ALLOC_BYTES,
83 };
84
85 let mut exceeded_limit = false;
86 let mut buf = BytesMut::with_capacity(cap);
87
88 pin!(body);
89
90 match poll_fn(|cx| loop {
91 let body = body.as_mut();
92
93 match ready!(body.poll_next(cx)) {
94 Some(Ok(bytes)) => {
95 if buf.len() + bytes.len() > limit {
97 exceeded_limit = true;
99 return Poll::Ready(Ok(()));
100 }
101
102 buf.extend_from_slice(&bytes)
103 }
104 None => return Poll::Ready(Ok(())),
105 Some(Err(err)) => return Poll::Ready(Err(err)),
106 }
107 })
108 .await
109 {
110 Err(err) => Ok(Err(err)),
112
113 Ok(()) if exceeded_limit => Err(BodyLimitExceeded),
115
116 Ok(()) => Ok(Ok(buf.freeze())),
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use std::io;
124
125 use futures_util::{stream, StreamExt as _};
126
127 use super::*;
128 use crate::{
129 body::{BodyStream, SizedStream},
130 Error,
131 };
132
133 #[actix_rt::test]
134 async fn to_bytes_complete() {
135 let bytes = to_bytes(()).await.unwrap();
136 assert!(bytes.is_empty());
137
138 let body = Bytes::from_static(b"123");
139 let bytes = to_bytes(body).await.unwrap();
140 assert_eq!(bytes, b"123"[..]);
141 }
142
143 #[actix_rt::test]
144 async fn to_bytes_streams() {
145 let stream = stream::iter(vec![Bytes::from_static(b"123"), Bytes::from_static(b"abc")])
146 .map(Ok::<_, Error>);
147 let body = BodyStream::new(stream);
148 let bytes = to_bytes(body).await.unwrap();
149 assert_eq!(bytes, b"123abc"[..]);
150 }
151
152 #[actix_rt::test]
153 async fn to_bytes_limited_complete() {
154 let bytes = to_bytes_limited((), 0).await.unwrap().unwrap();
155 assert!(bytes.is_empty());
156
157 let bytes = to_bytes_limited((), 1).await.unwrap().unwrap();
158 assert!(bytes.is_empty());
159
160 assert!(to_bytes_limited(Bytes::from_static(b"12"), 0)
161 .await
162 .is_err());
163 assert!(to_bytes_limited(Bytes::from_static(b"12"), 1)
164 .await
165 .is_err());
166 assert!(to_bytes_limited(Bytes::from_static(b"12"), 2).await.is_ok());
167 assert!(to_bytes_limited(Bytes::from_static(b"12"), 3).await.is_ok());
168 }
169
170 #[actix_rt::test]
171 async fn to_bytes_limited_streams() {
172 let body = SizedStream::new(8, stream::empty().map(Ok::<_, Error>));
174 assert!(to_bytes_limited(body, 3).await.is_err());
175
176 let body = SizedStream::new(3, stream::empty().map(Ok::<_, Error>));
178 assert!(to_bytes_limited(body, 3).await.unwrap().unwrap().is_empty());
179
180 let stream = stream::iter(vec![Bytes::from_static(b"1234")]).map(Ok::<_, Error>);
182 let body = SizedStream::new(3, stream);
183 assert!(to_bytes_limited(body, 3).await.is_err());
184
185 let stream = stream::iter(vec![Bytes::from_static(b"123"), Bytes::from_static(b"abc")])
186 .map(Ok::<_, Error>);
187 let body = BodyStream::new(stream);
188 assert!(to_bytes_limited(body, 3).await.is_err());
189 }
190
191 #[actix_rt::test]
192 async fn to_body_limit_error() {
193 let err_stream = stream::once(async { Err(io::Error::new(io::ErrorKind::Other, "")) });
194 let body = SizedStream::new(8, err_stream);
195 assert!(to_bytes_limited(body, 10).await.unwrap().is_err());
197 }
198}