1use std::{
4 cell::RefCell,
5 collections::VecDeque,
6 pin::Pin,
7 rc::{Rc, Weak},
8 task::{Context, Poll, Waker},
9};
10
11use bytes::Bytes;
12use futures_core::Stream;
13
14use crate::error::PayloadError;
15
16pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
18
19#[derive(Debug, PartialEq, Eq)]
20pub enum PayloadStatus {
21 Read,
22 Pause,
23 Dropped,
24}
25
26#[derive(Debug)]
33pub struct Payload {
34 inner: Rc<RefCell<Inner>>,
35}
36
37impl Payload {
38 pub fn create(eof: bool) -> (PayloadSender, Payload) {
44 let shared = Rc::new(RefCell::new(Inner::new(eof)));
45
46 (
47 PayloadSender::new(Rc::downgrade(&shared)),
48 Payload { inner: shared },
49 )
50 }
51
52 pub(crate) fn empty() -> Payload {
54 Payload {
55 inner: Rc::new(RefCell::new(Inner::new(true))),
56 }
57 }
58
59 #[cfg(test)]
61 pub fn len(&self) -> usize {
62 self.inner.borrow().len()
63 }
64
65 #[cfg(test)]
67 pub fn is_empty(&self) -> bool {
68 self.inner.borrow().len() == 0
69 }
70
71 #[inline]
73 pub fn unread_data(&mut self, data: Bytes) {
74 self.inner.borrow_mut().unread_data(data);
75 }
76}
77
78impl Stream for Payload {
79 type Item = Result<Bytes, PayloadError>;
80
81 fn poll_next(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 ) -> Poll<Option<Result<Bytes, PayloadError>>> {
85 Pin::new(&mut *self.inner.borrow_mut()).poll_next(cx)
86 }
87}
88
89pub struct PayloadSender {
91 inner: Weak<RefCell<Inner>>,
92}
93
94impl PayloadSender {
95 fn new(inner: Weak<RefCell<Inner>>) -> Self {
96 Self { inner }
97 }
98
99 #[inline]
100 pub fn set_error(&mut self, err: PayloadError) {
101 if let Some(shared) = self.inner.upgrade() {
102 shared.borrow_mut().set_error(err)
103 }
104 }
105
106 #[inline]
107 pub fn feed_eof(&mut self) {
108 if let Some(shared) = self.inner.upgrade() {
109 shared.borrow_mut().feed_eof()
110 }
111 }
112
113 #[inline]
114 pub fn feed_data(&mut self, data: Bytes) {
115 if let Some(shared) = self.inner.upgrade() {
116 shared.borrow_mut().feed_data(data)
117 }
118 }
119
120 #[allow(clippy::needless_pass_by_ref_mut)]
121 #[inline]
122 pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus {
123 if let Some(shared) = self.inner.upgrade() {
126 if shared.borrow().need_read {
127 PayloadStatus::Read
128 } else {
129 shared.borrow_mut().register_io(cx);
130 PayloadStatus::Pause
131 }
132 } else {
133 PayloadStatus::Dropped
134 }
135 }
136}
137
138#[derive(Debug)]
139struct Inner {
140 len: usize,
141 eof: bool,
142 err: Option<PayloadError>,
143 need_read: bool,
144 items: VecDeque<Bytes>,
145 task: Option<Waker>,
146 io_task: Option<Waker>,
147}
148
149impl Inner {
150 fn new(eof: bool) -> Self {
151 Inner {
152 eof,
153 len: 0,
154 err: None,
155 items: VecDeque::new(),
156 need_read: true,
157 task: None,
158 io_task: None,
159 }
160 }
161
162 fn wake(&mut self) {
164 if let Some(waker) = self.task.take() {
165 waker.wake();
166 }
167 }
168
169 fn wake_io(&mut self) {
171 if let Some(waker) = self.io_task.take() {
172 waker.wake();
173 }
174 }
175
176 fn register(&mut self, cx: &Context<'_>) {
179 if self
180 .task
181 .as_ref()
182 .map_or(true, |w| !cx.waker().will_wake(w))
183 {
184 self.task = Some(cx.waker().clone());
185 }
186 }
187
188 fn register_io(&mut self, cx: &Context<'_>) {
191 if self
192 .io_task
193 .as_ref()
194 .map_or(true, |w| !cx.waker().will_wake(w))
195 {
196 self.io_task = Some(cx.waker().clone());
197 }
198 }
199
200 #[inline]
201 fn set_error(&mut self, err: PayloadError) {
202 self.err = Some(err);
203 }
204
205 #[inline]
206 fn feed_eof(&mut self) {
207 self.eof = true;
208 }
209
210 #[inline]
211 fn feed_data(&mut self, data: Bytes) {
212 self.len += data.len();
213 self.items.push_back(data);
214 self.need_read = self.len < MAX_BUFFER_SIZE;
215 self.wake();
216 }
217
218 #[cfg(test)]
219 fn len(&self) -> usize {
220 self.len
221 }
222
223 fn poll_next(
224 mut self: Pin<&mut Self>,
225 cx: &Context<'_>,
226 ) -> Poll<Option<Result<Bytes, PayloadError>>> {
227 if let Some(data) = self.items.pop_front() {
228 self.len -= data.len();
229 self.need_read = self.len < MAX_BUFFER_SIZE;
230
231 if self.need_read && !self.eof {
232 self.register(cx);
233 }
234 self.wake_io();
235 Poll::Ready(Some(Ok(data)))
236 } else if let Some(err) = self.err.take() {
237 Poll::Ready(Some(Err(err)))
238 } else if self.eof {
239 Poll::Ready(None)
240 } else {
241 self.need_read = true;
242 self.register(cx);
243 self.wake_io();
244 Poll::Pending
245 }
246 }
247
248 fn unread_data(&mut self, data: Bytes) {
249 self.len += data.len();
250 self.items.push_front(data);
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use actix_utils::future::poll_fn;
257 use static_assertions::{assert_impl_all, assert_not_impl_any};
258
259 use super::*;
260
261 assert_impl_all!(Payload: Unpin);
262 assert_not_impl_any!(Payload: Send, Sync);
263
264 assert_impl_all!(Inner: Unpin, Send, Sync);
265
266 #[actix_rt::test]
267 async fn test_unread_data() {
268 let (_, mut payload) = Payload::create(false);
269
270 payload.unread_data(Bytes::from("data"));
271 assert!(!payload.is_empty());
272 assert_eq!(payload.len(), 4);
273
274 assert_eq!(
275 Bytes::from("data"),
276 poll_fn(|cx| Pin::new(&mut payload).poll_next(cx))
277 .await
278 .unwrap()
279 .unwrap()
280 );
281 }
282}